"""=========================================================================================

            Recurrent Neural Network Model for Spin Upset Scenario  
            Last modified on 12/31/2022

========================================================================================="""

import sys
import numpy as np
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from timer import Timer
import os
import time


class TimerError(Exception):
    """A custom exception used to report errors in use of Timer class"""

class Timer:
    def __init__(self):
        self._start_time = None

    def start(self):
        """Start a new timer"""
        if self._start_time is not None:
            raise TimerError(f"Timer is running. Use .stop() to stop it")

        self._start_time = time.perf_counter()

    def stop(self):
        """Stop the timer, and report the elapsed time"""
        if self._start_time is None:
            raise TimerError(f"Timer is not running. Use .start() to start it")

        elapsed_time = time.perf_counter() - self._start_time
        self._start_time = None
        print(f"Elapsed time: {elapsed_time:0.4f} seconds")


def RNN_Test(outputDim,lossobj,lossType):
    import tensorflow as tf
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.layers import Flatten
    from tensorflow.keras.layers import Reshape
    from tensorflow.keras.layers import Lambda
    from tensorflow.keras.layers import Bidirectional
    from tensorflow.keras.layers import SimpleRNN
    from tensorflow.keras.layers import BatchNormalization
    from tensorflow.keras.layers import Conv1D
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.layers import LeakyReLU
    from tensorflow.keras.layers import Dropout
    #from tensorflow.keras import RootMeanSquaredError

    
    model = tf.keras.models.Sequential()
    #model.add(SimpleRNN(outputDim,return_sequences=True)) 
    #model.add(SimpleRNN(outputDim,return_sequences=True))
    model.add(SimpleRNN(32,return_sequences=True)) 
    model.add(SimpleRNN(32,return_sequences=True))
    

    '''model.add(Dense(70)) # Fully connected layer, it has 70 neurons
    model.add(LeakyReLU(alpha=0.2)) # if alpha = 0 then  it is RelU
    # if alpha = 1 just linear function'''

    model.add(Dense(70))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())
    model.add(Dense(70))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())
    model.add(Dense(70))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())
    model.add(Dense(70))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())

    model.add(Dense(outputDim)) # output layer which gives you Lat, Long, Alt
    # output layer is a linear function

    # You need to change (tune) lr to have accurate results
    opt = Adam(lr=5e-5, beta_1=0.9, beta_2=0.999, amsgrad=False)
    # adaptive learning (adaptive gradient descent) with moment 


    if lossType == 'mse':
        model.compile(loss="mse", optimizer=opt)
    elif lossType == 'rmse':
        model.compile(loss=lossobj.RMSE, optimizer=opt)
    else:
        print('error: please enter a valid loss function for spin training')


    return model

def predict_validate_v5(model,dataSetX_val,DataMeanY,DataStdY,surrogateDt,myTime,scalingType,btchSize=32,groundTruth=[]):
    myTime.start()
    forecast = model.predict(dataSetX_val,batch_size=btchSize)
    myTime.stop()
    TotalTime = forecast.shape[1]*surrogateDt
    if scalingType ==0:
        for i in range (forecast.shape[0]):
            for j in range(forecast.shape[2]):
                #forecast[i,:,j] = (forecast[i,:,j]*1)+0
                forecast[i,:,j] = (forecast[i,:,j]*DataStdY[j])+DataMeanY[j]
    elif scalingType ==1:
        for i in range (forecast.shape[1]):
            for j in range(forecast.shape[2]):
                #forecast[i,:,j] = (forecast[i,:,j]*1)+0
                forecast[:,i,j] = (forecast[:,i,j]*DataStdY[i])+DataMeanY[i]
    time_Temp = np.linspace(0,TotalTime,num=forecast.shape[1])
    time_array = np.zeros((forecast.shape[0],forecast.shape[1]))
    for it in range(forecast.shape[0]):
        time_array[it,:] = time_Temp
    time_array = np.expand_dims(time_array,axis = 2)
    forecast = np.append(forecast, time_array,axis=2)
    #if groundTruth !="":
    groundTruth_plot = np.append(groundTruth, time_array,axis=2)
    #else:
    #groundTruth_plot = groundTruth
    return forecast, groundTruth_plot
   

