# from keras.layers import Dense
# from keras.models import Sequential
# from tensorflow.keras.layers import Dense,Sequential
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from bitstring import BitArray
import numpy as np
import pickle



model = None
data  = None
targets =None

last_best=np.infty

start_mse = 1147483648

def Cal_Fitness( weights  ):
    global model
    global data
    global target
    global last_best
    global start_mse
    model.set_weights( weights )

    mse= np.mean( np.sum( abs( model.predict( data ) - target[:,np.newaxis] ) , axis=1  )**0.5 ) 

    if mse < last_best and start_mse>0 :
        start_mse -=0.001*mse
    
    return start_mse


class Particle():
    
    def __init__(self, layers_shape=None   , dump_dict=None  ):
        
        if dump_dict ==None:
            self.layers_shape = layers_shape
            self.particle_size = np.sum( [ np.prod(shape) for shape in layers_shape ] )
            self.position   = ( np.random.rand(  self.particle_size   )*2 - 0.99 ) *0.09
            self.local_best = self.position
            self.velocity = ( np.random.rand(  self.particle_size   )*2 - 0.99 ) *0.09
            self.fitness=  Cal_Fitness(  self.Vec2Mat()     )
            self.local_best_fitness= self.fitness
        else:
            
            self.layers_shape  = dump_dict["layers_shape"]  
            self.particle_size = dump_dict["particle_size"]
            self.position      = dump_dict[ "position" ]
            self.local_best    = dump_dict[ "local_best" ] 
            self.velocity      = dump_dict[ "velocity" ]
            self.fitness       = dump_dict[ "fitness" ] 
            self.local_best_fitness  = dump_dict[ "local_best_fitness" ]
            
        
    def Vec2Mat( self ):
        weights=[]
        variable_used = 0
        for shape in self.layers_shape:
            weights.append( np.reshape( self.position[variable_used:variable_used+np.prod(shape)] , shape )  )
            variable_used += np.prod(shape)
            
        return  weights 
   
    def  dump_to_dict( self ):
        dump_dict = {}
        dump_dict["layers_shape"]  =  self.layers_shape
        dump_dict["particle_size"] =  self.particle_size
        dump_dict[ "position" ]    =  self.position 
        dump_dict[ "local_best" ]  =  self.local_best 
        dump_dict[ "velocity" ]    =  self.velocity
        dump_dict[ "fitness" ]     =  self.fitness
        dump_dict[ "local_best_fitness" ] = self.local_best_fitness
        
        return dump_dict
        
        
        
class Swarm():
    
    def __init__(self, layers_shape=None  ,swarm_size=30 , dump_dict=None ):
        
        if dump_dict==None:
            
            self.swarm_size = swarm_size
            self.layers_shape = layers_shape
            self.swarm = self.Make_Swarm( self.layers_shape , self.swarm_size  )
            self.C1 = 2
            self.C2 = 2 
            
            self.global_best , self.global_best_fitness  = self.Find_Global_Best()
            self.fitness_array = []
            
        else:
            self.swarm_size      = dump_dict["swarm_size"] 
            self.swarm           =  [ Particle(dump_dict = p_dict )  for p_dict in dump_dict[ "swarm" ] ] 
            self.layers_shape    = dump_dict["layers_shape"] 
            self.C1              = dump_dict["C1"] 
            self.C2              = dump_dict["C2"] 
            self.global_best     =  dump_dict["global_best"]
            self.global_best_fitness = dump_dict["global_best_fitness"] 
            self.fitness_array = dump_dict["fitness_array"] 
            
        
    def dump_to_dict(self):
        dump_dict   = {}
        dump_dict["swarm_size"] =  self.swarm_size
        dump_dict[ "swarm" ]  =  [ particle.dump_to_dict()  for particle in self.swarm ]
        dump_dict["layers_shape"] =  self.layers_shape
        dump_dict["C1"] =  self.C1
        dump_dict["C2"] =  self.C2
        dump_dict["global_best"] =  self.global_best
        dump_dict["global_best_fitness"] =  self.global_best_fitness
        dump_dict["fitness_array"] = self.fitness_array
        
        return dump_dict
            
    def Make_Swarm(self, layers_shape , size=30  ):
        Swarm=[ ]
        for i in range( size ):
            particle= Particle( layers_shape =layers_shape )
            Swarm.append( particle )
        return Swarm
        
        
    def Find_Global_Best(self):
        fitness_array = []
        for particle in self.swarm:
            fitness_array.append(  Cal_Fitness(   particle.Vec2Mat()  )   )
            
        min_indx = np.argmin( fitness_array )
        
        return self.swarm[ min_indx ].position ,  fitness_array[ min_indx ]
    
    def Find_Best_Weights( self, epoch=2000 , error=0.1 ):
        
        iteration=0
        
        fitness_array = []
        
        while iteration<epoch and self.global_best_fitness > error:
            
            if iteration%2==0:
                print ("Iteration:" ,  iteration , "   Error:" , self.global_best_fitness )
        
            
            fitness_array.append(  self.global_best_fitness  )
            
            for i in range( len(self.swarm)  ):
                R1 = np.random.rand( self.swarm[i].particle_size )
                R2 = np.random.rand( self.swarm[i].particle_size )
                
                
                self.swarm[i].velocity = (self.C1*R1*(self.swarm[i].local_best - self.swarm[i].position) + self.C2*R2*(self.global_best - self.swarm[i].position) )
                self.swarm[i].position = (self.swarm[i].position + self.swarm[i].velocity)
                
                self.swarm[i].fitness = Cal_Fitness( self.swarm[i].Vec2Mat() )
                
                
                if self.swarm[i].fitness < self.swarm[i].local_best_fitness:
                    self.swarm[i].local_best_fitness = self.swarm[i].fitness
                    self.swarm[i].local_best = self.swarm[i].position
            
            
            temp_position , temp_fitness = self.Find_Global_Best()
            if temp_fitness < self.global_best_fitness :
                self.global_best_fitness = temp_fitness
                self.global_best = temp_position
            
            
            iteration+=1
        
        return fitness_array

    
    
    def Search( self , epoch=2):
        global last_best
        while epoch>0:
            
            for particle in self.swarm:
                R1 = np.random.rand( particle.particle_size )
                R2 = np.random.rand( particle.particle_size )


                particle.velocity = (self.C1*R1*(particle.local_best - particle.position) + self.C2*R2*(self.global_best - particle.position) )
                particle.position = (particle.position + particle.velocity)

                particle.fitness = Cal_Fitness( particle.Vec2Mat() )


                if particle.fitness < particle.local_best_fitness:
                    particle.local_best_fitness = particle.fitness
                    particle.local_best = particle.position


            temp_position , temp_fitness = self.Find_Global_Best()
            if temp_fitness < self.global_best_fitness :
                self.global_best_fitness = temp_fitness
                self.global_best = temp_position
                last_best = temp_fitness
                self.fitness_array.append(  temp_fitness )
            
            epoch-=1

        return self.global_best_fitness
        
