from dataclasses import dataclass
from math import ceil, cos, sin, atan, pi
from tkinter import *
from tkinter import ttk
from typing import List, Optional, Tuple, Dict

from Automata import DFA, FA
from ToolTip import ToolTip

DEFAULT_STATE_RADIUS = 20
GRID_SIZE = 50
FINAL_WIDTH = 5
INITIAL_ARROW_LENGTH = 30
AR_MO = 15

def nearest_multiple(n, multiple):
    remainder = n % multiple
    difference = min(remainder, multiple - remainder)
    if difference == remainder:
        return n - remainder
    else:
        return n + difference


def arrow_coords(x1, y1, x2, y2, r1, r2):
    dx = abs(x1 - x2)
    dy = abs(y1 - y2)
    if dx != 0:
        a = atan(dy / dx)
    else:
        a = pi / 2
    dx1 = cos(a) * r1
    dy1 = sin(a) * r1
    dx2 = cos(a) * r2
    dy2 = sin(a) * r2
    if x1 < x2:
        fx1 = x1 + dx1
        fx2 = x2 - dx2
    else:
        fx1 = x1 - dx1
        fx2 = x2 + dx2
    if y1 < y2:
        fy1 = y1 + dy1
        fy2 = y2 - dy2
    else:
        fy1 = y1 - dy1
        fy2 = y2 + dy2
    return fx1, fy1, fx2, fy2

def move_arrow(x1, y1, x2, y2, p):
    dx = abs(x1 - x2)
    dy = abs(y1 - y2)
    if dx != 0:
        a = atan(dy / dx)
    else:
        a = pi / 2
    px = p * sin(a)
    py = p * sin((pi / 2) - a)
    if x1 < x2 and y1 < y2 or x1 > x2 and y1 > y2:
        return -px, py
    else:
        return px, py


@dataclass
class Arrow:
    pass


@dataclass
class State:
    circle_id: int
    text: Entry
    text_id: int
    radius_i: int
    radius_o: int
    from_arrows: List[Arrow]
    to_arrows: List[Arrow]
    loop_arrow: Optional[Arrow] = None
    circle2_id: Optional[int] = None


@dataclass
class Arrow:
    line_id: int
    arrow_id: int
    from_state: State
    to_state: State
    text: Entry
    text_id: int
    text_pos: str
    loop: Optional[str] = None


