from skorch.callbacks import LRScheduler
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from skorch.callbacks import Checkpoint
from skorch import NeuralNetClassifier
from skorch.helper import predefined_split
from model import ConvTransNet
from dataset import get_data
from utils import evaluate

def main():

    # Load dataset
    
    train_ds, val_ds, test_ds = get_data()
    lrscheduler = LRScheduler(
        policy='StepLR', step_size=7, gamma=0.1)

    checkpoint = Checkpoint(
        f_params=f'best_model.pt', monitor='valid_acc_best')
    num_blocks = [2, 2, 3]            # L
    channels = [64, 96, 192]      # D
    net = NeuralNetClassifier(
        ConvTransNet((224, 224), 3, num_blocks, channels, num_classes=5), 
        criterion=nn.CrossEntropyLoss,
        lr=0.001,
        batch_size=32,
        max_epochs=20,
        optimizer=optim.SGD,
        optimizer__momentum=0.9,
        iterator_train__shuffle=True,
        iterator_train__num_workers=2,
        iterator_valid__num_workers=2,
        train_split=predefined_split(val_ds),
        callbacks=[lrscheduler, checkpoint],
        device='cuda:0' # comment to train on cpu
    )
    # net.initialize()
    # net.load_params(f_params='best_model.pt')
    net.fit(train_ds, y=torch.tensor(train_ds.targets))
    output = net.predict_proba(test_ds)
    np.save('probs_update.npy', output)
    prediction = np.argmax(output, axis=1)
    _, _, accuracy, precision, recall, f1, _, _ = evaluate(test_ds.targets, prediction)
    print("Test Accuracy: ", accuracy)
    print("Test Precision: ", precision)
    print("Test Recall: ", recall)
    print("Test F1: ", f1)


if __name__ == "__main__":
    main()
