# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="Matej Pechac"
__date__ ="$Dec 8, 2010 5:29:55 PM$"

import math
from random import shuffle
from pybrain.supervised.trainers.trainer import Trainer
from pybrain.auxiliary.gradientdescent import GradientDescent

class RTRL_Trainer(Trainer):
    def __init__(self, module, dataset = None, learningrate = 1e-4, lrdecay = 1.0, momentum = 0.0, verbose = False, weightdecay = 0.0):
        Trainer.__init__(self, module)

        self.setData(dataset)
        self.verbose = verbose
        #self.batchlearning = batchlearning ? existuje batchlearning ?
        self.weightdecay = weightdecay
        self.epoch = 0
        self.totalepochs = 0

        self.descent = GradientDescent()
        self.descent.alpha = learningrate
        self.descent.momentum = momentum
        self.descent.alphadecay = lrdecay
        self.descent.init(module.params)

        self.inModules = ['input','context','bias']
        self.hidoutModules = ['hidden0','output']
        self.m = 0
        for s in self.inModules:
            self.m += self.module[s].dim
        self.n = 0
        for s in self.hidoutModules:
            self.n += self.module[s].dim

        self._initWeightMatrix()

    def train(self):
        from numpy import matrix
        assert len(self.ds) > 0, "Dataset cannot be empty."

        errors = 0
        ponderation = 0

        self.epoch += 1

        self._initPartialDerivsMatrix()
        #self._initWeightMatrix()

        temp = self.module.params[:].copy()

        self.units = []
        self.t = 0
        self.module.reset()

        for seq in self.ds._provideSequences():
            #self._initWeightMatrix()
            #e, p = self.ds._evaluateSequence(self.module.activate, seq, True)
            for sample in seq:
                self.module.activate(sample[0])
                self._copyOutputs()

                e, p = self._calcError(sample)
                
                self._updatePartialDerivsMatrix()
                self._updateWeightMatrix()
                self._applyWeightMatrix()
                self.t += 1
                
                errors += e
                ponderation += p
                
        print temp - self.module.params[:]
        print 'Epoch:', self.epoch, 'Total error:', errors / ponderation
        return errors / ponderation

    def _copyOutputs(self):
        self.units = []
        for s in self.inModules + self.hidoutModules:
            module = self.module.__getitem__(s)
            if module is not None:
                self.units.extend(module.outputbuffer[module.offset-1])
        #self.units.append(units)
        #print self.units

    def _calcError(self, sample):
        from numpy import array
        error = 0
        ponderation = 0.

        self.e = []
        target = sample[1]
        
        index = self.n - self.module['output'].dim
        for k in range(index):
            self.e.append(0)
        for k in range(self.module['output'].dim):
            if math.isnan(target[k]):
                outerr = 0
            else:
                outerr = target[k] - self.module['output'].outputbuffer[self.t][k]
            self.e.append(outerr)

        #print self.e
        
        if len(sample) > 2:
            importance = sample[2]
            error += 0.5 * dot(importance, array(self.e) ** 2)
            ponderation += sum(importance)
        else:
            error += 0.5 * sum(array(self.e) ** 2)
            ponderation += len(target)

        return error, ponderation
    
    def _initWeightMatrix(self):
        self.W = []
        for s in self.hidoutModules:
            for i in range(self.module[s].dim):
                row = []
                for t in self.inModules + self.hidoutModules:
                    conn = self._isConnected(t,s)
                    if conn is not None:
                        dim = self.module[t].dim
                        for j in range(dim):
                            row.append(conn.params[i*dim+j])
                    else:
                        for j in range(self.module[t].dim):
                            row.append(0)
                self.W.append(row)

    def _updateWeightMatrix(self):
        from numpy import matrix

        for i in range(self.n):
            for j in range(self.m + self.n):
                suma = 0
                for k in range(self.n):
                    suma += self.e[k] * self.P[i*(self.m + self.n)+j][k]
                self.W[i][j] += self.descent.alpha * suma
        #print matrix(self.W)

    def _applyWeightMatrix(self):
        id_s = 0
        for s in self.hidoutModules:
            for i in range(self.module[s].dim):
                id_t = 0
                for t in self.inModules + self.hidoutModules:
                    conn = self._isConnected(t,s)
                    if conn is not None:
                        dim = self.module[t].dim
                        for j in range(dim):
                            conn.params[i*dim+j] = self.W[id_s+i][id_t+j]
                    id_t += self.module[t].dim
            id_s += self.module[s].dim


    def _isConnected(self, layer1 = None, layer2 = None):
        if layer1 == layer2:
            return None
        for c in self.module.connections[self.module.__getitem__(layer1)]:
            if c.outmod.name == layer2:
                return c
        return None

    def _initPartialDerivsMatrix(self):
        from scipy import zeros

        self.P = []
        for i in range(self.n):
            for j in range(self.m + self.n):
                row = []
                for k in range(self.n):
                    row.append(0)
                self.P.append(row)

    def _copyMatrix(self, matrix):
        temp = []
        for i in matrix:
            row = []
            for j in i:
                row.append(j)
            temp.append(row)

        return temp

    def _updatePartialDerivsMatrix(self):
        from numpy import array, matrix

        tempP = self._copyMatrix(self.P)

        for i in range(self.n):
            for j in range(self.m + self.n):
                k = 0
                ij = i * (self.m + self.n) + j
                for s in self.hidoutModules:
                    module = self.module[s]
                    #module.offset = self.t
                    for n in range(module.dim):
                        suma = 0
                        for l in range(self.n):
                            suma += self.W[k][self.m + l] * self.P[ij][l]
                        #print module.backActivate(suma + self._kroneckerDelta(i, k) * self.units[self.t][j])[n]
                        #print module.inputerror
                        tempP[ij][k] = module.backActivate(suma + self._kroneckerDelta(i, k) * self.units[j])[n]
                        k += 1
        #print matrix(tempP) - matrix(self.P)
        self.P = self._copyMatrix(tempP)

    def _kroneckerDelta(self, i, j):
        if i == j:
            return 1
        else:
            return 0
