

class MultiSetHash:
    def __init__(self):
        self.counts = {}
        self.hashes = []

    def add(self, obj):
        obj_str = str(obj)
        self.counts[obj_str] = self.counts.get(obj_str, 0) + 1
        new_hash = hashlib.sha256((obj_str + str(self.counts[obj_str])).encode()).hexdigest()
        self.hashes.append(new_hash)

    def digest(self):
        combined = ''.join(sorted(self.hashes))
        return hashlib.sha256(combined.encode()).hexdigest()

def __gen_prime__(rs):
    p = gy.mpz_urandomb(rs, 128)
    while not gy.is_prime(p):
        p += 1
    return p
def KeyGen1():
    rs = gy.random_state(int(time.time()))
    p = __gen_prime__(rs)
    q = __gen_prime__(rs)
    n = p * q
    g = n + 3
    return [ g, n]

def xor_encrypt(tips, key):
    ltips = len(tips)
    lkey = len(key)
    secret = []
    num = 0
    for each in tips:
        if num >= lkey:
            num = num % lkey
        tr = each ^ key[num]
        # secret.append(chr(ord(each) ^ ord(key[num])))
        secret.append(str(tr))
        num += 1
    return "".join(secret)
    # return b64.b64encode("".join(secret).encode()).decode()
def xor_decrypt(secret, key):
    # tips = b64.b64decode(secret.encode()).decode()
    tips = secret
    ltips = len(tips)
    lkey = len(key)
    secret = []
    num = 0
    for each in tips:
        if num >= lkey:
            num = num % lkey
        # secret.append(chr(ord(each) ^ ord(key[num])))
        tr = each ^ key[num]
        secret.append(str(tr))
        num += 1
    return "".join(secret)

def aes_encrypt(data, key):
    cipher = AES.new(key, AES.MODE_CBC, key) 
    block_size = AES.block_size 
    if len(data) % block_size != 0:
        add = block_size - (len(data) % block_size)
    else:
        add = 0
    data += b'\0' * add
    encrypted = cipher.encrypt(data)  
    result = binascii.b2a_hex(encrypted)  
    return result

class Paillier(object):
    rr =1
    def __init__(self, pubKey=None, priKey=None):
        self.pubKey = pubKey
        self.priKey = priKey

    def __gen_prime__(self, rs):
        p = gy.mpz_urandomb(rs, 128)
        while not gy.is_prime(p):
            p += 1
        return p

    def __L__(self, x, n):
        res = gy.div((x - 1), n)
        # this step is essential, directly using "/" causes bugs
        # due to the floating representation in python
        return res

    def __key_gen__(self):
        # generate random state
        while True:
            rs = gy.random_state(int(time.time()))
            p = self.__gen_prime__(rs)
            q = self.__gen_prime__(rs)
            n = p * q
            lmd = (p - 1) * (q - 1)
            # originally, lmd(lambda) is the least common multiple.
            # However, if using p,q of equivalent length, then lmd = (p-1)*(q-1)
            if gy.gcd(n, lmd) == 1:
                # This property is assured if both primes are of equal length
                break
        g = n + 1
        mu = gy.invert(lmd, n)
        # Originally,
        # g would be a random number smaller than n^2,
        # and mu = (L(g^lambda mod n^2))^(-1) mod n
        # Since q, p are of equivalent length, step can be simplified.
        rr = gy.mpz_random(gy.random_state(int(time.time())), n ** 2)
        while gy.gcd(n, rr) != 1:
            rr += 1

        self.pubKey = [n, g]
        self.priKey = [lmd, mu]
        return

    def decipher(self, ciphertext):
        n, g = self.pubKey
        lmd, mu = self.priKey
        m = self.__L__(gy.powmod(ciphertext, lmd, n ** 2), n) * mu % n
        print("raw message:", m)
        print("raw length:",len(m))
        plaintext = libnum.n2s(int(m))
        return plaintext

    def encipher(self, plaintext):
        m = libnum.s2n(plaintext)
        n, g = self.pubKey
        r = gy.mpz_random(gy.random_state(int(time.time())), n ** 2)
        while gy.gcd(n, r) != 1:
            r += 1
        ciphertext = gy.powmod(g, m, n ** 2) * gy.powmod(r, n, n ** 2) % (n ** 2)
        return ciphertext
    def encipher_N(self, plaintext):
        m= plaintext
        n, g = self.pubKey

        # r = gy.mpz_random(gy.random_state(int(time.time())), n ** 2)
        # while gy.gcd(n, r) != 1:
        #     r += 1
        ciphertext = gy.powmod(g, m, n ** 2) * gy.powmod(self.rr, n, n ** 2) % (n ** 2)

        return ciphertext

