import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, GlobalAveragePooling2D, Reshape, Dense, Multiply
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# self-attention mechanism layer
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(SelfAttention, self).__init__()
        self.filters = filters

    def build(self, input_shape):
        self.theta = Conv2D(self.filters, kernel_size=1, strides=(1, 1), padding='same', use_bias=False)
        self.phi = Conv2D(self.filters, kernel_size=1, strides=(1, 1), padding='same', use_bias=False)
        self.g = Conv2D(self.filters, kernel_size=1, strides=(1, 1), padding='same', use_bias=False)
        self.o = Conv2D(input_shape[-1], kernel_size=1, strides=(1, 1), padding='same', use_bias=False)
        super(SelfAttention, self).build(input_shape)

    def call(self, x):
        theta = self.theta(x)
        phi = self.phi(x)
        g = self.g(x)

        theta = tf.reshape(theta, [tf.shape(theta)[0], -1, self.filters])
        phi = tf.reshape(phi, [tf.shape(phi)[0], -1, self.filters])
        g = tf.reshape(g, [tf.shape(g)[0], -1, self.filters])

        beta = tf.nn.softmax(tf.matmul(theta, phi, transpose_b=True))
        o = tf.matmul(beta, g)
        o = tf.reshape(o, tf.shape(x))

        return self.o(o)

# EfficientNet B7 backbone
def EfficientNetB7(input_shape=(224, 224, 3)):
    base_model = tf.keras.applications.EfficientNetB7(input_shape=input_shape, include_top=False, weights='imagenet')
    return base_model

# Build the model
def build_model(input_shape=(224, 224, 3)):
    backbone = EfficientNetB7(input_shape=input_shape)
    x = backbone.output
    
    # Add self-attention mechanism
    x_att = SelfAttention(filters=int(x.shape[-1]))(x)
    x = Multiply()([x, x_att])

    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    output = Dense(NUM_CLASSES, activation='softmax')(x)

    model = Model(inputs=backbone.input, outputs=output)

    return model

model = build_model(input_shape=(256, 256, 3))  # Adjust input_shape according to your MRI image dimensions
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])


model.fit(train_dataset, epochs=NUM_EPOCHS, validation_data=val_dataset)
