import os
from loguru import logger
import torch
from tqdm import trange, tqdm
import numpy as np
import pickle
from utils.utils import write_pickle, load_pickle
from utils.utils import load_lines, write_lines
from processors.trie_tree import Trie
from processors.dataset import NERDataset
import json
from processors.vocab import Vocabulary
from os.path import join
import argparse
from transformers import BertTokenizer


def load_word_embedding(word_embed_path, max_scan_num):

        logger.info('loading word embedding from pretrain')
        #word_embed_dict = dict()
        word_list = list()

        with open(word_embed_path, 'r', encoding='utf8') as f:
            for idx, line in tqdm(enumerate(f)):
                # 只扫描前max_scan_num个词向量
                if idx > max_scan_num:
                    break
                items = line.strip().split()

                if idx == 0:
                    assert len(items) == 2
                    num_embed, word_embed_dim = items
                    num_embed, word_embed_dim = int(num_embed), int(word_embed_dim)
                else:
                    assert len(items) == word_embed_dim + 1
                    word = items[0]
                    word_list.append(word)
        logger.info('word_embed_dim:{}'.format(word_embed_dim))
        logger.info('size of word_list:{}'.format(len(word_list)))

        return word_list, word_embed_dim

def build_trie_tree(word_list, save_path):
        logger.info('building trie tree')
        trie_tree = Trie()
        for word in word_list:
            trie_tree.insert(word)
        write_pickle(trie_tree, save_path)
        return trie_tree

def get_char2words(trie_tree, text):

        text_len = len(text)
        char_index2words = [[] for _ in range(text_len)]

        for idx in range(text_len):
            sub_sent = text[idx:idx + trie_tree.max_depth]  # speed using max depth
            words = trie_tree.enumerateMatch(sub_sent)  # 找到以text[idx]开头的所有单词

            for word in words:
                start_pos = idx
                end_pos = idx + len(word)
                for i in range(start_pos, end_pos):
                    char_index2words[i].append(word)
        # todo 截断
        # for i, words in enumerate(char_index2words):
        #     char_index2words[i] = char_index2words[i][:self.max_word_num]

        return char_index2words

if __name__ == '__main__':
    tokenizer = BertTokenizer.from_pretrained('../pretrain_model/bert-base-chinese')

    word_embed_path = '../datasets/tencent-ailab-embedding-zh-d200-v0.2.0-s/tencent-ailab-embedding-zh-d200-v0.2.0-s.txt'
    max_scan_num = 1000000
    data_path= '../datasets/weibo'
    save_path = join(data_path, 'trie_tree.pkl')

    word_list, _ = load_word_embedding(word_embed_path, max_scan_num)
    trie_tree = build_trie_tree(word_list, save_path)
    while True:
        text=tokenizer.tokenize(input())
        words=get_char2words(trie_tree,text)
        print(words)