import torch
import torch.nn as nn
import config as c
from modules.rrdb_denselayer import ResidualDenseBlock_out

# have key
class INV_block_k(nn.Module):
    def __init__(self, subnet_constructor=ResidualDenseBlock_out, clamp=c.clamp, in_1=3, in_2=3):
        super().__init__()

        self.split_len1 = in_1 * 4
        self.split_len2 = in_2 * 4

        self.clamp = clamp
        # ρ
        self.r = subnet_constructor(self.split_len1, self.split_len2)
        # η
        self.y = subnet_constructor(self.split_len1, self.split_len2)
        # φ
        self.f = subnet_constructor(self.split_len2, self.split_len1)


    def e(self, s):
        return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))

    def shuffle(self, x, k, rev=False):
        # [bs, c, w, h]
        bs, c, w, h = x.shape
        if rev:
            k = k.argsort()
        x = x.reshape([bs, c, 4, w // 4, h])
        x = torch.permute(x, [0, 1, 2, 4, 3])
        x = torch.reshape(x, [bs, c, 16, h // 4, w // 4])
        x = x[:, :, k]
        x = x.reshape([bs, c, 4, h, w // 4])
        x = torch.permute(x, [0, 1, 2, 4, 3])
        x = x.reshape([bs, c, w, h])
        return x

    def forward(self, x, k , rev=False):
        x1, x2 = (x.narrow(1, 0, self.split_len1),
                  x.narrow(1, self.split_len1, self.split_len2))

        k1 = k[0]
        k2 = k[1]
        if not rev:

            t2 =  self.f(x2)
            y1 = x1 + t2

            x2 = x2 * k1
            x2 = self.shuffle(x2, k2, rev)

            s1, t1 = self.r(y1), self.y(y1)
            y2 = self.e(s1) * x2 + t1


        else:

            s1, t1 = self.r(x1), self.y(x1)
            y2 = (x2 - t1) / self.e(s1)

            y2 = self.shuffle(y2, k2, rev)
            y2 = y2 / k1

            t2 = self.f(y2)
            y1 = (x1 - t2)


        return torch.cat((y1, y2), 1)

# no key
class INV_block(nn.Module):
    def __init__(self, subnet_constructor=ResidualDenseBlock_out, clamp=c.clamp, in_1=3, in_2=3):
        super().__init__()

        self.split_len1 = in_1 * 4
        self.split_len2 = in_2 * 4

        self.clamp = clamp
        # ρ
        self.r = subnet_constructor(self.split_len1, self.split_len2)
        # η
        self.y = subnet_constructor(self.split_len1, self.split_len2)
        # φ
        self.f = subnet_constructor(self.split_len2, self.split_len1)

    def e(self, s):
        return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))

    def forward(self, x, rev=False):
        x1, x2 = (x.narrow(1, 0, self.split_len1),
                  x.narrow(1, self.split_len1, self.split_len2))

        if not rev:

            t2 = self.f(x2)
            y1 = x1 + t2
            s1, t1 = self.r(y1), self.y(y1)
            y2 = self.e(s1) * x2 + t1

        else:

            s1, t1 = self.r(x1), self.y(x1)
            y2 = (x2 - t1) / self.e(s1)
            t2 = self.f(y2)
            y1 = (x1 - t2)

        return torch.cat((y1, y2), 1)

