# Step 4 -  Build models and make spatial predictions for Oregon state. Code was
# written and performed using RStudio and assumes environment was manually set up 
# via: 
#
# 'RStudio --> Sessions --> Set working directory --> Choose Directory'.


###-----------------------------------------------------------------------------
### Step 0: Environmnet setup
###-----------------------------------------------------------------------------


# Import necessary packages
library(caret)
library(caretEnsemble)
library(dplyr)
library(e1071)
library(gam)
library(ggplot2)
library(ggpubr)
library(import)
library(mboost)
library(MLmetrics)
library(nnet) 
library(parallel) 
library(pROC)
library(randomForest)
library(sf)
library(terra)
library(tictoc)
library(xgboost)

# Import data:

train_data <- rbind( readRDS("RDS_objects/train_data.rds"),
                     readRDS("RDS_objects/test_data_unsure.rds")) %>%
  mutate(Status = factor(ifelse(Status == 1, "S", "U"))) %>%
  dplyr::select(-c(x, y))
table(train_data$Status)

test_data_all <- readRDS("RDS_objects/test_data_OR.rds") %>%
  mutate(Status = factor(ifelse(Status == 1, "S", "U"))) 

table(test_data_all$Status)

pred_data <- terra::rast("Oregon/x_stack_OR.tif") %>%
  as.data.frame(xy = TRUE , row.names = FALSE, na.rm = TRUE)


###-----------------------------------------------------------------------------
### Step 1: Model training
###-----------------------------------------------------------------------------


# Set seed for reproducibility: 
set.seed(3531)

# # Set the training parameters for model learning:
TrainingParameters <- trainControl( method = "repeatedcv",
                                    classProbs = TRUE,
                                    savePredictions = TRUE,
                                    summaryFunction = twoClassSummary,
                                    returnResamp = "all",
                                    index = createMultiFolds(train_data$Status, 5, 5),
                                    allowParallel = TRUE)

#Create list of individually trained models. 
tic( "ensemble_model_list" )
ensemble_model_list <- caretList(Status  ~ . ,
                                data = train_data ,
                                trControl = TrainingParameters ,
                                metric = "ROC" ,
                                methodList = c('nnet',
                                               'gamSpline',
                                                'knn',
                                                'rf',
                                                'svmLinear2', 
                                                'xgbTree'))
toc( )

# Create an ensemble of the trained models 
tic( "ensemble_model" )
ensemble_model <- caretEnsemble(ensemble_model_list, 
                                metric = "ROC", 
                                trControl = trainControl(classProbs = TRUE,
                                                         savePredictions = TRUE,
                                                         summaryFunction = twoClassSummary))

toc( )

# Quick summary of training performances:
summary( ensemble_model)


###-----------------------------------------------------------------------------
### Step 2: Model testing 
###-----------------------------------------------------------------------------


###-------------------------------------
### Model performance  summary function
###-------------------------------------


# Function to collect performance statistics:
roc_stats <- function(name = "Model", trues, preds, rnd = 2){
  # Calculate ROC:
  my_roc <- pROC::roc(trues, preds)
  # Provide max(sens + spec) for best TSS:
  c <- pROC::coords(my_roc, "best", ret = c("threshold","specificity", "sensitivity"))
  # Collect relevant stats:
  df <- data.frame( Model = name, 
                    AUROC = round(my_roc$auc, rnd),
                    #SENS = round(c$sensitivity, rnd),
                    #SPEC = round(c$specificity, rnd),
                    TSS = round(c$sensitivity + c$specificity - 1, rnd) ,
                    Threshold = c$threshold)
  return(df)
}


###------------------------------------
### Collect all testing areas summaries
###------------------------------------


# Run the test data through each model and create a class probability matrix for each model.  
test_ann_all <- predict(ensemble_model$models$nnet, newdata = test_data_all, type = "prob")
test_gam_all <- predict(ensemble_model$models$gamSpline, newdata = test_data_all, type = "prob")
test_knn_all <- predict(ensemble_model$models$knn, newdata = test_data_all, type = "prob")
test_rf_all <- predict(ensemble_model$model$rf, newdata = test_data_all, type = "prob")
test_svm_all <- predict(ensemble_model$models$svmLinear2, newdata = test_data_all, type = "prob")
test_xgb_all <- predict(ensemble_model$models$xgbTree, newdata = test_data_all, type = "prob")
test_ensemble_all <-  predict(ensemble_model, newdata = test_data_all, type = "prob") 

