# -*- coding: utf-8 -*-

from sklearn.preprocessing import MinMaxScaler,StandardScaler
import pandas as pd
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
from scipy.io import savemat,loadmat
 
'''
Perform fitness calculations
'''
def fitness(pop,data):
    tf.compat.v1.reset_default_graph()
    tf.compat.v1.set_random_seed(0)

    alpha = pop[0]# Learning rate
    num_epochs = int(pop[1])# Number of iterations
    hidden_nodes0 = int(pop[2])#First hidden layer neuron
    hidden_nodes = int(pop[3])#Second hidden layer neuron
    batch_size = int(pop[4])# batchsize
    steps = int(pop[5])  # Step size
    in_, out_ = data_split(data, steps)
    n = range(in_.shape[0])
    m = -900
    train_data = in_[n[0:m],]
    test_data = in_[n[m:],]
    train_label = out_[n[0:m],]
    test_label = out_[n[m:],]
    # Normalized

    ss_X = MinMaxScaler(feature_range=(0, 1)).fit(train_data)
    ss_Y = MinMaxScaler(feature_range=(0, 1)).fit(train_label)
    P = ss_X.transform(train_data)
    T = ss_Y.transform(train_label)

    Pt = ss_X.transform(test_data)
    Tt = ss_Y.transform(test_label)

    input_features = P.shape[1]
    output_class = T.shape[1]

    # placeholder
    X = tf.compat.v1.placeholder("float", [None, input_features])
    Y = tf.compat.v1.placeholder("float", [None, output_class])
    

    def RNN(x,hidden_nodes0,hidden_nodes,input_features,output_class):
        x = tf.reshape(x , [-1, 1,input_features])
        weights = {'out': tf.Variable(tf.compat.v1.random_normal([hidden_nodes, output_class]))}
        biases = {'out': tf.Variable(tf.compat.v1.random_normal([output_class]))}
        gru_cell0 = tf.compat.v1.nn.rnn_cell.GRUCell(hidden_nodes0)
        gru_cell = tf.compat.v1.nn.rnn_cell.GRUCell(hidden_nodes)
        gru_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([gru_cell0,gru_cell])
        init_state = gru_cell.zero_state(tf.shape(x)[0], dtype=tf.float32)
        outputs, _ = tf.compat.v1.nn.dynamic_rnn(gru_cell, x, dtype=tf.float32, initial_state=init_state)
        output_sequence = tf.matmul(tf.reshape(outputs, [-1, hidden_nodes]), weights['out']) + biases['out']
        return tf.reshape(output_sequence, [-1, output_class])
    
    logits = RNN(X,hidden_nodes0,hidden_nodes,input_features,output_class)
    loss = tf.losses.mean_squared_error(predictions = logits, labels = Y)
    global_step = tf.Variable(0)
    learning_rate = tf.compat.v1.train.exponential_decay(
                    alpha,
                    global_step,
                    num_epochs, 0.99,
                    staircase=True)
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate, epsilon = 1e-10).minimize(loss,global_step=global_step)
    init = tf.compat.v1.global_variables_initializer()
    
    with tf.compat.v1.Session() as sess:
        sess.run(init)
        N = P.shape[0]
        for epoch in range(num_epochs):
            total_batch = int(math.ceil(N / batch_size))
            indices = np.arange(N)
            np.random.shuffle(indices)
            avg_loss = 0

            for i in range(total_batch):
                rand_index = indices[batch_size*i:batch_size*(i+1)]
                x = P[rand_index]
                y = T[rand_index]
                _, cost = sess.run([optimizer, loss],
                                    feed_dict={X: x, Y: y})
                avg_loss += cost / total_batch


        test_pred = sess.run(logits, feed_dict={X: Pt})
        test_pred = test_pred.reshape(-1, output_class)
    
    F2=np.mean(np.square((test_pred-Tt)))
    return F2

def boundary(pop,Lb,Ub):
    # Prevent jumping out of range
    # Except the learning rate, everything else is an integer
    pop=[pop[i] if i==0 else int(pop[i]) for i in range(len(Lb))]
    for i in range(len(Lb)):
        if pop[i]>Ub[i] or pop[i]<Lb[i]:
            if i==0:
                pop[i]=(Ub[i]-Lb[i])*np.random.rand()+Lb[i]
            else:
                pop[i]=np.random.randint(Lb[i],Ub[i])
    return pop
 
