from tensorflow.keras import layers
import tensorflow.keras.backend as K

def squeeze_excite_block(input_tensor, ratio=16):
    # Calculate the number of output channels
    filters = K.int_shape(input_tensor)[-1]

    # Squeeze operation (Global Average Pooling)
    se = layers.GlobalAveragePooling2D()(input_tensor)
    se = layers.Reshape((1, 1, filters))(se)
    se = layers.Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = layers.Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    # Excite operation
    x = layers.Multiply()([input_tensor, se])

    return x

def MobileNetV3BlockWithSE(input_tensor, filters, kernel_size, expand_ratio, stride, se_ratio):
    # Depthwise separable convolution (expansion, depthwise, pointwise)
    x = layers.Conv2D(filters * expand_ratio, 1, padding='same', use_bias=False)(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.DepthwiseConv2D(kernel_size, strides=stride, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # Squeeze and Excite block
    x = squeeze_excite_block(x, ratio=se_ratio)

    # Pointwise convolution
    x = layers.Conv2D(filters, 1, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    # Skip connection and residual connection
    if stride == 1 and input_tensor.shape[-1] == filters:
        x = layers.Add()([input_tensor, x])

    return x


def create_mobilenetv3_se(input_shape=(224, 224, 3), num_classes=1000):
    input_tensor = layers.Input(shape=input_shape)
    
    # Initial Convolution layer
    x = layers.Conv2D(16, 3, strides=(2, 2), padding='same', use_bias=False)(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # MobileNetV3 blocks with SE
    x = MobileNetV3BlockWithSE(x, 16, 3, 16, 1, 4)
    x = MobileNetV3BlockWithSE(x, 24, 3, 72 / 16, 2, 4)
    x = MobileNetV3BlockWithSE(x, 24, 3, 88 / 24, 1, 4)

    # Global Average Pooling and Dense layer for classification
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=input_tensor, outputs=x)
    return model

model = create_mobilenetv3_se()
model.summary()
