# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="matej"
__date__ ="$Dec 9, 2010 1:36:43 PM$"

from pybrain.structure.networks import RecurrentNetwork
from pybrain.structure.modules import BiasUnit, SigmoidLayer, LinearLayer, SoftmaxLayer
from pybrain.structure.connections import FullConnection, IdentityConnection

class JordanNetwork(RecurrentNetwork):
    def __init__(self, *layers, **options):

        # options
        opt = {'bias': True,
               'hiddenclass': SigmoidLayer,
               'outclass': LinearLayer,
               'outputbias': True,
               'peepholes': False,
               'recurrent': True,
               'fast': False,
        }
        for key in options:
            if key not in opt.keys():
                raise NetworkError('buildNetwork unknown option: %s' % key)
            opt[key] = options[key]

        if len(layers) < 2:
            raise NetworkError('buildNetwork needs 2 arguments for input and output layers at least.')

        RecurrentNetwork.__init__(self)
        # linear input layer
        self.addInputModule(LinearLayer(layers[0], name='input'))
        # output layer of type 'outclass'
        self.addOutputModule(opt['outclass'](layers[-1], name='output'))

        if opt['bias']:
            # add bias module and connection to out module, if desired
            self.addModule(BiasUnit(name='bias'))
            if opt['outputbias']:
                self.addConnection(FullConnection(self['bias'], self['output']))
        # arbitrary number of hidden layers of type 'hiddenclass'
        for i, num in enumerate(layers[1:-1]):
            layername = 'hidden%i' % i
            self.addModule(opt['hiddenclass'](num, name=layername))
            if opt['bias']:
                # also connect all the layers with the bias
                self.addConnection(FullConnection(self['bias'], self[layername]))

        self.addModule(LinearLayer(self['output'].dim, name = 'context'))
        # connections between hidden layers
        for i in range(len(layers) - 3):
            self.addConnection(FullConnection(self['hidden%i' % i], self['hidden%i' % (i + 1)]))
        # other connections
        if len(layers) == 2:
            # flat network, connection from in to out
            self.addConnection(FullConnection(self['input'], self['output']))
        else:
            # network with hidden layer(s), connections from in to first hidden and last hidden to out
            self.addConnection(FullConnection(self['input'], self['hidden0']))
            self.addConnection(FullConnection(self['hidden%i' % (len(layers) - 3)], self['output']))

        self.addConnection(FullConnection(self['context'], self['hidden0']))
        # recurrent connections
        self.addRecurrentConnection(IdentityConnection(self['output'], self['context']))

        self.sortModules()