import socket
import json
from game_state import GameState
#from bot import fight
import sys
from keras.models import load_model
from bot import Bot
from command import Command
import numpy as np
from buttons import Buttons
from collections import defaultdict
import PSO

def connect(port):
    #For making a connection with the game
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind(("127.0.0.1", port))
    server_socket.listen(5)
    (client_socket, _) = server_socket.accept()
    print ("Connected to game!")
    return client_socket

def send(client_socket, command):
    #This function will send your updated command to Bizhawk so that game reacts according to your command.
    command_dict = command.object_to_dict()
    pay_load = json.dumps(command_dict).encode()
    client_socket.sendall(pay_load)

def receive(client_socket):
    #receive the game state and return game state
    pay_load = client_socket.recv(4096)
    input_dict = json.loads(pay_load.decode())
    game_state = GameState(input_dict)

    return game_state

def get_train_example( current_game_state ,   previous_game_state  ):

    fight_result_int={ "NOT_OVER":0, "P1":1 ,"P2":2 }

    move_id_dict  = {0:1, 1:2, 131073:3, 1025:4, 33686532:5, 33685508:6}
    move_id_dict = defaultdict(int,move_id_dict)

    previous_health_p1 = int(previous_game_state.player1.health )
    current_health_p1 = int(current_game_state.player1.health )

    previous_health_p2 = int(previous_game_state.player2.health )
    current_health_p2 = int(current_game_state.player2.health )

    damage_done =  previous_health_p2 - current_health_p2
    damage_recieved = previous_health_p1 - current_health_p1

    distance_reward = abs(int(previous_game_state.player1.x_coord ) - int(previous_game_state.player2.x_coord ) ) - abs(int(current_game_state.player1.x_coord ) - int(current_game_state.player2.x_coord ) )
    #if abs(int(current_game_state.player1.x_coord ) - int(current_game_state.player1.x_coord ) ) > 100

    reward = (3*damage_done - damage_recieved) + distance_reward 

    if fight_result_int[ current_game_state.fight_result] == "NOT_OVER":
        reward -=1
    if fight_result_int[ current_game_state.fight_result] == "P1":
        reward +=200

    if fight_result_int[ current_game_state.fight_result] == "P2":
        reward -=200


    current_state_array = [   int(current_game_state.player1.player_buttons.up), 
                                int(current_game_state.player1.player_buttons.down), 
                                int(current_game_state.player1.player_buttons.right) ,
                                int(current_game_state.player1.player_buttons.left), 
                                int(current_game_state.player1.player_buttons.Y) ,
                                int(current_game_state.player1.player_buttons.B),
                                int(current_game_state.player1.player_buttons.X),
                                int(current_game_state.player1.player_buttons.A),
                                int(current_game_state.player1.player_buttons.L),
                                int(current_game_state.player1.player_buttons.R),  

                                int(current_game_state.timer ) ,                                            
                                fight_result_int[ current_game_state.fight_result] ,
                                int(current_game_state.has_round_started ) ,
                                int(current_game_state.is_round_over ) ,


                                #int(current_game_state.player1.health ) ,
                                int(current_game_state.player1.x_coord ) ,
                                int(current_game_state.player1.y_coord ) ,
                                int(current_game_state.player1.is_jumping ) ,
                                int(current_game_state.player1.is_crouching ) ,
                                int(current_game_state.player1.is_player_in_move ) ,
                                move_id_dict[int(current_game_state.player1.move_id )] ,
                                int(current_game_state.player1.player_buttons.select), 
                                int(current_game_state.player1.player_buttons.start),

                                #int(current_game_state.player1.player_buttons.up ) ,
                                #int(current_game_state.player1.player_buttons.down ) ,
                                #int(current_game_state.player1.player_buttons.right ) ,
                                #int(current_game_state.player1.player_buttons.left ) ,

                                #int(current_game_state.player2.health ) ,
                                int(current_game_state.player2.x_coord ) ,
                                int(current_game_state.player2.y_coord ) ,
                                int(current_game_state.player2.is_jumping ) ,
                                int(current_game_state.player2.is_crouching ) ,
                                int(current_game_state.player2.is_player_in_move ) ,
                                move_id_dict[int(current_game_state.player2.move_id )]
                                ]

    return (current_state_array , [reward])

