In [None]:

import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import yaml
from typing import Dict, List, Tuple, Optional
import warnings
import time
from collections import defaultdict
from PIL import Image

# Deep Learning Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
import torchvision.transforms.functional as F
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Fast R-CNN specific imports
try:
    import torchvision
    from torchvision.models.detection import fasterrcnn_resnet50_fpn
    from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
    from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
    from torchvision.ops import nms
    print("Torchvision detection modules imported successfully")
except ImportError as e:
    print(f"Error importing torchvision detection: {e}")
    print("Please install torchvision >= 0.3.0")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment Setup Complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")

class TurtleDataset(Dataset):
    """Custom dataset for turtle detection using Fast R-CNN format"""
    
    def __init__(self, images_path, labels_path, transforms=None, num_classes=30):
        self.images_path = Path(images_path)
        self.labels_path = Path(labels_path)
        self.transforms = transforms
        self.num_classes = num_classes
        
        # Get all image files
        self.image_files = list(self.images_path.glob("*.jpg")) + list(self.images_path.glob("*.png"))
        
        # Filter images that have corresponding label files
        self.valid_images = []
        for img_file in self.image_files:
            label_file = self.labels_path / f"{img_file.stem}.txt"
            if label_file.exists():
                self.valid_images.append(img_file)
    
    def __len__(self):
        return len(self.valid_images)
    
    def __getitem__(self, idx):
        img_path = self.valid_images[idx]
        label_path = self.labels_path / f"{img_path.stem}.txt"
        
        # Load image
        image = Image.open(img_path).convert("RGB")
        img_width, img_height = image.size
        
        # Load labels (YOLO format: class_id, x_center, y_center, width, height)
        boxes = []
        labels = []
        
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    values = line.strip().split()
                    if len(values) >= 5:
                        class_id = int(values[0])
                        x_center = float(values[1]) * img_width
                        y_center = float(values[2]) * img_height
                        width = float(values[3]) * img_width
                        height = float(values[4]) * img_height
                        
                        # Convert to Pascal VOC format (x_min, y_min, x_max, y_max)
                        x_min = x_center - width / 2
                        y_min = y_center - height / 2
                        x_max = x_center + width / 2
                        y_max = y_center + height / 2
                        
                        boxes.append([x_min, y_min, x_max, y_max])
                        labels.append(class_id + 1)  # +1 because 0 is background
        
        # Convert to tensors
        if len(boxes) == 0:
            # If no boxes, create a dummy box
            boxes = [[0, 0, 1, 1]]
            labels = [0]  # background class
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        # Calculate area
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        # Assume all instances are not crowds
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        
        target = {
            "boxes": boxes,
            "labels": labels,
            "area": area,
            "iscrowd": iscrowd,
            "image_id": torch.tensor([idx])
        }
        
        if self.transforms:
            image = self.transforms(image)
        else:
            # Default transform
            transform = transforms.Compose([
                transforms.ToTensor()
            ])
            image = transform(image)
        
        return image, target

class TurtleDatasetHandler:
    def __init__(self, data_path: str, num_classes: int = 30):
        self.data_path = Path(data_path)
        self.num_classes = num_classes
        self.images = []
        self.labels = []
        
    def prepare_dataset(self):
        # Assuming YOLO format dataset structure
        images_path = self.data_path / "images"
        labels_path = self.data_path / "labels"
        
        if not images_path.exists() or not labels_path.exists():
            print("Creating sample dataset structure...")
            self._create_sample_structure()
            
        # Load image and label paths
        for img_file in images_path.glob("*.jpg"):
            label_file = labels_path / f"{img_file.stem}.txt"
            if label_file.exists():
                self.images.append(str(img_file))
                self.labels.append(str(label_file))
        
        print(f"Found {len(self.images)} images with labels")
        return self.images, self.labels
    
    def _create_sample_structure(self):
        os.makedirs(self.data_path / "images", exist_ok=True)
        os.makedirs(self.data_path / "labels", exist_ok=True)
        print("Sample dataset structure created")
    
    def create_kfold_splits(self, k_folds: int = 5):
        kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
        splits = []
        
        indices = np.arange(len(self.images))
        for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
            train_images = [self.images[i] for i in train_idx]
            val_images = [self.images[i] for i in val_idx]
            splits.append({
                'fold': fold + 1,
                'train': train_images,
                'val': val_images
            })
        
        return splits

