{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"PyTorch_Sargassum_Detection.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.8"}},"cells":[{"cell_type":"markdown","metadata":{"id":"ilg8UZ3N5w38"},"source":["# AlexNet for Sargassum Classification Pictures by using PyTorch\n","Author: Javier Arellano-Verdejo \n","Contact: javier.arellano@ecosur.mx & javier_arellano_verdejo@hotmail.com \n","Personal webpage: http://www.ecosur.mx/ecoconsulta/personal/persona.php?id=641&nombre=Javier%20Arellano%20Verdejo \n","Date: September/2020 \n","Version: 1.21.0208 \n","\n","---"]},{"cell_type":"markdown","metadata":{"id":"6Vi3UCes8L1r"},"source":["## Import libraries"]},{"cell_type":"code","metadata":{"id":"Ks8Gx8kcZ312"},"source":["# for general use\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import os, shutil\n","import zipfile\n","import itertools\n","\n","# for PyTorch\n","import torch\n","import torch.nn.functional as F\n","\n","from torchvision import datasets, transforms, models\n","from torch import nn\n","\n","# for confusion matrix computing and classification report\n","from sklearn.metrics import confusion_matrix\n","from sklearn.metrics import classification_report\n","\n","# for web request\n","from urllib.request import urlretrieve"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HvcAzBo67-HO"},"source":["## Get device handler"]},{"cell_type":"code","metadata":{"id":"kdYDE4VoJ0c3"},"source":["device = torch.device('cuda:0' if torch.cuda.is_available() else \"cpu\")\n","\n","# If you get an out-of-memory error on the GPU use with CPU\n","#device = torch.device(\"cpu\")\n","print(device)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"s3MdMufHHuHD"},"source":["\n","## Download the dataset and unzip it to the local folder\n","IMPORTANT: It is also possible to manually download the dataset from the following link: \n","https://doi.org/10.6084/m9.figshare.13256174.v5"]},{"cell_type":"code","metadata":{"id":"Rdc91teCHtDY"},"source":["# figshare URL\n","download_url = \"https://ndownloader.figshare.com/files\"\n","\n","# figshare project id\n","project_id = \"26298643\"\n","\n","# dataset filename\n","fname = \"./sargassum_dataset.zip\"\n","\n","# path where the images will be decompressed\n","base_dir = '/content'\n","\n","# download the dataset\n","try:\n"," urlretrieve(download_url + \"/\" + project_id, fname)\n","\n"," # unzip the dataset\n"," zip_ref = zipfile.ZipFile(fname, 'r')\n"," zip_ref.extractall(base_dir)\n"," zip_ref.close()\n","\n","except:\n"," print(\"The file cannot be downloaded, try manually downloading the file from \" +\n"," \"https://doi.org/10.6084/m9.figshare.13256174.v5 and unzip it to the \" + \n"," \"directory {}\".format(base_dir))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"5tP83tNp-hHa"},"source":["## Processes and loads the test and training dataset using augmented data"]},{"cell_type":"code","metadata":{"id":"nD7wVGp8aA4-"},"source":["transform_train = transforms.Compose([transforms.Resize((224,224)),\n"," transforms.RandomHorizontalFlip(),\n"," transforms.RandomRotation(10),\n"," transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n"," transforms.ColorJitter(brightness=1.0, contrast=1.0, saturation=1.0),\n"," transforms.ToTensor(),\n"," transforms.Normalize((0.5,),(0.5,))\n"," ])\n","\n","transform = transforms.Compose([transforms.Resize((224,224)),\n"," transforms.ToTensor(),\n"," transforms.Normalize((0.5,),(0.5,))\n"," ])\n","\n","training_dataset = datasets.ImageFolder(root=base_dir + '/sargassum_dataset/train', \n"," transform=transform_train)\n","\n","\n","validation_dataset = datasets.ImageFolder(root=base_dir + '/sargassum_dataset/val', \n"," transform=transform)\n","\n","\n","training_loader = torch.utils.data.DataLoader(dataset=training_dataset,\n"," batch_size=100,\n"," shuffle=True)\n","\n","validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset,\n"," batch_size=100,\n"," shuffle=False)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"EhE9rekz-41F"},"source":["## Shows the size of the test and training data set as well as the classes"]},{"cell_type":"code","metadata":{"id":"dE9JcnBSKQp-"},"source":["print('Training dataset size: {}'.format(len(training_dataset)))\n","print('Validation dataset size: {}'.format(len(validation_dataset)))\n","classes = ['Without Sargassum','With Sargassum']\n","print('Classes: {}'.format(classes))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"qN1q_K60_iIX"},"source":["## Support functions"]},{"cell_type":"code","metadata":{"id":"ELK4uT5M1U5E"},"source":["# Converts a tensor into an image\n","def im_convert(tensor):\n"," image = tensor.cpu().clone().detach().numpy()\n"," # the tensor has the following shape [3,224,224] \n"," # however for matplot to visualize it the shape \n"," # must be [224,224,3]. The transpose function \n"," # helps to relocate the order of the columns\n"," image = image.transpose(1,2,0)\n"," # I unnormalize the image x_new = x * std = mean \n"," # in our case the mean and the std are equal to \n"," # 0.5 the previous thing is because when we \n"," # normalized we made the interval go from -1.0 to 1.0\n"," image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))\n"," image = image.clip(0, 1)\n"," \n"," return image\n","\n","# Prints the confusion matrix\n","def plot_confusion_matrix(cm,\n"," classes,\n"," normalize=False,\n"," title='Confusion matrix',\n"," cmap=plt.cm.Blues):\n"," if normalize:\n"," cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n"," print(\"Normalize confusion matrix\")\n"," else:\n"," print(\"Confusion matrix, without normalization\")\n","\n"," print(cm)\n","\n"," fig, ax = plt.subplots()\n"," plt.imshow(cm, interpolation='nearest', cmap=cmap)\n"," plt.title(title)\n"," plt.colorbar()\n"," tick_marks = np.arange(len(classes))\n"," plt.xticks(tick_marks, classes, rotation=45)\n"," plt.yticks(tick_marks, classes)\n","\n"," fmt = '.2f' if normalize else 'd'\n"," thresh = cm.max() / 2\n","\n"," for i,j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n"," plt.text(j, i, format(cm[i, j], fmt),\n"," horizontalalignment='center',\n"," color='white' if cm[i,j] > thresh else 'black')\n"," \n"," plt.tight_layout()\n"," plt.ylabel('True label')\n"," plt.xlabel('Predicted label')\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"_FFe7RNyATOa"},"source":["## Show some images"]},{"cell_type":"code","metadata":{"id":"G1rGcNHW8Gsl"},"source":["dataiter = iter(training_loader)\n","images, labels = dataiter.next()\n","\n","fig = plt.figure(figsize=(25, 4))\n","\n","for idx in np.arange(20):\n"," ax = fig.add_subplot(2, 10, idx+1)\n"," plt.imshow(im_convert(images[idx]))\n"," ax.set_title(labels[idx].item())\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HvIOwuLjHbMt"},"source":["## Load and adapt the AlexNet model "]},{"cell_type":"code","metadata":{"id":"jT6s1CTo8RAP"},"source":["model = models.alexnet(pretrained=True)\n","print(model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ft2rLW3hySeD"},"source":["## The layers we don't want to retrain are locked up"]},{"cell_type":"code","metadata":{"id":"qbwLZQHJM7C-"},"source":["for param in model.features.parameters():\n"," param.requires_grad = False"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"S12423qfyqz3"},"source":["## The last layer is modified so that it only classifies two classes instead of 1000"]},{"cell_type":"code","metadata":{"id":"WL48d4tHy06f"},"source":["n_inputs = model.classifier[6].in_features\n","last_layer = nn.Linear(n_inputs, len(classes))\n","model.classifier[6] = last_layer\n","model.to(device)\n","\n","print(model)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nC2pI3yHBCvn"},"source":["## The loss function and the optimizer are defined"]},{"cell_type":"code","metadata":{"id":"GC_amGlkM739"},"source":["# definition of the error function. CrossEntropyLoss is the \n","# best option for the classification of multiple classes\n","criterion = nn.CrossEntropyLoss()\n","\n","# Definition of the method of optimisation\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WAB9KaCcN6Dc"},"source":["## Train the Neural Network"]},{"cell_type":"code","metadata":{"id":"dL2GSHhVNzSd"},"source":["epochs = 40\n","running_loss_history = []\n","running_corrects_history = []\n","\n","val_running_loss_history = []\n","val_running_corrects_history = []\n","\n","for e in range(epochs):\n"," running_loss = 0.0\n"," running_corrects = 0.0\n","\n"," val_running_loss = 0.0\n"," val_running_corrects = 0.0\n","\n"," for inputs, labels in training_loader:\n"," # We use the GPU\n"," inputs = inputs.to(device)\n"," labels = labels.to(device)\n","\n"," outputs = model.forward(inputs)\n"," loss = criterion(outputs, labels)\n","\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n","\n"," _, pred = torch.max(outputs, 1)\n"," running_loss += loss.item()\n"," running_corrects += torch.sum(pred == labels.data)\n","\n"," else:\n"," # saves memory by not calculating the gradient\n"," with torch.no_grad():\n"," for val_inputs, val_labels in validation_loader:\n"," # We use the GPU\n"," val_inputs = val_inputs.to(device)\n"," val_labels = val_labels.to(device)\n"," val_outputs = model.forward(val_inputs)\n"," val_loss = criterion(val_outputs, val_labels) \n","\n"," _, val_pred = torch.max(val_outputs, 1)\n"," val_running_loss += val_loss.item()\n"," val_running_corrects += torch.sum(val_pred == val_labels.data)\n","\n"," epoch_loss = running_loss/len(training_loader.dataset)\n"," epoch_acc = running_corrects.float() / len(training_loader.dataset)\n"," running_loss_history.append(epoch_loss)\n"," running_corrects_history.append(epoch_acc)\n","\n"," val_epoch_loss = val_running_loss/len(validation_loader.dataset)\n"," val_epoch_acc = val_running_corrects.float() / len(validation_loader.dataset)\n"," val_running_loss_history.append(val_epoch_loss)\n"," val_running_corrects_history.append(val_epoch_acc)\n","\n"," print('Epoch: ', (e+1))\n"," print('Training loss: {:.4f}, acc: {:.4f} '.format(epoch_loss, epoch_acc.item()))\n"," print('Validation loss: {:.4f}, acc: {:.4f} '.format(val_epoch_loss, val_epoch_acc.item()))\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NX-GO2YAKJ0r"},"source":["## Feeds the Neural Network with the validation data and store the result"]},{"cell_type":"code","metadata":{"id":"-Lu_ZmyPoYUo"},"source":["val_running_loss = 0.0\n","val_running_corrects = 0.0\n","\n","all_preds = torch.tensor([])\n","all_targets = torch.tensor([])\n","\n","with torch.no_grad():\n"," for val_inputs, val_labels in validation_loader:\n"," # Usamos la GPU\n"," val_inputs = val_inputs.to(device)\n"," val_labels = val_labels.to(device)\n","\n"," val_outputs = model.forward(val_inputs) \n"," val_loss = criterion(val_outputs, val_labels) \n","\n"," _, val_pred = torch.max(val_outputs, 1)\n"," val_running_loss += val_loss.item()\n"," val_running_corrects += torch.sum(val_pred == val_labels.data)\n","\n"," all_preds = torch.cat(\n"," (all_preds, val_outputs.to('cpu'))\n"," ,dim=0\n"," )\n"," all_targets = torch.cat(\n"," (all_targets, val_labels.to('cpu').type('torch.FloatTensor') )\n"," ,dim=0\n"," )\n","\n","print('Validation dataset loss: {}'.format(val_running_loss))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1iYLm8IgKtFQ"},"source":["## Displays the confusion matrix"]},{"cell_type":"code","metadata":{"id":"lf4I44W3rX5G"},"source":["cm = confusion_matrix(all_targets, all_preds.argmax(dim=1))\n","plot_confusion_matrix(cm, list(range(2)))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"e9Rv0JJCK6dv"},"source":["## Prints the classification report"]},{"cell_type":"code","metadata":{"id":"xI784UMxnfWT"},"source":["print(classification_report(all_targets, all_preds.argmax(dim=1)))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"G_SnBEwHLDW6"},"source":["## Plots the loss data"]},{"cell_type":"code","metadata":{"id":"s_ereKlWpNEw"},"source":["fig, ax = plt.subplots()\n","plt.plot(running_loss_history, label='training loss')\n","plt.plot(val_running_loss_history, label='validation loss')\n","plt.legend()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Mp6AT5cDLMhh"},"source":["## Plots the Neural Network Accuracy"]},{"cell_type":"code","metadata":{"id":"XX6L47k6qmDy"},"source":["fig, ax = plt.subplots()\n","plt.plot(running_corrects_history, label='training accuracy')\n","plt.plot(val_running_corrects_history, label='validation accuracy')\n","plt.legend()"],"execution_count":null,"outputs":[]}]}