# To change this template, choose Tools | Templates
# and open the template in the editor.

__author__="Matej Pechac"
__date__ ="$Feb 16, 2011 12:08:49 PM$"

import sys
from matplotlib.figure import Figure
from matplotlib.backends.backend_gtk import FigureCanvasGTK

try:
    import pygtk
    pygtk.require("2.0")
except:
    pass
try:
    import gtk
    import gtk.glade
except:
    sys.exit(1)

from pylab import *

class GUI():
    def __init__(self, data):
        self.app = data
        self.builder = gtk.Builder()
        self.builder.add_from_file('GUI.glade')

        self.window = self.builder.get_object('window1')

        callback = {'on_window1_destroy' : gtk.main_quit,
                    'on_menuItemNew_activate' : self._show_new_network_dialog,
                    'on_menuItemOpen_activate' :  gtk.main_quit,
                    'on_menuItemSave_activate' :  gtk.main_quit,
                    'on_menuItemSaveAs_activate' :  gtk.main_quit,
                    'on_menuItemQuit_activate' : gtk.main_quit,
                    'on_button1_clicked' : self._set_start_training,
                    'on_button2_clicked' : self._set_start_loading,
                    'on_button3_clicked' : self._set_start_testing,
                    'on_bCreate_clicked' : self._set_new_network,
                    'on_bBack_clicked' : self._hide_new_netwrok_dialog
        }

        self.builder.connect_signals(callback)
        self._set_combobox_liststore()
        
        self.render_architecture()
        self.render_error_graph()
        self.render_test_graph()
        self.render_output()

        self.window.show()

    def main(self):
        gtk.main()

    def get_learning_algorithm(self):
        for i in range(4):
            if self.builder.get_object('radiobutton%i' % (i+1)).get_active():
                return i

    def _show_new_network_dialog(self, widget, data = None):
        self.builder.get_object('dialog1').show()

    def _hide_new_netwrok_dialog(self, widget, data = None):
        self.builder.get_object('dialog1').hide()

    def _set_new_network(self, widget, data = None):
        entry1 = self.builder.get_object('entry1').get_text()
        entry2 = self.builder.get_object('entry2').get_text()
        entry3 = self.builder.get_object('entry3').get_text()
        model = self.builder.get_object('combobox1').get_active()

        out_data = [entry1, entry2, entry3, model]

        self.app.new_neural_network(out_data)

        self._hide_new_netwrok_dialog(widget, data)
        self.builder.get_object('button2').set_sensitive(True)
        self.builder.get_object('button1').set_sensitive(False)
        self.builder.get_object('button3').set_sensitive(False)
        self.render_architecture()

    def _set_combobox_liststore(self):
        import gobject
        
        ls = gtk.ListStore(gobject.TYPE_STRING)
        for item in ['Elman model', 'Jordan model', 'Echo state model']:
            ls.append([item])

        self.builder.get_object('combobox1').set_model(ls)
        cell = gtk.CellRendererText()
        self.builder.get_object('combobox1').pack_start(cell)
        self.builder.get_object('combobox1').add_attribute(cell,'text',0)

    def _set_start_training(self, widget, data = None):
        self.app.start_training()

    def _set_start_loading(self, widget, data = None):
        self.builder.get_object('button1').set_sensitive(True)
        self.builder.get_object('button3').set_sensitive(True)
        self.app.load_data_set()

    def _set_start_testing(self, widget, data = None):
        self.app.start_testing()

    def render_architecture(self, data = None):
        figure = Figure(figsize=(4.72,4.72), dpi=72, facecolor = 'w')
        figure.suptitle('Network architecture')
        axes = figure.add_subplot(111)


        axes.xaxis.set_ticks([])
        axes.yaxis.set_ticks([])
        axes.spines['left'].set_color('none')
        axes.spines['right'].set_color('none')
        axes.spines['bottom'].set_color('none')
        axes.spines['top'].set_color('none')

        if self.app.network != None:
            
            ni = self.app.network.network.__getitem__('input').indim
            nh0 = self.app.network.network.__getitem__('hidden0').indim
            no = self.app.network.network.__getitem__('output').indim

            max_dim = max([ni, nh0, no])

            axes.plot(0,0, 'w.')
            axes.plot(max_dim*2,3, 'w.')
 
            for x in range(ni):
                axes.plot((x+1)*(float(max_dim)/float(ni+1)), 0.5, 'g^')
            for x in range(nh0):
                axes.plot((x+1)*(float(max_dim)/float(nh0+1))+max_dim, 0.5, 'b^')
                axes.plot((x+1)*(float(max_dim)/float(nh0+1)), 1.5, 'b^')
            for x in range(no):
                axes.plot((x+1)*(float(max_dim)/float(no+1)), 2.5, 'r^')
                
        canvas = FigureCanvasGTK(figure)
        canvas.show()
        self.builder.get_object('table2').attach(canvas, 0, 1, 0, 2)

    def render_error_graph(self, data = None):
        
        self.error_figure = Figure(figsize=(4.72,5), dpi=72)
        self.error_canvas = FigureCanvasGTK(self.error_figure)
        self.error_canvas.show()
        self.builder.get_object('table2').attach(self.error_canvas, 1, 2, 0, 2)

        axis = self.error_figure.add_subplot(111)

        if data != None:
            axis.plot(range(len(data)), data, 'r-')

        axis.set_xlabel('Epochs')
        axis.set_ylabel('Error')
        axis.grid(True)

        self.error_canvas.draw_idle()

    def render_test_graph(self, data = None):
        figure = Figure(figsize=(4.72,5), dpi=72)
        axis = figure.add_subplot(111)
        axis.set_xlabel('Input')
        axis.set_ylabel('Error')
        axis.grid(True)
        

        if data != None:
            axis.plot(range(len(data)), data, 'r-')
            axis.axis([0, len(data), -2 * max(data), 2 * max(data)])
                
        canvas = FigureCanvasGTK(figure)
        canvas.show()
        self.builder.get_object('table2').attach(canvas, 2, 3, 0, 2)

    def render_output(self, data1 = None, data2 = None):
        figure = Figure(figsize=(4.72,5), dpi=72)
        axis = figure.add_subplot(111)
        axis.set_xlabel('Input')
        axis.set_ylabel('Output')
        axis.grid(False)

        if data1 != None and data2 != None:
            axis.plot(range(len(data1)), data1, 'r-', range(len(data2)), data2, 'b-')

        canvas = FigureCanvasGTK(figure)
        canvas.show()
        self.builder.get_object('table2').attach(canvas, 3, 4, 0, 2)

    def get_learning_rate(self):
        return self.builder.get_object('adjustment1').value

    def get_weight_decay(self):
        if self.builder.get_object('checkbutton3').get_active:
            return self.builder.get_object('adjustment2').value
        return -1

    def get_momentum(self):
        if self.builder.get_object('checkbutton4').get_active:
            return self.builder.get_object('adjustment3').value
        return -1

    def get_learning_decay(self):
        if self.builder.get_object('checkbutton5').get_active:
            return self.builder.get_object('adjustment4').value
        return -1

    def get_epochs(self):
        return self.builder.get_object('adjustment5').value