# -*- coding: utf-8 -*-
"""
Created on Tue Dec 21 08:57:37 2021

@author: wxw
"""
import math
import numpy as np
from scipy.special import comb
from matplotlib import pyplot as plt
from sympy import *

m,N,k,r = symbols('m N k r')
x = symbols('x')
#p_e:p_expose
pe = (1-(1-r/m)**N)**k
dpedr = simplify(diff(pe,r))
dpedr2 = simplify(diff(pe,r,2))
dpedr3 = simplify(diff(pe,r,3))
# dpedr4 = simplify(diff(pe,r,4))
# curvature_pe = abs(dpedr2)/(1+dpedr**2)**1.5
# defence_point = solve(dpedr3,r)[1]
defence_point = m*(1-(((N-1)*(3*N*k-4-N)+N*((N-1)*(k-1)*(5*N*k-7-N-k))**0.5)/(2*(N*k-1)*(N*k-2)))**(1/N))
defence_point_rate = defence_point/m
attack_point = m*(1-(((N-1)*(3*N*k-4-N)-N*((N-1)*(k-1)*(5*N*k-7-N-k))**0.5)/(2*(N*k-1)*(N*k-2)))**(1/N))
attack_point_rate = attack_point/m
characteristic_equation= (N*k-1)*(N*k-2)*x**2 - (N-1)*(3*N*k-N-4)*x + (N-1)*(N-2)
# attack_point = solve(|dpedr2|-|dpedr|,r)[1]

def w(n,α,r):
    return comb(n,r)*α**r*(1-α)**(n-r)

def f(m,N,k,r):
    return (1-(1-r/m)**N)**k

def comb_gama(n,r):
    if r > n:
        return 0
    if  r > n-r:
        return comb_gama(n,n-r)
    return math.gamma(n+1)/math.gamma(r+1)/math.gamma(n-r+1)

def w_gama(n,α,r):
    return comb_gama(n,r)*α**r*(1-α)**(n-r)

def ps1(m,N,k,α,n):
    p_list = [w(n,α,r)*f(m,N,k,r) for r in range(1,n+1)]
    return sum(p_list)

def ps1_gama(m,N,k,α,n):
    p_list = [w_gama(n,α,r)*f(m,N,k,r) for r in range(1,int(n)+1)]
    return sum(p_list)

def compare(m,N,k,α,n):
    r = n*α
    return f(m,N,k,r),ps1(m,N,k,α,n)

def ps11(m,N,k,α,n,dr):
    f = lambda r: w_gama(n,α,r)*f(m,N,k,r)
    rmax = α*(n+1)
    ans = f(rmax)
    for i in range(1,dr+1):
        ans += f(rmax+i) + f(rmax-i)
    return ans


def n_otp1(m0,N0,k0,α0):
    n_attack = attack_point.evalf(subs={m:m0,N:N0,k:k0})/α0
    return min(n_attack,m0)

def ps1_opt(m,N,k,α):
    n1 = int(round(n_otp1(m,N,k,α)))
    return ps1(m,N,k,α,n1)

def ps2(m,N,k,α,n):
    return f(m,N,k,r)*α**n

def n_otp2(m,N,k,α):
    return round(k/(k*(N-1)/m-math.log(α)))

def ps2_opt(m,N,k,α):
    n1 = n_otp2(m,N,k,α)
    return ps2(m,N,k,α,n1)

def g(express, para):
    return express.evalf(subs=para)

def plot_f_by_x(
        func_list,x_list,para,x_symbol,normalization = False,
        label_list=None,xlabel=None,ylabel=None,
        title=None,save=False,savename='1.svg'
        ):
    y_list = []
    for func in func_list:
        y = []
        for i in x_list:
            para_with_x = para
            para_with_x[x_symbol] = i
            y.append(g(func, para_with_x))
        y = np.array(y)
        if normalization:
            y = (y-min(y))/(max(y)-min(y))    
        y_list.append(y)
    if label_list:
        for i in range(len(y_list)):
            plt.plot(x_list,y_list[i],label=label_list[i])
        plt.legend()
    else:
        for y in y_list:
            plt.plot(x_list,y)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    if save:
        plt.savefig(savename,format='svg')
    plt.show()
    return (x_list,y_list)

# #1
# m0 = 100
# N0 = 10
# k0 = 10
# para = {m:m0,N:N0,k:k0}
# plot_f_by_x([pe,dpedr,dpedr2],range(m0+1),para,r,
#             label_list=['pe','|dpedr|','|dpedr2|'],xlabel='r',ylabel='p',
#             title=str(para),normalization=True,save=True)

# #2
# k0 = 10
# para = {k:k0}
# plot_f_by_x([defence_point_rate,attack_point_rate],range(2,100),para,N,
#             label_list=['ω_defence','ω_attack'],xlabel='N',ylabel='ω',
#             title=str(para),save=True)
# N0 = 10
# para = {N:N0}
# plot_f_by_x([defence_point_rate,attack_point_rate],range(2,100),para,k,
#             label_list=['ω_defence','ω_attack'],xlabel='k',ylabel='ω',
#             title=str(para),save=True)

# plt.plot(x,y)    
# plt.xlabel('')
# plt.ylabel('')
# plt.title('')
# plt.show()

