from re import split as resplit
from typing import Set, Dict, Tuple, List, FrozenSet
from copy import deepcopy
from itertools import chain, combinations

from SaveData import SaveNFAToDFA, SaveEpsilonFreeNFA, SaveFA, SaveMinimalisation

CAP_SIGMA = 'Σ'
DELTA = '𝛿'
EPSILON = 'ε'

from Automata import NFA, DFA, FA, TEST_NFA
from DiagramEditor import DiagramEditor, State, Arrow


def diagram_to_dfa(d: DiagramEditor) -> DFA | Tuple[str, List[str]]:
    state_dict: Dict[str, State] = {}
    arrow_dict: Dict[Tuple[str, str], Arrow] = {}
    if not d.initial:
        return 'no initial state', []
    initial = d.initial[0].text.get()
    if not d.final:
        return 'no final states', []
    finals: Set[str] = set([s.text.get() for s in d.final])
    states: Set[str] = set()
    for s in d.states:
        n = s.text.get()
        if n == '':
            return 'no state name', []
        if n in states:
            return 'duplicate state name', [n]
        states.add(n)
        state_dict[n] = s
    alphabet: Set[str] = set()
    delta: Dict[Tuple[str, str], str] = {}
    if not len(d.arrows):
        return 'no transitions', []
    for t in d.arrows:
        letters = set(resplit("[,; ]", t.text.get()))
        letters.discard('')
        if not letters:
            return 'empty transit', [t.from_state.text.get(), t.to_state.text.get()]
        if EPSILON in letters:
            return 'epsilon in dfa', [t.from_state.text.get(), t.to_state.text.get()]
        alphabet |= letters
        for letter in letters:
            de = (t.from_state.text.get(), letter)
            if de in delta:
                return 'nondeterministic transit in dfa', [*de]
            else:
                delta[de] = t.to_state.text.get()
                arrow_dict[(t.from_state.text.get(), t.to_state.text.get())] = t
    for s in states:
        for l in alphabet:
            if not (s, l) in delta:
                return 'no transit for tuple ', [s, l]
    d.state_dict = state_dict
    d.arrow_dict = arrow_dict
    dfa = DFA(states, alphabet, delta, initial, finals, d.name_entry.get())
    return dfa


def diagram_to_nfa(d: DiagramEditor) -> NFA | Tuple[str, List[str]]:
    state_dict: Dict[str, State] = {}
    arrow_dict: Dict[Tuple[str, str], Arrow] = {}
    if not d.initial:
        return 'no initial state', []
    initial = d.initial[0].text.get()
    if not d.final:
        return 'no final states', []
    finals: Set[str] = set([s.text.get() for s in d.final])
    states: Set[str] = set()
    for s in d.states:
        n = s.text.get()
        if n == '':
            return 'no state name', []
        if n in states:
            return 'duplicate state name', [n]
        states.add(n)
        state_dict[n] = s
    alphabet: Set[str] = set()
    delta: Dict[Tuple[str, str], Set[str]] = {}
    if not len(d.arrows):
        return 'no transitions', []
    for t in d.arrows:
        letters = set(resplit("[,; ]", t.text.get()))
        letters.discard('')
        if not letters:
            return 'empty transit', [t.from_state.text.get(), t.to_state.text.get()]
        alphabet |= letters
        for letter in letters:
            de = (t.from_state.text.get(), letter)
            if de in delta:
                delta[de].add(t.to_state.text.get())
            else:
                delta[de] = {t.to_state.text.get()}
        arrow_dict[(t.from_state.text.get(), t.to_state.text.get())] = t
    alphabet.discard(EPSILON)
    d.state_dict = state_dict
    d.arrow_dict = arrow_dict
    nfa = NFA(states, alphabet, delta, initial, finals, d.name_entry.get())
    return nfa


def epsilon_free_nfa(nfa: NFA, save: SaveEpsilonFreeNFA = None) -> NFA:
    closure: Dict[str, Set[str]] = epsilon_closure(nfa)
    if save:
        save.add_closure(closure)
    new_nfa = deepcopy(nfa)
    new_nfa.name += "'"
    for state in closure:
        for st in closure[state]:
            for letter in new_nfa.alphabet:
                if (st, letter) in new_nfa.delta:
                    if (state, letter) in new_nfa.delta:
                        new_nfa.delta[(state, letter)] |= new_nfa.delta[(st, letter)]
                    else:
                        new_nfa.delta[(state, letter)] = new_nfa.delta[(st, letter)]
            if st in new_nfa.finals:
                new_nfa.finals.add(state)

    for state in new_nfa.states:
        if (state, EPSILON) in new_nfa.delta:
            del new_nfa.delta[state, EPSILON]
    return remove_unreachable_states(new_nfa)