def collate_fn(batch):
    """Custom collate function for Fast R-CNN DataLoader"""
    return tuple(zip(*batch))

class DirectionalWeightFastRCNN:
    def __init__(self, num_classes=31, directional_weight_map=0.75):  # +1 for background
        self.num_classes = num_classes
        self.directional_weight_map = directional_weight_map
        self.model = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.training_results = []
        
        # Training parameters optimized for turtle detection
        self.training_params = {
            'epochs': 50,
            'batch_size': 4,  # Smaller batch size for Fast R-CNN
            'learning_rate': 0.005,
            'momentum': 0.9,
            'weight_decay': 0.0005,
        }
        
        self.initialize_model()
    
    def initialize_model(self):
        """Initialize Fast R-CNN model with custom number of classes"""
        # Load pre-trained Faster R-CNN model
        self.model = fasterrcnn_resnet50_fpn(pretrained=True)
        
        # Replace the classifier head
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, self.num_classes)
        
        self.model.to(self.device)
        print(f"Initialized Fast R-CNN with {self.num_classes} classes")
    
    def calculate_directional_weights(self, pattern_type='plastron'):
        """Calculate weights for different turtle pattern types"""
        weight_configs = {
            'plastron': {
                'spatial': 1.2,
                'confidence': 1.0,
                'pattern': 1.3,
                'box': 1.1
            },
            'nasal': {
                'spatial': 1.0,
                'confidence': 1.1,
                'pattern': 1.0,
                'box': 1.0
            },
            'infraorbital': {
                'spatial': 1.1,
                'confidence': 1.0,
                'pattern': 1.1,
                'box': 1.05
            },
            'standard': {
                'spatial': 1.0,
                'confidence': 1.0,
                'pattern': 1.0,
                'box': 1.0
            }
        }
        
        return weight_configs.get(pattern_type, weight_configs['standard'])
    
    def create_data_loaders(self, train_images, val_images, data_path):
        """Create training and validation data loaders"""
        images_path = Path(data_path) / "images"
        labels_path = Path(data_path) / "labels"
        
        # Data transforms
        train_transforms = transforms.Compose([
            transforms.ColorJitter(brightness=0.2, contrast=0.2, hue=0.1),
            transforms.ToTensor()
        ])
        
        val_transforms = transforms.Compose([
            transforms.ToTensor()
        ])
        
        # Create datasets
        train_dataset = TurtleDataset(images_path, labels_path, train_transforms, self.num_classes-1)
        val_dataset = TurtleDataset(images_path, labels_path, val_transforms, self.num_classes-1)
        
        # Filter datasets based on splits
        train_indices = [i for i, img_path in enumerate(train_dataset.valid_images) 
                        if str(img_path) in train_images]
        val_indices = [i for i, img_path in enumerate(val_dataset.valid_images) 
                      if str(img_path) in val_images]
        
        train_subset = torch.utils.data.Subset(train_dataset, train_indices)
        val_subset = torch.utils.data.Subset(val_dataset, val_indices)
        
        # Create data loaders
        train_loader = DataLoader(
            train_subset,
            batch_size=self.training_params['batch_size'],
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=2
        )
        
        val_loader = DataLoader(
            val_subset,
            batch_size=self.training_params['batch_size'],
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=2
        )
        
        return train_loader, val_loader
    
    def train_fold(self, train_images, val_images, data_path, fold_num, pattern_type='plastron'):
        """Train Fast R-CNN for one fold"""
        print(f"\n{'='*50}")
        print(f"Training Fold {fold_num} with Fast R-CNN")
        print(f"{'='*50}")
        
        # Apply directional weights
        weights = self.calculate_directional_weights(pattern_type)
        
        # Create data loaders
        train_loader, val_loader = self.create_data_loaders(train_images, val_images, data_path)
        
        # Setup optimizer
        optimizer = optim.SGD(
            self.model.parameters(),
            lr=self.training_params['learning_rate'],
            momentum=self.training_params['momentum'],
            weight_decay=self.training_params['weight_decay']
        )
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
        
        start_time = time.time()
        
        # Training loop
        self.model.train()
        for epoch in range(self.training_params['epochs']):
            epoch_loss = 0.0
            num_batches = 0
            
            for images, targets in train_loader:
                if len(images) == 0:
                    continue
                    
                # Move to device
                images = [img.to(self.device) for img in images]
                targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
                
                # Forward pass
                loss_dict = self.model(images, targets)
                
                # Apply directional weights to losses
                weighted_losses = 0
                for key, loss in loss_dict.items():
                    if 'bbox' in key:
                        weighted_losses += loss * weights['box']
                    elif 'class' in key:
                        weighted_losses += loss * weights['pattern']
                    else:
                        weighted_losses += loss
                
                # Backward pass
                optimizer.zero_grad()
                weighted_losses.backward()
                optimizer.step()
                
                epoch_loss += weighted_losses.item()
                num_batches += 1
            
            scheduler.step()
            
            if epoch % 10 == 0:
                avg_loss = epoch_loss / max(num_batches, 1)
                print(f"Epoch {epoch}/{self.training_params['epochs']}, Loss: {avg_loss:.4f}")
        
        training_time = time.time() - start_time
        
        # Store results
        fold_results = {
            'fold': fold_num,
            'pattern_type': pattern_type,
            'training_time': training_time,
            'weights': weights
        }
        
        self.training_results.append(fold_results)
        
        return fold_results
    
    def validate_fold(self, val_images, data_path, fold_num):
        """Validate Fast R-CNN model"""
        print(f"\nValidating Fold {fold_num} with Fast R-CNN...")
        
        # Create validation data loader
        _, val_loader = self.create_data_loaders([], val_images, data_path)
        
        self.model.eval()
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for images, targets in val_loader:
                if len(images) == 0:
                    continue
                    
                # Move to device
                images = [img.to(self.device) for img in images]
                
                # Get predictions
                predictions = self.model(images)
                
                all_predictions.extend(predictions)
                all_targets.extend(targets)
        
        # Calculate metrics
        metrics = self.calculate_metrics(all_predictions, all_targets)
        
        return metrics
    
    def calculate_metrics(self, predictions, targets, iou_threshold=0.5):
        """Calculate detection metrics"""
        total_predictions = 0
        total_targets = 0
        true_positives = 0
        
        # Simplified metric calculation
        for pred, target in zip(predictions, targets):
            pred_boxes = pred['boxes'].cpu()
            pred_labels = pred['labels'].cpu()
            pred_scores = pred['scores'].cpu()
            
            target_boxes = target['boxes']
            target_labels = target['labels']
            
            # Filter predictions by confidence threshold
            high_conf_mask = pred_scores > 0.5
            pred_boxes = pred_boxes[high_conf_mask]
            pred_labels = pred_labels[high_conf_mask]
            
            total_predictions += len(pred_boxes)
            total_targets += len(target_boxes)
            
            # Simple IoU-based matching
            for pred_box, pred_label in zip(pred_boxes, pred_labels):
                for target_box, target_label in zip(target_boxes, target_labels):
                    if pred_label == target_label:
                        iou = self.calculate_iou(pred_box, target_box)
                        if iou > iou_threshold:
                            true_positives += 1
                            break
        
        # Calculate metrics
        precision = true_positives / max(total_predictions, 1)
        recall = true_positives / max(total_targets, 1)
        f1_score = 2 * precision * recall / max(precision + recall, 1e-8)
        
        # Simplified mAP calculation
        mAP50 = precision  # Approximation
        mAP50_95 = precision * 0.8  # Approximation
        
        metrics = {
            'mAP50': mAP50,
            'mAP50-95': mAP50_95,
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score
        }
        
        return metrics
    
    def calculate_iou(self, box1, box2):
        """Calculate Intersection over Union of two boxes"""
        # Convert to numpy for easier calculation
        box1 = box1.numpy()
        box2 = box2.numpy()
        
        # Calculate intersection
        x1 = max(box1[0], box2[0])
        y1 = max(box1[1], box2[1])
        x2 = min(box1[2], box2[2])
        y2 = min(box1[3], box2[3])
        
        if x2 <= x1 or y2 <= y1:
            return 0.0
        
        intersection = (x2 - x1) * (y2 - y1)
        
        # Calculate union
        area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
        area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
        union = area1 + area2 - intersection
        
        return intersection / max(union, 1e-8)

