"""VDPhage -- a tool for CRISPR-resistant phage therapy design.
This file provides the code to compute VDPhage and local BLAST baseline.
This file is to be ran as a standalone script.
"""


import os
import sys
import json
import vicinity
import argparse
import itertools
import numpy as np
import Levenshtein
import pandas as pd
import os.path as op
from tqdm import tqdm
from Bio import SeqIO
from Bio import SearchIO
from sklearn.preprocessing import normalize


def get_max(x):
    """
    Calculate and return the maximum alignment span, query, and hit from a DataFrame.
    
    Parameters:
        x (DataFrame): The input DataFrame containing 'aln_span', 'query', and 'hit' columns.
        
    Returns:
        tuple: A tuple containing three elements - the maximum alignment span value,
               the corresponding query sequence, and the corresponding hit sequence.
    """
    mx = x.sort_values(by="aln_span", ascending=False).iloc[0]
    return(mx.aln_span, mx.query, mx.hit)

def load_blast(fn):
    """
    Load and parse BLAST XML results from a file and extract the hits for each query.
    
    Parameters:
        fn (str): The path to the BLAST XML result file.
        
    Returns:
        list of lists: A list containing sublists, where each sublist corresponds to a query
                       in the input BLAST XML file and contains a list of Hit objects
                       with their respective HSPs (high-scoring pairs).
    """
    with open(fn, "rb") as ih:
        result = [[b.hsps for b in a.hits] for a in SearchIO.parse(ih, "blast-xml")]
    return(result)

def make_vicinity(vectors, items):
    """
    Create a Vicinity index using the given vectors and items with cosine similarity as the metric.
    
    Parameters:
        vectors (numpy.ndarray): A 2D array of vectorized data where each row is a vector.
        items (list): A list of item identifiers corresponding to each vector.
        
    Returns:
        vicinity.Vicinity: A Vicinity index object containing the vectors and items with cosine similarity
                           as the metric for querying similarities between items.
    """
    return(
        vicinity.Vicinity.from_vectors_and_items(
            vectors=vectors,
            items=items,
            backend_type=vicinity.Backend.USEARCH,
            metric=vicinity.Metric.COSINE
        )
    )

def load_vectors(fn):
    """
    Load vectorized data from a file and perform L1 normalization on each feature.
    
    Parameters:
        fn (str): The path to the file containing the vectorized data in .npy format.
        
    Returns:
        numpy.ndarray: A normalized 2D array where each row represents a vector and each
                       column represents a feature, after applying L1 normalization.
    """
    vectors = np.load(fn)
    vectors = normalize(vectors, axis=1, norm="l1")
    return(vectors)

def load_items(fn):
    """
    Load a list of items from a file and remove any newline characters.
    
    Parameters:
        fn (str): The path to the file containing a list of items, with each item 
                  on a separate line.
        
    Returns:
        list: A list containing the loaded items after removing any newline characters.
    """
    items = []
    with open(fn, "r") as ih:
        for a in ih.readlines():
            items.append(a.replace("\n", ""))
    return(items)

def load_protospacers(protospacers):
    """
    Load protospacer information from a file and parse the data into a dictionary and list of phage names.
    
    Parameters:
        protospacers (str): The path to the file containing protospacer sequences in FASTA format.
        
    Returns:
        tuple: A tuple containing two elements - a dictionary mapping protospacer IDs to their sequences,
               and a list of phage names extracted from the input file.
    """
    phages = []
    P = {}
    with open(protospacers, "r") as ih:
        u = ih.read()
        for a in list(filter(lambda x: len(x) >= 1, u.split(">"))):
            cur = a.split("\n")
            P[cur[0].split(" ")[0]] = cur[1]
            phages.append(cur[0].split(" ")[0].split("|")[1])
    return(P, phages)

def load_spacers(spacers):
    """
    Load spacer sequences from a file and parse the data into a dictionary.
    
    Parameters:
        spacers (str): The path to the file containing spacer sequences in FASTA format.
        
    Returns:
        dict: A dictionary mapping spacer names to their respective sequences.
    """
    S = {}
    with open(spacers, "r") as ih:
        u = ih.read()
        for a in list(filter(lambda x: len(x) >= 1, u.split(">"))):
            cur = a.split("\n")
            S[cur[0]] = cur[1]
    return(S)


