package sk.shanki.dbsuite.mapper.evolution_algo;


import sk.shanki.dbsuite.mapper.Mapping;
import sk.shanki.dbsuite.mapper.db.*;

import javax.swing.*;
import java.awt.*;
import java.util.*;
import java.util.List;
import java.util.stream.Collectors;

public class EvolutionAlgo {
    private Database correctSolution;
    private Database studentSolution;

    private final Integer epochs = 20;
    private final Integer numberInGeneration = 20;
    private final ArrayList<Integer> graph;

    public EvolutionAlgo(Database correctSolution, Database studentSolution){
        this.correctSolution = correctSolution;
        this.studentSolution = studentSolution;


        this.graph = new ArrayList<Integer>();
    }

    public void solve() {
        ArrayList<Individual> correctSolIds = new ArrayList<>();
        for (int i = 0; i < numberInGeneration; i++) {
            correctSolIds.add(new Individual(this.generateRandomIndividual(i != 0)));
        }
        ArrayList<Individual> newGeneration = new ArrayList<>();

        int numEpoch = 1;

        while (numEpoch <= epochs) {
//            System.out.println("epoch number: " + numEpoch);
            for (Individual individual : correctSolIds) {
                if (individual.fitness(correctSolution, studentSolution) == 0) {
//                    System.out.println("Correct solution");
                    Mapping m = getMapping(correctSolIds.get(0));
//                    correctSolIds.get(0).applyMapping(correctSolution, studentSolution);
                    System.out.println(m);
//                    System.out.println(m.getPenalisation());
                    return;
                }
            }
           //dalsia generacia z predchadzajucej + deti kazdych 2
            newGeneration.addAll(correctSolIds);
            for (int i = 0; i < correctSolIds.size(); i+=2){
                Individual individual1 = correctSolIds.get(i);
                Individual individual2 = correctSolIds.get(i + 1);

                Individual new_indiv1 = individual1.cross(individual2);
                Individual new_indiv2 = individual2.cross(individual1);

                new_indiv1.mutation(0.1f);
                new_indiv2.mutation(0.1f);

                newGeneration.add(new_indiv1);
                newGeneration.add(new_indiv2);
            }

            // usporiadaj novu generaciu a len prvu polovicu najlepsich vyber do dalsej generacie
            ArrayList<Individual> newGenerationCopy = new ArrayList<>(newGeneration);
            List<Double> fitness = newGeneration.stream().map(x -> x.fitness(correctSolution, studentSolution)).collect(Collectors.toList());
            newGeneration.sort(Comparator.comparingDouble(s -> fitness.get(newGenerationCopy.indexOf(s))));

            correctSolIds = new ArrayList<>(newGeneration.subList(0, newGeneration.size()/2));
            newGeneration = new ArrayList<>();
            numEpoch++;

            Individual bestIndividual = correctSolIds.get(0);
//            System.out.println("best fitness: " + bestIndividual.fitness(correctSolution, studentSolution));
//            graph.add((int) bestIndividual.fitness(correctSolution, studentSolution));
//            Mapping m = getMapping(correctSolIds.get(0));
//            System.out.println(m);
//            System.out.println(m.getPenalisation());

        }

//        correctSolIds.get(0).applyMapping(correctSolution, studentSolution);
        Mapping m = getMapping(correctSolIds.get(0));
        System.out.println(m);
        System.out.println(m.getPenalisation());
//        show_graph();

        }

    private List<Map<Row, Row>> generateRandomIndividual(Boolean random){
        List<Map<Row, Row>> newIndividual = new ArrayList<>();

        int maxSolutionSize = Math.max(correctSolution.size(), studentSolution.size());
        for(int table = 0; table < maxSolutionSize; table++) {
            Table correctSolutionTable = correctSolution.get(table);
            Table studentSolutionTable = studentSolution.get(table);

            Map<Row, Row> tableIndividual = new HashMap<>();
            int maxSolutionTableSize = Math.max(correctSolutionTable.size(), studentSolutionTable.size());

            if (correctSolutionTable.size() < studentSolutionTable.size()) {
                correctSolutionTable = studentSolutionTable;
                studentSolutionTable = correctSolution.get(table);
            }

            List<Integer> newIds = new ArrayList<>();
            for (int i = 0; i < maxSolutionTableSize; i++) {
                newIds.add(i);
            }
            if(random){
                Collections.shuffle(newIds);
            }

            for (int rowId = 0; rowId < maxSolutionTableSize; rowId++) {
                Row correctRow = null;
                Row targetRow = null;

                if (rowId < correctSolutionTable.size()) {
                    correctRow = correctSolutionTable.get(rowId);
                }
                if (newIds.get(rowId) < studentSolutionTable.size()) {
                    targetRow = studentSolutionTable.get(newIds.get(rowId));
                }
                tableIndividual.put(correctRow, targetRow);

            }
            newIndividual.add(tableIndividual);
        }
        return newIndividual;
    }

    private void show_graph(){
        HistogramPanel panel = new HistogramPanel(graph);
        for(int i = 0; i < panel.graf.size(); i++){
            panel.addHistogramColumn("", panel.graf.get(i), Color.RED);
        }
        panel.layoutHistogram();

        JFrame frame = new JFrame("Histogram Panel");
        frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        frame.add( panel );
        frame.setLocationByPlatform( true );
        frame.pack();
        frame.setVisible( true );
    }


    private Mapping getMapping(Individual correctSolution){
        correctSolution.applyMapping(this.correctSolution, this.studentSolution);
        return new Mapping(this.correctSolution, this.studentSolution, AlphaFormulas.numberOfColumnsFormula(), RowDistanceFormulas.getCUSTOM());
    }
}
