import random

import kglab
import json


def generate_triples(path):
    kg = kglab.KnowledgeGraph()
    kg.load_rdf(path, format="xml")

    data_properties = ['http://www.semanticweb.org/peter/ontologies/2020/malware-ontology#section_entropy',
                       "http://www.semanticweb.org/peter/ontologies/2020/malware-ontology#section_name"]

    sql = """SELECT ?subject ?predicate ?object
             WHERE {?subject ?predicate ?object} """

    data_objects = set()
    with open("tmp.txt", "w") as file:
        for item in kg.query(sql):
            sub, pre, obj = item
            sub, pre, obj = sub.toPython(), pre.toPython(), obj.toPython()
            if pre in data_properties:
                continue

            try:
                pom = sub.split("#")[1]
                _ = int(pom, 16)
                data_objects.add(sub)
            except:
                pass
            print("\t".join([sub, pre, obj]), file=file)

    print(len(data_objects))

    with open("data_objects.txt", "w") as file:
        json.dump(list(data_objects), file)
    return data_objects


def filter_triples(data_objects):
    data_objects = list(data_objects)
    indices = list(range(0, len(data_objects)))
    random.shuffle(indices)
    testing_indices = {data_objects[i]: False for i in indices[round(0.8 * len(indices)):]}

    with open("tmp.txt", "r") as file:
        with open("train_triples.txt", "w") as train:
            for line in file:
                line = line.split("\t")
                sub, pre, obj = line
                if "#" not in sub:
                    print("\t".join([sub, pre, obj]), file=train)
                    continue
                if sub in testing_indices:
                    if "#malware" in obj:
                        testing_indices[sub] = True
                        continue

                print("\t".join([sub, pre, obj]), file=train)

    with open("test.txt", "w") as file:
        json.dump(testing_indices, file)


def count_positive_and_negative(data_objects):
    data_objects = list(data_objects)
    testing_indices = {i: False for i in data_objects}
    with open("tmp.txt", "r") as file:
        for line in file:
            line = line.split("\t")
            sub, pre, obj = line
            if "#" not in sub:
                continue
            if sub in testing_indices:
                if "#malware" in obj:
                    testing_indices[sub] = True
    print("Positive:")
    print(len([x for x in testing_indices if testing_indices[x]]))
    print("Negative:")
    print(len([x for x in testing_indices if not testing_indices[x]]))


def data_normalization(path):
    do = generate_triples(path)
    filter_triples(do)
    count_positive_and_negative(do)