class VicinityWrapper():

    def __init__(self, protospacers, spacers, V):
        """
        Initialize the VicinityWrapper object with protospacers, spacers, and a database (V).
        - Load the spacers from the provided list `spacers`.
        - Load the protospacers from the provided list `protospacers`, which includes both protospacers and associated phages.
        - Store the loaded database in `self.db` attribute.
        - Store the loaded spacers in `self.spacers` attribute after loading them using a custom function `load_spacers()`.
        - Extract unique phages from the protospacers and store them in 
                `self.phages` attribute after converting them to a numpy array with `np.unique()`.
        """
        self.db = V
        self.spacers = load_spacers(spacers)
        P, phages = load_protospacers(protospacers)
        self.protospacers = P
        self.phages = np.unique(phages)

    def query(self, qv, k):
        """
        Query the database with a given query vector `qv` and return a summary dataframe.
        - Perform a query on the loaded database using `self.db.query()` method with 
            the provided query vector `qv` and keyword argument `k`.
        - Concatenate dataframes for each spacer in the spacers dictionary. 
            Each dataframe contains information about spacers, protospacers, distances, queries, and hits.
        - Add a new column 'phage' to the dataframe by splitting the 'protospacer' 
            column and extracting the phage ID.
        - Identify any phages that are not present in the queried results and add 
            them to the dataframe with default values.
        - Concatenate the modified dataframe with the additional rows for absent phages.
        - Create a summary dataframe with minimum distances, spacers, 
            protospacers, queries, hits, and phage IDs grouped by phage.
        - Sort the summary dataframe in descending order based on the 'distance' 
            column to prioritize shorter distances.
        - Calculate the Levenshtein distance between query strings and their 
            corresponding hit strings and add it as a new column 'levenshtein' to the summary dataframe.
        - Return the final sorted summary dataframe containing all queried results 
            with associated distances, phage information, and Levenshtein distances between queries and hits.
        """
        result = self.db.query(qv, k=k)
        df = pd.concat(
            [
                pd.DataFrame(
                    {
                        "spacer": [a]*len(result[i]), 
                        "protospacer": np.stack(result[i]).T[0], 
                        "distance": np.stack(result[i]).T[1],
                        "query": [self.spacers[a]]*len(result[i]),
                        "hit": [self.protospacers[b.split(" ")[0]] for b in np.stack(result[i]).T[0]]
                    }
                ) 
                    for i,a in enumerate(self.spacers.keys())
            ]
        )
        df["phage"] = df["protospacer"].apply(lambda x: x.split("|")[1].strip())
        absent_phages = np.setdiff1d(self.phages, df["phage"].unique())
        add = pd.DataFrame(
            {
                "spacer": ["NONE"]*absent_phages.shape[0],
                "protospacer": ["NONE"]*absent_phages.shape[0],
                "distance": [df["distance"].max()]*absent_phages.shape[0],
                "phage": absent_phages,
                "query": ["NONE"]*absent_phages.shape[0],
                "hit": ["NONE"]*absent_phages.shape[0],
            }
        )
        df = pd.concat([df, add])
        print(df.shape)
        summary_df = pd.DataFrame([a[1].min().values for a in df.groupby("phage")])
        summary_df.columns = ["spacer", "protospacer", "distance", "query", "hit", "phage"]
        summary_df = summary_df.sort_values("distance", ascending=False)
        summary_df["levenshtein"] = [
            Levenshtein.distance(summary_df.iloc[a].query, summary_df.iloc[a].hit) for a in np.arange(summary_df.shape[0])
        ]
        return(summary_df)


def print_and_run(cmd):
    """
    Print the given command and execute it using the operating system's shell.
    
    Parameters:
        cmd (str): The command to be executed.
        
    Returns:
        None
    """
    print(cmd)
    os.system(cmd)

def count_kmers(config, inputfile, k):
    """
    Count the occurrences of k-mers in a given sequence file using specified command and parameters.
    
    Parameters:
        config (dict): A configuration dictionary containing necessary information for counting k-mers.
        inputfile (str): The path to the input sequence file.
        k (int): The length of the k-mer to be counted.
        
    Returns:
        None
    
    The function constructs a command string based on the configuration and then executes that command.
    """
    cmd = config["COUNT KMERS"].replace(
        "INPUTFILE", inputfile
    ).replace("KKKK", str(k)).replace(
        "OUTFILE", op.join(config["TEMPDIR"], config["PROTOSPACER KMER COUNTS"])
    ).replace(
        "OUTIDS", op.join(config["TEMPDIR"], config["PROTOSPACER KMER COUNT IDS"])
    )
    print_and_run(cmd)

