# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="matej"
__date__ ="$Mar 29, 2011 12:34:06 PM$"

from pybrain.supervised.trainers.trainer import Trainer
from pybrain.datasets import SupervisedDataSet

class ESN_Trainer(Trainer):
    def __init__(self, module, dataset = None):
        Trainer.__init__(self, module)
        self.setData(dataset)

    def train(self):
        from numpy.linalg import pinv

        sam, trn = self._splitDataset(0.3)

        for sample in sam:
            self.module.activate(sample[0])

        self._initDataStructures(trn)

        Wout = pinv(self.M) * self.T

        self._applyWeights(Wout)

    def _initDataStructures(self, data):
        from numpy import matrix, zeros, arange, newaxis

        self.M = matrix(zeros((len(data), self.module['hidden0'].dim)))
        self.T = arange(0.,len(data))[:, newaxis]
        i = 0
        for sample in data:
            self.module.activate(sample[0])
            j = 0
            for a in self.module['hidden0'].outputbuffer[self.module['hidden0'].offset - 1]:
                self.M[i,j] = a
                j += 1
            self.T[i] = sample[1]
            i += 1

    def _applyWeights(self, weights):
        from numpy import asarray

        for i,w in enumerate(asarray(weights)):
            self.module.connections[self.module['hidden0']][0].params[i] = w[0]

    def _inverseFunction(self, func, x):
        from pybrain.structure.modules import LinearLayer, TanhLayer
        from numpy import arctanh
        
        if func == LinearLayer:
            return x
        if func == TanhLayer:
            return arctanh(x)

    def _splitDataset(self, proportion = 0.5):
        n = len(self.ds) * proportion

        trn = SupervisedDataSet(self.ds.indim, self.ds.outdim)
        tst = SupervisedDataSet(self.ds.indim, self.ds.outdim)

        for i,s in enumerate(self.ds):
            if i < n:
                trn.addSample(s[0],s[1])
            else:
                tst.addSample(s[0],s[1])

        return trn, tst

