import os
from itertools import cycle
from imblearn.over_sampling import SMOTE
import flwr as fl
import numpy as np
import tensorflow as tf
# from keras.layers import LSTM, Dense
# from keras.models import Sequential
# from keras.layers import Dense
# from keras.layers import Dropout
# from keras.utils import np_utils
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.preprocessing import MinMaxScaler
from tensorflow import keras
from tensorflow.keras.layers import Flatten   # to flatten the input data
from tensorflow.keras.layers import Dense     # for the hidden layer
from tensorflow.keras.models import Sequential
from sklearn.preprocessing import label_binarize
from keras.callbacks import EarlyStopping
import scikitplot as skplt
from sklearn import metrics
# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def preprocessing(X,Y):
    X = pd.read_csv(X)
    Y = pd.read_csv(Y)
    X.head()

    # df['Timestamp'] = pd.to_datetime(df['Timestamp'])
    # df['month'] = df['Timestamp'].apply(lambda date: date.month)
    # df['year'] = df['Timestamp'].apply(lambda date: date.year)
    # # df = df.drop('date', axis=1)
    # # Check the new columns
    # print(df.columns.values)
    # print(df['month'])
    # print(df['year'])

    # X = df.drop(['TenYearCHD'], axis=1)
    # Y = df['TenYearCHD']
    # print("Training Data shape Before Smote: ", X.shape)
    # oversample = SMOTE()
    # X, Y = oversample.fit_resample(X, Y)
    print("Label Count Unique After Smote:", Y.value_counts())
    # print(Y.value_counts())
    print("Total Training Data shape: ", X.shape)
    print("Total Classes in dataset: ", Y.nunique())
    # encoder = LabelEncoder()
    # encoder.fit(Y)
    # encoded_Y = encoder.transform(Y)
    labels = np.unique(Y)
    # # convert integers to dummy variables (i.e. one hot encoded)
    # dummy_y = np_utils.to_categorical(encoded_Y)
    return X,Y,labels

def datasplit(X,y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=0)
    print("Train shape of 75% data: ", X_train.shape, y_train.shape)
    print("Test shape of 25% data: ", X_test.shape, y_test.shape)
    # X_train = np.expand_dims(X_train, 1)
    # X_test = np.expand_dims(X_test, 1)
    return X_train,X_test,y_train,y_test

def normalization(X):
    scaler = MinMaxScaler()
    # fit and transfrom
    X = scaler.fit_transform(X)
    # X_test = scaler.transform(X_test)
    # everything has been scaled between 1 and 0
    print('Max: ', X.max())
    print('Min: ', X.min())
    return X

def plot_history(hist, client_name,file_name,images_path):
    acc = hist.history['accuracy']
    val_acc = hist.history['val_accuracy']
    loss = hist.history['loss']
    val_loss = hist.history['val_loss']
    x = range(1, len(acc) + 1)
    # plt.figure(figsize=(12, 5))
    # plt.subplot(1, 2, 1)
    # plt.rcParams["savefig.directory"] = images_path
    plt.plot(x, acc, 'b', label='Training acc')
    plt.plot(x, val_acc, 'r', label='Validation acc')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Training and validation accuracy of ' + client_name)
    plt.legend()
    plt.show()
    # plt.savefig(file_name + 'training_validation_accuracy.png')
    # plt.subplot(1, 2, 2)
    plt.plot(x, loss, 'b', label='Training loss')
    plt.plot(x, val_loss, 'r', label='Validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and validation loss of ' + client_name)
    plt.legend()
    plt.show()
    # plt.savefig(file_name + 'training_validation_loss.png')

def binary_roc(y_true, y_probas,client_name):
    skplt.metrics.plot_roc_curve(y_true, y_probas)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic of ' + client_name)
    plt.legend(loc="lower right", fontsize="xx-small")
    plt.show()

def binary_auc(y_test,y_pred_proba,client_name):
    y_pred_proba=y_pred_proba[::, 1]
    fpr, tpr, _ = metrics.roc_curve(y_test, y_pred_proba)
    auc = metrics.roc_auc_score(y_test, y_pred_proba)
    plt.plot(fpr, tpr, label="auc=" + str(auc))
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.legend(loc=4)
    plt.title('Area Under the Curve of ' + client_name)
    plt.show()

