#include <errno.h>
#include <getopt.h>
#include <signal.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <sys/wait.h>
#include <sys/utsname.h>
#include <time.h>
#include <unistd.h>
#include <bpf/libbpf.h>
#include "process_exit.h"
#include "process_exit.skel.h"

#define PCOMM_FLAG    0x00001
#define TGID_FLAG     0x00002
#define PID_FLAG      0x00004
#define PPID_FLAG     0x00008
#define UID_FLAG      0x00010
#define AGE_FLAG      0x00020
#define UTIME_FLAG    0x00040
#define STIME_FLAG    0x00080
#define EXIT_FLAG     0x00100
#define EXITSIG_FLAG  0x00200
#define NVCS_FLAG     0x00400
#define NIVCS_FLAG    0x00800
#define CUTIME_FLAG   0x01000
#define CSTIME_FLAG   0x02000
#define INBLOCK_FLAG  0x04000
#define OUBLOCK_FLAG  0x08000
#define CINBLOCK_FLAG 0x10000
#define COUBLOCK_FLAG 0x20000
#define TIME_FLAG     0x40000

static int libbpf_print_fn(enum libbpf_print_level level,
                           const char* format, va_list args) {
    if (level >= LIBBPF_DEBUG)
        return 0;

    return vfprintf(stderr, format, args);
}

static int csv = 0;
static FILE* output_file = 0;
static uint32_t parameter_flags = 0x7ffff; // defaultne vsetky parametre
static int enabled_flag_count = 0;
static int current_flags_printed = 1;
static uint64_t processes_handled = 0;
static int cumulative = 0;
static volatile __sig_atomic_t running = 1;
static uint32_t filtered_uid = 0;

static void parse_filter_flags(char* flag_string) {
    const char* token = strtok(flag_string, ",");
    while (token) {
        if      (strcmp(token, "pcomm")    == 0) parameter_flags |= PCOMM_FLAG;
        else if (strcmp(token, "tgid")     == 0) parameter_flags |= TGID_FLAG;
        else if (strcmp(token, "pid")      == 0) parameter_flags |= PID_FLAG;
        else if (strcmp(token, "ppid")     == 0) parameter_flags |= PPID_FLAG;
        else if (strcmp(token, "uid")      == 0) parameter_flags |= UID_FLAG;
        else if (strcmp(token, "age")      == 0) parameter_flags |= AGE_FLAG;
        else if (strcmp(token, "utime")    == 0) parameter_flags |= UTIME_FLAG;
        else if (strcmp(token, "stime")    == 0) parameter_flags |= STIME_FLAG;
        else if (strcmp(token, "exit")     == 0) parameter_flags |= EXIT_FLAG;
        else if (strcmp(token, "exitsig")  == 0) parameter_flags |= EXITSIG_FLAG;
        else if (strcmp(token, "nvcs")     == 0) parameter_flags |= NVCS_FLAG;
        else if (strcmp(token, "nivcs")    == 0) parameter_flags |= NIVCS_FLAG;
        else if (strcmp(token, "cutime")   == 0) parameter_flags |= CUTIME_FLAG;
        else if (strcmp(token, "cstime")   == 0) parameter_flags |= CSTIME_FLAG;
        else if (strcmp(token, "inblock")  == 0) parameter_flags |= INBLOCK_FLAG;
        else if (strcmp(token, "oublock")  == 0) parameter_flags |= OUBLOCK_FLAG;
        else if (strcmp(token, "cinblock") == 0) parameter_flags |= CINBLOCK_FLAG;
        else if (strcmp(token, "coublock") == 0) parameter_flags |= COUBLOCK_FLAG;
        else                                     fprintf(stderr, "Unknown parameter '%s'\n", token);

        token = strtok(NULL, ",");
    }
}

