import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
import statsmodels.api as sm
import plotly.express as px
from itertools import combinations
from scipy.spatial import distance
from scipy.stats import wilcoxon
from scipy import stats
from plotly.subplots import make_subplots

def detection_rate():
    taxon_table_xlsx = '/Users/tillmacher/Desktop/TTT_projects/Projects/Lippe_eRNA_tele02_vertebrates/TaXon_tables/Lippe_eRNA_tele02_taxon_table_cons_NCsub_vertebrates_normalized.xlsx'
    metadata_table = '/Users/tillmacher/Desktop/TTT_projects/Projects/Lippe_eRNA_tele02_vertebrates/Meta_data_table/Lippe_eRNA_tele02_taxon_table_cons_NCsub_vertebrates_normalized_metadata.xlsx'

    metadata_df = pd.read_excel(metadata_table)
    metadata = 'Molecule'
    taxonomic_level = 'Species'
    taxonomic_level_2 = 'Class'
    available_metadata = list(set(metadata_df[metadata].values.tolist()))

    taxon_table_df = pd.read_excel(taxon_table_xlsx).fillna('')
    metadata_df = pd.read_excel(metadata_table).fillna('')
    taxon_table_df = taxon_table_df.sort_values([taxonomic_level_2, taxonomic_level], ascending=[False, False])

    all_taxa = []
    for taxon in taxon_table_df[taxonomic_level].values.tolist():
        if taxon != '' and taxon not in all_taxa:
            all_taxa.append(taxon)

    ## sort samples according to metadata
    samples_dict = {}
    for sample in metadata_df[['Samples', metadata]].values.tolist():
        if sample[1] in samples_dict.keys():
            samples_dict[sample[1]] = samples_dict[sample[1]] + [sample[0]]
        else:
            samples_dict[sample[1]] = [sample[0]]

    if len(available_metadata) == 2:
        ## collect the number of taxa per samples (but store the unique species and count later)
        metadata = available_metadata[0]
        taxa = []
        for sample in samples_dict[metadata]:
            taxa.extend(list(set([i[1] for i in taxon_table_df[[sample, taxonomic_level]].values.tolist() if i[0] != 0 and i[1] != ''])))
        x_values = []
        for taxon in all_taxa:
            n = taxa.count(taxon) / len(samples_dict[metadata])
            if n == 0:
                x_values.append(0)
            else:
                x_values.append(round(n * 100, 2))

        metadata = available_metadata[1]
        taxa = []
        for sample in samples_dict[metadata]:
            taxa.extend(list(set([i[1] for i in taxon_table_df[[sample, taxonomic_level]].values.tolist() if i[0] != 0 and i[1] != ''])))
        y_values = []
        for taxon in all_taxa:
            n = taxa.count(taxon) / len(samples_dict[metadata])
            if n == 0:
                y_values.append(0)
            else:
                y_values.append(round(n * 100, 2))

        ## add custom marker smybols according to taxonomic level 2
        higher_taxon_dict = {}
        t2_loc = taxon_table_df.columns.tolist().index(taxonomic_level_2)
        for taxon in all_taxa:
            t2 = taxon_table_df.loc[taxon_table_df[taxonomic_level] == taxon].values.tolist()[0][t2_loc]
            higher_taxon_dict[taxon] = t2

        ## add markers to dict
        markers_dict = {}
        for i, taxon in enumerate(set(higher_taxon_dict.values())):
            markers_dict[taxon] = i

        ## create a list of markers per taxon
        marker_list = []
        for taxon in all_taxa:
            marker_list.append(markers_dict[higher_taxon_dict[taxon]])

        ## write df
        df_out = pd.DataFrame()
        df_out[taxonomic_level_2] = list(higher_taxon_dict.values())
        df_out[taxonomic_level] = all_taxa
        df_out[available_metadata[0]] = x_values
        df_out[available_metadata[1]] = y_values
        df_out.to_excel('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/test_vertebrates.xlsx', index=False)

        ## convert to italic if needed
        if taxonomic_level == 'Species':
            all_taxa = ['<i>{}<i> ({})'.format(i,n+1) for n,i in enumerate(all_taxa)]
        else:
            all_taxa = ['{} ({})'.format(i, n+1) for n, i in enumerate(all_taxa)]

        fig = make_subplots(rows=1, cols=3, column_widths=[0.05, 0.05, 0.6], vertical_spacing = 0.05, horizontal_spacing = 0.05)

        fig.add_trace(go.Bar(x=x_values, y=all_taxa, orientation='h', marker_color='blue'), row=1, col=1)
        fig.update_yaxes(tickmode = 'linear', showticklabels=True, showgrid=False, row=1, col=1)
        fig.add_trace(go.Bar(x=y_values, y=all_taxa, orientation='h', marker_color='red'), row=1, col=2)
        fig.update_yaxes(tickmode = 'linear', showticklabels=False, showgrid=False, row=1, col=2)

        fig.add_trace(go.Scatter(x=x_values, y=y_values, text=[i for i in range (1,len(all_taxa)+1)], textposition='top center', marker_color='navy', marker_symbol=marker_list, mode='markers+text', marker_size=20), row=1, col=3)
        fig.add_trace(go.Scatter(x=[-5,105], y=[-5,105], text=[-5,105], mode='lines'), row=1, col=3)
        fig.update_xaxes(range=[-5,105], dtick=10, title='detection probability {}'.format(available_metadata[0]), row=1, col=3)
        fig.update_yaxes(range=[-5,105], dtick=10, title='detection probability {}'.format(available_metadata[1]), row=1, col=3)

        width = 1800
        height = 1200
        font_size = 20
        template = 'simple_white'
        fig.update_layout(width=int(width), height=int(height), template=template, font_size=font_size, yaxis_nticks=5, title_font_size=font_size, showlegend=False)
        fig.write_image('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/vertebrates.pdf')
        fig.write_html('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/vertebrates.html')

        ## statistical tests
        x_values = df_out['eRNA'].values.tolist()
        y_values = df_out['eDNA'].values.tolist()
        statistic, p_value = wilcoxon(x_values, y_values)
        statistic, p_value = stats.spearmanr(x_values, y_values)
        round(p_value,4)