def KeyGen(keyLength):
    k1 = os.urandom(keyLength)
    k2 = os.urandom(keyLength)
    k3 = os.urandom(keyLength)
    return [k1,k2,k3]

def EDBSetup(file,k1,k3, pai,g,n):
    global Tb
    Tb={}
    global T
    T={}
    global B
    B={}
    global StateW
    StateW={}
    global Loc
    Loc = {}
    global  ttr
    ttr=1
    g_file = 1
    time_start = time.time()
    files = os.listdir('docs_old')
    files.sort()
    for doc in files:
        fs_pad = ''
        v = doc.split('.')[0]
        li = hashlib.sha256(str(v).encode("utf8")).hexdigest()
        fs = open('docs_old/'+doc, 'rb')
        fs_msg = fs.read()
        fs.close()
        x = len(fs_msg) % 16
        if x != 0:
            fs_pad = fs_msg + b'0' * (16 - x)
        else:
            fs_pad = fs_msg
        ci= aes_encrypt(fs_pad, "0123456789abcdef".encode("utf8"))
        hashi = hashlib.sha256(ci).hexdigest()

        T[li] = ci
        B[li] = hashi
        g_file= g_file+1
        # if g_file == 2500:
        #     break
    file = open("20_3137.txt")
    for line in file:
        wArray = line.split()
        keyword = wArray[0]
        # uwi= hashlib.sha256(keyword.encode("utf-8")).hexdigest()
        uwi = hmac.new(k3, keyword.encode("utf-8"), hashlib.sha256).digest()
        sti = get_random(128)
        tw = hashlib.sha256(uwi + str(sti).encode("utf-8")).hexdigest()
        tw_b = bin(int(tw,16))[2:]
        vb_t = hashlib.sha256(tw.encode("utf-8")).hexdigest()
        vb = xor_encrypt(wArray[1].encode("utf-8"),tw_b.encode("utf-8"))
        Tb[tw] = vb
        StateW[keyword] = sti
    # time_start = time.time()
    file1 = open("blockchain_phrase_location/loc_20_3137.txt")
    for line1 in file1:
        loc_array = line1.split()
        loc_keyword = loc_array[0]
        loc_id = loc_array[1]
        loc_pos = loc_array[2]
        loc_li = hashlib.sha256(loc_id.encode("utf8")).hexdigest()
        loc_sti = StateW[loc_keyword]
        loc_uwi = hmac.new(k1, loc_keyword.encode("utf-8"), hashlib.sha256).digest()
        loc_tw = hashlib.sha256(loc_uwi + str(loc_sti).encode("utf-8")).hexdigest()
        # gy.powmod(g,int(loc_id),n**2)
        Loc[loc_li+loc_tw] = pai.encipher_N(int(loc_id))
        if ((loc_id == '1') and (loc_keyword == 'all')):
            ttr = loc_li+loc_tw
    time_end = time.time()
    timeTouse = time_end - time_start
    print("EDBSetup Time:", timeTouse)

def TokenGen(k3,pai):
    time_start = time.time()
    global StateW
    lToken = []
    # ori_keywords = {'all', 'show', 'text', 'trade', 'affect', 'go', 'data', 'note', 'folder', 'fix'}
    queryKeyword = {'all', 'show'}
    d =0
    for words in queryKeyword:
        st = StateW[words]
        uwi = hmac.new(k3, words.encode("utf-8"), hashlib.sha256).digest()
        lwi = hashlib.sha256(uwi + str(st).encode("utf-8")).hexdigest()
        if d>0:
           ed = pai.encipher_N(d)
           lToken.append(lwi+" "+ str(ed))
        else:
           lToken.append(lwi)
        d=d+1
    time_end = time.time()
    timeTouse = time_end - time_start
    print("Token Time:", timeTouse)
    return