def get_predict_example( act , current_game_state ):

    fight_result_int={ "NOT_OVER":0, "P1":1 ,"P2":2 }
    act = defaultdict(bool,act)

    move_id_dict  = {0:1, 1:2, 131073:3, 1025:4, 33686532:5, 33685508:6}
    move_id_dict = defaultdict(int,move_id_dict)
    x_train = [   int( act["up"] ), 
                    int(act["down"]), 
                    int(act["right"]),
                    int(act["left"]),
                    int(act["Y"]),
                    int(act["B"]),
                    int(act["X"]),
                    int(act["A"]),
                    int(act["L"]),
                    int(act["R"]),
                            
                    int(current_game_state.timer ) ,                                            
                    fight_result_int[ current_game_state.fight_result ],
                    int(current_game_state.has_round_started ) ,
                    int(current_game_state.is_round_over ) ,


                    #int(current_game_state.player1.health ) ,
                    int(current_game_state.player1.x_coord ) ,
                    int(current_game_state.player1.y_coord ) ,
                    int(current_game_state.player1.is_jumping ) ,
                    int(current_game_state.player1.is_crouching ) ,
                    int(current_game_state.player1.is_player_in_move ) ,
                    move_id_dict[int(current_game_state.player1.move_id )] ,
                    int(current_game_state.player1.player_buttons.select), 
                    int(current_game_state.player1.player_buttons.start),

                    #int(current_game_state.player1.player_buttons.up ) ,
                    #int(current_game_state.player1.player_buttons.down ) ,
                    #int(current_game_state.player1.player_buttons.right ) ,
                    #int(current_game_state.player1.player_buttons.left ) ,

                    #int(current_game_state.player2.health ) ,
                    int(current_game_state.player2.x_coord ) ,
                    int(current_game_state.player2.y_coord ) ,
                    int(current_game_state.player2.is_jumping ) ,
                    int(current_game_state.player2.is_crouching ) ,
                    int(current_game_state.player2.is_player_in_move ) ,
                    move_id_dict[int(current_game_state.player2.move_id )] ]
    return x_train
        
def get_all_possible_Action(action):
    if  len(action)==1:
        return [{ action[0]:True, action[0]:False }]
    if len(action)==2:
        temp_array = [ [False,False],[False,True],[True,False],[True,True] ]
        ret_array = []
        for act in temp_array:
            ret_array.append( { action[0]:act[0] , action[1]:act[1]  } )
        return ret_array
    elif len(action)==3:
        temp_array1 = [ [False,False,False],[False,False,True],[False,True,False],[False,True,True],[True,False,False],[True,False,True],[True,True,False],[True,True,True] ]
        ret_array1 = []
        for act in temp_array1:
            ret_array1.append( { action[0]:act[0] , action[1]:act[1],action[2]:act[2] } )
        return ret_array1
    
def update_button(act):

    act = defaultdict(bool,act)

    buttn= Buttons()
    buttn.up =  act["up"] 
    buttn.down=act["down"] 
    buttn.right=act["right"]
    buttn.left=act["left"]
    buttn.Y=act["Y"]
    buttn.B=act["B"]
    buttn.X=act["X"]
    buttn.A=act["A"]
    buttn.L=act["L"]
    buttn.R=act["R"]
    return buttn