def epsilon_closure(nfa: NFA) -> Dict[str, Set[str]]:
    closure: Dict[str, Set[str]] = {}
    processed: Dict[str, Set[str]] = {}
    for state in nfa.states:
        closure[state] = set(state)
        processed[state] = set()
    for item in nfa.delta:
        if item[1] == EPSILON:
            closure[item[0]] |= (nfa.delta[item])
    processing = deepcopy(closure)
    while processing:
        for state in processing.copy().keys():
            s = deepcopy(processing[state])
            for st in s:
                if st != state and st in processing:
                    new = closure[st] - processed[state]
                    processing[state] |= new
                    closure[state] |= new
                    processed[state] |= processing[st]
                processing[state].remove(st)
            if not processing[state]:
                processing.pop(state)
    return closure


def nfa_to_dfa(_nfa: NFA, save: SaveNFAToDFA = None) -> DFA:
    if has_epsilon(_nfa):
        sev = save.getSaveEpsilonFree() if save else None
        nfa = epsilon_free_nfa(_nfa, sev)
        if sev: sev.add_save_data2(SaveFA(fa=nfa))
    else:
        nfa = _nfa
    # noinspection PyTypeChecker
    new_states: Set[Tuple[str]] = set(powerset(nfa.states))
    new_finals: Set[Tuple[str]] = set()
    for state in new_states:
        if set(state) & nfa.finals:
            new_finals.add(state)
    delta: Dict[Tuple[Tuple[str], str], Set[str]] = {}
    for state in new_states:
        for letter in nfa.alphabet:
            delta[(state, letter)] = set()
            for old_state in state:
                if (old_state, letter) in nfa.delta:
                    delta[(state, letter)] |= (nfa.delta[(old_state, letter)])
    state_to_string: Dict[FrozenSet[str], str] = {}
    state_strings: Set[str] = set()
    for state in new_states:
        s: str = '{' + ','.join(sorted(state)) + '}'
        state_strings.add(s)
        state_to_string[frozenset(state)] = s
    finals_strings: Set[str] = set()
    for state in new_finals:
        finals_strings.add(state_to_string[frozenset(state)])
    delta_string: Dict[Tuple[str, str], str] = {}
    for key, value in delta.items():
        delta_string[(state_to_string[frozenset(key[0])], key[1])] = state_to_string[frozenset(value)]
    initial = '{' + nfa.initial + '}'
    alphabet = deepcopy(nfa.alphabet)
    name = nfa.name + "'"
    dfa = DFA(state_strings, alphabet, delta_string, initial, finals_strings, name, nfa.sets_name)
    return dfa


def remove_unreachable_states(fa: DFA | NFA) -> DFA | NFA:
    reachable_states: Set[str] = set()
    processing: List[str] = [fa.initial]
    while processing:
        state = processing.pop()
        if state in reachable_states:
            continue
        reachable_states.add(state)
        for letter in fa.alphabet:
            if (state, letter) in fa.delta:
                if type(fa.delta[state, letter]) == set:
                    processing.extend(fa.delta[state, letter])
                elif type(fa.delta[state, letter]) == str:
                    processing.append(fa.delta[state, letter])
    new_fa = deepcopy(fa)
    for key in fa.delta:
        if key[0] not in reachable_states:
            new_fa.delta.pop(key)
    new_fa.states &= reachable_states
    new_fa.finals &= reachable_states
    return new_fa


def has_epsilon(nfa: NFA) -> bool:
    for key in nfa.delta:
        if key[1] == EPSILON:
            return True
    return False


def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))


def minimalisation(dfa: DFA, save: SaveMinimalisation = None) -> DFA:
    partitions: List[Set[str]] = [deepcopy(dfa.finals)]
    non_finals = deepcopy(dfa.states) - dfa.finals
    if non_finals:
        partitions.append(non_finals)
    if save:
        save.change_delta(dfa.sets_name[2])
        save.add_first_partition(partitions)
    equivalents: List[Set[str]] = []
    st_to_par: Dict[str, Set[str]] = {s: partitions[0] if s in dfa.finals else partitions[1] for s in dfa.states}
    while partitions:
        new_st_to_par: Dict[str, Set[str]] = {}
        new_partitions: List[Set[str]] = []
        for p in partitions:
            if save:
                save.add_partition(p)
            unchanged = True
            next_partition: List[Set[str]] = []
            for state in p:
                if not next_partition:
                    next_partition.append({state})
                    new_st_to_par[state] = next_partition[-1]
                    continue
                for np in next_partition:
                    for state2 in np:  # to get any element from set
                        break
                    for letter in dfa.alphabet:
                        if st_to_par[dfa.delta[(state, letter)]] != st_to_par[dfa.delta[(state2, letter)]]:
                            if save:
                                save.add_p_data(state, state2, letter, dfa.delta[(state, letter)],
                                                dfa.delta[(state2, letter)])
                            break
                    else:
                        np.add(state)
                        new_st_to_par[state] = np
                        break
                else:
                    next_partition.append({state})
                    new_st_to_par[state] = next_partition[-1]
                    unchanged = False
            if unchanged:
                equivalents.append(p)
                if save:
                    save.unchanged()
            else:
                new_partitions += next_partition
                if save:
                    save.add_new_partition(next_partition)
        partitions, new_partitions = new_partitions, []
        st_to_par, new_st_to_par = new_st_to_par, {}
    if save:
        save.add_final_partition(equivalents)
    new_states: Set[str] = set()
    new_finals: Set[str] = set()
    st_to_old_st: Dict[str, str] = {}
    equiv_to_state: Dict[FrozenSet[str]: str] = {}
    old_st_to_equiv: Dict[str: FrozenSet[str]] = {}
    for e in equivalents:
        ns: str = ''
        for s in (sorted(e)):
            ns += s + ", "
            old_st_to_equiv[s] = frozenset(e)
        ns = ns[:-2]
        new_states.add(ns)
        equiv_to_state[frozenset(e)] = ns
        if s in dfa.finals:
            new_finals.add(ns)
        st_to_old_st[ns] = s
    new_delta: Dict[Tuple[str, str], str] = {}
    for state in new_states:
        for letter in dfa.alphabet:
            new_delta[(state, letter)] = equiv_to_state[old_st_to_equiv[dfa.delta[(st_to_old_st[state], letter)]]]
    return DFA(new_states, deepcopy(dfa.alphabet), new_delta, equiv_to_state[old_st_to_equiv[dfa.initial]], new_finals,
               dfa.name + "'", dfa.sets_name)


