import tensorflow as tf
import os
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.metrics import classification_report

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns





# Define the image size for resizing and other parameters
img_size = (100, 100)
num_classes = 1
batch_size = 32
num_epochs = 250


# Define the path which is the directory where Real Data is saved
base_dir = r'C:\Users\MSI\Documents\Research\RealData'

train_dir = os.path.join(base_dir, 'RealTrainingSet')
val_dir = os.path.join(base_dir, 'RealValidationSet')
test_dir = os.path.join(base_dir, 'RealTestSet')


train_data_gen = ImageDataGenerator(rescale=1.0/255.0,)
val_data_gen = ImageDataGenerator(rescale=1.0/255.0)
test_data_gen = ImageDataGenerator(rescale=1.0/255.0)

# Load and prepare the training data
train_data = train_data_gen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    color_mode='grayscale'
)

# Load and prepare the validation data
val_data = val_data_gen.flow_from_directory(
    val_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    color_mode='grayscale'
)

# Load and prepare the test data
test_data = test_data_gen.flow_from_directory(
    test_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    color_mode='grayscale',
    shuffle=False  
)



model = Sequential()

model.add(Conv2D(8, (3, 3),padding="same", activation='relu', input_shape=(img_size[0], img_size[1], 1)))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(16, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(32, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(64, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(128, (3, 3),padding="same", activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Flatten())


model.add(Dense(128, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.summary()


# Compile the model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Define the ModelCheckpoint callback to save the best model
checkpoint_callback = ModelCheckpoint(
    filepath='best_model_wtrl.h5',
    monitor='val_accuracy',
    save_best_only=True,
    save_weights_only=False,
    mode='max',
    verbose=1
)


# # Train the model with the ModelCheckpoint callback
history = model.fit(
    train_data,
    steps_per_epoch=len(train_data),
    epochs=num_epochs,
    validation_data=val_data,
    validation_steps=len(val_data),
    callbacks=[checkpoint_callback]
)

# Load the best model (model with highest validation accuracy)
best_model = tf.keras.models.load_model('best_model_wtrl.h5')

# Evaluate the best model on the test data
loss, accuracy = best_model.evaluate(test_data)

# Print the test loss and accuracy
print(f'Test Loss: {loss:.4f}')
print(f'Test Accuracy: {accuracy:.4f}')

# Evaluate the fine-tuned model on the test data
loss, accuracy = best_model.evaluate(test_data, steps=len(test_data))
print(f"Test loss: {loss}, Test accuracy: {accuracy}")

# Get the true labels and predicted probabilities for the test data
true_labels = test_data.labels
predicted_probabilities = best_model.predict(test_data, steps=len(test_data))

# Convert predicted probabilities to binary predictions (0 or 1)
predicted_labels = (predicted_probabilities > 0.5).astype(int)

# Generate the classification report
report = classification_report(true_labels, predicted_labels)

# Print and display the classification report
print("Classification Report:")
print(report)

# Get the true labels and predicted labels for the test data
y_true = test_data.classes
y_pred = best_model.predict(test_data).round()

# Display the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix as a heatmap
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['Defected', 'Normal'], yticklabels=['Defected 0', 'Normal'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

import matplotlib.pyplot as plt

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracies')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()