def main():
    if (sys.argv[1]=='1'):
        client_socket = connect(9999)
    elif (sys.argv[1]=='2'):
        client_socket = connect(10000)

    dummy_actions_list = [["left"],["right"],["up" , "Y"],["up" , "B"],["up" , "X"],["up" , "L"],["up" , "R"],["up" , "A"],
                          ["down" , "Y"],["down" , "B"],["down" , "X"],["down" , "L"],["down" , "R"],["down" , "A"],
                          ["left" , "Y"],["left" , "B"],["left" , "X"],["left" , "L"],["left" , "R"],["left" , "A"],
                          ["right" , "Y"],["right" , "B"],["right" , "X"],["right" , "L"],["right" , "R"],["right" , "A"],
                          ["up","left","Y"],["up","left","B"],["up","left","X"],["up","left","L"],["up","left","R"],["up","left","A"],
                          ["up","right","Y"],["up","right","B"],["up","right","X"],["up","right","L"],["up","right","R"],["up","right","A"],
                          ["down","left","Y"],["down","left","B"],["down","left","X"],["down","left","L"],["down","left","R"],["down","left","A"],
                          ["down","right","Y"],["down","right","B"],["down","right","X"],["down","right","L"],["down","right","R"],["down","right","A"]]



    current_game_state = None
    #print( current_game_state.is_round_over )
    bot=Bot()
    model = load_model("afd1.h5")
    PSO.model= model
    layers_shape  = [ lw.shape for lw in model.get_weights()]
    pso_model = None
    first_game = 1
    pso_model = PSO.PSO(dump_file="pso_weights.pkl")
    #firsttime
    #pso_model = PSO.PSO(  layers_shape= layers_shape  , total_sub_swarms=10, each_swarm_size=40 )

    previous_current_state = None
    start_train =0 
    turn = 0
    bot_command = Command()
    buttn= Buttons()
    while (current_game_state is None) or (not current_game_state.is_round_over):
        

        current_game_state = receive(client_socket)
        if previous_current_state == None :
            previous_current_state =  current_game_state
        elif (current_game_state.has_round_started==True) and (current_game_state.fight_result=="NOT_OVER") and start_train%20==0:
            start_train = 0
            x_example , reward = get_train_example( current_game_state, previous_current_state )

            x_example = np.array(x_example)
            #call fit functionw
            print( x_example[np.newaxis,:].shape , np.array(reward).shape )
            #model.fit( x_example[np.newaxis,:] , np.array(reward)   )
            PSO.data = x_example[np.newaxis,:]
            PSO.target = np.array(reward)
            if first_game : 
                first_game -=1
                #pso_model = PSO.PSO(layers_shape=layers_shape , total_sub_swarms=5, each_swarm_size=5)
                pso_model = PSO.PSO(dump_file="pso_weights.pkl")


                
            #fitness_array = pso_model.fit(epoch=2)


            previous_current_state = current_game_state
            
            
            action_reward = []
            filled_actions = []       
            for dummy_action in dummy_actions_list :
                filled_actions += get_all_possible_Action( dummy_action )

            if turn%2==0:
                turn=0
            
                act_indx = []
                for act in np.random.random_integers(0,len(filled_actions)-1, size=(30,) ) :
                    act_indx.append(act)
                    X_test =  get_predict_example( filled_actions[act] , current_game_state )

                    weights = pso_model.get_best_weigts()
                    model.set_weights( weights )
                    action_reward.append(model.predict( np.array(X_test)[np.newaxis,:]  ))
                    print( action_reward[-1] )

                buttn = update_button( filled_actions[ act_indx[ np.argmax( action_reward )] ])
            else:
                act= np.random.randint(0, len(filled_actions) )
                buttn= update_button( filled_actions[ act  ] )



       	turn+=1
        start_train+=1    




        #bot_command = bot.fight(current_game_state,sys.argv[1])
        bot_command.player_buttons= buttn

        send(client_socket, bot_command)

    model.save("afd1.h5")
    pso_model.save( "pso_weights.pkl")
        
        
       
if __name__ == '__main__':
   main()
