import tensorflow as tf
import os
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.metrics import classification_report

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

pretrained_model = load_model('best_model.h5')

for layer in pretrained_model.layers[:8]:  #Freezing the first 8 layers
    layer.trainable = False

x = pretrained_model.layers[-5].output  # Get the output of the flatten layer in the pretrained model

x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(8, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=pretrained_model.input, outputs=x)
model.summary()

learning_rate = 0.0001
model.compile(optimizer=Adam(learning_rate), loss='binary_crossentropy', metrics=['accuracy'])

# Define the image size for resizing and other parameters
img_size = (100, 100)
num_classes = 1
batch_size = 32
num_epochs = 250

# Define the path which is the directory where Real Data is saved
base_dir = r'C:\Users\MSI\Documents\Research\RealData'

train_dir = os.path.join(base_dir, 'RealTrainingSet')
val_dir = os.path.join(base_dir, 'RealValidationSet')
test_dir = os.path.join(base_dir, 'RealTestSet')

# Create ImageDataGenerators for training, validation, and testing data
train_data_gen = ImageDataGenerator(rescale=1.0/255.0)
val_data_gen = ImageDataGenerator(rescale=1.0/255.0)
test_data_gen = ImageDataGenerator(rescale=1.0/255.0)

# Load and prepare the training data
train_data = train_data_gen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
    class_mode='binary'
)

# Load and prepare the validation data
val_data = val_data_gen.flow_from_directory(
    val_dir,
    target_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
    class_mode='binary'
)

# Load and prepare the test data
test_data = test_data_gen.flow_from_directory(
    test_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    color_mode='grayscale',
    shuffle=False
)

# Define callbacks (optional)
checkpoint_path = "best_model_fine_tuned.h5"
checkpoint_callback = ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max', verbose=1)

history = model.fit(
    train_data,
    steps_per_epoch=len(train_data),
    epochs=num_epochs,
    validation_data=val_data,
    validation_steps=len(val_data),
    callbacks=[checkpoint_callback]
)

# Load the best model (model with highest validation accuracy)
best_model = tf.keras.models.load_model('best_model_fine_tuned.h5')

# Evaluate the fine-tuned model on the test data
loss, accuracy = best_model.evaluate(test_data, steps=len(test_data))
print(f"Test loss: {loss}, Test accuracy: {accuracy}")

# Get the true labels and predicted probabilities for the test data
true_labels = test_data.labels
predicted_probabilities = best_model.predict(test_data, steps=len(test_data))

# Convert predicted probabilities to binary predictions (0 or 1)
predicted_labels = (predicted_probabilities > 0.5).astype(int)

# Generate the classification report
report = classification_report(true_labels, predicted_labels)

# Print and display the classification report
print("Classification Report:")
print(report)

# Get the true labels and predicted labels for the test data
y_true = test_data.classes
y_pred = best_model.predict(test_data).round()

# Display the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix as a heatmap
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['Defected', 'Normal'], yticklabels=['Defected', 'Normal'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

import matplotlib.pyplot as plt


plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracies')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='best')
plt.show()
