

# Define the denoising module

import modules.module_util as mutil
import math
import copy
from functools import partial
from collections import OrderedDict
from typing import Optional, Callable

from torch import Tensor
from torch.nn import functional as F
import torch
import torch.nn as nn
from torch.nn import functional as F
from collections import OrderedDict

def _make_divisible(ch, divisor=8, min_ch=None):

    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch


def drop_path(x, drop_prob: float = 0., training: bool = False):

    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob       # Retention rate
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize  Binaryization
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

class ConvBNActivation(nn.Sequential):
    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        padding = (kernel_size - 1) // 2
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.SiLU

        super(ConvBNActivation, self).__init__(nn.Conv2d(in_channels=in_planes,
                                                         out_channels=out_planes,
                                                         kernel_size=kernel_size,
                                                         stride=stride,
                                                         padding=padding,
                                                         groups=groups,
                                                         bias=False),
                                               norm_layer(out_planes),
                                               activation_layer())


class SqueezeExcitation(nn.Module):
    def __init__(self,
                 input_c: int,   # block input channel
                 expand_c: int,  # block expand channel，DW convolution with unchanged channel dimension
                 squeeze_factor: int = 4):
        super(SqueezeExcitation, self).__init__()
        squeeze_c = input_c // squeeze_factor
        self.fc1 = nn.Conv2d(expand_c, squeeze_c, 1)    # 1x1 convolution replaces the fully connected layer and serves the same purpose
        self.ac1 = nn.SiLU()  # alias Swish
        self.fc2 = nn.Conv2d(squeeze_c, expand_c, 1)
        self.ac2 = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        # output_size=(1, 1)：Perform global average pooling on each channel
        scale = F.adaptive_avg_pool2d(x, output_size=(1, 1))
        scale = self.fc1(scale)
        scale = self.ac1(scale)
        scale = self.fc2(scale)
        scale = self.ac2(scale)     # Obtain the corresponding degree for each channel.
        return scale * x

class InvertedResidualConfig:
    # kernel_size, in_channel, out_channel, exp_ratio, strides, use_SE, drop_connect_rate
    def __init__(self,
                 kernel: int=3,
                 input_c: int=32,
                 out_c: int=32,
                 expanded_ratio: int=6,
                 stride: int=1,
                 use_se: bool=True,
                 drop_rate: float=0.2,
                 index: str='1a',
                 width_coefficient: float=1.0):
        self.input_c = self.adjust_channels(input_c, width_coefficient)
        self.kernel = kernel
        self.expanded_c = self.input_c * expanded_ratio
        self.out_c = self.adjust_channels(out_c, width_coefficient)
        self.use_se = use_se
        self.stride = stride
        self.drop_rate = drop_rate
        self.index = index

    @staticmethod
    def adjust_channels(channels: int, width_coefficient: float):
        return _make_divisible(channels * width_coefficient, 8)


# DBConv module
class InvertedResidual(nn.Module):
    def __init__(self,
                 cnf: InvertedResidualConfig,       # Incoming configuration file
                 norm_layer: Callable[..., nn.Module]): # BN
        super(InvertedResidual, self).__init__()

        if cnf.stride not in [1, 2]:
            raise ValueError("illegal stride value.")

        # satisfy cnf.stride == 1 and cnf.input_c == cnf.out_c These two conditions are connected by a short conjunction
        self.use_res_connect = (cnf.stride == 1 and cnf.input_c == cnf.out_c)

        # Define an ordered dictionary for building the DBConv structure.
        layers = OrderedDict()
        activation_layer = nn.SiLU  # alias Swish   SiLU Alternative name Swish

        # expand
        #   For the first DBConv, if the values are equal, skip the following statement.
        if cnf.expanded_c != cnf.input_c:
            layers.update({"expand_conv": ConvBNActivation(cnf.input_c,
                                                           cnf.expanded_c,
                                                           kernel_size=1,
                                                           norm_layer=norm_layer,
                                                           activation_layer=activation_layer)})

        # depthwise
        layers.update({"dwconv": ConvBNActivation(cnf.expanded_c,
                                                  cnf.expanded_c,
                                                  kernel_size=cnf.kernel,
                                                  stride=cnf.stride,
                                                  groups=cnf.expanded_c,    # DW Convolution！
                                                  norm_layer=norm_layer,
                                                  activation_layer=activation_layer)})

        if cnf.use_se:
            layers.update({"se": SqueezeExcitation(cnf.input_c,
                                                   cnf.expanded_c)})

        # project
        layers.update({"project_conv": ConvBNActivation(cnf.expanded_c,
                                                        cnf.out_c,
                                                        kernel_size=1,
                                                        norm_layer=norm_layer,
                                                        activation_layer=nn.Identity)})     # Linear activation layer, no processing is performed

        self.block = nn.Sequential(layers)  # Pass the ordered dictionary "layers" to the class "Sequential"
        self.out_channels = cnf.out_c
        self.is_strided = cnf.stride > 1    # cnf.stride=1，self.is_strided is False

        # Only when using shortcut connection and drop_rate is greater than 0, the dropout layer should be used
        if self.use_res_connect and cnf.drop_rate > 0:
            self.dropout = DropPath(cnf.drop_rate)      # Implement dropout in MBConv through the DropPath class
        else:
            self.dropout = nn.Identity()        # nn.Identity() indicates no processing is performed

    def forward(self, x: Tensor) -> Tensor:
        result = self.block(x)
        result = self.dropout(result)
        if self.use_res_connect:
            result += x

        return result




class CueBlock(nn.Module):
    def __init__(self, prompt_dim=3, prompt_len=5, prompt_size=96, lin_dim=3):
        super(CueBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(
            B, 1, 1, 1, 1, 1).squeeze(1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt

class ResidualDenseBlock_out(nn.Module):
    def __init__(self, input, output, bias=True, prompt_dim=3, prompt_len=5, prompt_size=96, lin_dim=3):
        super(ResidualDenseBlock_out, self).__init__()
        self.conv1 = nn.Conv2d(input, 32, 3, 1, 1, bias=bias)

        self.dbconv = InvertedResidual(
            cnf=InvertedResidualConfig(kernel=3, input_c=32, out_c=32, expanded_ratio=6, stride=1, use_se=True,
                                       drop_rate=0.2, index='1a', width_coefficient=1.0),
            norm_layer=nn.BatchNorm2d
        )

        self.conv2 = nn.Conv2d(input + 32, 32, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(input + 2 * 32, 32, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(input + 3 * 32, 32, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(input + 4 * 32, output, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(inplace=True)
        # initialization
        # Prompt Generation Block
        self.prompt_block = CueBlock(prompt_dim, prompt_len, prompt_size, lin_dim)
        mutil.initialize_weights([self.conv5], 0.)



    def forward(self, x):

        x1 = self.lrelu(self.conv1(x))

        x11 = self.dbconv(x1)

        x2 = self.lrelu(self.conv2(torch.cat((x, x11), 1)))

        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))

        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))

        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))

        # Add Cue M Block
        prompt = self.prompt_block(x5)
        out = x5 + prompt  # Add the original output to the generated prompt.
        return out





