from simpletransformers.classification import ClassificationModel
from sklearn.metrics import precision_recall_fscore_support
import pandas as pd
import numpy as np
np.random.seed(555)
torch.manual_seed(555)

train=pd.read_csv("../yida_mu/tr1.csv")
dev=pd.read_csv("../yida_mu/de1.csv")
test=pd.read_csv("../yida_mu/te1.csv")
train=pd.DataFrame({'label':train.labels, 'text':train.text})
dev=pd.DataFrame({'label':dev.labels, 'text':dev.text})
test=pd.DataFrame({'label':test.labels, 'text':test.text})

#using 'sliding_window = True' to split the input text .The model output for each sub-sequence is averaged into a single output before being sent to the linear classifier.

train_args={
    'sliding_window': True,
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'evaluate_during_training': True,
    'save_model_every_epoch': True,
    'train_batch_size': 8,
    'eval_batch_size': 8,
    'best_model_dir': 'outputs/best_model/',
    'max_seq_length': 512,
    'learning_rate': 2e-5,
    'num_train_epochs': 4,
}

model = ClassificationModel('bert', 'bert-base-uncased', args=train_args)
model.train_model(train, eval_df=dev)
#load_best_model 
model = ClassificationModel('bert', 'outputs/best_model/', args=train_args)
predictions, raw_outputs = model.predict(df_te.text.values)
print(precision_recall_fscore_support(df_te.labels.values, predictions, average='macro'))