# Summary table of performance statistics:
test_summaries <- rbind(roc_stats("Ensemble", test_data_all$Status, test_ensemble_all),
                         roc_stats("ANN", test_data_all$Status, test_ann_all$S),
                         roc_stats("GAM", test_data_all$Status, test_gam_all$S),
                         roc_stats("KNN", test_data_all$Status, test_knn_all$S),
                         roc_stats("RF", test_data_all$Status, test_rf_all$S),
                         roc_stats("SVM", test_data_all$Status, test_svm_all$S),
                         roc_stats("XGBoost", test_data_all$Status, test_xgb_all$S)
)
test_summaries


###-----------------------------------------------------------------------------
### Step 3: Model statewide spatial predictions 
###-----------------------------------------------------------------------------


# Run the test data through each model to create a class probability matrices: 
pred_ann <- predict( ensemble_model$models$nnet, newdata = pred_data, type = "prob")
pred_gam <- predict(ensemble_model$models$gamSpline, newdata = pred_data, type = "prob")
pred_knn <- predict(ensemble_model$models$knn, newdata = pred_data, type = "prob")
pred_rf <- predict(ensemble_model$models$rf, newdata = pred_data, type = "prob")
pred_svm <- predict(ensemble_model$models$svmLinear2, newdata = pred_data, type = "prob")
pred_xgb <- predict(ensemble_model$models$xgbTree, newdata = pred_data, type = "prob")
pred_ensemble <- predict(ensemble_model, newdata = pred_data, type = "prob") 

# Combine predictions into single data frame:
pred_all <- cbind(pred_data[,1:2], 
                  pred_ensemble, 
                  pred_ann$S, 
                  pred_gam$S, 
                  pred_knn$S , 
                  pred_rf$S, 
                  pred_svm$S, 
                  pred_xgb$S)
names(pred_all) <- c("x" , "y" , "Ensemble", "ANN" , "GAM", "KNN", 
                     "RF", "SVM", "XGBoost")

# Set color palette for ggplots:
viridis_option <- "magma"

# Collect occurrences for first plot:
occs <- test_data_all %>% 
  filter(Status == "S")

# Call border for pretty outlines
border <- readRDS("Geodata/gadm/gadm41_USA_1_pk.rds") %>%
  st_as_sf() %>%
  filter(NAME_1 == "Oregon")


###-------------------------------------
### Pretty plots
###-------------------------------------


# Occurrences
gg_occs <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_sf( data = border, color = "black", fill = "antiquewhite") +
  geom_point( data = occs , 
              aes( x = x, y = y, col = factor(Status)) ,  
              shape = 1, 
              alpha = 0.45) +
  guides(color = guide_legend(override.aes = list(size = 5, alpha = 1, lwd = 2))) +
  theme_void() +
  scale_color_discrete(name = "Occurrence", labels = c("")) +
  theme(legend.title.align = 0.5) 

# Ensemble
gg_ensemble <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = Ensemble)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align = 0.5)

# ANN
gg_ann <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = ANN)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align = 0.5) 

# GAM
gg_gam <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = GAM)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align = 0.5)

# KNN
gg_knn <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = KNN)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align = 0.5)

# RF
gg_rf <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = RF)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align = 0.5)

# SVM
gg_svm <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = SVM)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align = 0.5) 

# XGBoost
gg_xgboost <- ggplot() +
  geom_sf( data = border, color = "black", fill = NA) +
  geom_raster(data = pred_all, aes( x = x, y = y, fill = XGBoost)) +
  scale_fill_viridis_c(option = viridis_option) +
  theme_void()+
  theme(legend.title.align=0.45) 

# Combine all plots into single figure:
ggarrange(gg_occs, gg_ensemble, gg_ann, gg_gam,
          gg_knn, gg_rf, gg_svm, gg_xgboost,
          ncol = 4,
          nrow = 2, legend = "bottom")


###-----------------------------------------------------------------------------
### EOC
###-----------------------------------------------------------------------------