class PerformanceAnalyzer:
    def __init__(self):
        self.results = defaultdict(list)
        self.comparison_data = []
    
    def add_fold_results(self, fold_num, metrics, model_type='FastRCNN'):
        """Add results from a fold"""
        self.results[model_type].append({
            'fold': fold_num,
            **metrics
        })
    
    def calculate_statistics(self, model_type='FastRCNN'):
        """Calculate mean and standard deviation for metrics"""
        if not self.results[model_type]:
            return {}
        
        metrics = ['mAP50', 'mAP50-95', 'precision', 'recall', 'f1_score']
        stats = {}
        
        for metric in metrics:
            values = [fold[metric] for fold in self.results[model_type]]
            stats[f'{metric}_mean'] = np.mean(values)
            stats[f'{metric}_std'] = np.std(values)
        
        return stats
    
    def compare_models(self, baseline_type='StandardFastRCNN', enhanced_type='DirectionalFastRCNN'):
        """Compare baseline vs enhanced model"""
        baseline_stats = self.calculate_statistics(baseline_type)
        enhanced_stats = self.calculate_statistics(enhanced_type)
        
        comparison = {}
        metrics = ['mAP50', 'mAP50-95', 'precision', 'recall', 'f1_score']
        
        for metric in metrics:
            baseline_mean = baseline_stats.get(f'{metric}_mean', 0)
            enhanced_mean = enhanced_stats.get(f'{metric}_mean', 0)
            
            improvement = enhanced_mean - baseline_mean
            improvement_pct = (improvement / baseline_mean * 100) if baseline_mean > 0 else 0
            
            comparison[metric] = {
                'baseline': baseline_mean,
                'enhanced': enhanced_mean,
                'improvement': improvement,
                'improvement_pct': improvement_pct
            }
        
        return comparison
    
    def plot_results(self, save_path='fastrcnn_performance_analysis.png'):
        """Plot performance comparison"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        metrics = ['mAP50', 'mAP50-95', 'precision', 'recall', 'f1_score']
        
        for i, metric in enumerate(metrics):
            if i >= len(axes):
                break
                
            ax = axes[i]
            
            # Prepare data for plotting
            model_types = list(self.results.keys())
            data_for_plot = []
            labels = []
            
            for model_type in model_types:
                values = [fold[metric] for fold in self.results[model_type]]
                data_for_plot.append(values)
                labels.append(model_type)
            
            # Box plot
            bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True)
            
            # Customize colors
            colors = ['lightblue', 'lightcoral', 'lightgreen']
            for patch, color in zip(bp['boxes'], colors[:len(bp['boxes'])]):
                patch.set_facecolor(color)
            
            ax.set_title(f'{metric.upper()} Comparison')
            ax.set_ylabel(metric.upper())
            ax.grid(True, alpha=0.3)
        
        # Remove empty subplots
        for i in range(len(metrics), len(axes)):
            axes[i].remove()
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def generate_report(self):
        """Generate comprehensive performance report"""
        report = "\n" + "="*80 + "\n"
        report += "FAST R-CNN TURTLE IDENTIFICATION - PERFORMANCE REPORT\n"
        report += "="*80 + "\n"
        
        for model_type in self.results.keys():
            stats = self.calculate_statistics(model_type)
            
            report += f"\n{model_type} Results:\n"
            report += "-" * 40 + "\n"
            
            metrics = ['mAP50', 'mAP50-95', 'precision', 'recall', 'f1_score']
            for metric in metrics:
                mean_val = stats.get(f'{metric}_mean', 0)
                std_val = stats.get(f'{metric}_std', 0)
                report += f"{metric.upper():<12}: {mean_val:.4f} Â± {std_val:.4f}\n"
        
        # Add comparison if multiple models
        if len(self.results) > 1:
            model_types = list(self.results.keys())
            if len(model_types) >= 2:
                comparison = self.compare_models(model_types[0], model_types[1])
                
                report += f"\nModel Comparison ({model_types[1]} vs {model_types[0]}):\n"
                report += "-" * 50 + "\n"
                
                for metric, comp in comparison.items():
                    report += f"{metric.upper():<12}: {comp['improvement']:+.4f} ({comp['improvement_pct']:+.2f}%)\n"
        
        report += "\n" + "="*80 + "\n"
        
        return report

def run_kfold_experiment(data_path='./turtle_dataset', k_folds=5):
    """Run complete K-fold cross-validation experiment with Fast R-CNN"""
    
    print("Starting Fast R-CNN K-Fold Cross-Validation Experiment")
    print("="*60)
    
    # Initialize components
    dataset_handler = TurtleDatasetHandler(data_path)
    analyzer = PerformanceAnalyzer()
    
    # Prepare dataset
    images, labels = dataset_handler.prepare_dataset()
    
    if len(images) == 0:
        print("No images found! Please ensure your dataset structure is correct.")
        return analyzer
    
    # Create K-fold splits
    splits = dataset_handler.create_kfold_splits(k_folds)
    
    # Run experiments for both standard and directional Fast R-CNN
    model_configs = [
        {'name': 'StandardFastRCNN', 'use_directional': False},
        {'name': 'DirectionalFastRCNN', 'use_directional': True}
    ]
    
    for config in model_configs:
        print(f"\nRunning {config['name']} experiment...")
        
        for split in splits:
            fold_num = split['fold']
            
            # Initialize model
            model = DirectionalWeightFastRCNN(num_classes=31)  # 30 turtle classes + background
            
            try:
                # Train
                if config['use_directional']:
                    train_results = model.train_fold(
                        split['train'], split['val'], data_path, fold_num, 'plastron'
                    )
                else:
                    # Standard training without directional weights
                    train_results = model.train_fold(
                        split['train'], split['val'], data_path, fold_num, 'standard'
                    )
                
                # Validate
                metrics = model.validate_fold(split['val'], data_path, fold_num)
                
                # Add to analyzer
                analyzer.add_fold_results(fold_num, metrics, config['name'])
                
                print(f"Fold {fold_num} completed - mAP50: {metrics['mAP50']:.4f}")
                
            except Exception as e:
                print(f"Error in fold {fold_num}: {e}")
                # Add default metrics for failed fold
                default_metrics = {
                    'mAP50': 0.0, 'mAP50-95': 0.0, 'precision': 0.0, 
                    'recall': 0.0, 'f1_score': 0.0
                }
                analyzer.add_fold_results(fold_num, default_metrics, config['name'])
    
    return analyzer

if __name__ == "__main__":
    # Run the experiment
    analyzer = run_kfold_experiment(data_path='./turtle_dataset', k_folds=5)
    
    # Generate and print report
    report = analyzer.generate_report()
    print(report)
    
    # Plot results
    try:
        analyzer.plot_results()
    except Exception as e:
        print(f"Error plotting results: {e}")
    
    # Save detailed results to CSV
    all_results = []
    for model_type, results in analyzer.results.items():
        for result in results:
            result['model_type'] = model_type
            all_results.append(result)
    
    if all_results:
        df_results = pd.DataFrame(all_results)
        df_results.to_csv('turtle_fastrcnn_kfold_results.csv', index=False)
        print("\nResults saved to 'turtle_fastrcnn_kfold_results.csv'")
    else:
        print("\nNo results to save.")