import argparse
import os
import subprocess
import sys

path = "TARGET-PRECISION"

# Use polyhedra domain for consistent counts.
cmp_domain = "pk-ap-pplite"

def fatal_error(msg):
    print(f"ERROR : {msg}\n")
    print("Exiting due to fatal error\n")
    sys.exit(1)

def process_driver(path1, path2, out_all_drivers, out_path, driver):
    res1_path = f"{path1}/{driver}"
    res2_path = f"{path2}/{driver}"

    out_driver_files = f"{out_path}/{driver}.txt"
    with open(out_driver_files, "w") as f:
        f.write("FUNCTION;EQ;LT;GT;UN;BT1;BT2;SZ1;SZ2\n")

    # driver-level counters
    EQ = 0
    LT = 0
    GT = 0
    UN = 0

    BT1 = 0
    BT2 = 0
    SZ1 = 0
    SZ2 = 0

    for myfile in sorted(os.listdir(res1_path)):
        res1_file = f"{res1_path}/{myfile}"
        res2_file = f"{res2_path}/{myfile}"
        if not os.path.isfile(res1_file):
            continue
        if not os.path.isfile(res2_file):
            fatal_error(f"file {res2_file} does not exists")

        cmd = ["clam-diff", f"--dom={cmp_domain}", "--semdiff",
               res1_file, res2_file]
        ## print(f"{cmd}\n")

        result = subprocess.run(cmd,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT,
                                text=True)
        all_lines = result.stdout.split("\n")
        ## print(all_lines)

        my_lines = [x for x in all_lines if "Number of equals" in x]
        assert(len(my_lines) == 1)
        f_EQ = int(my_lines[0].split()[4])

        my_lines = [x for x in all_lines if "1 being more precise than 2" in x]
        assert(len(my_lines) == 1)
        f_LT = int(my_lines[0].split()[8])

        my_lines = [x for x in all_lines if "1 being less precise than 2" in x]
        assert(len(my_lines) == 1)
        f_GT = int(my_lines[0].split()[8])

        my_lines = [x for x in all_lines if "Number of incomparable" in x]
        assert(len(my_lines) == 1)
        f_UN = int(my_lines[0].split()[4])

        #########################################

        my_lines = [x for x in all_lines if "Number of bottom in 1:" in x]
        assert(len(my_lines) == 1)
        f_BT1 = int(my_lines[0].split()[5])

        my_lines = [x for x in all_lines if "Number of bottom in 2:" in x]
        assert(len(my_lines) == 1)
        f_BT2 = int(my_lines[0].split()[5])

        my_lines = [x for x in all_lines if "Number of lincons in 1:" in x]
        assert(len(my_lines) == 1)
        f_SZ1 = int(my_lines[0].split()[5])

        my_lines = [x for x in all_lines if "Number of lincons in 2:" in x]
        assert(len(my_lines) == 1)
        f_SZ2 = int(my_lines[0].split()[5])

        # append line for function-level counters to out_driver
        out_line = f"{myfile};{f_EQ};{f_LT};{f_GT};{f_UN};{f_BT1};{f_BT2};{f_SZ1};{f_SZ2}\n"
        with open(out_driver_files, "a") as f:
            f.write(out_line)

        EQ = EQ + f_EQ
        LT = LT + f_LT
        GT = GT + f_GT
        UN = UN + f_UN
        BT1 = BT1 + f_BT1
        BT2 = BT2 + f_BT2
        SZ1 = SZ1 + f_SZ1
        SZ2 = SZ2 + f_SZ2

    # append line for driver-level counters to out_all
    out_line = f"{driver};{EQ};{LT};{GT};{UN};{BT1};{BT2};{SZ1};{SZ2}\n"
    print(out_line)
    with open(out_all_drivers, "a") as f:
        f.write(out_line)


def process(oracle1, domain1, oracle2, domain2):
    assert oracle1 in ["none", "def-nonrel", "def-rel", "pos-nonrel", "pos-rel"]
    assert oracle2 in ["none", "def-nonrel", "def-rel", "pos-nonrel", "pos-rel"]
    assert domain1 in ["int", "pk-ap-pplite"]
    assert domain2 in ["int", "pk-ap-pplite"]

    path1 = f"{path}/{oracle1}-{domain1}"
    if not os.path.exists(path1):
        fatal_error(f"1st target path {path1} does not exists")
    path2 = f"{path}/{oracle2}-{domain2}"
    if not os.path.exists(path2):
        fatal_error(f"2nd target path {path2} does not exists")

    out_path = f"{path}/CMP/{oracle1}-{domain1}-vs-{oracle2}-{domain2}"
    if os.path.exists(out_path):
        os.system(f"rm {out_path}/*")
    os.makedirs(out_path, exist_ok=True)

    out_all_drivers = f"{out_path}/all_drivers.txt"
    with open(out_all_drivers, "w") as f:
        f.write("DRIVER;EQ;LT;GT;UN;BT1;BT2;SZ1;SZ2\n")

    print("\nComparing target analysis results:")
    print(f"  1st target: {path1}")
    print(f"  2nd target: {path2}")
    print(f"  output file: {out_all_drivers}\n")

    drivers_path = "../downloaded_benchmarks/drivers/"

    id = 0
    for driver in sorted(os.listdir(drivers_path)):
        driver_path = os.path.join(drivers_path, driver)
        if not os.path.isfile(driver_path):
            continue
        id = id + 1
        print(f"{id}: comparing results for {driver}")
        process_driver(path1, path2, out_all_drivers, out_path, driver)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--oracle1", help="1st analysis havoc oracle", required=True)
    argparser.add_argument("--domain1", help="1st analysis target domain", required=True)
    argparser.add_argument("--oracle2", help="2nd analysis havoc oracle", required=True)
    argparser.add_argument("--domain2", help="2nd analysis target domain", required=True)
    args = argparser.parse_args()
    process(args.oracle1, args.domain1, args.oracle2, args.domain2)
