# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="Matej Pechac"
__date__ ="$Dec 8, 2010 5:29:39 PM$"

from random import shuffle
from pybrain.datasets.sequential import SequentialDataSet
from pybrain.supervised.trainers.backprop import BackpropTrainer
from main.UnfoldedNetwork import UnfoldedNetwork

class BPTT_Trainer(BackpropTrainer):
    def __init__(self, module, dataset = None, learningrate=0.01, lrdecay=1.0, momentum=0., verbose=False, batchlearning=False, weightdecay=0.):
        BackpropTrainer.__init__(self, module, dataset, learningrate, lrdecay, momentum, verbose, batchlearning, weightdecay)

    def train(self):
        assert len(self.ds) > 0, "Dataset cannot be empty."

        self.module.resetDerivatives()

        total_error = 0
        errors = 0
        ponderation = 0

        shuffledSequences = []
        for seq in self.ds._provideSequences():
            shuffledSequences.append(seq)
        shuffle(shuffledSequences)

        for seq in shuffledSequences:
            tmp_data_seq = []
            target_seq = []
            
            unfoldedNetwork = UnfoldedNetwork(self.module, seq)
            self.descent.init(unfoldedNetwork.params)

            for sample in seq:
                tmp_data_seq.extend(sample[0])
                target_seq.extend(sample[1])

            data_seq = SequentialDataSet(len(tmp_data_seq),len(target_seq[-1:]))
            data_seq.newSequence()
            data_seq.addSample(tmp_data_seq, target_seq[-1:])

            trainer = BackpropTrainer(unfoldedNetwork, data_seq, verbose = False)
            total_error += trainer.train()

            data_seq = []
            data_seq.append(tmp_data_seq)
            data_seq.append(target_seq)
            self._apply_changes(unfoldedNetwork, data_seq)
            #errors, ponderation = self._train(unfoldedNetwork, data_seq, errors, ponderation)
            #total_error += errors / ponderation

        if self.verbose:
            print "Total error:", total_error
        #if self.batchlearning:
        #    self.module._setParameters(self.descent(self.module.derivs))

        self.epoch += 1
        self.totalepochs += 1
        #return errors / ponderation
        return total_error


    def _train(self, network, seq, errors, ponderation):
        from scipy import zeros
        """Train the associated module for one epoch."""
        
        network.resetDerivatives()

        e, p = self._calcDerivs(network, seq)
        errors += e
        ponderation += p
        if not self.batchlearning:
            gradient = network.derivs - self.weightdecay * network.params
            new = self.descent(gradient, errors)
            if new is not None:
                network.params[:] = new
            self._apply_changes(network, seq)
            network.resetDerivatives()

        return errors, ponderation

    def _calcDerivs(self, network, sample):
        """Calculate error function and backpropagate output errors to yield
        the gradient."""
        network.reset()
        network.activate(sample[0])
        error = 0
        ponderation = 0.
        target = sample[1][-1:]
        # network.outputbuffer[0] neviem ci je celkom korektne, uplne neviem ako funguje outputbuffer offset
        outerr = target - network.outputbuffer[network.offset]

        if len(sample) > 2:
            importance = sample[2]
            error += 0.5 * dot(importance, outerr ** 2)
            ponderation += sum(importance)
            network.backActivate(outerr * importance)
        else:
            error += 0.5 * sum(outerr ** 2)
            ponderation += len(target)
            str(outerr)
            network.backActivate(outerr)

        return error, ponderation

    def _apply_changes(self, network, seq):
        #print network.params[:]
        sum_hidden = network.connections[network.__getitem__('hidden0')][0].params
        sum_input = network.connections[network.__getitem__('input0')][0].params
        sum_output = network.connections[network.__getitem__('hidden%i' % (len(seq[0])/self.module.indim))][0].params
        for i in range(len(seq[0])/self.module.indim - 1):
            sum_hidden += network.connections[network.__getitem__('hidden%i' % i)][0].params
            sum_input += network.connections[network.__getitem__('input%i' % i)][0].params

        self.module.connections[self.module.__getitem__('input')][0].params[:] = sum_input
        self.module.connections[self.module.__getitem__('context')][0].params[:] = sum_hidden
        self.module.connections[self.module.__getitem__('hidden0')][0].params[:] = sum_output