def Search(k1,k3,pai):
    global StateW
    global  ttr
    lToken=[]
    Loc_ltw=[]
    d=0
    # ori_keywords = {'all', 'show', 'text', 'trade', 'affect', 'go', 'data', 'note', 'folder', 'fix'}
    queryKeyword={'all', 'show', 'text', 'trade', 'affect', 'go', 'data', 'note', 'folder', 'fix'}
    for words in queryKeyword:
        st = StateW[words]
        uwi = hmac.new(k3, words.encode("utf-8"), hashlib.sha256).digest()
        lwi = hashlib.sha256(uwi + str(st).encode("utf-8")).hexdigest()
        loc_uwi = hmac.new(k1, words.encode("utf-8"), hashlib.sha256).digest()
        loc_tw = hashlib.sha256(loc_uwi + str(st).encode("utf-8")).hexdigest()
        Loc_ltw.append(loc_tw)
        if d > 0:
            ed = pai.encipher_N(d)
            lToken.append(lwi + " " + str(ed))
        else:
            lToken.append(lwi)
        d = d + 1
    time_start = time.time()
    global Tb
    global B
    global T
    lbwi = []; led = []; lid=[]; lresult=[]
    vb_old = 1; cnt= 1;index = 1
    for tk in lToken:
        tk_array = tk.split()
        vb_xor = Tb[tk_array[0]]
        lbwi.append(tk_array[0])
        vb = xor_encrypt(vb_xor.encode("utf-8"), bin(int(tk_array[0], 16))[2:].encode("utf-8"))
        if cnt == 1:
            vb_old = int(vb, 2)
        else:
            vb_old = vb_old & int(vb, 2)
        if len(tk_array) > 1:
            led.append(tk_array[1])
        cnt+=1
    b = bin(vb_old)[2:]
    for bit in b:
        if bit =='1':
            lid.append(index)
            index += 1
    print(index)
    for id in lid:
        loc_li = hashlib.sha256(str(id).encode("utf8")).hexdigest()

        stand = 1; flag = 0
        for wordIndex in range(len(lbwi)):
            tw = Loc_ltw[wordIndex]
            if wordIndex ==0:
               Lij = Loc[ttr]
            if wordIndex ==0:
                stand = Lij
            else :
                ed = led[wordIndex-1]
                if Lij == stand * gy.mpz(ed):
                    flag += 1
        if flag == (len(lbwi)-1):
            lresult.append(id)
    g_file = 0
    acc = 0;acc1 =0
    for ci in T:
        hashi = hashlib.sha256(T[ci]).hexdigest()
        if g_file==0:
           acc1 = bin(int(hashi,16))[2:]
        else:
           acc = xor_encrypt(acc1.encode('utf-8'), bin(int(hashi,16))[2:].encode('utf-8'))
           acc1 = acc
        g_file = g_file + 1
        if g_file == 10:
            break
    time_end = time.time()
    timeTouse = time_end - time_start
    print("search Times(s):  ",timeTouse)
    return