class PSO():
    
    def __init__(self, layers_shape=None , total_sub_swarms=10, each_swarm_size=40, dump_file = None):
        if dump_file==None:
            self.sub_swarms = [ Swarm(layers_shape=layers_shape , swarm_size=each_swarm_size)  for i in range(total_sub_swarms) ]
            self.total_sub_swarms  = total_sub_swarms
            self.each_swarm_size=each_swarm_size
            print( "We have made ", len(self.sub_swarms) ,  " Swarms" )
        else:
            dump_dict = {}
            with open( dump_file , "rb") as f:
                dump_dict = pickle.load( f )
            
            self.total_sub_swarms = dump_dict[ "total_sub_swarms" ]  
            self.each_swarm_size = dump_dict[ "each_swarm_size" ] 
            self.sub_swarms  = [ Swarm( dump_dict= s_dict)  for s_dict in dump_dict[ "sub_swarms" ]  ]
            
    
    def save(self , file_name):
        dump_dict={}
        dump_dict[ "total_sub_swarms" ] = self.total_sub_swarms
        dump_dict[ "each_swarm_size" ] = self.each_swarm_size
        dump_dict[ "sub_swarms" ] = [ swarm.dump_to_dict()  for swarm in self.sub_swarms  ]
        
        with open( file_name , "wb") as f:
            pickle.dump( dump_dict , f )
        
        
    def Merge_Sub_Swarms(self):
        all_particles = []
        all_particles_positions= []
        for swarm in self.sub_swarms:
            all_particles +=swarm.swarm
            all_particles_positions += [ particle.position for particle in swarm.swarm ]
            
        all_particles = np.array( all_particles )
        all_particles_positions = np.array( all_particles_positions )
        
        self.sub_swarms = sorted( self.sub_swarms , key=lambda swarm : swarm.global_best_fitness  )
        
        
        print( all_particles_positions.shape , all_particles_positions[0].shape  )
        for swarm in self.sub_swarms:
            
            
            closest_particles_indexes = np.argsort( np.sum( ( all_particles_positions -  swarm.global_best )**2 , axis=1)**0.5 )
            closest_particles_indexes = closest_particles_indexes[:swarm.swarm_size]
            swarm.swarm = list( all_particles[ closest_particles_indexes  ] )
            all_particles = np.delete( all_particles ,  closest_particles_indexes )
            all_particles_positions = np.delete(  all_particles_positions ,  closest_particles_indexes , axis=0 )
        
    
    def fit( self, epoch=1000 , error=0.1 ):
        
        iteration=0
        
        fitness_array = []
        
        while iteration<epoch :
            
            swarm_best_fitness=[]
            for swarm in self.sub_swarms:
                swarm_best_fitness.append( swarm.Search() )
                
            self.Merge_Sub_Swarms()
               
            
            #if iteration%2==0:
            fitness_array.append( min(swarm_best_fitness) )
            print ("Iteration:" ,  iteration , "   Error:" , min(swarm_best_fitness) )
            iteration+=1
        
        return fitness_array

    def Vec2Mat( self , position , layers_shape):
        weights=[]
        variable_used = 0
        for shape in layers_shape:
            weights.append( np.reshape( position[variable_used:variable_used+np.prod(shape)] , shape )  )
            variable_used += np.prod(shape)
            
        return  weights 
    
    def get_best_weigts(self):
        best_fitness= np.infty
        best_weigths = None
        for swarm in self.sub_swarms:
            if swarm.global_best_fitness < best_fitness :
                best_fitness = swarm.global_best_fitness
                best_weigths = swarm.global_best
        return self.Vec2Mat( best_weigths , self.sub_swarms[0].layers_shape  )         