# To change this template, choose Tools | Templates
# and open the template in the editor.


from pybrain.tools.xml.handling import XMLHandling
from pybrain.datasets import SequentialDataSet, SupervisedDataSet

class DataSet():
    def __init__(self):
        self.dataset = None
        self.options = {}

        self._tempset = []
        self._conversion = {}

    def load(self, filename):
        self.handler = XMLHandling(filename, False)

        node_conversion = self.handler.getChild(self.handler.root, 'Conversion')
        if node_conversion is not None:
            adict = self.handler.readAttrDict(node_conversion)
            self._parseConversion(adict)

        node_dataset = self.handler.getChild(self.handler.root, 'Dataset')
        adict = self.handler.readAttrDict(node_dataset)
        self._parseDataSetOptions(adict)

        indim = self.options['InputLength']
        outdim = self.options['OutputLength']
        if self.options['SequentialData']:
            self.dataset = SequentialDataSet(indim,outdim)
        else:
            self.dataset = SupervisedDataSet(indim,outdim)

        samples = self.handler.getChildrenOf(node_dataset)

        for s in samples:
            self._parseSample(s)

        if self.options['Task'] == 'Prediction' and not(self.options['Row']):
            self._postProcessNonRowSamples()

    def splitDataset(self, proportion = 0.5):
        n = len(self.dataset) * proportion

        trn = SupervisedDataSet(self.options['InputLength'],self.options['OutputLength'])
        tst = SupervisedDataSet(self.options['InputLength'],self.options['OutputLength'])

        for i,s in enumerate(self.dataset):
            if i < n:
                trn.addSample(s[0],s[1])
            else:
                tst.addSample(s[0],s[1])

        return trn, tst

    def _parseConversion(self, adict):
        for name, val in adict.items():
            c = []
            for v in val:
                c.append(float(v))
            self._conversion[name] = c

    def _parseDataSetOptions(self, adict):
        for name, val in adict.items():
            if val == 'True':
                val = True
            if val == 'False':
                val = False
            if name != 'Task':
                val = int(val)
            self.options[name] = val

    def _parseSample(self, node_sample):
        node_input = self.handler.getChild(node_sample, 'Input')
        adict = self.handler.readAttrDict(node_input)

        for name, val in adict.items():
            if name == 'Value':
                input = val
            else:
                input = None

        if not (self.options['Task'] == 'Prediction'):
            node_target = self.handler.getChild(node_sample, 'Target')
            adict = self.handler.readAttrDict(node_target)

            for name, val in adict.items():
                if name == 'Value':
                    target = val
                else:
                    target = None

        if self.options['Task'] == 'Classification':
            if self.options['SequentialData']:
                self.dataset.newSequence()

            if (len(input) == len(target)):
                for i,t in zip(self._convert(input),self._convert(target)):
                    self.dataset.addSample(i, t)
            else:
                d = len(input) - len(target)
                for i in self._convert(input)[:d]:
                    self.dataset.addSample(i, 0)
                for i,t in zip(self._convert(input)[d:],self._convert(target)):
                    self.dataset.addSample(i, t)

        if self.options['Task'] == 'Prediction':
            if self.options['SequentialData']:
                self.dataset.newSequence()
                
            if not(self.options['NumericInput']):
                inp = self._convert(input)
            else:
                inp = [float(input)]

            if self.options['Row']:
                for i,w in enumerate(inp):
                    if (i+1) < len(inp):
                        target = inp[i+1]
                    else:
                        target = inp[0]
                    self.dataset.addSample(w, target)
            else:
                self._tempset.append(inp)

    def _postProcessNonRowSamples(self):
        for i,w in enumerate(self._tempset):
            if (i+1) < len(self._tempset):
                target = self._tempset[i+1]
            else:
                target = self._tempset[0]
            self.dataset.addSample(w[0], target[0])


    def _convert(self, input):
        tmp = []
        for s in input:
            tmp.append(self._conversion[s])
        return tmp