static void print_help(char* argv[]) {
    fprintf(stdout,
            "Usage: %s [OPTION]...\n"
            "Print information about terminated processes.\n"
            "\n"
            "Options:\n"
            "  -c, --csv                          use csv format\n"
            "  -f FILE, --file FILE               print to file FILE\n"
            "  -F params..., --filter params...   filter process information based on\n"
            "                                     params\n"
            "  -u UID, --uid UID                  show only processes owened by user\n"
            "                                     with uid UID\n"
            "  -C, --cumulative                   show cumulative data for utime,\n"
            "                                     stime, inblock, oublock, filters\n"
            "                                     out cutime, cstime, cinblock,\n"
            "                                     coublock\n"
            "  -h, --help                         print this help message and exit\n"
            "\n"
            "Filter parameters:\n"
            "  pcomm, tgid, pid, ppid, uid, age, utime, stime, exit, exitsig, nvcs,\n"
            "  nivcs, cutime, cstime, inblock, oublock, cinblock, coublock\n",
            argv[0]);
}

static void parse_options(int argc, char* argv[], const char** file) {
    static struct option long_options[] = {
        { "csv"        ,       no_argument, 0, 'c' },
        { "file"       , required_argument, 0, 'f' },
        { "filter"     , required_argument, 0, 'F' },
        { "cumulative" ,       no_argument, 0, 'C' },
        { "help"       ,       no_argument, 0, 'h' },
        { "uid"        , required_argument, 0, 'u' },
        { 0            ,                 0, 0,   0 },
    };

    int ret = EXIT_FAILURE;
    int opt;
    while ((opt = getopt_long(argc, argv, "cf:hF:Cu:", long_options, NULL)) != -1) {
        switch (opt) {
        case 'c':
            csv = 1;
            break;
        case 'f':
            *file = optarg;
            break;
        case 'F':
            parameter_flags = 0;
            parameter_flags |= TIME_FLAG; // zatial defaultne nechame TIME_FLAG
            parse_filter_flags(optarg);
            break;
        case 'C':
            cumulative = 1;
            break;
        case 'u':
            errno = 0;
            filtered_uid = strtoul(optarg, NULL, 10);
            if (errno != 0 || optarg[0] == '-') {
                fprintf(stderr, "Invalid UID after -u.\n");
                exit(EXIT_FAILURE);
            }
            break;
        case 'h':
            ret = EXIT_SUCCESS;
        default:
            print_help(argv);
            exit(ret);
        }
    }
}

static void print_int_parameter2(__u64 parameter, const char* format) {
    if (csv) {
        format = "%d";
        if (current_flags_printed < enabled_flag_count) {
            format = "%d,";
            current_flags_printed++;
        } else {
            current_flags_printed = 1;
        }
    }
    fprintf(output_file, format, parameter);
}

static void print_int_parameter(__u64 parameter) {
    print_int_parameter2(parameter, "%-7d\t");
}

static void print_double_parameter(double parameter) {
    const char* format = "%-7.3f\t";
    if (csv) {
        format = "%.3f";
        if (current_flags_printed < enabled_flag_count) {
            format = "%.3f,";
            current_flags_printed++;
        } else {
            current_flags_printed = 1;
        }
    }
    fprintf(output_file, format, parameter);
}

static void print_pcomm_parameter(const char* parameter) {
    const char* format = "%-16s";
    if (csv) {
        format = "%s";
        if (current_flags_printed < enabled_flag_count) {
            format = "%s,";
            current_flags_printed++;
        } else {
            current_flags_printed = 1;
        }
    }
    fprintf(output_file, format, parameter);
}

static void print_string_parameter(const char* parameter) {
    const char* format = "%s\t";
    if (csv) {
        format = "%s";
        if (current_flags_printed < enabled_flag_count) {
            format = "%s,";
            current_flags_printed++;
        } else {
            current_flags_printed = 1;
        }
    }
    fprintf(output_file, format, parameter);
}

static void print_time_parameter(time_t sec, time_t ns) {
    const char* format = "%d.%06d\t";
    if (csv) {
        format = "%d.%06d";
        if (current_flags_printed < enabled_flag_count) {
            format = "%d.%06d,";
            current_flags_printed++;
        } else {
            current_flags_printed = 1;
        }
    }
    fprintf(output_file, format, sec, ns / 1000);
}