# Python implementation of whale optimization algorithm
def woa(data):
    ''' 
        noclus = Dimension
        max_iterations = Number of iterations
        noposs= Population number
    '''
    noclus=6
    max_iterations=10
    noposs=10
    Lb=[0.001,10,1   ,1 ,32 ,2]
    Ub=[0.01,100,200,200,256,50]   #Upper bound, including learning rate, training times, number of nodes in two layers, batchsiza, time step
    poss_sols = np.zeros((noposs, noclus)) # whale positions
    gbest = np.zeros((noclus,)) # globally best whale postitions
    b = 2.0
    for i in range(noposs):
        for j in range(noclus):
            if j==0:
                poss_sols[i][j] = (Ub[j]-Lb[j])*np.random.rand()+Lb[j]
            else:
                poss_sols[i][j] = np.random.randint(Lb[j],Ub[j])
    global_fitness = np.inf
    for i in range(noposs):
        cur_par_fitness = fitness(poss_sols[i,:],data)
        if cur_par_fitness < global_fitness:
            global_fitness = cur_par_fitness
            gbest = poss_sols[i].copy()

    trace,trace_pop=[],[]
    for it in range(max_iterations):
        for i in range(noposs):
            a = 2.0 - (2.0*it)/(1.0 * max_iterations)
            r = np.random.random_sample()
            A = 2.0*a*r - a
            C = 2.0*r
            l = 2.0 * np.random.random_sample() - 1.0
            p = np.random.random_sample()
            
            for j in range(noclus):
                x = poss_sols[i][j]
                if p < 0.5:
                    if abs(A) < 1:
                        _x = gbest[j].copy()
                    else :
                        rand = np.random.randint(noposs)
                        _x = poss_sols[rand][j]
                    D = abs(C*_x - x)
                    updatedx = _x - A*D
                else :
                    _x = gbest[j].copy()
                    D = abs(_x - x)
                    updatedx = D * math.exp(b*l) * math.cos(2.0* math.acos(-1.0) * l) + _x

                poss_sols[i][j] = updatedx
            poss_sols[i,:]=boundary(poss_sols[i,:],Lb,Ub)#边界判断
            fitnessi = fitness(poss_sols[i],data)
            if fitnessi < global_fitness :
                global_fitness = fitnessi
                gbest = poss_sols[i].copy()
        trace.append(global_fitness)
        print ("iteration",it+1,"=",global_fitness,[gbest[i] if i==0 else int(gbest[i]) for i in range(len(Lb))])
        trace_pop.append(gbest)
    return gbest, trace,trace_pop
# In[] Load data
def data_split(data,steps):
    in_,out_=[],[]
    samples=len(data)-steps
    for i in range(samples):
        in_.append(data[i:i+steps])
        out_.append(data[i+steps])
    in_=np.array(in_).reshape(len(in_),steps)
    out_=np.array(out_).reshape(len(out_),1)
    return in_,out_
data=pd.read_csv('yantailaizhou8192.csv',engine='python')[['salt']].fillna(0).values


result=loadmat('woa_gru_result.mat')['result']
trace=loadmat('woa_gru_result.mat')['trace'].reshape(-1,)

# In[]
result=np.array(result)
plt.figure()
plt.plot(trace)
plt.title('fitness curve')
plt.xlabel('iteration')
plt.ylabel('fitness value')
plt.savefig('fitness curve.png')
plt.show()


plt.figure()
plt.plot(result[:,0])
plt.title('learning rate optim')
plt.xlabel('iteration')
plt.ylabel('learning rate value')
plt.savefig('lr curve.png')
plt.show()
plt.figure()
plt.plot(result[:,1])
plt.title('itration optim')
plt.xlabel('iteration')
plt.ylabel('itration value')
plt.savefig('itration curve.png')
plt.show()


plt.figure()
plt.plot(result[:,2])
plt.title('first hidden nodes optim')
plt.xlabel('iteration')
plt.ylabel('first hidden nodes value')
plt.savefig('first hidden-node curve.png')
plt.show()

plt.figure()
plt.plot(result[:,3])
plt.title('second hidden nodes optim')
plt.xlabel('iteration')
plt.ylabel('second hidden nodes value')
plt.savefig('second hidden-node curve.png')
plt.show()

plt.figure()
plt.plot(result[:,4])
plt.title('batchsize optim')
plt.xlabel('iteration')
plt.ylabel('batchsize value')
plt.savefig('batchsize curve.png')
plt.show()

plt.figure()
plt.plot(result[:,5])
plt.title('steps optim')
plt.xlabel('iteration')
plt.ylabel('steps value')
plt.savefig('steps curve.png')
plt.show()