def normal_dis_and_glm():
    ####################################################################################################################
    df = pd.read_excel('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/5_stuff/Lippe_efishing.xlsx')
    y_values = df['Δ eRNA-eDNA detection probability (%)'].values
    x_values = df['Δ downstream-upstream specimen numbers (%)'].values

    ####################################################################################################################
    # Perform Shapiro-Wilk test
    shapiro_stat, shapiro_p_value = stats.shapiro(y_values)
    print(f"Shapiro-Wilk Test - Statistic: {shapiro_stat}, p-value: {shapiro_p_value}")

    # Perform Kolmogorov-Smirnov test
    ks_stat, ks_p_value = stats.kstest(y_values, 'norm')
    print(f"Kolmogorov-Smirnov Test - Statistic: {ks_stat}, p-value: {ks_p_value}")

    # Plot histogram and Q-Q plot
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.hist(y_values, bins=10, density=True, alpha=0.7)
    plt.title("Histogram")

    plt.subplot(1, 2, 2)
    stats.probplot(y_values, dist="norm", plot=plt)
    plt.title("Q-Q Plot")

    plt.tight_layout()
    plt.show()

    ####################################################################################################################
    # Fit a nonlinear regression model with Gaussian family and identity link
    glm_model = sm.GLM(y_values, sm.add_constant(x_values), family=sm.families.Gaussian(link=sm.families.links.identity()))

    # Fit the model
    result = glm_model.fit()

    # Print model summary
    print(result.summary())

    # Fitted coefficients from the GLM regression
    const_coeff = -3.9556
    x1_coeff = 0.1425

    # Generate predicted y values based on the regression coefficients
    predicted_y = const_coeff + x1_coeff * np.array(x_values)

    ####################################################################################################################
    ## ADD data points
    # Create a scatter plot for data points

    fig = go.Figure()

    fig.add_trace(go.Scatter(x=x_values.flatten(), y=predicted_y.flatten(), mode='lines', marker_color='Red', name='GLM'))
    fig.add_trace(go.Scatter(x=x_values.flatten(), y=y_values.flatten(), mode='markers', marker_color='Navy', marker_size=15, name='Data Points'))
    # Update layout
    width = 900
    height = 900
    font_size = 20
    template = 'simple_white'
    fig.update_xaxes(range=(-105, 105), dtick=10, title='upstream-downstream specimen numbers (%)')
    fig.update_yaxes(range=(-105, 105), dtick=10, title='eDNA-eRNA detection probability (%)')
    fig.update_layout(width=int(width), height=int(height), template=template, font_size=font_size, yaxis_nticks=5, title_font_size=font_size, showlegend=False)
    fig.write_image('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/5_stuff/glm.pdf')