def build_baseline(args):
    """
    Build a baseline for protospacer-spacer comparison using the provided configuration 
    and input files.
    
    Parameters:
        args (argparse.Namespace): Command line arguments containing paths to necessary input files and output directories.
        
    Returns:
        None
    
    The function reads in the configuration file, creates necessary output 
    directories, copies the protospacers file to the output directory, builds a BLAST 
    database, and saves metadata about the baseline build.
    """
    with open(args.config, "r") as ih:
        config = json.loads(ih.read())
    if not op.exists(config["TEMPDIR"]):
        os.makedirs(config["TEMPDIR"])
    if not op.exists(args.output):
        os.makedirs(args.output)
    print_and_run("cp "+op.abspath(args.protospacers)+" "+args.output)
    onlyfile = op.split(args.protospacers)[-1]
    print_and_run(config["MAKEBLASTDB"].replace("INPUTFILE", op.join(args.output, onlyfile)))
    meta = {
        "protospacers": op.abspath(args.protospacers), "method": "BLAST"
    }
    with open(op.join(args.output, "meta.json"), "w") as oh:
        oh.write(json.dumps(meta))

def build_vdphage(args):
    """
    Build VDPhage model using provided arguments.
    
    Args:
        args (dict): Arguments containing necessary information for building the model.
        
    Returns:
        None. The function saves the model and writes metadata to a file.
    """
    with open(args.config, "r") as ih:
        config = json.loads(ih.read())
    if not op.exists(config["TEMPDIR"]):
        os.makedirs(config["TEMPDIR"])
    count_kmers(config, args.protospacers, args.k)
    vectors = load_vectors(op.join(config["TEMPDIR"], config["PROTOSPACER KMER COUNTS"]))
    items = load_items(op.join(config["TEMPDIR"], config["PROTOSPACER KMER COUNT IDS"]))
    V = make_vicinity(vectors, items)
    V.save(args.output)
    meta = {
        "protospacers": op.abspath(args.protospacers), "k": args.k, "method": "VDPhage"
    }
    with open(op.join(args.output, "meta.json"), "w") as oh:
        oh.write(json.dumps(meta))

def query_vdphage(args, meta):
    """
    Query the VDPhage model with a set of spacers and return top results.
    
    Args:
        args (dict): Arguments containing necessary information for querying the model.
        meta (dict): Metadata dictionary including protospacers and k value used in model build
        
    Returns:
        None. The function saves query results to a CSV file.
    """
    with open(args.config, "r") as ih:
        config = json.loads(ih.read())
    if not op.exists(config["TEMPDIR"]):
        os.makedirs(config["TEMPDIR"])
    count_kmers(config, args.spacers, meta["k"])
    query_vectors = load_vectors(op.join(config["TEMPDIR"], config["PROTOSPACER KMER COUNTS"]))
    query_items = load_items(op.join(config["TEMPDIR"], config["PROTOSPACER KMER COUNT IDS"]))
    V = vicinity.Vicinity.load(args.database)
    vw = VicinityWrapper(
        meta["protospacers"], args.spacers, V
    )
    results = vw.query(query_vectors, k=int(args.n))
    results.to_csv(args.output, sep="\t")

def parse_baseline(flat_result, spacers, protospacers):
    """
    Parse the results of a BLASTN search against a baseline database to generate a summary.
    
    Parameters:
        flat_result (list): A flattened list of result objects from the BLASTN search.
        spacers (dict): A dictionary mapping spacer IDs to their sequences.
        protospacers (dict): A dictionary mapping protospacer IDs to their sequences and phage names.
        
    Returns:
        summary_df (DataFrame): A DataFrame containing the summary of the BLASTN
        results, sorted by alignment span.
    """
    result_df = pd.DataFrame([[a.hit_id, a.query_id, a.aln_span] for a in flat_result])
    result_df.columns = ["protospacer", "spacer", "aln_span"]
    result_df["query"] = result_df["spacer"].apply(lambda x: spacers[x])
    result_df["hit"] = result_df["protospacer"].apply(lambda x: protospacers[x])
    result_df["phage"] = result_df["protospacer"].apply(lambda x: x.split("|")[1])
    print(result_df.shape)
    summary_df = pd.DataFrame([[a[0], *get_max(a[1])] for a in result_df.groupby("phage")])
    summary_df.columns = ["phage", "max aln_span", "query", "hit"]
    summary_df = summary_df.sort_values(by="max aln_span")
    summary_df["levenshtein"] = [
        Levenshtein.distance(summary_df.iloc[a].query, summary_df.iloc[a].hit) for a in np.arange(summary_df.shape[0])
    ]
    return(summary_df)
    
