import os import sys import tqdm import json import argparse import numpy as np from bisect import bisect_right from sklearn.metrics import classification_report import torch torch.manual_seed(0) from torch import optim, nn from torchvision import transforms from src.dataloader import CustomDataset from src.model import load_model def adjust_lr(optimizer, epoch, args): cur_lr = 0. cur_lr = args.init_lr * 0.1 ** bisect_right(args.milestones, epoch) for param_group in optimizer.param_groups: param_group['lr'] = cur_lr return cur_lr def train_model(model, criterion, optimizer, args): best_acc, best_min_cls_f1 = 0, 0 for epoch in range(args.epochs): print('Epoch {}/{}'.format(epoch, args.epochs)) print('-' * 10) lr = adjust_lr(optimizer, epoch, args) for phase in ['train', 'validation']: count = 0 if phase == 'train': model.train() else: model.eval() preds_all = [] labels_all = [] running_loss = 0.0 for batch_idx, (inputs, labels) in enumerate(tqdm.tqdm(dataloaders[phase])): inputs = inputs.to(args.device) labels = labels.to(args.device) outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) if phase == 'train': optimizer.zero_grad() loss.backward() optimizer.step() preds_all.extend(list(preds.cpu().numpy())) labels_all.extend(list(labels.data.cpu().numpy())) running_loss += loss.item() * inputs.size(0) for i in range(len(preds_all)): preds_all[i] = 1 if preds_all[i]>0 else 0 labels_all[i] = 1 if labels_all[i]>0 else 0 clf_report = classification_report(labels_all, preds_all, target_names=[args.minority_cls, args.majority_cls], output_dict=True) epoch_acc, min_cls_f1 = clf_report['accuracy'], clf_report[args.minority_cls]['f1-score'] print('{} acc: {:.4f}, minority-f1: {:.4f}'.format(phase, epoch_acc, min_cls_f1)) if phase == 'validation': if min_cls_f1 > best_min_cls_f1: best_min_cls_f1 = min_cls_f1 os.system('rm -rf weights/best_minority_class_f1*') torch.save(model, f'weights/best_minority_class_f1-{min_cls_f1:.3f}_epoch-{epoch}.pt') if epoch_acc > best_acc: best_acc = epoch_acc os.system('rm -rf weights/best_acc*') torch.save(model, f'weights/best_acc-{epoch_acc:.3f}_epoch-{epoch}.pt') else: epoch_loss = running_loss / len(train_set) with open('loss.txt', 'a') as f: f.write(f"{epoch_loss}\n") print('Learning rate:', optimizer.param_groups[0]['lr']) return model if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data-dir', type=str, help='Path to folder containing train and val split') parser.add_argument("--img-size", default=224, type=int, help='Model input size') parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--epochs", type=int, default=50) parser.add_argument('--device', type=str, default='0', help='cuda device, i.e. 0,1,2.. or cpu') parser.add_argument('--optimizer', type=str, default='sgd', help='Optimizer type (adam/sgd)') parser.add_argument('--weight-decay', default=5e-4, type=float, help='weight decay') parser.add_argument('--init-lr', default=0.005, type=float, help='learning rate') parser.add_argument('--milestones', default=[12,24,36], type=list, help='milestones for lr-multistep') parser.add_argument('--checkpoint-path', type=str, default=None, help='checkpoint path for finetuning') parser.add_argument('--backbone', type=str, default='resnet18', help='Backbone architecture (resnet18/resnet34/resnet50)') parser.add_argument('--weighted-clf', action='store_true', help='For imbalanced dataset, uses weighted classification') parser.add_argument('--weighted-aug', action='store_true', help='For imbalanced dataset, uses weighted augmentation') parser.add_argument('--mcbc', action='store_true', help='For majority clustering balanced classification') parser.add_argument('--minority-cls', type=str) args = parser.parse_args() # Setting device args.device = torch.device(f"cuda:{args.device}" if args.device.isnumeric() else "cpu") # Dataloaders args.aug_transforms = transforms.Compose([ transforms.ColorJitter(brightness=(0.75, 1.25), contrast = 0, saturation = 0, hue = 0), transforms.RandomAffine(degrees=10, translate=None, scale=None, shear=None, fill=0, center=None), transforms.RandomHorizontalFlip() ]) args.train_transforms = transforms.Compose([ transforms.RandomResizedCrop(args.img_size, scale=(0.8, 1.0)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) args.val_transforms = transforms.Compose([ transforms.Resize((args.img_size+32)), transforms.CenterCrop(args.img_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) train_set = CustomDataset(args) args.class_map = train_set.class_map val_set = CustomDataset(args, train_loader=False) print('Number of training samples:',len(train_set)) print('Number of validation samples:',len(val_set)) print('Class map:', args.class_map) dataloaders = { 'train': torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4), 'validation': torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=4) } keys_list = list(args.class_map.keys()) if args.weighted_clf: print('\nClasses weights:') for i in range(len(sorted(keys_list))): print(f' -Class "{keys_list[i]}" with {(train_set.data[1, :] == keys_list[i]).sum()} samples has weightage = {train_set.class_weights[i]:.3f}') if args.weighted_aug: print('\nClass-wise augmentation probabilties:') for i in range(len(sorted(keys_list))): print(f' -Class "{keys_list[i]}" with {(train_set.data[1, :] == keys_list[i]).sum()} samples will be augmented with probabilty = {train_set.aug_probs[i]:.3f}') # Model loading print('\ndevice -', args.device) model = load_model(args) # Training if args.weighted_clf: criterion = nn.CrossEntropyLoss(weight=train_set.class_weights) else: criterion = nn.CrossEntropyLoss() if args.optimizer == 'adam': optimizer = optim.Adam(model.fc.parameters(), weight_decay=args.weight_decay, amsgrad=True) elif args.optimizer == 'sgd': trainable_list = nn.ModuleList([]) trainable_list.append(model) optimizer = optim.SGD(trainable_list.parameters(), lr=0.1, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) else: print('Optimizer not supported!') sys.exit(0) model_trained = train_model(model, criterion, optimizer, args)