static void handle_event(void* ctx, int cpu, void* data, unsigned int data_sz) {
    struct data_t* event = (struct data_t*)data;
    (void)ctx;
    (void)cpu;
    (void)data_sz;

    if (filtered_uid != 0 && filtered_uid != event->uid) {
        return;
    }

    if (cumulative) {
        event->utime += event->cutime;
        event->stime += event->cstime;
        event->inblock += event->cinblock;
        event->oublock += event->coublock;
    }

    processes_handled++;

    struct timespec tm;
    if (clock_gettime(CLOCK_REALTIME, &tm) == -1) {
        perror("clock_gettime");
        exit(EXIT_FAILURE);
    }

    if (parameter_flags & TIME_FLAG    ) print_time_parameter(tm.tv_sec, tm.tv_nsec);
    if (parameter_flags & PCOMM_FLAG   ) print_pcomm_parameter(event->task);
    if (parameter_flags & TGID_FLAG    ) print_int_parameter(event->tgid);
    if (parameter_flags & PID_FLAG     ) print_int_parameter(event->pid);
    if (parameter_flags & PPID_FLAG    ) print_int_parameter(event->ppid);
    if (parameter_flags & UID_FLAG     ) print_int_parameter(event->uid);
    if (parameter_flags & AGE_FLAG     ) print_double_parameter((event->exit_time - event->start_time) / 1e9);
    if (parameter_flags & UTIME_FLAG   ) print_double_parameter(event->utime / 1e9);
    if (parameter_flags & STIME_FLAG   ) print_double_parameter(event->stime / 1e9);
    if (parameter_flags & EXIT_FLAG    ) print_int_parameter(WIFEXITED(event->exit_code) ? WEXITSTATUS(event->exit_code) : -1);
    if (parameter_flags & EXITSIG_FLAG ) print_int_parameter(WIFSIGNALED(event->exit_code) ? WTERMSIG(event->exit_code) : -1);
    if (parameter_flags & NVCS_FLAG    ) print_int_parameter(event->nvcsw);
    if (parameter_flags & NIVCS_FLAG   ) print_int_parameter(event->nivcsw);
    if (parameter_flags & CUTIME_FLAG  ) print_double_parameter(event->cutime / 1e9);
    if (parameter_flags & CSTIME_FLAG  ) print_double_parameter(event->cstime / 1e9);
    if (parameter_flags & INBLOCK_FLAG ) print_int_parameter(event->inblock);
    if (parameter_flags & OUBLOCK_FLAG ) print_int_parameter(event->oublock);
    if (parameter_flags & CINBLOCK_FLAG) print_int_parameter(event->cinblock);
    if (parameter_flags & COUBLOCK_FLAG) print_int_parameter(event->coublock);
    fprintf(output_file, "\n");

    fflush(output_file);
}

static void lost_event(void* ctx, int cpu, long long unsigned cnt) {
    (void)ctx;
    (void)cpu;
    (void)cnt;

    printf("lost event\n");
}

static void print_header(void) {
    time_t start_time = time(NULL);
    fprintf(output_file, "start timestamp: %s", ctime(&start_time));

    struct utsname sysinfo;
    if (uname(&sysinfo) == -1) {
        perror("uname");
        exit(EXIT_FAILURE);
    }

    fprintf(output_file,
            "system info: %s %s %s %s %s\n",
            sysinfo.sysname,
            sysinfo.nodename,
            sysinfo.release,
            sysinfo.version,
            sysinfo.machine);

    if (parameter_flags & TIME_FLAG    ) print_string_parameter("TIME");
    if (parameter_flags & PCOMM_FLAG   ) print_pcomm_parameter("PCOMM");
    if (parameter_flags & TGID_FLAG    ) print_string_parameter("TGID");
    if (parameter_flags & PID_FLAG     ) print_string_parameter("PID");
    if (parameter_flags & PPID_FLAG    ) print_string_parameter("PPID");
    if (parameter_flags & UID_FLAG     ) print_string_parameter("UID");
    if (parameter_flags & AGE_FLAG     ) print_string_parameter("age");
    if (parameter_flags & UTIME_FLAG   ) print_string_parameter("utime");
    if (parameter_flags & STIME_FLAG   ) print_string_parameter("stime");
    if (parameter_flags & EXIT_FLAG    ) print_string_parameter("exit");
    if (parameter_flags & EXITSIG_FLAG ) print_string_parameter("exitsig");
    if (parameter_flags & NVCS_FLAG    ) print_string_parameter("nvcsw");
    if (parameter_flags & NIVCS_FLAG   ) print_string_parameter("nivcsw");
    if (parameter_flags & CUTIME_FLAG  ) print_string_parameter("cutime");
    if (parameter_flags & CSTIME_FLAG  ) print_string_parameter("cstime");
    if (parameter_flags & INBLOCK_FLAG ) print_string_parameter("inblk");
    if (parameter_flags & OUBLOCK_FLAG ) print_string_parameter("oublk");
    if (parameter_flags & CINBLOCK_FLAG) print_string_parameter("cinblk");
    if (parameter_flags & COUBLOCK_FLAG) print_string_parameter("coublk");
    fprintf(output_file, "\n");

    fflush(output_file);
}