def concatenate_regex(left:str, middle:str, right:str) -> str:
    if left == EPSILON:
        left = ''
    elif '!' in left:
        left = '(' + left.replace('!', '') + ')'
    if right == EPSILON and (left != '' or middle != ''):
        right = ''
    elif '!' in right:
        right = '(' + right.replace('!', '') + ')'
    return left + middle + right

def regex_is_plus(r: str) -> bool:
    '''
    chceks if regex is of type "rr*" or "r*r", which is equal to "r+"
    '''
    l = len(r)
    if l == 1 or l % 2 == 0:
        return False
    left = r[:l//2]
    right = r[l//2:-1]
    if left == right and r[-1] == '*':
        return True
    right = r[l//2+1:]
    if left == right and r[l//2+1] == '*':
        return True
    return False


def add_regex(left: str, right: str) -> str:
    # R+R = R
    if left == right:
        return left
    # epsilon + RR* = epsilon + R*R = R*
    if left == EPSILON and regex_is_plus(right):
        return right[:len(right)//2] + '*'
    if right == EPSILON and regex_is_plus(left):
        return left[:len(left)//2] + '*'
    return left + "!+" + right


def update_reverse_delta(reverse_delta: Dict[Tuple[str, str], str], key: Tuple[str, str], string: str):
    if key in reverse_delta:
        left = reverse_delta[key]
        reverse_delta[key] = add_regex(left, string)
    else:
        reverse_delta[key] = string

def regular_expression(_fa: FA) -> str:
    fa: FA = deepcopy(_fa)

    fa.delta[("init", EPSILON)] = fa.initial
    fa.initial = 'init'
    for state in fa.finals:
        fa.delta[(state, EPSILON)] = 'final'
    fa.finals = {'final'}

    # maps pair of states to string such that (first state, string) -> second state is in fa.delta
    reverse_delta: Dict[Tuple[str, str], str] = {}
    for key, value in fa.delta.items():
        if isinstance(value, set):
            for s in value:
                update_reverse_delta(reverse_delta, (key[0], s), key[1])
        else:
            update_reverse_delta(reverse_delta, (key[0], value), key[1])

    while fa.states:
        state: str = ''
        for state in fa.states:
            break
        in_going: Set[Tuple[str, str]] = set()
        out_going: Set[Tuple[str, str]] = set()
        loop: str = ''
        all = []
        for (from_s, to_s), string in reverse_delta.items():
            if from_s == state:
                if to_s == state:
                    loop = string
                else:
                    out_going.add((to_s, string))
                all.append((from_s, to_s))
            elif to_s == state:
                in_going.add((from_s, string))
                all.append((from_s, to_s))
        for key in all:
            del reverse_delta[key]

        if loop and loop != EPSILON:
            if len(loop) > 1:
                loop = '(' + loop.replace('!', '') + ')'
            loop += '*'
        for in_state, in_string in in_going:
            for out_state, out_string in out_going:
                key = (in_state, out_state)
                string = concatenate_regex(in_string, loop, out_string)
                update_reverse_delta(reverse_delta, key, string)
        fa.states.remove(state)
    return reverse_delta[('init', 'final')].replace('!', '')


if __name__ == '__main__':
    #print(regular_expression(TEST_NFA))
    E_NFA = epsilon_free_nfa(TEST_NFA)
    print(E_NFA)
    #print(regular_expression(E_NFA))
    _DFA = remove_unreachable_states(nfa_to_dfa(TEST_NFA))
    print(_DFA)
    print(regular_expression(_DFA))
    MIN_DFA = minimalisation(_DFA)
    print(MIN_DFA)
    print(regular_expression(MIN_DFA))
