import os

import numpy as np
from sklearn.metrics import confusion_matrix, recall_score, precision_score
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input, optimizers
from tensorflow.python.keras.layers import Convolution2D, MaxPooling2D, concatenate, Dense, Flatten
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizer_v2.adam import Adam as Adam
from tensorflow.python.keras.utils.vis_utils import plot_model
from sklearn.utils import shuffle
from dcm_reader.patient import Patient
from datetime import datetime
from matplotlib import pyplot as plt
from tensorflow.python.keras.models import load_model


def create_simple_model():
    # channel 1
    inputs1 = Input(shape=(512, 512, 1))
    conv1 = Convolution2D(filters=16, kernel_size=(3, 3), activation='relu')(inputs1)
    pool1 = MaxPooling2D(pool_size=(3, 3))(conv1)
    conv2 = Convolution2D(filters=64, kernel_size=(3, 3), activation='relu')(pool1)

    # channel 2
    conv1b = Convolution2D(filters=64, kernel_size=(9, 9), activation='relu')(inputs1)
    pool2 = MaxPooling2D(pool_size=(3, 3))(conv1b)

    # merge
    merged = concatenate([conv2, pool2])
    fc1 = Dense(100, activation='relu')(merged)
    fc2 = Dense(50, activation='relu')(fc1)
    flatten_layer = Flatten()(fc2)
    output = Dense(1, activation='sigmoid', name="output_layer")(flatten_layer)
    model = Model(inputs=inputs1, outputs=output)

    optimizer = Adam(learning_rate=0.0001)
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
    return model


#
#
# # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
#
POSITIVE_DATA_FOLDER = os.path.join("..", "data", "meningeomy")
NEGATIVE_DATA_FOLDER = os.path.join("..", "data", "mozgy", "mozgy_1")

MENINGEOMS_DIRECTORY = os.listdir(POSITIVE_DATA_FOLDER)
NEGATIVE_PATIENTS_DIRECTORY = os.listdir(NEGATIVE_DATA_FOLDER)

positive_patients = []
is_positive = True

ERRONEOUS_PATIENTS = ["Mozog_ID500_MozogCT_Anonymized_1.3.6.1.4.1.20468.2.152.0.1.105290.436261",
                      "Mozog_ID500_MozogCT_Anonymized_1.3.6.1.4.1.20468.2.152.0.1.99697.195841",
                      "Mozog_ID500_MozogCT_Anonymized_1.3.6.1.4.1.20468.2.152.0.1.111609.829613",
                      "Mozog_ID500_MozogCCT_Anonymized_1.3.6.1.4.1.20468.2.152.0.1.108152.598040",
                      "Mozog_ID500_MozogCT_Anonymized_1.3.6.1.4.1.20468.2.152.0.1.136409.283386"
                      ]

for meningeoms_folder in MENINGEOMS_DIRECTORY:
    path = os.path.join(POSITIVE_DATA_FOLDER, meningeoms_folder)
    positive_patients_directory = os.listdir(path)

    for patient_directory in positive_patients_directory:

        if patient_directory in ERRONEOUS_PATIENTS:
            continue

        patient_directory_path = os.path.join(path, patient_directory)
        patient = Patient(patient_directory_path, is_positive=is_positive)
        positive_patients.append(patient)

X_data_positive, Y_data_positive = [], []
for patient in positive_patients:
    x, y = patient.get_positive_train_examples(new_size=512)
    X_data_positive.extend(x)
    Y_data_positive.extend(y)

negative_patients = []
is_positive = False
for patient_directory_name in NEGATIVE_PATIENTS_DIRECTORY:
    patient_directory_path = os.path.join(NEGATIVE_DATA_FOLDER, patient_directory_name)
    patient = Patient(patient_directory_path, is_positive=is_positive)
    negative_patients.append(patient)

X_data_negative, Y_data_negative = [], []
for patient in negative_patients:
    x, y = patient.get_negative_train_examples(new_size=512)
    X_data_negative.extend(x)
    Y_data_negative.extend(y)

X_data = X_data_positive + X_data_negative
Y_data = Y_data_positive + Y_data_negative

X_data = np.array(X_data)
Y_data = np.array(Y_data).reshape((-1, 1))
X_data, Y_data = shuffle(X_data, Y_data)

print("Pocet pozit = {}".format(len(X_data_positive)))
print("Pocet negat = {}".format(len(X_data_negative)))

X_train, X_test, y_train, y_test = train_test_split(X_data, Y_data, test_size=0.25, random_state=42)

# my_model = create_simple_model()
# #
# history = my_model.fit(X_train, y_train, epochs=13, batch_size=1)
# # train_score = my_model.evaluate(X_train, y_train)
# # print("\n\ntrain loss: {} | train acc: {}\n".format(train_score[0], train_score[1]))
# my_model.save("weigths/model1.h5")


my_model = load_model("weigths/model1.h5")

del X_train
del y_train
#
#
test_score = my_model.evaluate(X_test, y_test)
print("\n\ntest loss: {} | test acc: {}".format(test_score[0], test_score[1]))

predictions = my_model.predict(X_test)
predict = np.array(predictions)
predict[predict < 0.5] = 0
predict[predict >= 0.5] = 1
confusion_mat = confusion_matrix(y_test, predict)
recall = recall_score(y_test, predict)
precision = precision_score(y_test, predict)

print("Recall: {}".format(recall))  # sensitivity
print("Precision: {}".format(precision))

print(confusion_mat)
print("True negatives: {}".format(confusion_mat[0, 0]))
print("False negatives: {}".format(confusion_mat[1, 0]))
print("True positive: {}".format(confusion_mat[1, 1]))
print("False positive: {}".format(confusion_mat[0, 1]))