def Search_verify(k1, k3, pai):
    global StateW
    global ttr
    lToken = []
    Loc_ltw = []
    d = 0
    # ori_keywords = {'all', 'show', 'text', 'trade', 'affect', 'go', 'data', 'note', 'folder', 'fix'}
    queryKeyword = {'all', 'show'}
    for words in queryKeyword:
        st = StateW[words]
        uwi = hmac.new(k3, words.encode("utf-8"), hashlib.sha256).digest()
        lwi = hashlib.sha256(uwi + str(st).encode("utf-8")).hexdigest()
        loc_uwi = hmac.new(k1, words.encode("utf-8"), hashlib.sha256).digest()
        loc_tw = hashlib.sha256(loc_uwi + str(st).encode("utf-8")).hexdigest()
        Loc_ltw.append(loc_tw)
        if d > 0:
            ed = pai.encipher_N(d)
            lToken.append(lwi + " " + str(ed))
        else:
            lToken.append(lwi)
        d = d + 1
    time_start = time.time()
    global Tb
    global B
    global T
    lbwi = [];
    led = [];
    lid = [];
    lresult = []
    vb_old = 1;
    cnt = 1;
    index = 1
    print(len(lToken))
    for tk in lToken:
        tk_array = tk.split()
        vb_xor = Tb[tk_array[0]]
        lbwi.append(tk_array[0])
        vb = xor_encrypt(vb_xor.encode("utf-8"), bin(int(tk_array[0], 16))[2:].encode("utf-8"))
        if cnt == 1:
            vb_old = int(vb, 2)
        else:
            vb_old = vb_old & int(vb, 2)
        if len(tk_array) > 1:
            led.append(tk_array[1])
        cnt += 1
    b = bin(vb_old)[2:]
    for bit in b:
        if bit == '1':
            lid.append(index)
            index += 1
    print(index)
    g_file = 0
    hasher1 = MultiSetHash()
    for id in lid:
        #loc_li = hashlib.sha256(str(id).encode("utf8")).hexdigest()
        hasher1.add(id)
        g_file = g_file +1
        if g_file == 500:
            digest1 = hasher1.digest()
            break
    time_end = time.time()
    timeTouse = time_end - time_start
    print("Verify Times(s):  ", timeTouse)
    return
def Verify(self, Res):
    count =1
    verify_n = 0
    for doc in Res:
        hashi = hashlib.sha256(doc).hexdigest()
        if count == 1:
            verify_n = hashi
        else:
            verify_n = self.xor_encrypt(verify_n.encode("utf-8"), hashi.encode("utf-8"))
        count +=1
    return verify_n

def Update(self, k1,k2):
    global Tb
    global T
    global B
    global StateW
    g_file = 5001
    time_start = time.time()
    files = os.listdir('docs')
    files.sort()
    li = hashlib.sha256(str(g_file).encode("utf8")).hexdigest()
    fs = open('docs/5000.', 'rb')
    fs_msg = fs.read()
    fs.close()
    x = len(fs_msg) % 16
    if x != 0:
        fs_pad = fs_msg + b'0' * (16 - x)
    else:
        fs_pad = fs_msg
    ci= self.aes_encrypt(fs_pad, "0123456789abcdef".encode("utf8"))
    hashi = hashlib.sha256(ci).hexdigest()
    T[li] = ci
    B[li] = hashi

    update_keywords = {'all','show','text','trade','affect'}

    lst_bw=[]
    lst_bw.append(bwi_all)
    lst_bw.append(bwi_show)
    lst_bw.append(bwi_text)
    lst_bw.append(bwi_trade)
    lst_bw.append(bwi_affect)

    # lst_bw.append(bwi_go)
    # lst_bw.append(bwi_data)
    # lst_bw.append(bwi_note)
    # lst_bw.append(bwi_folder)
    # lst_bw.append(bwi_fix)
    index = 0
    for words in update_keywords:
        keyword = words
        uwi= hashlib.sha256(keyword.encode("utf-8")).hexdigest()
        st0 = StateW.get(keyword)
        if st0 == None:
           st0 = get_random(128)
        sti = hashlib.sha256(str(st0).encode("utf-8")).hexdigest()
        tw = hashlib.sha256(uwi.encode("utf-8") + str(sti).encode("utf-8")).hexdigest()
        tw_b = bin(int(tw,16))[2:]
        vb_t = hashlib.sha256(tw.encode("utf-8")).hexdigest()
        vb = xor_encrypt(lst_bw[index].encode("utf-8"),tw_b.encode("utf-8"))
        Tb[tw] = vb
        StateW[keyword]=sti
        index+=1
    time_end = time.time()
    timeTouse = time_end - time_start
    print("Update Time:", timeTouse)