def jaccard_distances():

    taxon_table_xlsx = '/Users/tillmacher/Desktop/TTT_projects/Projects/Lippe_eRNA_tele02_vertebrates/TaXon_tables/Lippe_eRNA_fwh_taxon_table_cons_NCsub_invertebrates_normalized.xlsx'
    metadata_table = '/Users/tillmacher/Desktop/TTT_projects/Projects/Lippe_eRNA_tele02_vertebrates/TaXon_tables/Lippe_eRNA_tele02_taxon_table_cons_NCsub_vertebrates_normalized.xlsx'

    metadata_df = pd.read_excel(metadata_table)
    metadata = 'Molecule'
    taxonomic_level = 'Species'
    taxonomic_level_2 = 'Class'

    taxon_table_df = pd.read_excel(taxon_table_xlsx).fillna('')
    metadata_df = pd.read_excel(metadata_table).fillna('')
    taxon_table_df = taxon_table_df.sort_values([taxonomic_level_2, taxonomic_level], ascending=[False, False])

    ## sort samples according to metadata
    samples_dict = {}
    for sample in metadata_df[['Samples', metadata]].values.tolist():
        if sample[1] in samples_dict.keys():
            samples_dict[sample[1]] = samples_dict[sample[1]] + [sample[0]]
        else:
            samples_dict[sample[1]] = [sample[0]]

    res = {}
    for molecule, samples in samples_dict.items():
        unique_combinations = list(combinations(samples, 2))
        distances = []
        for combo in unique_combinations:
            s1, s2 = combo[0], combo[1]
            s1_species = sorted(set([i[0] for i in taxon_table_df[['Species', s1]].values.tolist() if i[0] != '' and i[1] != 0]))
            s2_species = sorted(set([i[0] for i in taxon_table_df[['Species', s2]].values.tolist() if i[0] != '' and i[1] != 0]))
            all_species = set(s1_species + s2_species)
            s1_species_counts = [1 if i in s1_species else 0 for i in all_species]
            s2_species_counts = [1 if i in s2_species else 0 for i in all_species]
            j = distance.jaccard(s1_species_counts, s2_species_counts)
            distances.append(j)
        average = np.mean(distances)
        res[molecule] = [distances, average, unique_combinations]

    ## calculate wilcoxon
    for molecule, color in zip(['eRNA', 'eDNA'], ['Red', 'Blue']):
        y_values = np.array(res[molecule][0])
        ## calcualte time difference between samples
        x_values = np.array([abs(int(i[0].split('_')[1]) - int(i[1].split('_')[1]))*5 for i in res[molecule][2]])

        ## highlight outliers
        if molecule == 'eDNA':
            colors = ['Grey' if 'LDNA_5' in list(i) or 'LDNA_12' in list(i) else 'Blue' for i in res[molecule][2]]
            pos = [i for i in range(len(colors)) if colors[i] == "Blue"]
            y2 = [y_values[i] for i in pos]
        else:
            colors = ['Grey' if 'LRNA_5' in list(i) or 'LRNA_12' in list(i) else 'Red' for i in res[molecule][2]]
            pos = [i for i in range(len(colors)) if colors[i] == "Red"]
            y2 = [y_values[i] for i in pos]

        ## calcualte Spearman rho
        rho, p_value = stats.spearmanr(x_values, y_values)
        title = '{} rho={} p={}'.format(molecule, round(rho,2), round(p_value,4))

        ## create figure
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=y_values, x=x_values, mode='markers', marker_color=colors))
        z = np.polyfit(x_values, y_values, 1)
        p = np.poly1d(z)
        fig.add_trace(go.Scatter(y=p(x_values), x=x_values, mode='lines', marker_color='Grey'))

        # Update layout
        fig.update_yaxes(range=(0, 1), dtick=0.2, title='Jaccard distance')
        fig.update_xaxes(title='Time difference (min)', dtick=5)
        fig.update_layout(width=int(600), height=int(600), template='simple_white', font_size=20, yaxis_nticks=5, title_font_size=20, showlegend=False, title=title)
        fig.write_image('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/5_stuff/jaccard_distance_scatter_{}.pdf'.format(molecule))

    ## create figure
    fig = go.Figure()
    y1_values = np.array(res['eRNA'][0])
    y2_values = np.array(res['eDNA'][0])
    statistic, p_value = wilcoxon(y1_values, y2_values)

    title = 'z={} p={}'.format(statistic, round(p_value, 4))
    fig.add_trace(go.Box(y=y1_values, name='eRNA', boxpoints = 'all', jitter = 0.3, pointpos = -1.8, marker_color='Red'))
    fig.add_trace(go.Box(y=y2_values, name='eDNA', boxpoints = 'all', jitter = 0.3, pointpos = -1.8, marker_color='Blue'))
    # Update layout
    fig.update_yaxes(range=(0, 1), dtick=0.2, title='Jaccard distance')
    fig.update_layout(width=int(600), height=int(600), template='simple_white', font_size=20, yaxis_nticks=5, title_font_size=20, showlegend=False, title=title)
    fig.write_image('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/5_stuff/jaccard_distance_boxplot.pdf')


    ## richness boxplot
    res = {}
    for molecule, samples in samples_dict.items():
        richness = []
        for sample in samples:
            r = len(sorted(set([i[0] for i in taxon_table_df[['Species', sample]].values.tolist() if i[0] != '' and i[1] != 0])))
            richness.append(r)
        res[molecule] = richness

    ## create figure
    fig = go.Figure()
    y1_values = res['eRNA']
    y2_values = res['eDNA']
    statistic, p_value = wilcoxon(y1_values, y2_values)

    title = 'z={} p={}'.format(statistic, round(p_value, 4))
    fig.add_trace(go.Box(y=y1_values, name='eRNA', boxpoints = 'all', jitter = 0.3, pointpos = -1.8, marker_color='Red'))
    fig.add_trace(go.Box(y=y2_values, name='eDNA', boxpoints = 'all', jitter = 0.3, pointpos = -1.8, marker_color='Blue'))
    # Update layout
    fig.update_yaxes(title='Species richness', rangemode='tozero')
    fig.update_layout(width=int(600), height=int(600), template='simple_white', font_size=20, yaxis_nticks=5, title_font_size=20, showlegend=False, title=title)
    fig.write_image('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/5_stuff/jaccard_distance_boxplot_richness.pdf')

