# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="Matej Pechac"
__date__ ="$Feb 22, 2011 12:06:32 PM$"

from pybrain.structure.networks.feedforward import FeedForwardNetwork
from pybrain.structure.modules import SigmoidLayer, LinearLayer, SoftmaxLayer, BiasUnit
from pybrain.structure.connections import FullConnection

class UnfoldedNetwork(FeedForwardNetwork):
    def __init__(self, network, sequence):
        from copy import deepcopy
        from copy import copy

        FeedForwardNetwork.__init__(self)
        
        self._network = network

        t = 0
        for sample in sequence:
            self.addInputModule(LinearLayer(self._network['input'].dim, name = 'input%i' % t))
            self.addModule(self._network['hidden0'].__class__(self._network['context'].dim, name = 'hidden%i' % t))
            t += 1
        self.addModule(self._network['hidden0'].__class__(self._network['hidden0'].dim, name = 'hidden%i' % t))
        self.addOutputModule(self._network['output'].__class__(self._network['output'].dim, name = 'output'))
        self.addModule(BiasUnit(name = 'bias'))

        t = 0
        for sample in sequence:
            self._addConnection('input', 'input%i'%t, 'hidden%i'%(t+1), 0)
            self._addConnection('context', 'hidden%i'%t, 'hidden%i'%(t+1), 0)
            self._addConnection('bias', 'bias', 'hidden%i'%t, t)
            t += 1
        self._addConnection('hidden0', 'hidden%i'%t, 'output', 0)
        #self._addConnection('bias', 'bias', 'output', t)

        self.sortModules()


    def _addContextModule(self, t):
        self.addModule(self._network['hidden0'].__class__(self._network['context'].dim, name = 'hidden%i' % t))
        #self['hidden%i' % t].outputbuffer = self._network['context'].outputbuffer.copy()

    def _addConnection(self, net_input, input, output, in_id):
        self.addConnection(FullConnection(self[input], self[output], name = input+'_'+output))
        for c in self._network.connections[self._network.__getitem__(net_input)]:
            if len(self.connections[self.__getitem__(input)][in_id].params) == len(c.params):
                self.connections[self.__getitem__(input)][in_id].params[:] = c.params[:].copy()
        