import torch
import torch.nn as nn
import torch.optim as optim
from utils.options import args
import utils.common as utils

import os
import time
import math
from data import imagenet
from importlib import import_module
from model.resnet import Bottleneck_class

from utils.common import *

import thop
from thop import profile

device = torch.device(f"cuda:{args.gpus[0]}") if torch.cuda.is_available() else 'cpu'
checkpoint = utils.checkpoint(args)
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
logger = utils.get_logger(os.path.join(args.job_dir, 'logger'+now+'.log'))
if args.criterion == 'Softmax':
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
else:
    raise ValueError('invalid criterion : {:}'.format(args.criterion))

# Data
print('==> Preparing data..')
data_tmp = imagenet.Data(args)
train_loader = data_tmp.trainLoader
val_loader = data_tmp.testLoader

# Architecture
if args.arch == 'resnet':
    origin_model = import_module(f'model.{args.arch}').resnet(args.cfg).to(device)
else:
    raise('arch not exist!')

# Calculate FLOPs of origin model
input = torch.randn(1, 3, 224, 224).to(device)
oriflops, oriparams = profile(origin_model, inputs=(input, ))


def get_layerwise(model):
    layerwise = []
    with torch.no_grad():
        for n,m in model.named_modules():
            if hasattr(m, 'mask'):
                mask = m.fixed_mask
                preserve = torch.sum(mask) / mask.numel()
                layerwise.append(float(preserve.cpu().numpy()))
    logger.info(layerwise)
    return layerwise

def get_prune_rate(model):
    layerwise = get_layerwise(model)
    input = torch.randn(1, 3, 224, 224).cuda()
    model_t =  import_module(f'model.{args.arch}').resnet(args.cfg).to(device)

    flops, params = profile(model_t, inputs=(input, ))
    model_t =  import_module(f'model.{args.arch}').resnet(args.cfg, layer_cfg = layerwise).to(device)

    flops_prune, params_prune = profile(model_t, inputs=(input, ))
    #import pdb; pdb.set_trace()
    return 1 - flops_prune/flops, layerwise

def freeze_mask(model):
    for n,m in model.named_modules():
        if hasattr(m, 'mask'):
            m.mask.requires_grad_(False)



# Train function
def train(epoch, train_loader, model, criterion, optimizer):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    data_time = utils.AverageMeter('Data', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    model.train()
    end = time.time()

    num_iter = len(train_loader)

    print_freq = num_iter // 10
    i = 0 
    for batch_idx, (images, targets) in enumerate(train_loader):
        if args.debug:
            if i > 5:
                break
            i += 1
        images = images.cuda()
        targets = targets.cuda()
        data_time.update(time.time() - end)

        adjust_learning_rate(optimizer, epoch, batch_idx, num_iter)

        # compute output
        logits = model(images)
        loss = criterion(logits, targets)

        # measure accuracy and record loss
        prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
        n = images.size(0)
        losses.update(loss.item(), n)  # accumulated loss
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % print_freq == 0 and batch_idx != 0:
            logger.info(
                'Epoch[{0}]({1}/{2}): '
                'Loss {loss.avg:.4f} '
                'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                    epoch, batch_idx, num_iter, loss=losses,
                    top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg

# Validate function
def validate(val_loader, model, criterion, args):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    num_iter = len(val_loader)

    model.eval()
    with torch.no_grad():
        end = time.time()
        i = 0
        for batch_idx, (images, targets) in enumerate(val_loader):
            if args.debug:
                if i > 5:
                    break
                i += 1
            images = images.cuda()
            targets = targets.cuda()

            # compute output
            logits = model(images)
            loss = criterion(logits, targets)

            # measure accuracy and record loss
            pred1, pred5 = utils.accuracy(logits, targets, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)
            top1.update(pred1[0], n)
            top5.update(pred5[0], n)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                    .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg

def adjust_learning_rate(optimizer, epoch, step, len_epoch):
    #Warmup
    if args.lr_type == 'step':
        factor = epoch // 30
        if epoch >= 80:
            factor = factor + 1
        lr = args.lr * (0.1 ** factor)
    elif args.lr_type == 'cos':  # cos without warm-up
        lr = 0.5 * args.lr * (1 + math.cos(math.pi * (epoch - 5) / (args.num_epochs - 5)))
    elif args.lr_type == 'exp':
        step = 1
        decay = 0.96
        lr = args.lr * (decay ** (epoch // step))
    elif args.lr_type == 'fixed':
        lr = args.lr
    else:
        raise NotImplementedError
    if step == 0:
        print('current learning rate:{0}'.format(lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def main():
    start_epoch = 0
    best_top1_acc = 0.0
    best_top5_acc = 0.0
    
    print('==> Building Model..')
    if args.resume == False:
        if args.arch == 'resnet':
            model = import_module(f'model.{args.arch}').resnet_class(args.cfg).to(device)
        else:
            raise('arch not exist!')
        print("Model Construction Down!")

        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        

    if len(args.gpus) != 1:
        model = nn.DataParallel(model, device_ids=args.gpus)

    # fine-tuning
    for epoch in range(start_epoch, args.num_epochs):
        train_obj, train_top1_acc,  train_top5_acc = train(epoch, train_loader, model, criterion, optimizer)
        prune_rate,layer_cfg = get_prune_rate(model)
        if prune_rate > args.pruning_rate:
            freeze_mask(model)
        valid_obj, test_top1_acc, test_top5_acc = validate(val_loader, model, criterion, args)

        is_best = best_top5_acc < test_top5_acc
        best_top1_acc = max(best_top1_acc, test_top1_acc)
        best_top5_acc = max(best_top5_acc, test_top5_acc)

        model_state_dict = model.module.state_dict() if len(args.gpus) > 1 else model.state_dict()

        state = {
            'state_dict': model_state_dict,
            'best_top1_acc': best_top1_acc,
            'best_top5_acc': best_top5_acc,
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'arch': args.cfg,
            'cfg': layer_cfg
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    logger.info('Best Top-1 accuracy: {:.3f} Top-5 accuracy: {:.3f}'.format(float(best_top1_acc), float(best_top5_acc)))

if __name__ == '__main__':
    main()