# x = range(1,101)
# y1 = [ps1(100,10,10,0.1,n)[0] for n in x]
# y2 = [ps1(100,10,10,0.3,n)[0] for n in x]
# y3 = [ps1(100,10,10,0.5,n)[0] for n in x]
# y4 = [ps1(100,10,10,0.7,n)[0] for n in x]
# y5 = [ps1(100,10,10,0.9,n)[0] for n in x]
# plt.plot(x,y1,label = 'α=0.1') 
# plt.plot(x,y2,label = 'α=0.3')   
# plt.plot(x,y3,label = 'α=0.5')
# plt.plot(x,y4,label = 'α=0.7')
# plt.plot(x,y5,label = 'α=0.9')
# plt.legend()
# plt.xlabel('n')
# plt.ylabel('p_success_I')
# plt.title('{m:100,N:10,k:10}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(1,101)
# y1 = [ps11(100,10,10,0.1,n,5) for n in x]
# y2 = [ps11(100,10,10,0.3,n,5) for n in x]
# y3 = [ps11(100,10,10,0.5,n,5) for n in x]
# y4 = [ps11(100,10,10,0.7,n,5) for n in x]
# y5 = [ps11(100,10,10,0.9,n,5) for n in x]
# plt.plot(x,y1,label = 'α=0.1') 
# plt.plot(x,y2,label = 'α=0.3')   
# plt.plot(x,y3,label = 'α=0.5')
# plt.plot(x,y4,label = 'α=0.7')
# plt.plot(x,y5,label = 'α=0.9')
# plt.legend()
# plt.xlabel('n')
# plt.ylabel('p_success_I_appro')
# plt.title('{m:100,N:10,k:10}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(1,101)
# y_list = []
# for i in range(10):
#     y_list.append([ps11(100,10,10,0.9,n,i) for n in x])
# for i in range(10):
#     plt.plot(x,y_list[i],label = 'dr='+str(i)) 
# plt.legend()
# plt.xlabel('n')
# plt.ylabel('p_success_I_appro')
# plt.title('{m:100,N:10,k:10,α:0.9}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = np.linspace(0.001,1,1000)
# y = [ps1_opt(100,10,10,α) for α in x]
# plt.plot(x,y)    
# plt.xlabel('α')
# plt.ylabel('p_success_I_otp')
# plt.title('{m:100,N:10,k:10}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(1,101)
# y = [ps1_opt(100,N,10,0.2) for N in x]
# plt.plot(x,y)    
# plt.xlabel('N')
# plt.ylabel('p_success_I_otp')
# plt.title('{m:100,k:10,α:0.2}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(10,201)
# y = [ps1_opt(m,10,10,0.2) for m in x]
# plt.plot(x,y)    
# plt.xlabel('m')
# plt.ylabel('p_success_I_otp')
# plt.title('{N:10,k:10,α:0.2}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(1,101)
# y = [ps1_opt(100,10,k,0.2) for k in x]
# plt.plot(x,y)    
# plt.xlabel('k')
# plt.ylabel('p_success_I_otp')
# plt.title('{m:100,N:10,α:0.2}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(1,101)
# y = [ps1(100,N,10,0.2,100)[0] for N in x]
# plt.plot(x,y)    
# plt.xlabel('N')
# plt.ylabel('p_success_I_otp')
# plt.title('{m:100,k:10,α:0.2}')
# plt.show()

# x = range(1,101)
# y = [ps2_opt(100,N,10,0.1) for N in x]
# plt.plot(x,y)    
# plt.xlabel('N')
# plt.ylabel('p_success_II_otp')
# plt.title('{m:100,k:10,α:0.1}')
# plt.show()

# x = range(1,101)
# y = [ps2_opt(100,10,k,0.1) for k in x]
# plt.plot(x,y)    
# plt.xlabel('k')
# plt.ylabel('p_success_II_otp')
# plt.title('{m:100,N:10,α:0.1}')
# plt.show()

# x = np.linspace(0.001,0.999,1000)
# y = [ps2_opt(100,10,10,α) for α in x]
# plt.plot(x,y)    
# plt.xlabel('α')
# plt.ylabel('p_success_II_otp')
# plt.title('{m:100,N:10,K:10}')
# plt.show()

# x = range(10,201)
# y = [ps2_opt(m,10,10,0.1) for m in x]
# plt.plot(x,y)    
# plt.xlabel('m')
# plt.ylabel('p_success_II_otp')
# plt.title('{N:10,K:10,α:0.1}')
# plt.show()

# x = range(1,101)
# y_list = []
# for i in range(10):
#     y_list.append([ps11(100,10,10,0.9,n,i) for n in x])
# for i in range(10):
#     plt.plot(x,y_list[i],label = 'dr='+str(i)) 
# plt.legend()
# plt.xlabel('n')
# plt.ylabel('p_success_I_appro')
# plt.title('{m:100,N:10,k:10,α:0.9}')
# plt.savefig('1.svg',format='svg')
# plt.show()

# x = range(1,101)
# for i in range(1,10):
#     y1 = [compare(100,10,10,0.1*i,n)[0] for n in x]
#     y2 = [compare(100,10,10,0.1*i,n)[1] for n in x]
#     plt.plot(x,y1,label = 'appro_p')
#     plt.plot(x,y2,label = 'real_p')
#     # plt.legend()
# plt.xlabel('n')
# plt.ylabel('p')
# plt.title('{m:100,N:10,k:10}')
# # plt.savefig('1.svg',format='svg')
# plt.show()