def query_baseline(args, meta):
    """
    Query a previously built baseline database with new spacer sequences and 
    generate a summary of the results.
    
    Parameters:
        args (argparse.Namespace): Command line arguments containing paths to 
                                   necessary input files and output directory.
        meta (dict): Metadata about the baseline build, including the path to the 
                     protospacers file and the method used.
        
    Returns:
        None
    
    The function reads the configuration from a JSON file, creates a temporary 
    directory if it doesn't exist, performs a BLASTN search with the new spacers 
    against the baseline database, parses the results, and saves a summary as a TSV file.
    """
    with open(args.config, "r") as ih:
        config = json.loads(ih.read())
    if not op.exists(config["TEMPDIR"]):
        os.makedirs(config["TEMPDIR"])
    outfile = op.join(config["TEMPDIR"], config["BLAST OUT"])
    dbpath = op.join(args.database, op.split(meta["protospacers"])[1])
    cmd = config["BLASTN"].replace(
        "INPUTFILE", args.spacers
    ).replace("OUTFILE", outfile).replace("DATABASE", dbpath)
    print_and_run(cmd)
    result = load_blast(outfile)
    flat_result = sum(sum(list(filter(lambda x: len(x) > 0, result)), []), [])
    spacers = load_spacers(args.spacers)
    protospacers, phages = load_protospacers(dbpath)
    summary_df = parse_baseline(flat_result, spacers, protospacers)
    summary_df.to_csv(args.output, sep="\t")

def build_command(args):
    """
    Build command based on the specified method and arguments.
    
    Parameters:
        args (argparse.Namespace): Command line arguments containing information 
                                   about the method to be used and input/output paths.
        
    Returns:
        None
    
    The function checks the specified method and calls either `build_baseline` 
    or `build_vdphage` accordingly.
    """
    if args.method == "BLAST":
        build_baseline(args)
    elif args.method == "VDPhage":
        build_vdphage(args)

def query_command(args):
    """
    Query a model based on the method specified in metadata.
    
    Args:
        args (dict): Arguments containing necessary information for querying the model
        
    Returns:
        None. The function calls other functions to handle query based on method type.
    """
    with open(op.join(args.database, "meta.json"), "r") as ih:
        meta = json.loads(ih.read())
    if meta["method"] == "VDPhage":
        query_vdphage(args, meta)
    elif meta["method"] == "BLAST":
        query_baseline(args, meta)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VDPhage -- a CLI tool for CRISPR-resistant phage cocktail generation")
    subparsers = parser.add_subparsers(dest='command', help='Available commands')
    # Build subparser
    build_parser = subparsers.add_parser('build', help='Build the database')
    build_parser.add_argument('--protospacers', '-p', help='protospacers.fasta')
    build_parser.add_argument('--k', '-k', help='size of the k-mer', default=4)
    build_parser.add_argument('--output', '-o', help='output database file (VDPhage) or directory (BLAST)')
    build_parser.add_argument(
        '--method', '-m', help='choose the method for database construction', choices=["BLAST", "VDPhage"], default="VDPhage"
    )
    build_parser.add_argument('--config', '-c', help='config.json')
    build_parser.set_defaults(func=build_command)
    # Query subparser  
    query_parser = subparsers.add_parser('query', help='Query the database')
    query_parser.add_argument('--database', '-db', help='database.vd')
    query_parser.add_argument('--spacers', '-s', help='spacers.fasta')
    query_parser.add_argument('--n', '-n', help='number of records to extract for each spacer', default=5000)
    query_parser.add_argument('--output', '-o', help='phage_ranking.tsv')
    query_parser.add_argument('--config', '-c', help='config.json')
    query_parser.set_defaults(func=query_command)
    # Parse arguments
    args = parser.parse_args()
    # Call the appropriate function based on the subcommand
    if args.command:
        args.func(args)
    else:
        parser.print_help()
