'''
Created on Nov 10, 2010

@author: Matej Pechac
'''
'''
'''

import pygtk
pygtk.require('2.0')

from pybrain.supervised.trainers import BackpropTrainer
from main.BPTT_Trainer import BPTT_Trainer
from main.RTRL_Trainer import RTRL_Trainer
from main.ESN_Trainer import ESN_Trainer
from main.DataSet import DataSet


from main.ElmanModel import ElmanNetwork
from main.JordanModel import JordanNetwork
from main.EchoStateNetwork import EchoStateNetwork


class neuralNetwork():
    def __init__(self, app, data):
        self.app = app
        self.model = data[3]
        
        if (self.model == 0):
            self.network = ElmanNetwork(int(data[0]), int(data[1]), int(data[2]))
        if (self.model == 1):
            self.network = JordanNetwork(int(data[0]), int(data[1]), int(data[2]))
        if (self.model == 2):
            self.network = EchoStateNetwork(int(data[0]), int(data[1]), int(data[2]))
             
    def onLoadDataSet(self):
        self.dataset = DataSet()
        self.dataset.load('datasets/sinus.xml')
        self.trndata, self.tstdata = self.dataset.splitDataset(0.5)

        print len(self.trndata), len(self.tstdata)

    def onTrain(self, alg = 0, epochs = 30, lrnrate = 0.01, lrndecay = 1.0, momentum = 0., weightdecay = 0.):
        error_data = []
        if alg == 0:
            self.trainer = BackpropTrainer(self.network, self.trndata, learningrate = lrnrate, momentum = momentum, verbose = True, weightdecay = weightdecay)
        if alg == 1:
            self.trainer = BPTT_Trainer(self.network, self.trndata, learningrate = lrnrate, momentum = momentum, verbose = True, weightdecay = weightdecay)
        if alg == 2:
            self.trainer = RTRL_Trainer(self.network, self.trndata, learningrate = lrnrate, lrdecay = lrndecay, momentum = momentum, verbose = False, weightdecay = weightdecay)
        if alg == 3:
            self.trainer = ESN_Trainer(self.network, self.trndata)

        for i in range(int(epochs)):
            error_data.append(self.trainer.train())
        self.app.GUI.render_error_graph(error_data)
            

    def onTest(self, verbose = False):
        if self.tstdata is not None:
            dataset = self.tstdata
        dataset.reset()

        if verbose:
            print '\nTesting on data:'
        
        errors = []
        importances = []
        ponderatedErrors = []

        output = []
        target = []

        for seq in dataset._provideSequences():
            self.network.reset()
            e, i = dataset._evaluateSequence(self.network.activate, seq, True)

            self.network.reset()
            for x,t in seq:
                target.extend(t)
                output.extend(self.network.activate(x))

            importances.append(i)
            errors.append(e)
            ponderatedErrors.append(e / i)
        if verbose:
            print 'All errors:', ponderatedErrors
        assert sum(importances) > 0
        avgErr = sum(errors) / sum(importances)
        if verbose:
            print 'Average error:', avgErr
            print ('Max error:', max(ponderatedErrors), 'Median error:',
                   sorted(ponderatedErrors)[len(errors) / 2])
        self.app.GUI.render_test_graph(ponderatedErrors)

        if self.dataset.options['NumericOutput']:
            self.app.GUI.render_output(output, target)

        return avgErr

    def _convertWord(self, word):
        tmp = []
        for s in word:
            tmp.append(self.conversion[s])
        return tmp