t = Timer() # Instantiate an object from my customized Timer class
trainingMode =1

if trainingMode==0:
    allTraj = readInput.read_input_spin(sys.argv[1])
    datasetX, datasetY = dataSets.getDataSets_spin(allTraj)
    datasetX_scaled = dataSets.scaleDataSets(datasetX,0)
    TotalTime = readInput.getTime(datasetX.shape[1],gtm_dt)
    datasetX_scaledF = dataSets.skipDt(datasetX_scaled,gtm_dt,surrogateDt,TotalTime)
    datasetY_unscaledF = dataSets.skipDt(datasetY,gtm_dt,surrogateDt,TotalTime)
    model = load_model(test_model)
    DataMeanY,DataStdY = dataSets.readMean_std(1,stat_All)
    predictedData, groundTruth_plot= predictTraj.predict_validate_v5(model,datasetX_scaledF,DataMeanY,DataStdY,surrogateDt,btch_sz,datasetY_unscaledF)
    for i in range(predictedData.shape[0]):
        fnameP = "RNN_PredictedProjectionV6_GTM"+"_"+str(i)+".csv"
        if groundTruthTest ==1:
            fnameO = "groundTruth_plot"+"_"+str(i)+".csv"
            np.savetxt(os.path.join(outDir,fnameO), groundTruth_plot[i,:,:], delimiter=",")
        np.savetxt(os.path.join(outDir,fnameP), predictedData[i,:,:], delimiter=",")
elif trainingMode ==1:
    allTraj = readInput.read_input_spin(sys.argv[1])
    datasetX, datasetY = dataSets.getDataSets_spin(allTraj)
    print(datasetX.shape,datasetY.shape)
    if storeStat_flag ==1:
        if stat_All ==0:
            dataSets.storeMean_std(datasetY,1)
        elif stat_All ==1:
            dataSets.storeMean_std_All(datasetY,1)
    datasetX_scaled = dataSets.scaleDataSets(datasetX, 0)
    datasetY_scaled = dataSets.scaleDataSets(datasetY,0)
    TotalTime = readInput.getTime(datasetX.shape[1],gtm_dt)
    print("Total Sim Time",TotalTime)
    datasetXScaled_F = dataSets.skipDt(datasetX_scaled,gtm_dt,surrogateDt,TotalTime)
    datasetYScaled_F = dataSets.skipDt(datasetY_scaled,gtm_dt,surrogateDt,TotalTime)
    trainX,trainY,testX,testY = trainValidate_split.train_split_curves(datasetXScaled_F, datasetYScaled_F)
    print("train and test data size",trainX.shape,trainY.shape,testX.shape,testY.shape)
    loss_out = open(os.path.join(outDir,"modelLoss_spin.txt"),"w")

    if restart_Training ==1:
        model =  load_model(restart_model)
        startS = startStep
    elif restart_Training == 0:     
        model =  NN_model.RNN_Test(trainY.shape[2])
        startS = 0

    bat_per_epo = int(np.ceil(trainX.shape[0] / btch_sz))
    t.start() 
    for i in range(startS,trainingEpcohs):
        for j in range(bat_per_epo):
            xtrain, ytrain = dataSets.getSamples(trainX, trainY, btch_sz)
            trainLoss = model.train_on_batch(xtrain, ytrain,reset_metrics=False)
        valLoss = model.evaluate(testX,testY,verbose=1)
        '''temp = objloss.getLoss()
        print(temp.shape)
        print(temp)'''
        loss_out.write("%d\t%E\t%E\n"%(i+1, trainLoss/bat_per_epo,valLoss))
        loss_out.flush()
        if (i+1) % saveFreq ==0:
            saveName = modelName+"_" + str(i+1) + ".h5"
            model.save(os.path.join(outDir,saveName))
    loss_out.close()
    t.stop()

