import torch.optim
import torch.nn as nn
import modules.rrdb_denselayer
from modules.hinet import Hinet, Hinet_key
import config as c
import modules.rrdb_deno


class DERIS_key(nn.Module):
    def __init__(self, in_1=3, in_2=3):
        super(DERIS_key, self).__init__()
        self.inbs = Hinet_key(in_1=in_1, in_2=in_2)
        self.pre_enhance = modules.rrdb_deno.ResidualDenseBlock_out(3, 3)
        self.post_enhance = modules.rrdb_deno.ResidualDenseBlock_out(3, 3)

    def load_hinet(self, path):
        state_dicts = torch.load(path)
        # network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
        network_state_dict = {k: v for k, v in state_dicts['net'].items() }
        self.inbs.load_state_dict(network_state_dict)

    def forward(self, x, k,rev=False):

        if not rev:
            out = self.inbs(x, k)

        else:
            out = self.inbs(x, k,rev=True)

        return out


class DERIS(nn.Module):
    def __init__(self, in_1=3, in_2=3):
        super(DERIS, self).__init__()
        self.inbs = Hinet(in_1=in_1, in_2=in_2)
        self.pre_enhance = modules.rrdb_deno.ResidualDenseBlock_out(3, 3)
        self.post_enhance = modules.rrdb_deno.ResidualDenseBlock_out(3, 3)

    def load_hinet(self, path):
        state_dicts = torch.load(path)
        # network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
        network_state_dict = {k: v for k, v in state_dicts['net'].items() }
        self.inbs.load_state_dict(network_state_dict)

    def forward(self, x, rev=False):

        if not rev:
            out = self.inbs(x)

        else:
            out = self.inbs(x, rev=True)

        return out


def init_model(mod):
    for key, param in mod.named_parameters():
        split = key.split('.')
        if param.requires_grad and 'alpha' not in key:
            param.data = c.init_scale * torch.randn(param.data.shape).cuda()
            if split[-2] == 'conv5':
                param.data.fill_(0.)
