'''
Created on Nov 10, 2010

@author: Matej Pechac
'''
import Queue

'''
'''

import pygtk
pygtk.require('2.0')


from pybrain.datasets.sequential import SequentialDataSet
from pybrain.supervised.trainers import BackpropTrainer
from main.BPTT_Trainer import BPTT_Trainer


from ElmanModel import ElmanNetwork
from JordanModel import JordanNetwork


class neuralNetwork():
    def __init__(self, app, data):
        self.app = app
        self.conversion = {'a':[0,1],'b':[1,0]}
        #self.accept = {'b':1,'a':0}
        self.model = data[3]
        
        if (self.model == 0):
            self.network = ElmanNetwork(int(data[0]), int(data[1]), int(data[2]))
        if (self.model == 1):
            self.network = JordanNetwork(int(data[0]), int(data[1]), int(data[2]))
        if (self.model == 2):
            #EchoStateNetwork
            pass
             
    def onLoadDataSet(self):

        test = ['ba','bb','baa']
        target_test = ['01','00','010']
        input = ['ababba', 'baba', 'abba', 'baaa', 'babaab']
        target = ['001001', '0101', '0001', '0100', '010100']

        self.trndata = SequentialDataSet(2,1)
        self.tstdata = SequentialDataSet(2,1)

        for i in range(len(input)):
            self.trndata.newSequence()
            for j in range(len(input[i])):
                self.trndata.addSample(self.conversion[input[i][j]], target[i][j])

        for i in range(len(test)):
            self.tstdata.newSequence()
            for j in range(len(test[i])):
                self.tstdata.addSample(self.conversion[test[i][j]], target_test[i][j])

    def onTrain(self, alg = 0, lrnrate = 0.01, lrndecay = 0., momentum = 1.0, weightdecay = 0.):
        error_data = []
        if alg == 0:
            self.trainer = BackpropTrainer(self.network, self.trndata, learningrate = lrnrate, lrdecay = lrndecay, momentum = momentum, verbose = False, weightdecay = weightdecay)
        if alg == 1:
            self.trainer = BPTT_Trainer(self.network, self.trndata, learningrate = lrnrate, lrdecay = lrndecay, momentum = momentum, verbose = False, weightdecay = weightdecay)
        if alg == 2:
            #RTRL_Trainer
            pass

        for i in range(30):
            error_data.append((i, self.trainer.train()))
            self.app.GUI.render_error_graph(error_data)
            

    def onTest(self):
        print self.network.inputbuffer

        self.trainer.testOnData(self.tstdata, verbose = True)