static void catch_interrupt(int signal) {
    (void)signal;
    running = 0;
}

int main(int argc, char* argv[]) {

    output_file = stdout;
    const char* file_name = NULL;
    parse_options(argc, argv, &file_name);
    if ((enabled_flag_count = __builtin_popcount(parameter_flags)) == 0) {
        fprintf(stderr, "No valid parameters in filter\n");
        exit(EXIT_FAILURE);
    }

    if (cumulative) {
        // pri kumulativnom vypise nahradime hodnoty suctom
        // rodicovskych a dcerskych hodnot, preto nebudeme vypisovat
        // cutime, cstime, atd.
        parameter_flags &= ~CUTIME_FLAG;
        parameter_flags &= ~CSTIME_FLAG;
        parameter_flags &= ~CINBLOCK_FLAG;
        parameter_flags &= ~COUBLOCK_FLAG;
        enabled_flag_count -= 4;
    }

    if (getuid() != 0) {
        fprintf(stderr, "You should run this program with root privileges (sudo).\n");
        exit(EXIT_FAILURE);
    }

    libbpf_set_print(libbpf_print_fn);

    struct process_exit_bpf* skel = process_exit_bpf__open_and_load();
    if (!skel) {
        fprintf(stderr, "Failed to open BPF object file\n");
        return 1;
    }

    int err = process_exit_bpf__attach(skel);
    if (err) {
        fprintf(stderr, "Failed to attach BPF skeleton: %d\n", err);
        process_exit_bpf__destroy(skel);
        return 1;
    }

#ifndef OLD_LIBBPF
    struct perf_buffer* pb = perf_buffer__new(bpf_map__fd(skel->maps.output),
                                              8,
                                              handle_event,
                                              lost_event,
                                              NULL,
                                              NULL);
#else
    const struct perf_buffer_opts opts = {
        handle_event,
        lost_event,
        NULL
    };
    struct perf_buffer* pb = perf_buffer__new(bpf_map__fd(skel->maps.output),
                                              8,
                                              &opts);
#endif
    if (!pb) {
        err = -1;
        fprintf(stderr, "Failed to create ring buffer\n");
        process_exit_bpf__destroy(skel);
        return -err;
    }

    if (signal(SIGINT, catch_interrupt) == SIG_ERR || signal(SIGTERM, catch_interrupt) == SIG_ERR) {
        fprintf(stderr, "Can't set signal handler: %s\n", strerror(errno));
        err = 1;
        running = 0;
    }

    int writing_to_file = 0;
    if (file_name != NULL && strcmp(file_name, "-") != 0) {
        output_file = fopen(file_name, "w+");
        writing_to_file = 1;
    }

    print_header();
    while (running) {
        err = perf_buffer__poll(pb, 1000);
        if (err == -EINTR) {
            err = 0;
            break;
        }
        if (err < 0) {
            printf("Error polling perf buffer: %d\n", err);
            break;
        }
    }

    printf("\nHandled %lu processes.\n", processes_handled);
    time_t end_time = time(NULL);
    fprintf(output_file, "end timestamp: %s", ctime(&end_time));
    fflush(output_file);

    if (writing_to_file) {
        fclose(output_file);
    }

    perf_buffer__free(pb);
    process_exit_bpf__destroy(skel);

    return -err;
}