class DiagramEditor:
    def __init__(self, master_frame, row, col, icons_dict: Dict[str, PhotoImage]):
        self.states: List[State] = []
        self.arrows: List[Arrow] = []
        self.initial: Optional[Tuple[State, int]] = None
        self.final: List[State] = []
        self.icons_dict = icons_dict
        self.state_dict: Dict[str, State] = {}
        self.arrow_dict: Dict[Tuple[str, str], Arrow] = {}

        self.frame = frame = ttk.Frame(master_frame)
        self.frame.grid(row=row, column=col, sticky='nwes')

        self.canvas = Canvas(frame, width=500, height=500, highlightthickness=0)
        self.canvas.grid(row=1, column=0, sticky='nwes')

        self.button_frame = bf = ttk.Frame(frame)
        self.button_frame.grid(row=0, column=0, sticky='we')
        self.name_label = Label(bf, text='Názov: ')
        self.name_label.pack(side='left')
        self.name_entry = ttk.Entry(bf, width=5)
        self.name_entry.insert(0, 'A')
        self.name_entry.pack(side='left')

        self.buttons: Dict[str, ttk.Button] = {}
        self.tool_tips: Dict[str, Tuple[ToolTip, str]] = {}

        self.buttons['add_state'] = b = ttk.Button(bf, command=self.add_state_callback)
        b.pack(side='left')
        self.tool_tips['add_state'] = (ToolTip(b), 'Pridaj stav')

        self.buttons['add_arrow'] = b = ttk.Button(bf, command=self.add_arrow_callback)
        b.pack(side='left')
        self.tool_tips['add_arrow'] = (ToolTip(b), 'Pridaj prechod')

        self.buttons['remove_state'] = b = ttk.Button(bf, command=self.remove_state_callback)
        b.pack(side='left')
        self.tool_tips['remove_state'] = (ToolTip(b), 'Odstráň stav')

        self.buttons['remove_arrow'] = b = ttk.Button(bf, command=self.remove_arrow_callback)
        b.pack(side='left')
        self.tool_tips['remove_arrow'] = (ToolTip(b), 'Odstráň prechod')

        self.buttons['set_initial'] = b = ttk.Button(bf, command=self.set_initial_callback)
        b.pack(side='left')
        self.tool_tips['set_initial'] = (ToolTip(b), 'Označ počiatočný stav')

        self.buttons['set_final'] = b = ttk.Button(bf, command=self.set_final_callback)
        b.pack(side='left')
        self.tool_tips['set_final'] = (ToolTip(b), 'Označ akceptačný stav')

        for name, button in self.buttons.items():
            button.bind('<Enter>', lambda _, name=name: self.button_enter(name))
            button.bind('<Leave>', lambda _, name=name: self.button_leave(name))
            self.button_leave(name)

    def button_enter(self, name):
        self.buttons[name].configure(image=self.icons_dict[name + "_active"])
        tool_tip, text = self.tool_tips[name]
        tool_tip.show_tip(text)

    def button_leave(self, name):
        self.buttons[name].configure(image=self.icons_dict[name])
        self.tool_tips[name][0].hide_tip()

    def center(self, state: State) -> Tuple[int, int]:
        coords = self.canvas.coords(state.circle_id)
        x = coords[0] + state.radius_i
        y = coords[1] + state.radius_i
        return x, y

    def add_state_callback(self):
        self.canvas.bind('<1>', self.add_state)

    def add_state(self, event):
        self.canvas.unbind('<1>')

        x = nearest_multiple(event.x, GRID_SIZE)
        y = nearest_multiple(event.y, GRID_SIZE)
        radius = DEFAULT_STATE_RADIUS
        circle = self.canvas.create_oval(x - radius, y - radius, x + radius, y + radius, fill='white')
        text = Entry(self.canvas, background='white', width=-1, justify='center', relief='flat')
        text_id = self.canvas.create_window(x, y, window=text)
        state = State(circle, text, text_id, radius, radius, [], [])
        self.states.append(state)

        self.canvas.tag_bind(circle, '<B1-Motion>', lambda e: self.move_state(e, state))
        self.canvas.tag_bind(circle, '<ButtonRelease-1>', lambda e: self.move_state(e, state, True))
        text.bind('<B1-Motion>', lambda e: self.move_state(e, state, False, True))
        text.bind('<ButtonRelease-1>', lambda e: self.move_state(e, state, True, True))
        text.bind('<KeyRelease>', lambda _: self.update_size_callback(state))
        text.focus_set()

    def update_size_callback(self, state):
        # print(state.text.get(), state.text.winfo_reqwidth())
        entry_width = state.text.winfo_reqwidth()
        state_width = state.radius_i * 2
        expand = entry_width > state_width - 5
        shrink = state.radius_i > DEFAULT_STATE_RADIUS and entry_width < state_width - 5
        if expand or shrink:
            new_radius = max(ceil((entry_width + 5) / 2), DEFAULT_STATE_RADIUS)
            self.update_size(state, new_radius)

    def update_size(self, state, radius):
        x, y = self.center(state)
        self.canvas.coords(state.circle_id, x - radius, y - radius, x + radius, y + radius)
        state.radius_i = radius
        if state.circle2_id:
            r = radius + FINAL_WIDTH
            self.canvas.coords(state.circle2_id, x - r, y - r, x + r, y + r)
            state.radius_o = r
        else:
            state.radius_o = radius
        self.move_arrows(state)

    def move_state(self, event, state: State, grid: bool = False, on_text: bool = False):
        ex = event.x + (event.widget.winfo_x() if on_text else 0)
        ey = event.y + (event.widget.winfo_y() if on_text else 0)
        x = nearest_multiple(ex, 50) if grid else ex
        y = nearest_multiple(ey, 50) if grid else ey
        r = state.radius_i
        r2 = state.radius_o
        self.canvas.coords(state.circle_id, x - r, y - r, x + r, y + r)
        self.canvas.coords(state.text_id, x, y)
        for a in state.from_arrows:
            coords = self.canvas.coords(a.to_state.circle_id)
            x2, y2 = coords[0] + a.to_state.radius_i, coords[1] + a.to_state.radius_i
            x1, y1, x2, y2 = arrow_coords(x, y, x2, y2, r, a.to_state.radius_o)
            self.canvas.coords(a.arrow_id, x1, y1, x2, y2)
            tx = x1 - (x1 - x2) / 2
            ty = y1 - (y1 - y2) / 2
            self.canvas.coords(a.text_id, tx, ty)
        for a in state.to_arrows:
            coords = self.canvas.coords(a.from_state.circle_id)
            x2, y2 = coords[0] + a.from_state.radius_i, coords[1] + a.from_state.radius_i
            x1, y1, x2, y2 = arrow_coords(x2, y2, x, y, a.from_state.radius_o, r)
            self.canvas.coords(a.arrow_id, x1, y1, x2, y2)
            tx = x1 - (x1 - x2) / 2
            ty = y1 - (y1 - y2) / 2
            self.canvas.coords(a.text_id, tx, ty)
            if self.exists(state, a.from_state):
                self.canvas.move(a.arrow_id, *move_arrow(x1, y1, x2, y2, AR_MO))
                self.canvas.move(a.text_id, *move_arrow(x1, y1, x2, y2, AR_MO))

        if state.loop_arrow:
            a = state.loop_arrow
            dr = DEFAULT_STATE_RADIUS
            self.canvas.coords(a.line_id, x + r2 - dr // 4, y - dr // 2, x + r2 + dr + dr // 4, y + dr // 2)
            self.canvas.coords(a.arrow_id, x + r2 - dr / 20 + 6, y + dr / 3 + 3, x + r2 - dr / 20, y + dr / 3)
            self.canvas.coords(a.text_id, x + r2 + dr + dr // 2, y)
        if state in self.final:
            self.canvas.coords(state.circle2_id, x - r2, y - r2, x + r2, y + r2)
        if self.initial and self.initial[0] == state:
            self.canvas.coords(self.initial[1], x - r2 - INITIAL_ARROW_LENGTH, y, x - r2, y)

    def move_arrows(self, state):
        pass

    def add_arrow_callback(self):
        for state in self.states:
            self.canvas.tag_bind(state.circle_id, '<1>', lambda _, s=state: self.add_arrow_first_state(s))
            state.text.bind('<1>', lambda _, s=state: self.add_arrow_first_state(s))

    def add_arrow_first_state(self, state):
        self.canvas.itemconfig(state.circle_id, width='5', outline='red')

        for s in self.states:
            self.canvas.tag_unbind(s.circle_id, '<1>')
            s.text.unbind('<1>')
            self.canvas.tag_bind(s.circle_id, '<1>', lambda _, s1=state, s2=s: self.add_arrow(s1, s2))
            s.text.bind('<1>', lambda _, s1=state, s2=s: self.add_arrow(s1, s2))

    def add_arrow(self, s1: State, s2: State):
        if self.exists(s1, s2):
            return
        self.canvas.itemconfig(s1.circle_id, width='1', outline='black')
        for state in self.states:
            self.canvas.tag_unbind(state.circle_id, '<1>')
            state.text.unbind('<1>')
        if s1 == s2:
            self.add_loop_arrow(s1)
        else:
            x1, y1 = self.center(s1)
            x2, y2 = self.center(s2)
            # self.canvas.create_arc(x1, y1, x2, y2, style='arc', start=0, extent=180)
            a = self.canvas.create_line(*arrow_coords(x1, y1, x2, y2, s1.radius_o, s2.radius_o), arrow='last')
            text = Entry(self.canvas, background=self.canvas['bg'], width=3, justify='center', relief='flat')
            text.focus_set()
            x = x1 - (x1 - x2) / 2
            y = y1 - (y1 - y2) / 2
            text_id = self.canvas.create_window(x, y, window=text)

            if self.exists(s2, s1):
                self.canvas.move(a, *move_arrow(x1, y1, x2, y2, 5))
                self.canvas.move(text_id, *move_arrow(x1, y1, x2, y2, AR_MO))

            arrow = Arrow(-1, a, s1, s2, text, text_id, 'center')
            self.arrows.append(arrow)
            s1.from_arrows.append(arrow)
            s2.to_arrows.append(arrow)

    def add_loop_arrow(self, state):
        x, y = self.center(state)
        y1 = y - state.radius_i
        dr = DEFAULT_STATE_RADIUS
        r = state.radius_o
        # a = self.canvas.create_oval(x + state.radius_o, y1, x + 2 * state.radius_i + state.radius_o, y + state.radius_i)
        l = self.canvas.create_arc(x + r - dr // 4, y - dr // 2, x + r + dr + dr // 4, y + dr // 2, style='arc',
                                   start=225, extent=270)
        a = self.canvas.create_line(x + r - dr / 20 + 6, y + dr / 3 + 3, x + r - dr / 20, y + dr / 3, arrow='last')
        text = Entry(self.canvas, background=self.canvas['bg'], width=-1, justify='center', relief='flat')
        text.focus_set()
        text_id = self.canvas.create_window(x + r + dr + dr // 2, y, window=text, anchor='w')
        arrow = Arrow(l, a, state, state, text, text_id, 'center')
        self.arrows.append(arrow)
        state.loop_arrow = arrow

    def exists(self, s1, s2):
        if s1 == s2 and s1.loop_arrow:
            return True
        for a in s1.from_arrows:
            if a in s2.to_arrows:
                return True
        return False

    def remove_arrow_callback(self):
        for a in self.arrows:
            self.canvas.itemconfig(a.arrow_id, activewidth=5)
            if a.line_id >= 0:
                self.canvas.itemconfig(a.line_id, activewidth=5)
            self.canvas.tag_bind(a.arrow_id, '<1>', lambda _, ar=a: self.remove_arrow(ar))

    def remove_arrow(self, arrow: Arrow):
        for a in self.arrows:
            self.canvas.itemconfig(a.arrow_id, activewidth=1)
            if a.line_id >= 0:
                self.canvas.itemconfig(a.line_id, activewidth=1)
            self.canvas.tag_unbind(a.arrow_id, '<1>')

        if arrow.from_state == arrow.to_state:
            arrow.from_state.loop_arrow = None
        else:
            arrow.from_state.from_arrows.remove(arrow)
            arrow.to_state.to_arrows.remove(arrow)
        self.canvas.delete(arrow.line_id)
        self.canvas.delete(arrow.arrow_id)
        self.canvas.delete(arrow.text_id)

        self.arrows.remove(arrow)

    def remove_state_callback(self):
        for state in self.states:
            self.canvas.tag_bind(state.circle_id, '<1>', lambda _, s=state: self.remove_state(s))
            state.text.bind('<1>', lambda _, s=state: self.remove_state(s))

    def remove_state(self, state: State):
        for s in self.states:
            self.canvas.tag_unbind(s.circle_id, '<1>')
            s.text.unbind('<1>')

        for arrow in state.from_arrows.copy():
            self.remove_arrow(arrow)
        for arrow in state.to_arrows.copy():
            self.remove_arrow(arrow)
        if state.loop_arrow:
            self.remove_arrow(state.loop_arrow)

        self.canvas.delete(state.circle_id)
        self.canvas.delete(state.text_id)
        if state in self.final:
            self.final.remove(state)
            self.canvas.delete(state.circle2_id)
        if self.initial and self.initial[0] == state:
            self.canvas.delete(self.initial[1])
            self.initial = None

        self.states.remove(state)

    def set_initial_callback(self):
        for state in self.states:
            self.canvas.tag_bind(state.circle_id, '<1>', lambda _, s=state: self.set_initial(s))
            state.text.bind('<1>', lambda _, s=state: self.set_initial(s))

    def set_initial(self, state):
        for s in self.states:
            self.canvas.tag_unbind(s.circle_id, '<1>')
            s.text.unbind('<1>')

        if self.initial:
            if self.initial[0] == state:
                return
            self.canvas.delete(self.initial[1])

        # is_final = state in self.final
        # x, y, _, _ = self.canvas.coords(state.circle2_id) if is_final else self.canvas.coords(state.circle_id)
        # y = y + state.radius_o
        x, y = self.center(state)
        x -= state.radius_o
        x1 = x - INITIAL_ARROW_LENGTH
        self.initial = (state, self.canvas.create_line(x1, y, x, y, arrow='last'))

    def set_final_callback(self):
        for state in self.states:
            self.canvas.tag_bind(state.circle_id, '<1>', lambda _, s=state: self.set_final(s))
            state.text.bind('<1>', lambda _, s=state: self.set_final(s))

    def set_final(self, state):
        for s in self.states:
            self.canvas.tag_unbind(s.circle_id, '<1>')
            s.text.unbind('<1>')

        w = FINAL_WIDTH
        if state in self.final:
            self.canvas.delete(state.circle2_id)
            state.circle2_id = None
            state.radius_o -= w
            if self.initial and self.initial[0] == state:
                self.canvas.move(self.initial[1], w, 0)
            if state.loop_arrow:
                self.canvas.move(state.loop_arrow.arrow_id, -w, 0)
                self.canvas.move(state.loop_arrow.text_id, -w, 0)
            self.final.remove(state)

        else:
            self.final.append(state)
            state.radius_o += w
            x1, y1, x2, y2 = self.canvas.coords(state.circle_id)
            state.circle2_id = self.canvas.create_oval(x1 - w, y1 - w, x2 + w, y2 + w)
            if self.initial and self.initial[0] == state:
                self.canvas.move(self.initial[1], -w, 0)
            if state.loop_arrow:
                self.canvas.move(state.loop_arrow.arrow_id, w, 0)
                self.canvas.move(state.loop_arrow.text_id, w, 0)

    def create_diagram(self, fa: FA):
        while self.states:
            self.remove_state(self.states[0])
        self.name_entry.delete(0, 'end')
        self.name_entry.insert(0, fa.name)
        x = 150
        y = 50
        for state in fa.states:
            self.add_state(Coords(x, y))
            self.state_dict[state] = self.states[-1]
            self.state_dict[state].text.insert(0, state)
            if fa.initial == state:
                self.set_initial(self.state_dict[state])
            if state in fa.finals:
                self.set_final(self.state_dict[state])
            self.update_size_callback(self.state_dict[state])
            x += 100
            if x > 500:
                x = 150
                y += 100
        for key in fa.delta:
            if type(fa) == DFA:
                s1 = self.state_dict[key[0]]
                s2 = self.state_dict[fa.delta[key]]
                if self.exists(s1, s2):
                    self.arrow_dict[(key[0], fa.delta[key])].text.insert('end', ", " + key[1])
                else:
                    self.add_arrow(s1, s2)
                    self.arrow_dict[(key[0], fa.delta[key])] = self.arrows[-1]
                    self.arrows[-1].text.insert(0, key[1])
            else:
                for r in fa.delta[key]:
                    s1 = self.state_dict[key[0]]
                    s2 = self.state_dict[r]
                    if self.exists(s1, s2):
                        self.arrow_dict[(key[0], r)].text.insert('end', ", " + key[1])
                    else:
                        self.add_arrow(s1, s2)
                        self.arrow_dict[(key[0], r)] = self.arrows[-1]
                        self.arrows[-1].text.insert(0, key[1])


@dataclass
class Coords:
    x: int
    y: int
