import fastai
from fastai import *
from fastai.text import *
from fastai.callbacks import *
import pandas as pd
import numpy as np
from functools import partial
from sklearn import metrics
#random_seed{555,666,777}
np.random.seed(555)
torch.manual_seed(555)
#print('555')
###
train=pd.read_csv("../yida_mu/tr1.csv")
dev=pd.read_csv("../yida_mu/de1.csv")
test=pd.read_csv("../yida_mu/te1.csv")
train=pd.DataFrame({'label':train.labels, 'text':train.text})
dev=pd.DataFrame({'label':dev.labels, 'text':dev.text})
test=pd.DataFrame({'label':test.labels, 'text':test.text})
###
data_lm = TextLMDataBunch.from_df(train_df = train, valid_df = dev, path = "")
data_clas = TextClasDataBunch.from_df(path = "", train_df = train, valid_df = dev, vocab=data_lm.train_ds.vocab, bs=64)
##
learn = language_model_learner(data_lm, AWD_LSTM, drop_mult=0.5,pretrained=True)
#learn.lr_find()
#learn.recorder.plot(suggestion=True)
#min_grad_lr = learn.recorder.min_grad_lr
learn.fit_one_cycle(1, 1e-2)
learn.unfreeze()
#learn.lr_find()
#learn.recorder.plot(suggestion=True)
#min_grad_lr_1 = learn.recorder.min_grad_lr
learn.fit_one_cycle(10, 1e-3,callbacks=[SaveModelCallback(learn, name="best_lm")])
learn.load('best_lm')
learn.save_encoder('ft_encs')
###
learn1 = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)
learn1.load_encoder('ft_encs')
#learn1.lr_find()
#learn1.recorder.plot(suggestion=True)
#best_clf_lr = learn1.recorder.min_grad_lr
learn1.fit_one_cycle(1, 1e-2)
learn1.freeze_to(-1)
learn1.lr_find()
learn1.recorder.plot(suggestion=True)
best_clf_lr = learn1.recorder.min_grad_lr
learn1.fit_one_cycle(1, best_clf_lr)
learn1.freeze_to(-2)
learn1.lr_find()
learn1.recorder.plot(suggestion=True)
best_clf_lr = learn1.recorder.min_grad_lr
learn1.fit_one_cycle(1, best_clf_lr)
learn1.unfreeze()
learn1.lr_find()
learn1.recorder.plot(suggestion=True)
best_clf_lr = learn1.recorder.min_grad_lr
learn1.fit_one_cycle(5, best_clf_lr, callbacks=[SaveModelCallback(learn1, name="best_lmsss")])
learn1.load('best_lmsss')
predictions = []
for text in test.text.values:
    predictions.append(learn1.predict(text)[1].numpy())
print('p_r_f1', metrics.precision_recall_fscore_support(test.label.values, predictions, average='macro'))
