import threading
import argparse
import traceback

import torch
import json

from pykeen.triples import TriplesFactory
from pykeen.pipeline import pipeline

from data_filter import data_normalization

cuda = torch.cuda.is_available()


def train(triples_factory, epochs):
    _, testing = triples_factory.split([.99, .01])
    training = triples_factory.split([1, ])[0]

    result = pipeline(
        training=training,
        testing=testing,
        model='TransE',
        stopper='early',
        epochs=epochs
    )
    result.save_to_directory('./test_unstratified_transe')


def test(triples_factory, thread_number, batch_size):
    malware = "http://www.semanticweb.org/peter/ontologies/2020/malware-ontology#malware"
    model = torch.load('./test_unstratified_transe/trained_model.pkl')
    with open("test.txt", "r") as file:
        testing_indices = json.load(file)

    entities = list(testing_indices.keys())
    entity_lock = threading.Lock()

    result = []
    result_lock = threading.Lock()

    def get_entity(bs):
        if len(entities) == 0:
            return None

        while bs > 0 and len(entities) > 0:
            yield entities.pop()
            bs -= 1

    def run():
        misclassified = {"positive": {x: 0 for x in range(1, 11)}, "negative": {x:0 for x in range(1, 11)}}
        while True:
            entity_lock.acquire()
            entity_list = get_entity(batch_size)
            if entity_list is not None:
                entity_list = list(entity_list)
            entity_lock.release()

            if entity_list is None or len(entity_list) == 0:
                result_lock.acquire()
                result.append(misclassified)
                result_lock.release()
                break
            for entity in entity_list:
                relation_id = torch.as_tensor(triples_factory.relation_to_id["http://www.w3.org/1999/02/22-rdf-syntax-ns#type"])
                subject_id = torch.as_tensor(triples_factory.entity_to_id[entity])
                pom = torch.LongTensor([[subject_id, relation_id], ])
                if cuda:
                    score = model.score_t(pom).cpu().detach().numpy().tolist()[0]
                else:
                    score = model.score_t(pom).detach().numpy().tolist()[0]

                score = zip(score, triples_factory.entity_to_id.keys())
                score = [x[1] for x in sorted(score, key=lambda x: x[0], reverse=True)]

                for i in range(10):
                    if malware in score[:i + 1]:
                        if not testing_indices[entity]:
                            misclassified["negative"][i + 1] += 1
                    else:
                        if testing_indices[entity]:
                            misclassified["positive"][i + 1] += 1
    threads = []
    for _ in range(thread_number):
        threads.append(threading.Thread(target=run))
        threads[-1].start()

    for i in range(thread_number):
        threads[i].join()

    misclassified = {"positive": {x: 0 for x in range(1, 11)}, "negative": {x: 0 for x in range(1, 11)}}
    for m in result:
        for x in range(1, 11):
            misclassified["positive"][x] += m["positive"][x]
            misclassified["negative"][x] += m["negative"][x]
    print(f"Total number of examples: {len(testing_indices)}")
    print("Misclassified: ")
    print(json.dumps(misclassified, indent=4))


def main():

    parser = argparse.ArgumentParser(description='Main script for TransE embedding ... still in development ' +
                                                 '- multithread version')
    parser.add_argument("--normalize-data", action='store_true', required=False, help="Option for normalizing data for main script")
    parser.add_argument("--train", action='store_true', required=False, help="Option for training model on normalized data")
    parser.add_argument("--test", action='store_true', required=False, help="Option for testing model on normalized data")
    parser.add_argument("--path", required=False, type=str, help="Path to data for normalization")
    parser.add_argument("--epochs", required=False, type=int, help="Number of epochs for training")
    parser.add_argument("--batch-size", required=False, type=int, help="Batch size for testing data")
    parser.add_argument("--thread-number", required=False, type=int, help="Number of threads used for model evaluation")

    args = parser.parse_args()

    if args.normalize_data:
        if not args.path:
            print("Please enter path to data for normalization")
            exit()
        data_normalization(args.path)

    if args.train:
        if args.epochs:
            epochs = args.epochs
        else:
            epochs = 9
        try:
            tf = TriplesFactory.from_path(".//train_triples.txt")
        except Exception:
            print("Please run normalization on your data first or provide normalized data in file train_triples.txt")
            return
        train(tf, epochs)
    if args.test:
        try:
            tf = TriplesFactory.from_path(".//train_triples.txt")
        except Exception:
            print("Please run normalization on your data first or provide normalized data in file train_triples.txt")
            return
        if args.batch_size:
            batch_size = args.batch_size
        else:
            batch_size = 100
        if args.thread_number:
            thread_number = args.thread_number
        else:
            thread_number = 100
        try:
            test(triples_factory=tf, batch_size=batch_size, thread_number=thread_number)
        except:
            print("Please run data normalization and model training before testing model")


if __name__ == '__main__':
    try:
        main()
    except Exception as ex:
        print("Script failed with following exception:")
        traceback.print_exc()


