from datasets import load_dataset
from transformers import DefaultDataCollator, TrainingArguments
import os
import numpy as np 
import json
from model import Ours
from mytrainer import MyTrainer
import fasttext
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize

os.environ["CUDA_VISIBLE_DEVICES"] = '5'
#os.environ["MASTER_ADDR"] = "localhost"
#os.environ["MASTER_PORT"] = "9994" 
#os.environ["RANK"] = "0"
#os.environ["LOCAL_RANK"] = "0"
#os.environ["WORLD_SIZE"] = "1"


with open('train_data.txt') as f:
    corpus = f.readlines()

vectorizer = TfidfVectorizer(analyzer=lambda x : x[:-1].split(' '), norm='l2')
X = vectorizer.fit_transform(corpus)

vocabs = vectorizer.get_feature_names_out().tolist()
idf = vectorizer.idf_.tolist()

idf_dict = {vocab:idf_value for vocab, idf_value in zip(vocabs, idf)}

ft = fasttext.load_model('word_vectors/cc.en.300.bin')

LEN = 128
def tokenize_function(examples):
    texts = examples['text']
    reps = examples['doc_topic']
    word_embeds = []
    labels = []
    for text, label in zip(texts, reps):
        vocabs = text.split(' ')

        word_embed = [ft.get_word_vector(vocab) *idf_dict[vocab]  for vocab in vocabs]
        word_embed = np.asarray(word_embed)
        word_embed = normalize(word_embed)

        l = len(word_embed)

        if l < LEN:
            pad_size = LEN - l
            word_embed = np.pad(word_embed, ((pad_size, 0), (0, 0)), 'constant')
        else:
            word_embed = word_embed[:LEN,:]
    
        word_embeds.append(word_embed)
        labels.append(np.asarray(label))

    return {
        'word_embeds': word_embeds,
        'labels':labels,
    }


mydataset = load_dataset('dataset.py')

tokenized_dataset = mydataset.map(tokenize_function,
                                 batched=True,
                                 num_proc=4,
                                 remove_columns=['text', 'doc_topic'])

data_collator = DefaultDataCollator(return_tensors='pt')

training_args = TrainingArguments(
    output_dir='/mnt/',
    #evaluation_strategy='epoch',
    save_strategy='epoch',
    learning_rate=1e-3,
    per_device_train_batch_size=128,
    #per_device_eval_batch_size=64,
    #local_rank=int(os.environ.get('LOCAL_RANK', -1)),
    #deepspeed="ds_config_zero3.json",
    remove_unused_columns=False,
    dataloader_num_workers=16,
    num_train_epochs=24,
    weight_decay=0.00,
    save_total_limit=1,
    fp16=False,
)

model = Ours(64, 5, 128, 100)

trainer = MyTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    #eval_dataset=tokenized_dataset['validation'],
    data_collator=data_collator,
)

trainer.train()#resume_from_checkpoint=True)
#trainer.evaluate()
