{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "LGNzfnDkbxEY" }, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import backend as K\n", "from tensorflow.keras.layers import Dense, Activation,Dropout,Conv2D, MaxPooling2D,BatchNormalization, Flatten\n", "from tensorflow.keras.optimizers import Adam, Adamax\n", "from tensorflow.keras.metrics import categorical_crossentropy\n", "from tensorflow.keras import regularizers\n", "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", "from tensorflow.keras.models import Model, load_model, Sequential\n", "import numpy as np\n", "import pandas as pd\n", "import shutil\n", "import time\n", "import cv2 as cv2\n", "from tqdm import tqdm\n", "from sklearn.model_selection import train_test_split\n", "import matplotlib.pyplot as plt\n", "from matplotlib.pyplot import imshow\n", "import os\n", "import seaborn as sns\n", "sns.set_style('darkgrid')\n", "from PIL import Image\n", "from sklearn.metrics import confusion_matrix, classification_report\n", "from IPython.core.display import display, HTML\n", "\n", "#\n", "from tensorflow.keras.callbacks import EarlyStopping" ] }, { "cell_type": "code", "source": [ "colon_dir=r'../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/colon_image_sets'\n", "lung_dir=r'../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets'\n", "for i,d in enumerate([colon_dir,lung_dir]):\n", " filepaths=[]\n", " labels=[]\n", " classlist=os.listdir(d)\n", " for klass in classlist:\n", " classpath=os.path.join(d,klass)\n", " if os.path.isdir(classpath):\n", " flist=os.listdir(classpath)\n", " for f in flist:\n", " fpath=os.path.join(classpath,f)\n", " filepaths.append(fpath)\n", " labels.append(klass)\n", " Fseries= pd.Series(filepaths, name='filepaths')\n", " Lseries=pd.Series(labels, name='labels')\n", " if i==0:\n", " colon_df=pd.concat([Fseries, Lseries], axis=1)\n", " else:\n", " lung_df=pd.concat([Fseries, Lseries], axis=1)\n", "df=pd.concat([colon_df, lung_df], axis =0).reset_index(drop=True)# make a combined dataframe\n", "print (df.head())\n", "print(df['labels'].value_counts())" ], "metadata": { "id": "9ZBdJh5Db_a7" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "sample_size=5000\n", "sample_list=[]\n", "group=df.groupby('labels')\n", "for label in df['labels'].unique():\n", " label_group=group.get_group(label).sample(sample_size, replace=False, random_state=123, axis=0)\n", " sample_list.append(label_group)\n", "df=pd.concat(sample_list, axis=0).reset_index(drop=True)\n", "print (len(df))" ], "metadata": { "id": "uJVBDvkqcKno" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "train_split=.8\n", "test_split=.1\n", "dummy_split=test_split/(1-train_split)\n", "train_df, dummy_df=train_test_split(df, train_size=train_split, shuffle=True, random_state=123)\n", "test_df, valid_df=train_test_split(dummy_df, train_size=dummy_split, shuffle=True, random_state=123)\n", "print ('train_df length: ', len(train_df), ' _test_df length: ', len(test_df), ' valid_df length: ', len(valid_df))" ], "metadata": { "id": "gd_v74v9cSO-" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "height=128\n", "width=128\n", "channels=3\n", "batch_size=128\n", "img_shape=(height, width, channels)\n", "img_size=(height, width)\n", "length=len(test_df)\n", "test_batch_size=sorted([int(length/n) for n in range(1,length+1) if length % n ==0 and length/n<=80],reverse=True)[0]\n", "test_steps=int(length/test_batch_size)\n", "print ( 'test batch size: ' ,test_batch_size, ' test steps: ', test_steps)\n", "\n", "\n", "#------------------------------------------------\n", "import cv2\n", "from skimage import io\n", "from matplotlib import pyplot as plt\n", "def scalar(img):\n", " #path=\"../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/colon_image_sets/colon_aca/colonca1.jpeg\"\n", " #img = img[...,::-1] # Added\n", " #img = cv2.imread(img)\n", " #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", " lab_img= cv2.cvtColor(img, cv2.COLOR_RGB2LAB)\n", " l, a, b = cv2.split(lab_img)\n", " #plt.hist(l.flat, bins=100, range=(0,255))\n", " #plt.show()\n", " equ = cv2.equalizeHist(l)\n", " updated_lab_img1 = cv2.merge((equ,a,b))\n", " hist_eq_img = cv2.cvtColor(updated_lab_img1, cv2.COLOR_LAB2BGR)\n", " clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))\n", " clahe_img = clahe.apply(l)\n", " updated_lab_img2 = cv2.merge((clahe_img,a,b))\n", " CLAHE_img = cv2.cvtColor(updated_lab_img2, cv2.COLOR_LAB2BGR)\n", " return CLAHE_img/127.5-1 # scale pixel between -1 and +1\n", "#return img/127.5-1 # scale pixel between -1 and +1\n", "#------------------------------------------------\n", "\n", "gen=ImageDataGenerator(preprocessing_function=scalar)\n", "train_gen=gen.flow_from_dataframe( train_df, x_col='filepaths', y_col='labels', target_size=img_size, class_mode='categorical',\n", " color_mode='rgb', shuffle=True, batch_size=batch_size)\n", "test_gen=gen.flow_from_dataframe( test_df, x_col='filepaths', y_col='labels', target_size=img_size, class_mode='categorical',\n", " color_mode='rgb', shuffle=False, batch_size=test_batch_size)\n", "valid_gen=gen.flow_from_dataframe( valid_df, x_col='filepaths', y_col='labels', target_size=img_size, class_mode='categorical',\n", " color_mode='rgb', shuffle=False, batch_size=batch_size)\n", "classes=list(train_gen.class_indices.keys())\n", "class_count=len(classes)" ], "metadata": { "id": "W2XIkL5EcUwR" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import cv2\n", "import numpy as np\n", "from PIL import Image\n", "def myFunc(image):\n", " img = np.array(image)\n", " print(img.shape)\n", " lab_img= cv2.cvtColor(img, cv2.COLOR_RGB2LAB)\n", " l, a, b = cv2.split(lab_img)\n", " #plt.hist(l.flat, bins=100, range=(0,255))\n", " #plt.show()\n", " equ = cv2.equalizeHist(l)\n", " updated_lab_img1 = cv2.merge((equ,a,b))\n", " hist_eq_img = cv2.cvtColor(updated_lab_img1, cv2.COLOR_LAB2BGR)\n", " clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))\n", " clahe_img = clahe.apply(l)\n", " updated_lab_img2 = cv2.merge((clahe_img,a,b))\n", " CLAHE_img = cv2.cvtColor(updated_lab_img2, cv2.COLOR_LAB2BGR)\n", " #return CLAHE_img/127.5-1 # scale pixel between -1 and +1\n", " return Image.fromarray(CLAHE_img )" ], "metadata": { "id": "uKSuvdnBcWyo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from tensorflow.keras.preprocessing import image\n", "SAMPLES = ['../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_scc/lungscc1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_n/lungn1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_aca/lungaca1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/colon_image_sets/colon_n/colonn1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/colon_image_sets/colon_aca/colonca1.jpeg']\n", "\n", "plt.figure(figsize=(22, 10))\n", "global c\n", "c = 0\n", "\n", "for i in SAMPLES:\n", " plt.subplot(1, 5, c + 1)\n", " c += 1\n", " t = i.split('/')\n", " plt.title(t[5])\n", " plt.imshow(image.load_img(i))\n", " plt.axis('off')\n", "plt.show()" ], "metadata": { "id": "X5Cz894ocjrW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ " from tensorflow.keras.preprocessing import image\n", " import cv2\n", " import numpy as np\n", " from PIL import Image\n", " image= '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_n/lungn1.jpeg'\n", " SAMPLES = ['../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_scc/lungscc1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_n/lungn1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_aca/lungaca1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/colon_image_sets/colon_n/colonn1.jpeg',\n", " '../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/colon_image_sets/colon_aca/colonca1.jpeg']\n", " SAMPLES = ['../input/lung-and-colon-cancer-histopathological-images/lung_colon_image_set/lung_image_sets/lung_scc/lungscc1.jpeg']\n", " plt.figure(figsize=(15, 8))\n", " t = i.split('/')\n", " plt.title(t[5])\n", " #image=image.load_img(i)\n", " image=cv2.imread(image)\n", " plt.imshow(image)\n", " plt.axis('off')\n", "\n", " plt.figure(figsize=(15, 8))\n", " plt.hist(image.ravel(), bins = 256, color = 'orange', )\n", " plt.hist(image[:, :, 0].ravel(), bins = 256, color = 'red', alpha = 0.5)\n", " _ = plt.hist(image[:, :, 1].ravel(), bins = 256, color = 'Green', alpha = 0.5)\n", " _ = plt.hist(image[:, :, 2].ravel(), bins = 256, color = 'Blue', alpha = 0.5)\n", " _ = plt.xlabel('Intensity Value')\n", " _ = plt.ylabel('Count')\n", " _ = plt.legend(['Total', 'red_Channel', 'Green_Channel', 'Blue_Channel'])\n", " plt.show()\n", "\n", " plt.figure(figsize=(15, 8))\n", " h1=plt.hist(image.flat, bins=100, range=(0,255))\n", " plt.xlabel('image hisgram Values')\n", " plt.show()\n", "\n", " img = np.array(image)\n", "\n", " plt.figure(figsize=(15, 8))\n", " lab_img= cv2.cvtColor(img, cv2.COLOR_RGB2LAB)\n", " plt.imshow(lab_img)\n", " plt.xlabel('image after after convert from RGB_COLOR to LAB_COLOR')\n", " plt.show()\n", "\n", " plt.figure(figsize=(15, 8))\n", " l, a, b = cv2.split(lab_img)\n", " h2=plt.hist(l.flat, bins=100, range=(0,255))\n", " plt.xlabel('L chanel hisgram Values')\n", " plt.show()\n", "\n", "\n", " plt.figure(figsize=(15, 8))\n", " equ = cv2.equalizeHist(l)\n", " h3=plt.hist(equ.flat, bins=100, range=(0,255))\n", " plt.xlabel(' equalizeHist of L chanel hisgram Values')\n", " plt.show()\n", "\n", " plt.figure(figsize=(15, 8))\n", " updated_lab_img1 = cv2.merge((equ,a,b))\n", " hist_eq_img = cv2.cvtColor(updated_lab_img1, cv2.COLOR_LAB2BGR)\n", " h4=plt.hist(hist_eq_img.flat, bins=100, range=(0,255))\n", " plt.xlabel('image hisgram Values after equalizeHist of L')\n", " plt.show()\n", "\n", " plt.figure(figsize=(15, 8))\n", " plt.imshow(hist_eq_img)\n", " plt.xlabel('image after equalizeHist of L')\n", " plt.show()\n", " plt.axis('off')\n", "\n", "\n", " plt.figure(figsize=(15, 8))\n", " clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(8,8))\n", " clahe_img = clahe.apply(equ)\n", " h5=plt.hist(clahe_img.flat, bins=100, range=(0,255))\n", " plt.xlabel('clahe hisgram Values for equalizeHist ')\n", " plt.show()\n", "\n", " plt.figure(figsize=(15, 8))\n", " updated_lab_img2 = cv2.merge((clahe_img,a,b))\n", " CLAHE_img = cv2.cvtColor(updated_lab_img2, cv2.COLOR_LAB2BGR)\n", " h6=plt.hist(CLAHE_img.flat, bins=100, range=(0,255))\n", " plt.xlabel('image hisgram Values after CLAHE')\n", " plt.show()\n", "\n", " plt.figure(figsize=(15, 8))\n", " plt.imshow(CLAHE_img)\n", " plt.xlabel('image after CLAHE')\n", " plt.show()\n", " plt.axis('off')" ], "metadata": { "id": "9jLmV8occnhb" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def print_in_color(txt_msg,fore_tupple,back_tupple,):\n", " #prints the text_msg in the foreground color specified by fore_tupple with the background specified by back_tupple\n", " #text_msg is the text, fore_tupple is foregroud color tupple (r,g,b), back_tupple is background tupple (r,g,b)\n", " rf,gf,bf=fore_tupple\n", " rb,gb,bb=back_tupple\n", " msg='{0}' + txt_msg\n", " mat='\\33[38;2;' + str(rf) +';' + str(gf) + ';' + str(bf) + ';48;2;' + str(rb) + ';' +str(gb) + ';' + str(bb) +'m'\n", " print(msg .format(mat), flush=True)\n", " print('\\33[0m', flush=True) # returns default print color to back to black\n", " return" ], "metadata": { "id": "BUfy-n-nF4Hg" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "model_name='InceptionResNetV2'\n", "base_model=tf.keras.applications.InceptionResNetV2(include_top=False, weights=\"imagenet\",input_shape=img_shape, pooling='max')\n", "x=base_model.output\n", "x=keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001 )(x)\n", "x = Dense(256, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006),\n", " bias_regularizer=regularizers.l1(0.006) ,activation='relu')(x)\n", "x=Dropout(rate=.45, seed=123)(x)\n", "output=Dense(class_count, activation='softmax')(x)\n", "model=Model(inputs=base_model.input, outputs=output)\n", "model.compile(Adamax(lr=.001), loss='categorical_crossentropy', metrics=['accuracy'])" ], "metadata": { "id": "57iG1HuWcvb6" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "early_stopping=EarlyStopping(monitor='val_loss', patience=3)\n", "history = model.fit(\n", " train_gen,\n", " validation_data=valid_gen,\n", " epochs=20,\n", " callbacks=[early_stopping])\n", "#model.save_weights(\"model.h5\")\n", "model.save(\"InceptionResNetV2.hd5\")\n", "print(\"Saved model to disk\")" ], "metadata": { "id": "CGJbepX1c9Ni" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "acc = history.history['accuracy']\n", "val_acc = history.history['val_accuracy']\n", "loss = history.history['loss']\n", "val_loss = history.history['val_loss']\n", "\n", "plt.figure(figsize=(10, 10))\n", "plt.subplot(2, 1, 1)\n", "\n", "plt.plot(acc, label='Training Accuracy', color='r')\n", "plt.plot(val_acc, label='Validation Accuracy', color='b')\n", "\n", "\n", "plt.xticks(fontsize=14)\n", "plt.yticks(fontsize=14)\n", "plt.legend(loc='lower right', fontsize=13)\n", "plt.ylabel('Accuracy', fontsize=16, weight='bold')\n", "plt.title('Training & Validation Acc.', fontsize=16, weight='bold')\n", "\n", "plt.subplot(2, 1, 2)\n", "plt.plot(loss, label='Training Loss', color='r')\n", "plt.plot(val_loss, label='Validation Loss', color='b')\n", "plt.xticks(fontsize=14)\n", "plt.yticks(fontsize=14)\n", "plt.legend(loc='upper right', fontsize=13)\n", "plt.ylabel('Cross Entropy', fontsize=16, weight='bold')\n", "plt.title('Training & Validation Loss', fontsize=15, weight='bold')\n", "plt.xlabel('Epoch', fontsize=15, weight='bold')\n", "plt.show()" ], "metadata": { "id": "t5bhoBObd5oq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.metrics import classification_report\n", "\n", "Y_pred = model.predict(test_gen)\n", "y_pred = np.argmax(Y_pred, axis=1)\n", "\n", "print(classification_report(test_gen.labels, y_pred))" ], "metadata": { "id": "ozzMZDqeerH9" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import itertools\n", "def plot_confusion_matrix(cm, classes,\n", " normalize=True,\n", " title='Confusion matrix',\n", " cmap=plt.cm.Greys):\n", " if normalize:\n", " cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", " print(\"Normalized confusion matrix\")\n", " else:\n", " print('Confusion matrix, without normalization')\n", "\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title, weight='bold', fontsize=16)\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, classes, fontsize=14)\n", " plt.yticks(tick_marks, classes, fontsize=14)\n", "\n", " fmt = '.2f' if normalize else 'd'\n", " thresh = cm.max() / 2.\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], fmt),\n", " horizontalalignment=\"center\", fontsize=16, weight='bold',\n", " color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " plt.tight_layout()\n", " plt.ylabel('True label', fontsize=20, weight='bold')\n", " plt.xlabel('Predicted label', fontsize=16, weight='bold')\n", "\n", "# Compute confusion matrix\n", "cnf_matrix = confusion_matrix(test_gen.labels, y_pred)\n", "np.set_printoptions(precision=2)\n", "\n", "# Plot non-normalized confusion matrix\n", "plt.figure(figsize=(10, 10))\n", "plot_confusion_matrix(cnf_matrix, classes=['colon_aca', 'colon_n', 'lung_aca', 'lung_n', 'lung_scc'],normalize=False,\n", " title='Normalized Confusion Matrix')\n", "plt.show()" ], "metadata": { "id": "mOSTrPRFfBiG" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "model_name='EfficientNetB0'\n", "base_model=tf.keras.applications.EfficientNetB0(include_top=False, weights=\"imagenet\",input_shape=img_shape, pooling='max')\n", "x=base_model.output\n", "x=keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001 )(x)\n", "x = Dense(256, kernel_regularizer = regularizers.l2(l = 0.016),activity_regularizer=regularizers.l1(0.006),\n", " bias_regularizer=regularizers.l1(0.006) ,activation='relu')(x)\n", "x=Dropout(rate=.45, seed=123)(x)\n", "output=Dense(class_count, activation='softmax')(x)\n", "model=Model(inputs=base_model.input, outputs=output)\n", "model.compile(Adamax(lr=.001), loss='categorical_crossentropy', metrics=['accuracy'])" ], "metadata": { "id": "i6vpum3UDmu4" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "early_stopping=EarlyStopping(monitor='val_loss', patience=3)\n", "history = model.fit(\n", " train_gen,\n", " validation_data=valid_gen,\n", " epochs=30,\n", " callbacks=[early_stopping])\n", "model.save(\"EfficientNetB0.hd5\")\n", "#model.save_weights(\"model.h5\")\n", "print(\"Saved model to disk\")" ], "metadata": { "id": "O9p_4HB1DK8Z" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from sklearn.metrics import classification_report\n", "\n", "Y_pred = model.predict(test_gen)\n", "y_pred = np.argmax(Y_pred, axis=1)\n", "\n", "print(classification_report(test_gen.labels, y_pred))" ], "metadata": { "id": "yI93jrjQDtjf" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!pip install lime" ], "metadata": { "id": "B_eWK4ueFOCO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import lime\n", "from lime import lime_image\n", "explainer = lime_image.LimeImageExplainer()" ], "metadata": { "id": "B0BQcERMFVxY" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "explanation = explainer.explain_instance(x_batch[15], model.predict, top_labels=5, hide_color=0, num_samples=10000)" ], "metadata": { "id": "WpbbveojFiaS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from skimage.segmentation import mark_boundaries\n", "temp_1, mask_1 = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=10000, hide_rest=True)\n", "temp_2, mask_2 = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=10000, hide_rest=True)\n", "#plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))\n", "fig, (ax,ax1, ax2) = plt.subplots(1, 3, figsize=(15,15))\n", "ax.imshow(x_batch[15]/255)\n", "ax1.imshow(mark_boundaries(temp_1, mask_1))\n", "ax2.imshow(mark_boundaries(temp_2, mask_2))\n", "\n", "ax.axis('off')\n", "ax1.axis('off')\n", "ax2.axis('off')\n", "\n", "plt.savefig('mask_default.png')" ], "metadata": { "id": "vwmMeQS4FqwD" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import lime\n", "from lime import lime_image\n", "explainer = lime_image.LimeImageExplainer()\n", "explanation = explainer.explain_instance(x_batch[1], model.predict, top_labels=5, hide_color=0, num_samples=10000)\n" ], "metadata": { "id": "qZQ6XSH9F6Fo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from skimage.segmentation import mark_boundaries\n", "temp_1, mask_1 = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=10000, hide_rest=True)\n", "temp_2, mask_2 = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=10000, hide_rest=True)\n", "#plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))\n", "fig, (ax,ax1, ax2) = plt.subplots(1, 3, figsize=(15,15))\n", "ax.imshow(x_batch[1]/255)\n", "ax1.imshow(mark_boundaries(temp_1, mask_1))\n", "ax2.imshow(mark_boundaries(temp_2, mask_2))\n", "\n", "ax.axis('off')\n", "ax1.axis('off')\n", "ax2.axis('off')\n", "\n", "plt.savefig('mask_default.png')" ], "metadata": { "id": "sKVZ_mmqJzyv" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import shap\n", "masker = shap.maskers.Image(\"blur(28,28)\", x_batch[0].shape)\n", "explainer = shap.Explainer(model, masker, output_names=classes)\n", "explainer" ], "metadata": { "id": "7c-TskyvJ1Wz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "shap_values = explainer(x_batch[0:8], outputs=shap.Explanation.argsort.flip[:5])\n", "shap_values.shape" ], "metadata": { "id": "zpbiqLDzK2mx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(\"Actual Labels : {}\".format(np.argmax(y_batch[0:8], axis=1)))\n", "probs = model.predict(x_batch[0:8])\n", "print(\"Predicted Labels : {}\".format( np.argmax(probs, axis=1)))\n", "print(\"Probabilities : {}\".format(np.max(probs, axis=1)))" ], "metadata": { "id": "V1iwE6dzLVPw" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "KN2oWYdv3Cb8" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "shap.image_plot(shap_values)" ], "metadata": { "id": "34yczKROLoqj" }, "execution_count": null, "outputs": [] } ] }