def richness_delta():
    taxon_table_xlsx = '/Users/tillmacher/Desktop/TTT_projects/Projects/Lippe_eRNA_tele02_vertebrates/TaXon_tables/Lippe_eRNA_fwh_taxon_table_cons_NCsub_invertebrates_normalized.xlsx'
    metadata_table = '/Users/tillmacher/Desktop/TTT_projects/Projects/Lippe_eRNA_tele02_vertebrates/Meta_data_table/Lippe_eRNA_fwh_taxon_table_cons_NCsub_invertebrates_normalized_metadata.xlsx'

    metadata_df = pd.read_excel(metadata_table)
    metadata = 'Molecule'
    taxonomic_level = 'Species'
    taxonomic_level_2 = 'Class'

    taxon_table_df = pd.read_excel(taxon_table_xlsx).fillna('')
    metadata_df = pd.read_excel(metadata_table).fillna('')
    taxon_table_df = taxon_table_df.sort_values([taxonomic_level_2, taxonomic_level], ascending=[False, False])

    ## sort samples according to metadata
    samples_dict = {}
    for sample in metadata_df[['Samples', metadata]].values.tolist():
        if sample[1] in samples_dict.keys():
            samples_dict[sample[1]] = samples_dict[sample[1]] + [sample[0]]
        else:
            samples_dict[sample[1]] = [sample[0]]

    y1_values = []
    y2_values = []
    x_values = []
    n_samples = len(list(samples_dict.values())[0])
    for i in range(0, n_samples):
        s1 = samples_dict['eRNA'][i]
        s2 = samples_dict['eDNA'][i]
        s1_species = len(sorted(set([i[0] for i in taxon_table_df[['Species', s1]].values.tolist() if i[0] != '' and i[1] != 0])))
        s2_species = len(sorted(set([i[0] for i in taxon_table_df[['Species', s2]].values.tolist() if i[0] != '' and i[1] != 0])))
        y1_values.append(s1_species)
        y2_values.append(s2_species)
        x_values.append(i*5)

    # pairwise jaccard distance
    y3_values = []
    for s1,s2 in zip(samples_dict['eRNA'], samples_dict['eDNA']):
        s1_species = sorted(set([i[0] for i in taxon_table_df[['Species', s1]].values.tolist() if i[0] != '' and i[1] != 0]))
        s2_species = sorted(set([i[0] for i in taxon_table_df[['Species', s2]].values.tolist() if i[0] != '' and i[1] != 0]))
        all_species = set(s1_species + s2_species)
        s1_species_counts = [1 if i in s1_species else 0 for i in all_species]
        s2_species_counts = [1 if i in s2_species else 0 for i in all_species]
        j = distance.jaccard(s1_species_counts, s2_species_counts)
        y3_values.append(j)


    fig = go.Figure()
    fig.add_trace(go.Scatter())
    fig.add_trace(go.Scatter())
    title = 'eRNA±{} eDNA±{}'.format(round(np.std(y1_values),2), round(np.std(y2_values),2))
    fig.update_yaxes()

    # Create figure with secondary y-axis
    fig = make_subplots(specs=[[{"secondary_y": True}]])

    # Add traces
    fig.add_trace(go.Scatter(x=x_values, y=y1_values, mode='markers+lines', marker_color='Red'), secondary_y=False)
    fig.add_trace(go.Scatter(x=x_values, y=y2_values, mode='markers+lines', marker_color='Blue'), secondary_y=False)
    fig.add_trace(go.Scatter(x=x_values, y=y3_values, mode='markers+lines', opacity=0.4, marker_color='Grey'), secondary_y=True)

    fig.update_yaxes(rangemode='tozero', dtick=10, title='Species richness', secondary_y=False)
    fig.update_yaxes(title="Pairwise Jaccard distance", range=(0,1), secondary_y=True)
    fig.update_xaxes(range=(-1, 96), dtick=5, title='Time point (min)')
    fig.update_layout(width=int(600), height=int(600), template='simple_white', font_size=20, yaxis_nticks=5, title_font_size=20, showlegend=False, title=title)
    fig.write_image('/Users/tillmacher/Desktop/Paper/eRNA_eDNA_Lippe/5_stuff/species_richness.pdf')

    print(min(y1_values), max(y1_values), np.mean(y1_values))
    print(min(y2_values), max(y2_values), np.mean(y2_values))
    print(min(y3_values), max(y3_values), np.mean(y3_values))


