import tensorflow as tf
import os
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Dropout

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns



# Define the image size for resizing and other parameters
img_size = (100, 100)
num_classes = 1
batch_size = 32
num_epochs = 50


# Define the path which is the directory where Artificial Data is saved
base_dir = r'C:\Users\MSI\Documents\Research\ArtificialData'

train_dir = os.path.join(base_dir, 'ArtificialTrainingSetGray')
val_dir = os.path.join(base_dir, 'ArtificialValidationSetGray')
test_dir = os.path.join(base_dir, 'ArtificialTestSetGray') 

# Create ImageDataGenerators for training, validation, and testing data
train_data_gen = ImageDataGenerator(rescale=1.0/255.0)
val_data_gen = ImageDataGenerator(rescale=1.0/255.0)
test_data_gen = ImageDataGenerator(rescale=1.0/255.0)

# Load and prepare the training data
train_data = train_data_gen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
    class_mode='binary'
)

# Load and prepare the validation data
val_data = val_data_gen.flow_from_directory(
    val_dir,
    target_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
    class_mode='binary'
)

# Load and prepare the test data
test_data = test_data_gen.flow_from_directory(
    test_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    color_mode='grayscale',
    shuffle=False  
)


model = Sequential()

model.add(Conv2D(8, (3, 3),padding="same", activation='relu', input_shape=(img_size[0], img_size[1], 1)))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(16, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(32, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(64, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(128, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Flatten())

model.add(Dense(128, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))

model.summary()


# Compile the model
learning_rate = 0.0001
model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate), metrics=['accuracy'])

# Define the ModelCheckpoint callback to save the best model
checkpoint_callback = ModelCheckpoint(
    filepath='best_model.h5',
    monitor='val_accuracy',
    save_best_only=True,
    save_weights_only=False,
    mode='max',
    verbose=1
)

# Train the model with the ModelCheckpoint callback
history = model.fit(
    train_data,
    steps_per_epoch=len(train_data),
    epochs=num_epochs,
    validation_data=val_data,
    validation_steps=len(val_data),
    callbacks=[checkpoint_callback]  
)

# Load the best model (model with highest validation accuracy)
best_model = tf.keras.models.load_model('best_model.h5')

# Evaluate the best model on the test data
loss, accuracy = best_model.evaluate(test_data)

# Print the test loss and accuracy
print(f'Test Loss: {loss:.4f}')
print(f'Test Accuracy: {accuracy:.4f}')


# Get the true labels and predicted labels for the test data
y_true = test_data.classes
y_pred = best_model.predict(test_data).round()

# Display the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix as a heatmap
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracies')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()

