import os import sys import tqdm import json import glob import argparse import numpy as np import seaborn as sns import scikitplot as skplt import matplotlib.pyplot as plt from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay import torch from torch import optim, nn from torchvision import transforms from src.dataloader import CustomDataset from src.model import load_model def evaluate(model, args): preds_all = [] labels_all = [] outputs_all = [] for batch_idx, (inputs, labels) in enumerate(tqdm.tqdm(dataloaders['test'])): inputs = inputs.to(args.device) labels = labels.to(args.device) outputs = model(inputs) _, preds = torch.max(outputs, 1) outputs_all.extend(list(outputs.cpu().detach().numpy())) preds_all.extend(list(preds.cpu().numpy())) labels_all.extend(list(labels.data.cpu().numpy())) for i in range(len(preds_all)): preds_all[i] = 1 if preds_all[i] else 0 labels_all[i] = 1 if labels_all[i] else 0 cls_map_keys = list(args.class_map.keys()) minority_cls = cls_map_keys[0] if args.class_map[cls_map_keys[0]] == 0 else cls_map_keys[1] majority_cls = cls_map_keys[0] if args.class_map[cls_map_keys[0]] == 1 else cls_map_keys[1] print(classification_report(labels_all, preds_all, target_names=[minority_cls, majority_cls], digits=3)) # print('roc_auc_score', roc_auc_score(labels_all, preds_all)) out_dir = f"../data/results/{args.cls_json.split('/')[-2]}/figures" outputs_all = np.concatenate((np.array(outputs_all)[:,0].reshape(-1,1), np.max(np.array(outputs_all)[:, 1:], axis=1).reshape(-1,1)), axis=1) print('roc_auc_score', roc_auc_score(labels_all, outputs_all[:, 1])) # skplt.metrics.plot_roc_curve(labels_all, np.array(outputs_all)) # plt.savefig(f'{out_dir}/mcbc-roc.png') # # labels = [minority_cls, majority_cls] # # labels_all = [labels[index] for index in labels_all] # # preds_all = [labels[index] for index in preds_all] # cm = confusion_matrix(labels_all, preds_all,) # cmd = ConfusionMatrixDisplay(cm, display_labels=[minority_cls, majority_cls]) # cmd.plot() # plt.savefig(f'{out_dir}/mcbc-cm.png') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data-dir', type=str) parser.add_argument('--cls-json', type=str) parser.add_argument("--img-size", default=224, type=int) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0,1,2.. or cpu') parser.add_argument('--checkpoint-path', type=str, default=None, help='checkpoint path, if None then imagenet weights will be used') parser.add_argument('--mcbc', action='store_true', help='For majority clustering balanced classification') args = parser.parse_args() with open(args.cls_json, 'r') as f: args.class_map = json.load(f) print('Class map:', args.class_map) # Dataloaders args.val_transforms = transforms.Compose([ transforms.Resize((args.img_size+32)), transforms.CenterCrop(args.img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) test_set = CustomDataset(args, train_loader=False, test_split_eval=True) print('Number of test samples:', len(test_set)) dataloaders = { 'test': torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4) } # Model loading args.device = torch.device(f"cuda:{args.device}" if args.device.isnumeric() else "cpu") print('device -', args.device) ckpt_pth = glob.glob(f"{args.checkpoint_path}/best_minority*")[0] print('Checkpoint path:', ckpt_pth) model = torch.load(ckpt_pth) model.eval() evaluate(model, args)