import math

import torch.nn as nn
import torch.nn.functional as F

from .init_utils import weights_init

defaultcfg = {
    8: [64, 'M', 128, 'M', 256, 'M', 512, 'M', 512], 
    11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13: [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16: [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512
    ],
    19: [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512,
        512, 'M', 512, 512, 512, 512
    ],
}


# class VGG(nn.Module):

#     def __init__(self,
#                  dataset='cifar10',
#                  depth=19,
#                  init_weights=True,
#                  cfg=None,
#                  affine=True,
#                  batchnorm=True):
#         super(VGG, self).__init__()
#         if cfg is None:
#             cfg = defaultcfg[depth]
#         self._AFFINE = affine
#         self.feature = self.make_layers(cfg, batchnorm)
#         self.dataset = dataset
#         if dataset == 'cifar10' or dataset == 'cinic-10':
#             num_classes = 10
#         elif dataset == 'cifar100':
#             num_classes = 100
#         elif dataset == 'tiny_imagenet':
#             num_classes = 200
#         else:
#             raise NotImplementedError('Unsupported dataset ' + dataset)
#         self.classifier = nn.Linear(cfg[-1], num_classes)
#         if init_weights:
#             self.apply(weights_init)
#         # if pretrained:
#         #     model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))

#     def make_layers(self, cfg, batch_norm=False):
#         layers = []
#         in_channels = 3
#         for v in cfg:
#             if v == 'M':
#                 layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
#             else:
#                 conv2d = nn.Conv2d(
#                     in_channels, v, kernel_size=3, padding=1, bias=False)
#                 if batch_norm:
#                     layers += [
#                         conv2d,
#                         nn.BatchNorm2d(v, affine=self._AFFINE),
#                         nn.ReLU(inplace=True)
#                     ]
#                 else:
#                     layers += [conv2d, nn.ReLU(inplace=True)]
#                 in_channels = v
#         return nn.Sequential(*layers)

#     def forward(self, x):
#         x = self.feature(x)
#         if self.dataset == 'tiny_imagenet':
#             x = nn.AvgPool2d(4)(x)
#         else:
#             x = nn.AvgPool2d(2)(x)
#         x = x.view(x.size(0), -1)
#         y = self.classifier(x)
#         return F.log_softmax(y, dim=1)

#     def _initialize_weights(self):
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
#                 m.weight.data.normal_(0, math.sqrt(2. / n))
#                 if m.bias is not None:
#                     m.bias.data.zero_()
#             elif isinstance(m, nn.BatchNorm2d):
#                 if m.weight is not None:
#                     m.weight.data.fill_(1.0)
#                     m.bias.data.zero_()
#             elif isinstance(m, nn.Linear):
#                 m.weight.data.normal_(0, 0.01)
#                 m.bias.data.zero_()
# models/vgg.py

import torch.nn.functional as F  # 确保已导入

class VGG(nn.Module):
    def __init__(self,
                 dataset='cifar10',
                 depth=19,
                 num_classes=None,  # 添加这个参数
                 init_weights=True,
                 cfg=None,
                 affine=True,
                 batchnorm=True):
        super(VGG, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]
        self._AFFINE = affine
        self.dataset = dataset
        self.feature = self.make_layers(cfg, batchnorm)
        if num_classes is not None:
            self.num_classes = num_classes
        elif dataset == 'cifar10' or dataset == 'cinic-10':
            self.num_classes = 10
        elif dataset == 'cifar100':
            self.num_classes = 100
        elif dataset == 'tiny_imagenet':
            self.num_classes = 200
        elif dataset == 'mini_imagenet':
            self.num_classes = 100
        else:
            raise NotImplementedError('Unsupported dataset ' + dataset)
        self.classifier = nn.Linear(cfg[-1], self.num_classes)
        if init_weights:
            self.apply(weights_init)

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        self.layer_indices = []  # To mark layers after which to collect features
        for i, v in enumerate(cfg):
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                self.layer_indices.append(len(layers))  # mark after pooling
            else:
                conv2d = nn.Conv2d(
                    in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v, affine=self._AFFINE),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x, is_feat=False, preact=False):
        features = []
        for idx, layer in enumerate(self.feature):
            x = layer(x)
            if (idx + 1) in self.layer_indices:
                features.append(x)

        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        if is_feat:
            return features, y
        else:
            return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
                    m.weight.data.fill_(1.0)
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()




def vgg11(**kwargs):
    return VGG(depth=11, **kwargs)

def vgg13(**kwargs):
    return VGG(depth=13, **kwargs)

def vgg16(**kwargs):
    return VGG(depth=16, **kwargs)

def vgg19(**kwargs):
    return VGG(depth=19, **kwargs)

def vgg8(**kwargs):
    return VGG(depth=8, **kwargs)
