import os
import pydicom as pd
from pydicom import FileDataset

from dcm_reader.exceptions import LineWithWrongIDsException
from dcm_reader.measurement import PatientMeningeomaMeasurements
from dcm_reader.patient_metadata import Metadata
import matplotlib.pyplot as plt
from dcm_reader.utils import *
import random
from dcm_reader.augmentation import *


# import augmentation as aug


class Patient:
    SERIES_DESC_KEY = "SeriesDescription"
    SERIES_DESC_VALUE = "Head  3.0  MPR"
    INSTANCE_NUMBER = "InstanceNumber"

    def __init__(self, patient_directory_path, is_positive=False):
        self.patient_directory_path = patient_directory_path
        self.results_filename = ""
        self.metadata_filename = ""
        self.slices = []
        self.metadata = None
        self.measurements = None
        self.is_positive = is_positive
        self._load_files()

    @staticmethod
    def is_dicom_file(file_name):
        return str.lower(file_name[-4:]) == ".dcm"

    @staticmethod
    def is_txt_file(file_name):
        return str.lower(file_name[-4:]) == ".txt"

    @staticmethod
    def is_json_file(file_name):
        return str.lower(file_name[-5:]) == ".json"

    def _load_files(self):
        patient_directory_items = os.listdir(self.patient_directory_path)

        for item in patient_directory_items:
            item_path = os.path.join(self.patient_directory_path, item)
            if Patient.is_dicom_file(item):
                slice = pd.dcmread(item_path, force=True)
                if slice.get(Patient.SERIES_DESC_KEY) == Patient.SERIES_DESC_VALUE:
                    self.slices.append(slice)
            elif Patient.is_txt_file(item):
                self.results_filename = item_path
                try:
                    self.measurements = PatientMeningeomaMeasurements.create(self.results_filename)
                except LineWithWrongIDsException:
                    print("Line with wrong IDs:" + self.patient_directory_path)

            else:
                self.metadata_filename = item_path
                self.metadata = Metadata.create(self.metadata_filename)

        # if self.measurements is None:
        #     print("Missing results file: " + self.patient_directory_path)

    def compare_n_lines_and_n_meningeomas(self):

        if len(self.measurements.lines) / 2 != self.metadata.meningeoma_count:
            raise Exception("Number of measurements and meningeomas is not same! ")

    def _find_slice_by_SOPInstanceUID(self, key):
        for slice in self.slices:
            if slice.SOPInstanceUID == key:
                # print(slice.filename)
                return slice

    def _find_slice_by_InstanceNumber(self, instance_number):
        for slice in self.slices:
            if slice.get(Patient.INSTANCE_NUMBER) == instance_number:
                return slice
        return None

    def plot_meningeomas(self, plot="only_ct"):

        for i, measurement in enumerate(self.measurements.measurements):
            SOPInstanceUID = measurement.SOPInstanceUID
            slice: FileDataset = self._find_slice_by_SOPInstanceUID(SOPInstanceUID)

            pixel_array = slice.pixel_array
            (height, width) = pixel_array.shape

            plot_ct_image(plt, slice, pixel_array)

            if plot == "only_ct":
                plt.show()
                return

            if plot == "dots":
                line1_P0 = scale_point(measurement.line1.get_point0(), width, height)
                line1_P1 = scale_point(measurement.line1.get_point1(), width, height)

                line2_P0 = scale_point(measurement.line2.get_point0(), width, height)
                line2_P1 = scale_point(measurement.line2.get_point1(), width, height)

                plot_point(plt, line1_P0)
                plot_point(plt, line1_P1)
                plot_point(plt, line2_P0)
                plot_point(plt, line2_P1)

            elif plot == "lines":
                line1_P0 = scale_point(measurement.line1.get_point0(), width, height)
                line1_P1 = scale_point(measurement.line1.get_point1(), width, height)

                line2_P0 = scale_point(measurement.line2.get_point0(), width, height)
                line2_P1 = scale_point(measurement.line2.get_point1(), width, height)

                plot_line(plt, line1_P0, line1_P1, color="blue", linewidth=0.8)
                plot_line(plt, line2_P0, line2_P1, color="blue", linewidth=0.8)

            elif plot == "bounding_box":
                (left_upper_corner, right_upper_corner, right_down_corner,
                 left_down_corner) = measurement.get_corner_points()

                left_upper_corner_scaled = scale_point(left_upper_corner, width, height)
                right_upper_corner_scaled = scale_point(right_upper_corner, width, height)
                right_down_corner_scaled = scale_point(right_down_corner, width, height)
                left_down_corner_scaled = scale_point(left_down_corner, width, height)

                plot_line(plt, left_upper_corner_scaled, right_upper_corner_scaled, color="blue", linewidth=0.8)
                plot_line(plt, right_upper_corner_scaled, right_down_corner_scaled, color="blue", linewidth=0.8)
                plot_line(plt, right_down_corner_scaled, left_down_corner_scaled, color="blue", linewidth=0.8)
                plot_line(plt, left_down_corner_scaled, left_upper_corner_scaled, color="blue", linewidth=0.8)

            plt.show()

    def get_positive_train_examples(self, new_size=512, augmentation=True) -> (list, list):
        X_train = []
        Y_train = []
        if self.is_positive:
            for measurement in self.measurements.measurements:
                SOPInstanceUID = measurement.SOPInstanceUID
                slice = self._find_slice_by_SOPInstanceUID(SOPInstanceUID)
                resized_pxl_array = resize_image(slice.pixel_array, new_size)
                X_train.append(resized_pxl_array)
                Y_train.append(1)

                if augmentation is True:
                    self._vertical_horizontal_flips(X_train, Y_train, resized_pxl_array)

                InstanceNumber = slice.get(Patient.INSTANCE_NUMBER)

                upper_slice = self._find_slice_by_InstanceNumber(InstanceNumber + 1)
                if upper_slice is not None:
                    resized_pxl_array = resize_image(upper_slice.pixel_array, new_size)
                    X_train.append(resized_pxl_array)
                    Y_train.append(1)

                    if augmentation is True:
                        self._vertical_horizontal_flips(X_train, Y_train, resized_pxl_array)

                lower_slice = self._find_slice_by_InstanceNumber(InstanceNumber - 1)
                if lower_slice is not None:
                    resized_pxl_array = resize_image(lower_slice.pixel_array, new_size)
                    X_train.append(resized_pxl_array)
                    Y_train.append(1)

                    if augmentation is True:
                        self._vertical_horizontal_flips(X_train, Y_train, resized_pxl_array)

        return X_train, Y_train

    @staticmethod
    def _vertical_horizontal_flips(X_train, Y_train, pixel_array):
        vertically_flipped = vertical_flip(pixel_array)
        horizontally_flipped1 = horizontal_flip(pixel_array)
        horizontally_flipped2 = horizontal_flip(vertically_flipped)
        X_train.append(vertically_flipped)
        Y_train.append(1)
        X_train.append(horizontally_flipped1)
        Y_train.append(1)
        X_train.append(horizontally_flipped2)
        Y_train.append(1)

    def get_negative_train_examples(self, new_size=512) -> (list, list):
        X_train = []
        Y_train = []
        count = 0

        indices = list(range(len(self.slices)))
        random.shuffle(indices)

        for index in indices:
            slice = self.slices[index]
            resized_pxl_array = resize_image(slice.pixel_array, new_size)
            X_train.append(resized_pxl_array)
            Y_train.append(0)

        return X_train, Y_train
