import time
import os
import torch

from utils.early_stopping import EarlyStopping
from transformers import AdamW, get_linear_schedule_with_warmup


def save_model(model, path):
    if not os.path.exists(path):
        os.makedirs(path)
    torch.save(model, f'{path}/best_model.pth')

def train(args, model, train_loader, val_loader, device):

    best_val_accuracy = 0.0

    early_stopping = EarlyStopping(patience=20, min_delta=0.01)
    
    total_steps = len(train_loader) * args.epochs

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr = args.lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = args.num_warmup_steps, num_training_steps = total_steps)

    for epoch in range(args.epochs):
        model.train()
        num = 0
        total_loss = 0
        
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            mask_position = batch.pop('mask_position')
            label = batch.pop('label')

            outputs = model(**batch)

            logits = outputs.logits
            mask_logits = logits[torch.arange(logits.size(0)), mask_position, :]

            loss = criterion(mask_logits, label.to(device))
        #     mask_position = batch['mask_positions']
        #     answer_logit = logits[:,mask_position,:].squeeze()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            if num % 20 == 0:
                        # print(f'loss: {loss}')
        print(f'Epoch {epoch}, Step {num}, Loss: {loss.item()}')
        # writer.add_scalar('loss', loss, num)
            num += 1

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch}, Average Loss: {avg_loss}')

        model.eval()


        total_eval_accuracy = 0
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            mask_position = batch.pop('mask_position')
            label = batch.pop('label')

            with torch.no_grad():
                outputs = model(**batch)
            
                logits = outputs.logits
                mask_logits = logits[torch.arange(logits.size(0)), mask_position, :]

                predictions = torch.argmax(mask_logits, dim=-1)      #32*64

                accuracy = (predictions == label).cpu().numpy().mean()
                total_eval_accuracy += accuracy

        avg_val_accuracy = total_eval_accuracy / len(val_loader)

        if avg_val_accuracy > best_val_accuracy:
            best_val_accuracy = avg_val_accuracy
            # save_model(model, '/workspace/Mymodel/checkpoint/SOTA' + time.strftime("%m-%d_%H.%M", time.localtime()))

        early_stopping(avg_val_accuracy)
        if early_stopping.early_stop:
            print("Early stopping")
            break

        # writer.add_scalar('accuracy', avg_val_accuracy, epoch)
        print(f'========= epoch loss: {loss}, Validation accuracy: {avg_val_accuracy:.4f} ==========')

    # writer.close()