import argparse
import os
import subprocess

aa_path = "DRIVER-HAVOC-ANALYSIS-ACCURACY"
ta_path = "DRIVER-HAVOC-TRANSFORM-ACCURACY"


''' example lines for time spent
[2025-03-21.08:55:35] Havoc analysis time: 0.00102901s
[2025-03-21.08:55:35] Target analysis time: 0.00209201s
[2025-03-21.08:55:35] Havoc analysis accuracy check time: 0.00826301s
[2025-03-21.08:55:35] Havoc transform accuracy check time: 0.00436901s
'''
def process_time(all_lines, driver, myfile):
    my_lines = [x for x in all_lines if "Havoc analysis time" in x]
    print(my_lines)
    havoc = sum( float(x.split()[4].rstrip("s")) for x in my_lines )

    my_lines = [x for x in all_lines
                if "Target analysis time" in x]
    print(my_lines)
    target = sum( float(x.split()[4].rstrip("s")) for x in my_lines )

    my_lines = [x for x in all_lines
                if "Havoc analysis accuracy check time" in x]
    print(my_lines)
    an_check = sum( float(x.split()[6].rstrip("s")) for x in my_lines )

    my_lines = [x for x in all_lines
                if "Havoc transform accuracy check time" in x]
    print(my_lines)
    tr_check = sum( float(x.split()[6].rstrip("s")) for x in my_lines )

    out_line = f"{driver} havoc: {havoc} target: {target} analysis_accuracy: {an_check} transform_accuracy: {tr_check}\n"
    print("PROCESSING_TIME: " + out_line)
    with open(myfile, "a") as f:
        f.write(out_line)
    return

''' example line for havoc analysis accuracy
[2025-03-18.11:48:45] HAVOC_ANALYSIS_ACCURACY: fun: main blocks: 208 numeric_vars: 58 TP: 11512 TN: 526 FP: 4 FN: 22
'''
def process_analysis(all_lines, driver, myfile):
    my_lines = [x for x in all_lines if "HAVOC_ANALYSIS_ACCURACY:" in x]
    print(my_lines)
    checks = sum(
        int(x.split()[5]) * int(x.split()[7]) for x in my_lines
    )
    tp = sum( int(x.split()[9]) for x in my_lines )
    tn = sum( int(x.split()[11]) for x in my_lines )
    fp = sum( int(x.split()[13]) for x in my_lines )
    fn = sum( int(x.split()[15]) for x in my_lines )
    assert(checks == tp + tn + fp + fn)
    out_line = f"{driver} {checks} {tp} {tn} {fp} {fn}\n"
    print("HAVOC ANALYSIS: " + out_line)
    with open(myfile, "a") as f:
        f.write(out_line)
    return


''' example (additional) line for unreachable end of blocks
UNREACHABLE_BLOCK_ENDS: fun: cx25840_work_handler num_unreach: 1 TP: 0 TN: 9 FP: 0 FN: 0
'''
def process_unreach(all_lines, driver, myfile):
    my_lines = [x for x in all_lines if "UNREACHABLE_BLOCK_ENDS:" in x]
    print(my_lines)
    tp = sum( int(x.split()[6]) for x in my_lines )
    tn = sum( int(x.split()[8]) for x in my_lines )
    fp = sum( int(x.split()[10]) for x in my_lines )
    fn = sum( int(x.split()[12]) for x in my_lines )
    out_line = f"{driver} {tp} {tn} {fp} {fn}\n"
    print("UNREACHABLE: " + out_line)
    with open(myfile, "a") as f:
        f.write(out_line)
    return


''' example line for havoc transform accuracy
[2025-03-18.11:48:45] HAVOC_TRANSFORM_ACCURACY: fun: main stmts: 177 havoc_checks: 15 TP: 8 TN: 7 FP: 0 FN: 0
'''
def process_transform(all_lines, driver, myfile):
    my_lines = [x for x in all_lines if "HAVOC_TRANSFORM_ACCURACY:" in x]
    print(my_lines)
    checks = sum( int(x.split()[7]) for x in my_lines )
    tp = sum( int(x.split()[9]) for x in my_lines )
    tn = sum( int(x.split()[11]) for x in my_lines )
    fp = sum( int(x.split()[13]) for x in my_lines )
    fn = sum( int(x.split()[15]) for x in my_lines )
    assert(checks == tp + tn + fp + fn)
    out_line = f"{driver} {checks} {tp} {tn} {fp} {fn}\n"
    print("HAVOC_TRANSFORM: " + out_line)
    with open(myfile, "a") as f:
        f.write(out_line)
    return


def process(oracle):
    """
    Processes driver files for different oracles/domains and checks accuracy of havoc tranformation.

    This function performs the following steps:
    1. Creates directories named
       aa) "DRIVER-HAVOC-ANALYSIS-ACCURACY"
       ta) "DRIVER-HAVOC-TRANSFORM-ACCURACY"
       if they do not exist, or clear then if they exists.
    2. For the oracle, creates a corresponding text file to store results.
    3. Iterates over driver files in the specified drivers directory.
    4. For each driver file, runs a command using subprocess to perform
         aa) havoc analysis accuracy check;
         ta) havoc transform accuracy check.
    5. Parses the output of the command to extract the number of
         aa) TP, TN, FP, FN of havoc analysis;
         ta) TP, TN, FP, FN of havoc transformation.
    6. Writes the collected info to the corresponding oracle text file.

    Note:
        The function assumes the existence of a directory "../downloaded_benchmarks/drivers/" containing driver files.
        The command "clam.py" should be available in the system's PATH.

    """
    assert oracle in ["def-nonrel", "def-rel", "pos-nonrel", "pos-rel"]
    if os.path.exists(aa_path):
        os.system(f"rm {aa_path}/{oracle}*.txt")
        os.system(f"rm {aa_path}/{oracle}_time.txt")
        os.system(f"rm {aa_path}/{oracle}_unreach.txt")
    os.makedirs(aa_path, exist_ok=True)

    if os.path.exists(ta_path):
        os.system(f"rm {ta_path}/{oracle}*.txt")
    os.makedirs(ta_path, exist_ok=True)

    drivers_path = "../downloaded_benchmarks/drivers/"

    if "nonrel" in oracle:
        domain = "int"
    else:
        domain = "pk-ap-pplite"

    tm_file = os.path.join(aa_path, oracle + "_time.txt")
    aa_file = os.path.join(aa_path, oracle + ".txt")
    ur_file = os.path.join(aa_path, oracle + "_unreach.txt")
    ta_file = os.path.join(ta_path, oracle + ".txt")

    print(f"\n\n### Processing {oracle} oracle ###\n\n")

    for ff in [tm_file, aa_file, ur_file, ta_file]:
        with open(ff, "w") as f:
            f.write("\n")

    for driver in sorted(os.listdir(drivers_path)):
        driver_path = os.path.join(drivers_path, driver)
        if not os.path.isfile(driver_path):
            continue
        print(f"### Processing {driver_path} ###\n")

        cmd = ["clam.py", "--inline", f"--crab-havoc-analysis={oracle}", "--crab-havoc-analysis-accuracy-check", f"--crab-dom={domain}", "--crab-print-invariants=false", driver_path]

        result = subprocess.run(cmd, capture_output=True, text=True)
        output_lines = result.stdout.split("\n")

        process_time(output_lines, driver_path, tm_file)
        process_analysis(output_lines, driver_path, aa_file)
        process_unreach(output_lines, driver_path, ur_file)
        process_transform(output_lines, driver_path, ta_file)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--oracle", help="oracle to use for analysis", required=True)
    args = argparser.parse_args()
    process(args.oracle)