def multiclass_roc(n_classes, y_test, y_pred, client_name,file_name,images_path):
    lw = 1
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    colors = cycle(['blue', 'red', 'green'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(i, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([-0.05, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic of ' + client_name)
    plt.legend(loc="lower right", fontsize="xx-small")
    plt.show()
    # plt.rcParams["savefig.directory"] = images_path
    # plt.savefig(file_name)


def plot_confusion_matrix(cnf_matrix,class_names,file_name,images_path,client_name, numbers_type='numbers_and_percentage', title='Client2 Confusion matrix', cmap=plt.cm.Blues):
    combined = True
    cnf_matrix_normalized = cnf_matrix.astype('float') / cnf_matrix.sum(axis=1)[:, np.newaxis]
    # plt.figure(figsize=(10,10))
    plt.imshow(cnf_matrix, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    thresh = 0.8*cnf_matrix.max() / 1.
    for i, j in itertools.product(range(cnf_matrix.shape[0]), range(cnf_matrix.shape[1])):
        if numbers_type == 'numbers_and_percentage':
            st1 = '{:.2f}%'.format(100 * cnf_matrix_normalized[i, j])
            st2 = '({:2d})'.format(cnf_matrix[i, j])
            plt.text(j, i, st1+st2,
                     horizontalalignment="center", verticalalignment='bottom',
                     color="white" if cnf_matrix[i, j] > thresh else "black")

        elif numbers_type == 'percentage':
            fmt = '.2f'
            plt.text(j, i, format(cnf_matrix_normalized[i, j], fmt),
                     horizontalalignment="center", verticalalignment='bottom',
                     color="white" if cnf_matrix[i, j] > thresh else "black")
        else:
            fmt = 'd'
            plt.text(j, i, format(cnf_matrix[i, j], fmt),
                     horizontalalignment="center", verticalalignment='bottom',
                     color="white" if cnf_matrix[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
    # plt.rcParams["savefig.directory"] = images_path
    # plt.savefig(file_name)

    return

if __name__ == "__main__":
    # enter path name of dataset file
    client_name = input("Enter Client Name: ")
    x = 'X_features.csv'
    y = 'y_features.csv'
    X,y,labels=preprocessing(x,y)
    # X=normalization(X)
    X_train, X_test,y_train,y_test=datasplit(X,y)
    # Model Creating.....
    model = Sequential()
    # input layer
    model.add(Dense(256, activation='relu', input_shape=(56,)))
    # hidden layers
    model.add(Dense(128, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(28, activation='relu'))
    # output layer
    model.add(Dense(3, activation='softmax'))
    # compile model
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    #Training the model.......
    # model.fit(x=X_train, y=y_train.values,
    #           validation_data=(X_test, y_test.values),
    #           batch_size=128, epochs=10)


    # Define Flower client
    class CifarClient(fl.client.NumPyClient):
        def get_parameters(self):  # type: ignore
            return model.get_weights()

        def fit(self, parameters, config):  # type: ignore
            model.set_weights(parameters)
            # es = EarlyStopping(monitor='accuracy', mode='auto', verbose=1, baseline=.90, patience=5)
            hist = model.fit(X_train, y_train, epochs=40,validation_data=(X_test, y_test))
            # steps_per_epoch=3 validation_data=(X_test, y_test),             callbacks=[es]
            # hist =cross_val_score(estimator, X_train, y_train, cv=kfold)
            save_name = 'round'+str(config['rnd'])+client_name
            images_path = 'result_images/'
            print(hist.history.keys())
            plot_history(hist, client_name=client_name,images_path=images_path,file_name=save_name)
            pred = model.predict(X_test)
            print('y_test---------', y_test.shape, 'pred--------', pred.shape)
            y_pred = np.argmax(pred, axis=-1)
            y_test1 = label_binarize(y_test, classes=labels)
            n_class = y_test1.shape[1]
            # label
            # y_test1 = np.argmax(y_test, axis=-2)  # , axis=0
            print('y_test1---------', y_test1.shape, 'y_pred', y_pred.shape)
            multiclass_roc(n_class, y_test1, pred, client_name=client_name,images_path=images_path,file_name=save_name+'_roc.png')
            # binary_roc(y_test1, pred, client_name)
            # binary_auc(y_test1, pred, client_name)
            print(classification_report(y_test, y_pred))
            cf_matrix = confusion_matrix(y_test, y_pred)
            plot_confusion_matrix(cf_matrix,images_path=images_path, class_names=labels,client_name=client_name,file_name=save_name+'_confusionmatrix.png')
            return model.get_weights(), len(X_train), {}

        def evaluate(self, parameters, config):  # type: ignore
            model.set_weights(parameters)
            loss, accuracy = model.evaluate(X_test, y_test)
            # save_name = 'round' + client_name
            # images_path = 'result_images/'
            # pred = model.predict(X_test)
            # y_pred = np.argmax(pred, axis=-1)
            # y_test1 = label_binarize(y_test, classes=labels)
            # binary_roc(y_test1, pred, client_name)
            # binary_auc(y_test1, pred, client_name)
            # print(classification_report(y_test, y_pred))
            # cf_matrix = confusion_matrix(y_test, y_pred)
            # plot_confusion_matrix(cf_matrix, images_path=images_path, class_names=labels, client_name=client_name,
            #                       file_name=save_name + '_confusionmatrix.png')
            return loss, len(X_test), {"accuracy": accuracy}


    # Start Flower client
    fl.client.start_numpy_client("localhost:5000", client=CifarClient())