import torch
import torch.nn as nn
import math
from torch.autograd import Variable

# Initialization function
def init_xavier(module):
    if type(module) in (nn.Linear, nn.Conv2d, nn.Conv1d, nn.Conv3d):
        nn.init.xavier_uniform_(module.weight, gain=1)

        if module.bias is not None:
            nn.init.zeros_(module.bias)

def init_relu(module):
    if type(module) in (nn.Linear, nn.Conv2d, nn.Conv1d, nn.Conv3d):
        nn.init.orthogonal_(module.weight, gain=1.41421)

        if module.bias is not None:
            nn.init.zeros_(module.bias)

def init_proj2d(module):
    if type(module) in (nn.Linear, nn.Conv2d, nn.Conv1d, nn.Conv3d):
        torch.nn.init.dirac_(module.weight, groups=1)
        
        if module.bias is not None:
            nn.init.zeros_(module.bias)

# impala_resnet
class DQN_Conv(nn.Module):
    def __init__(self, in_hiddens, hiddens, ks, stride, padding=0, max_pool=False, norm=True, init=init_relu, act=nn.SiLU()):
        super().__init__()
        
        self.conv = nn.Sequential(#nn.Conv2d(in_hiddens, hiddens, ks, stride, padding, padding_mode='replicate'),
                                  nn.Conv2d(in_hiddens, hiddens, ks, stride, padding),
                                  nn.MaxPool2d(3,2,padding=1) if max_pool else nn.Identity(),
                                  (nn.GroupNorm(32, hiddens, eps=1e-6) if hiddens%32==0 else nn.BatchNorm2d(hiddens, eps=1e-6)) if norm else nn.Identity(),
                                  act,
                                  )
        self.conv.apply(init)
        
    def forward(self, X):
        return self.conv(X)

class Residual_Block(nn.Module):
    def __init__(self, in_channels, channels, stride=1, act=nn.SiLU(), out_act=nn.SiLU(), norm=True, init=init_relu):
        super().__init__()
        
        

        conv1 = nn.Sequential(nn.Conv2d(in_channels, channels, kernel_size=3, padding=1,
                                            stride=stride),
                              (nn.GroupNorm(32, channels, eps=1e-6) if channels%32==0 else nn.BatchNorm2d(channels, eps=1e-6)) if norm else nn.Identity(),
                              act)
        conv2 = nn.Sequential(nn.Conv2d(channels, channels, kernel_size=3, padding=1),
                              (nn.GroupNorm(32, channels, eps=1e-6) if channels%32==0 else nn.BatchNorm2d(channels, eps=1e-6)) if norm else nn.Identity(),
                              out_act)

        conv1.apply(init)
        conv2.apply(init if out_act!=nn.Identity() else init_xavier)
        
        self.conv = nn.Sequential(conv1, conv2)
        
        self.proj=nn.Identity()
        if stride>1 or in_channels!=channels:
            self.proj = nn.Conv2d(in_channels, channels, kernel_size=1,
                        stride=stride)
        
        self.proj.apply(init_proj2d)
        
    def forward(self, X):
        Y = self.conv(X)
        Y = Y+self.proj(X)
        return Y

class IMPALA_Resnet(nn.Module):
    def __init__(self, first_channels=3, scale_width=1, norm=True, init=init_relu, act=nn.SiLU()): # first_channels changes according to the input channel
        super().__init__()
        self.norm=norm
        self.init=init
        self.act =act
        
        self.cnn = nn.Sequential(self.get_block(first_channels, 16*scale_width),
                                 self.get_block(16*scale_width, 32*scale_width),
                                 self.get_block(32*scale_width, 32*scale_width, last_relu=True))
        # params_count(self, 'IMPALA ResNet')
    def get_block(self, in_hiddens, out_hiddens, last_relu=False):
        
        blocks = nn.Sequential(DQN_Conv(in_hiddens, out_hiddens, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init, out_act=self.act if last_relu else nn.Identity())
                              )
        
        return blocks
        
    def forward(self, X):
        return self.cnn(X)