import math
# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="Matej Pechac"
__date__ ="$Dec 8, 2010 5:29:39 PM$"

from random import shuffle
from pybrain.datasets import SupervisedDataSet
from pybrain.supervised.trainers import Trainer, BackpropTrainer
from main.UnfoldedNetwork import UnfoldedNetwork

class BPTT_Trainer(Trainer):
    def __init__(self, module, dataset = None, learningrate=0.01, lrdecay=1.0, momentum=0., verbose=False, batchlearning=False, weightdecay=0.):
        Trainer.__init__(self, module)
        self.setData(dataset)
        self.learningrate = learningrate
        self.lrdecay = lrdecay
        self.momentum = momentum
        self.verbose = verbose
        self.batchlearning = batchlearning
        self.weightdecay = weightdecay

        self.epoch = 0
        self.totalepochs = 0

    def train(self):
        assert len(self.ds) > 0, "Dataset cannot be empty."

        self.module.resetDerivatives()

        total_error = 0

        shuffledSequences = []
        for seq in self.ds._provideSequences():
            shuffledSequences.append(seq)
        shuffle(shuffledSequences)

        for seq in shuffledSequences:
            tmp_data_seq = []
            target_seq = []
            
            unfoldedNetwork = UnfoldedNetwork(self.module, seq)

            for sample in seq:
                tmp_data_seq.extend(sample[0])
            
            i,t = seq[-1:][0]
            target_seq.extend(t)

            data_seq = SupervisedDataSet(len(tmp_data_seq),len(target_seq))
            data_seq.addSample(tmp_data_seq, target_seq)

            trainer = BackpropTrainer(unfoldedNetwork, data_seq, learningrate = self.learningrate, momentum = self.momentum, verbose = False, weightdecay = self.weightdecay)
            total_error += trainer.train() / len(seq)

            self._apply_changes(unfoldedNetwork, len(seq))
 
        if self.verbose:
            print "Epoch:", self.epoch, "Total error:", total_error 
        #if self.batchlearning:
        #    self.module._setParameters(self.descent(self.module.derivs))

        self.epoch += 1
        self.totalepochs += 1

        return total_error



    def _apply_changes(self, network, length):
        #print self.module.params[:]
        sum_hidden = network.connections[network.__getitem__('hidden0')][0].params
        sum_input = network.connections[network.__getitem__('input0')][0].params
        sum_output = network.connections[network.__getitem__('hidden%i' % length)][0].params
        bias_hidden = network.connections[network.__getitem__('bias')][0].params
        #bias_output = network.connections[network.__getitem__('bias')][-1:][0].params
        
        for i in range(length - 1):
            sum_hidden += network.connections[network.__getitem__('hidden%i' % i)][0].params
            sum_input += network.connections[network.__getitem__('input%i' % i)][0].params
            bias_hidden += network.connections[network.__getitem__('bias')][i].params

        self.module.connections[self.module.__getitem__('input')][0].params[:] = sum_input / length
        self.module.connections[self.module.__getitem__('context')][0].params[:] = sum_hidden / length 
        self.module.connections[self.module.__getitem__('hidden0')][0].params[:] = sum_output
        self.module.connections[self.module.__getitem__('bias')][0].params[:] = bias_hidden / length
        #self.module.connections[self.module.__getitem__('bias')][1].params[:] = bias_output
        