{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TvQvDlAfM4Sr", "outputId": "93516842-f87e-4ab3-93f7-90d041d6faff" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting pytorch-tabnet\n", " Downloading https://files.pythonhosted.org/packages/94/e5/2a808d611a5d44e3c997c0d07362c04a56c70002208e00aec9eee3d923b5/pytorch_tabnet-3.1.1-py3-none-any.whl\n", "Requirement already satisfied: torch<2.0,>=1.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (1.9.0+cu102)\n", "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (1.19.5)\n", "Requirement already satisfied: tqdm<5.0,>=4.36 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (4.41.1)\n", "Requirement already satisfied: scipy>1.4 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (1.4.1)\n", "Requirement already satisfied: scikit_learn>0.21 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (0.22.2.post1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch<2.0,>=1.2->pytorch-tabnet) (3.7.4.3)\n", "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit_learn>0.21->pytorch-tabnet) (1.0.1)\n", "Installing collected packages: pytorch-tabnet\n", "Successfully installed pytorch-tabnet-3.1.1\n" ] } ], "source": [ "pip install pytorch-tabnet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0fLPTFDANJvy", "outputId": "9bd6bc9e-a04b-49cf-eb85-296384ac62d5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: imbalanced-learn in /usr/local/lib/python3.7/dist-packages (0.4.3)\n", "Requirement already satisfied: scikit-learn>=0.20 in /usr/local/lib/python3.7/dist-packages (from imbalanced-learn) (0.22.2.post1)\n", "Requirement already satisfied: scipy>=0.13.3 in /usr/local/lib/python3.7/dist-packages (from imbalanced-learn) (1.4.1)\n", "Requirement already satisfied: numpy>=1.8.2 in /usr/local/lib/python3.7/dist-packages (from imbalanced-learn) (1.19.5)\n", "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20->imbalanced-learn) (1.0.1)\n" ] } ], "source": [ " pip install imbalanced-learn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TtSu1A6q7jJJ", "outputId": "a664b1d0-5861-4f79-a6e6-95df2e4a543c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tsne\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/4c/ea/f4deb02eb49bbc7006624398d3909bbb43cd93efd58e66b74320b5530edd/tsne-0.3.1.tar.gz (547kB)\n", "\r", "\u001b[K |▋ | 10kB 6.8MB/s eta 0:00:01\r", "\u001b[K |█▏ | 20kB 11.2MB/s eta 0:00:01\r", "\u001b[K |█▉ | 30kB 15.0MB/s eta 0:00:01\r", "\u001b[K |██▍ | 40kB 18.1MB/s eta 0:00:01\r", "\u001b[K |███ | 51kB 20.6MB/s eta 0:00:01\r", "\u001b[K |███▋ | 61kB 22.9MB/s eta 0:00:01\r", "\u001b[K |████▏ | 71kB 22.1MB/s eta 0:00:01\r", "\u001b[K |████▉ | 81kB 22.9MB/s eta 0:00:01\r", "\u001b[K |█████▍ | 92kB 23.9MB/s eta 0:00:01\r", "\u001b[K |██████ | 102kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████▋ | 112kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████▏ | 122kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████▉ | 133kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████▍ | 143kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████ | 153kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████▋ | 163kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████▏ | 174kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████▊ | 184kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████▍ | 194kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████ | 204kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████▋ | 215kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████▏ | 225kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████▊ | 235kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████▍ | 245kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████ | 256kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████▋ | 266kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████▏ | 276kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████▊ | 286kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████████▍ | 296kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████ | 307kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████▌ | 317kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████▏ | 327kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████▊ | 337kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████████▍ | 348kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████████████ | 358kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████████████▌ | 368kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████████▏ | 378kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████████▊ | 389kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████████▍ | 399kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████████████ | 409kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████████████▌ | 419kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████████████████▏ | 430kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████████████████▊ | 440kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████████████▎ | 450kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████████████ | 460kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████████████▌ | 471kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████████████████▏ | 481kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████████████████▊ | 491kB 24.7MB/s eta 0:00:01\r", "\u001b[K |█████████████████████████████▎ | 501kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████████████████ | 512kB 24.7MB/s eta 0:00:01\r", "\u001b[K |██████████████████████████████▌ | 522kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████████████████▏| 532kB 24.7MB/s eta 0:00:01\r", "\u001b[K |███████████████████████████████▊| 542kB 24.7MB/s eta 0:00:01\r", "\u001b[K |████████████████████████████████| 552kB 24.7MB/s \n", "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: numpy>=1.7.1 in /usr/local/lib/python3.7/dist-packages (from tsne) (1.19.5)\n", "Requirement already satisfied: scipy>=0.12.0 in /usr/local/lib/python3.7/dist-packages (from tsne) (1.4.1)\n", "Requirement already satisfied: cython>=0.19.1 in /usr/local/lib/python3.7/dist-packages (from tsne) (0.29.23)\n", "Building wheels for collected packages: tsne\n", " Building wheel for tsne (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for tsne: filename=tsne-0.3.1-cp37-cp37m-linux_x86_64.whl size=260482 sha256=e979c270ac532f6a8424de16caa61797ab75017d9683c488fe286210dfd976d7\n", " Stored in directory: /root/.cache/pip/wheels/3e/d6/fc/58392f18ea8fc4c74e20185d2faeee87a1c1924a182606c6cd\n", "Successfully built tsne\n", "Installing collected packages: tsne\n", "Successfully installed tsne-0.3.1\n" ] } ], "source": [ "pip install tsne" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZdQJ5v4bNNbB", "outputId": "cea7db10-6176-4dae-a954-5f4f61dcbade" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/sklearn/externals/six.py:31: FutureWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).\n", " \"(https://pypi.org/project/six/).\", FutureWarning)\n", "/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.neighbors.base module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.neighbors. Anything that cannot be imported from sklearn.neighbors is now part of the private API.\n", " warnings.warn(message, FutureWarning)\n" ] } ], "source": [ "import pandas as pd\n", "#from tabnet import TabNet, TabNetClassifier\n", "#from tabnet import TabNet, TabNetClassifier\n", "import tensorflow as tf\n", "from tensorflow.keras.models import Model\n", "from tensorflow.keras.layers import Input, Dense\n", "import imblearn\n", "from imblearn.over_sampling import ADASYN, SMOTE\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from collections import Counter\n", "#import umap\n", "import random \n", "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.model_selection import cross_val_score\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.manifold import TSNE\n", "from sklearn.decomposition import FastICA, PCA, FactorAnalysis\n", "from sklearn.metrics import roc_auc_score, recall_score, f1_score, classification_report, accuracy_score,roc_curve, confusion_matrix, auc,precision_score, log_loss\n", "pd.set_option('display.max_columns', None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 226 }, "id": "LFMmD6WGNPEh", "outputId": "cde67d9b-46d9-490f-fecf-64a798d3813c" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGender..female.0..male1.ethnicityraceFeverCoughSOBFatigueSputumMyalgiaDiarrheaNausea.VomitingSore.throatRunny.nose.Nasal.congestionLoss.of.smellLoss.of.tasteHeadahceChest.discomfort..chest.painAsymptomaticsmoking_historyhypertensionhxdiabeteshxasthmahxcoronaryheartdiseasehxcopdhxheartfailurehxcarcinomahximmunosuppressionhxckdhxALTCRPD.dimerFerritinHRLDHLymphocyteSpO2ProcalcitoninRRSystolic.BPTemperatureTroponinICU.or.not
0230210000000000000010000000000312.2857203.911734540.6960.28188937.00.011
12513711011100000011000000000003711.7183587.11324836.8980.642810739.50.011
22812111100011110001000000000001822.0389932.112456210.4881.811712339.10.011
328117111000000000000000000000018310.53702068.012370613.4940.922012038.90.011
43012111010000000000000000000002186.73492141.010168817.9980.371513037.10.011
\n", "
" ], "text/plain": [ " Age Gender..female.0..male1. ethnicity race Fever Cough SOB Fatigue \\\n", "0 23 0 2 1 0 0 0 0 \n", "1 25 1 3 7 1 1 0 1 \n", "2 28 1 2 1 1 1 1 0 \n", "3 28 1 1 7 1 1 1 0 \n", "4 30 1 2 1 1 1 0 1 \n", "\n", " Sputum Myalgia Diarrhea Nausea.Vomiting Sore.throat \\\n", "0 0 0 0 0 0 \n", "1 1 1 0 0 0 \n", "2 0 0 1 1 1 \n", "3 0 0 0 0 0 \n", "4 0 0 0 0 0 \n", "\n", " Runny.nose.Nasal.congestion Loss.of.smell Loss.of.taste Headahce \\\n", "0 0 0 0 0 \n", "1 0 0 0 1 \n", "2 1 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Chest.discomfort..chest.pain Asymptomatic smoking_history \\\n", "0 0 1 0 \n", "1 1 0 0 \n", "2 1 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " hypertensionhx diabeteshx asthmahx coronaryheartdiseasehx copdhx \\\n", "0 0 0 0 0 0 \n", "1 0 0 0 0 0 \n", "2 0 0 0 0 0 \n", "3 0 0 0 0 0 \n", "4 0 0 0 0 0 \n", "\n", " heartfailurehx carcinomahx immunosuppressionhx ckdhx ALT CRP \\\n", "0 0 0 0 0 31 2.2 \n", "1 0 0 0 0 37 11.7 \n", "2 0 0 0 0 18 22.0 \n", "3 0 0 0 0 183 10.5 \n", "4 0 0 0 0 218 6.7 \n", "\n", " D.dimer Ferritin HR LDH Lymphocyte SpO2 Procalcitonin RR \\\n", "0 857 203.9 117 345 40.6 96 0.28 18 \n", "1 183 587.1 132 483 6.8 98 0.64 28 \n", "2 389 932.1 124 562 10.4 88 1.81 17 \n", "3 370 2068.0 123 706 13.4 94 0.92 20 \n", "4 349 2141.0 101 688 17.9 98 0.37 15 \n", "\n", " Systolic.BP Temperature Troponin ICU.or.not \n", "0 89 37.0 0.01 1 \n", "1 107 39.5 0.01 1 \n", "2 123 39.1 0.01 1 \n", "3 120 38.9 0.01 1 \n", "4 130 37.1 0.01 1 " ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "data_icu = pd.read_csv('ICUMICE2754.csv')\n", "data_icu.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 226 }, "id": "oUKQ3JWMNYCC", "outputId": "3e50789e-101d-4d1c-f2cb-9c5f8312aa2e" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGenderethnicityraceFeverCoughSOBFatigueSputumMyalgiaDiarrheaNausea_VomitingSore_throatRunny_nose_Nasal_congestionLoss_of_smellLoss_of_tasteHeadahceChest_discomfort__chest_painAsymptomaticsmoking_historyhypertensionhxdiabeteshxasthmahxcoronaryheartdiseasehxcopdhxheartfailurehxcarcinomahximmunosuppressionhxckdhxALTCRPD_dimerFerritinHRLDHLymphocyteSpO2ProcalcitoninRRSystolic_BPTemperatureTroponinICU_or_not
0230210000000000000010000000000312.2857203.911734540.6960.28188937.00.011
12513711011100000011000000000003711.7183587.11324836.8980.642810739.50.011
22812111100011110001000000000001822.0389932.112456210.4881.811712339.10.011
328117111000000000000000000000018310.53702068.012370613.4940.922012038.90.011
43012111010000000000000000000002186.73492141.010168817.9980.371513037.10.011
\n", "
" ], "text/plain": [ " Age Gender ethnicity race Fever Cough SOB Fatigue Sputum Myalgia \\\n", "0 23 0 2 1 0 0 0 0 0 0 \n", "1 25 1 3 7 1 1 0 1 1 1 \n", "2 28 1 2 1 1 1 1 0 0 0 \n", "3 28 1 1 7 1 1 1 0 0 0 \n", "4 30 1 2 1 1 1 0 1 0 0 \n", "\n", " Diarrhea Nausea_Vomiting Sore_throat Runny_nose_Nasal_congestion \\\n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 1 1 1 1 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Loss_of_smell Loss_of_taste Headahce Chest_discomfort__chest_pain \\\n", "0 0 0 0 0 \n", "1 0 0 1 1 \n", "2 0 0 0 1 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Asymptomatic smoking_history hypertensionhx diabeteshx asthmahx \\\n", "0 1 0 0 0 0 \n", "1 0 0 0 0 0 \n", "2 0 0 0 0 0 \n", "3 0 0 0 0 0 \n", "4 0 0 0 0 0 \n", "\n", " coronaryheartdiseasehx copdhx heartfailurehx carcinomahx \\\n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " immunosuppressionhx ckdhx ALT CRP D_dimer Ferritin HR LDH \\\n", "0 0 0 31 2.2 857 203.9 117 345 \n", "1 0 0 37 11.7 183 587.1 132 483 \n", "2 0 0 18 22.0 389 932.1 124 562 \n", "3 0 0 183 10.5 370 2068.0 123 706 \n", "4 0 0 218 6.7 349 2141.0 101 688 \n", "\n", " Lymphocyte SpO2 Procalcitonin RR Systolic_BP Temperature Troponin \\\n", "0 40.6 96 0.28 18 89 37.0 0.01 \n", "1 6.8 98 0.64 28 107 39.5 0.01 \n", "2 10.4 88 1.81 17 123 39.1 0.01 \n", "3 13.4 94 0.92 20 120 38.9 0.01 \n", "4 17.9 98 0.37 15 130 37.1 0.01 \n", "\n", " ICU_or_not \n", "0 1 \n", "1 1 \n", "2 1 \n", "3 1 \n", "4 1 " ] }, "execution_count": 6, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "data_icu.columns = data_icu.columns.str.replace('.','_')\n", "data_icu = data_icu.rename(columns={'Gender__female_0__male1_':'Gender'})\n", "data_icu.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jVJ62zq1RdEa" }, "outputs": [], "source": [ "X = data_icu.drop(columns='ICU_or_not')\n", "y= data_icu['ICU_or_not']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5iLr69ziPxqy", "outputId": "17191d97-b7f7-4158-8edb-9408fd054b47" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Resampled dataset shape Counter({1: 836, 0: 835})\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", " warnings.warn(msg, category=FutureWarning)\n" ] } ], "source": [ "ada = ADASYN(random_state=0, n_neighbors=20)\n", "X_res, y_res = ada.fit_resample(X, y)\n", "print('Resampled dataset shape {}'.format(Counter(y_res)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0ild212ijoEG", "outputId": "aea4893c-461f-4e16-c655-043e2f2aaca0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Resampled dataset shape Counter({1: 836, 0: 835})\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", " warnings.warn(msg, category=FutureWarning)\n" ] } ], "source": [ "smote = SMOTE(random_state=0)\n", "X_resm, y_resm = ada.fit_resample(X, y)\n", "print('Resampled dataset shape {}'.format(Counter(y_resm)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xlBN3hsNI3Zh", "outputId": "97ee2b34-a0bf-49aa-90cd-97db84a7b67c" }, "outputs": [ { "data": { "text/plain": [ "(1671, 42)" ] }, "execution_count": 10, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "transformer = FastICA(random_state=0)\n", "X_transformed = transformer.fit_transform(X_res)\n", "X_transformed.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-6fRKosCkAH4", "outputId": "f39af7a7-df42-4e88-8ca9-98cef56ab6de" }, "outputs": [ { "data": { "text/plain": [ "(1671, 42)" ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "transformer = PCA(random_state=0)\n", "X_transformed0 = transformer.fit_transform(X_res)\n", "X_transformed0.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yRHyXaoHNu_L" }, "outputs": [], "source": [ "from pytorch_tabnet.tab_model import TabNetClassifier" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X_transformed, y_res, \n", " test_size=0.1, random_state=80)\n", "target = data_icu['ICU_or_not']\n", "skf = StratifiedKFold(n_splits=5)\n", "fold_no = 1\n", "for train_index, test_index in skf.split(data_icu, target):\n", " train = data_icu.loc[train_index,:]\n", " test = data_icu.loc[test_index,:]\n", " print('Fold',str(fold_no),'Class Ratio:',sum(test['ICU_or_not'])/len(test['ICU_or_not']))\n", " fold_no += 1\n", "def train_model(train, test, fold_no):\n", " \n", " \n", " clf = TabNetClassifier(n_d=64,n_shared=2,mask_type='entmax', momentum=0.3,n_steps=3,n_independent=2,lambda_sparse=0.003,gamma=2.7)\n", " clf.fit(\n", " X_train=X_train, y_train=y_train,\n", " eval_set=[(X_train, y_train), (X_test, y_test)],\n", " eval_name=['train', 'test'],\n", " eval_metric=['auc'], max_epochs=200, batch_size=128, patience=60\n", " \n", " )\n", " predictions = clf.predict(X_test)\n", " #print('Fold',str(fold_no),'Accuracy:',recall_score(y_test,predictions))\n", " print('Roc score is',np.mean(roc_auc_score(y_test,predictions)))\n", " print('f1 score is'np.mean(f1_score(y_test,predictions)))\n", " print('accuracy score is'np.mean(accuracy_score(y_test,predictions)))\n", " print('recall score is'np.mean(recall_score(y_test,predictions)))\n", " print('precision score is'np.mean(precision_score(y_test,predictions))) \n", "import numpy as np\n", "\n", "fold_no = 1\n", "for train_index, test_index in skf.split(data_icu, target):\n", " train = data_icu.loc[train_index,:]\n", " test = data_icu.loc[test_index,:]\n", " train_model(train,test,fold_no)\n", " fold_no += 1 " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tab_pred = clf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 313 }, "id": "oqbEgJJNFNi-", "outputId": "157c3fb2-4b78-46f2-a7f8-902a2f8684e5" }, "outputs": [ { "data": { "application/javascript": [ "\n", " async function download(id, filename, size) {\n", " if (!google.colab.kernel.accessAllowed) {\n", " return;\n", " }\n", " const div = document.createElement('div');\n", " const label = document.createElement('label');\n", " label.textContent = `Downloading \"${filename}\": `;\n", " div.appendChild(label);\n", " const progress = document.createElement('progress');\n", " progress.max = size;\n", " div.appendChild(progress);\n", " document.body.appendChild(div);\n", "\n", " const buffers = [];\n", " let downloaded = 0;\n", "\n", " const channel = await google.colab.kernel.comms.open(id);\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", "\n", " for await (const message of channel.messages) {\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", " if (message.buffers) {\n", " for (const buffer of message.buffers) {\n", " buffers.push(buffer);\n", " downloaded += buffer.byteLength;\n", " progress.value = downloaded;\n", " }\n", " }\n", " }\n", " const blob = new Blob(buffers, {type: 'application/binary'});\n", " const a = document.createElement('a');\n", " a.href = window.URL.createObjectURL(blob);\n", " a.download = filename;\n", " div.appendChild(a);\n", " a.click();\n", " div.remove();\n", " }\n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/javascript": [ "download(\"download_897462b8-3bb2-438a-a17d-2d874801eb43\", \"trvalicu.pdf\", 16269)" ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "# plot auc\n", "import matplotlib.pyplot as plt\n", "from google.colab import files\n", "\n", "plt.plot(clf.history['train_auc'], label='Training accuracy')\n", "plt.plot(clf.history['test_auc'],label='Validation accuracy')\n", "plt.rcParams[\"figure.figsize\"]=(8, 7)\n", "plt.title('Accuracy score for training and testing')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.savefig('trvalicu.pdf', transparent=True,dphi=300)\n", "files.download(\"trvalicu.pdf\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mIvc97zgB_66" }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from google.colab import files\n", "\n", "def plot_feature_importance(importance,names,model_type):\n", "\n", "\n", " #Create arrays from feature importance and feature names\n", " feature_importance = np.array(importance)\n", " feature_names = np.array(names)\n", "\n", " #Create a DataFrame using a Dictionary\n", " data_icu={'feature_names':feature_names,'feature_importance':feature_importance}\n", " fi_df = pd.DataFrame(data_icu)\n", "\n", " #Sort the DataFrame in order decreasing feature importance\n", " fi_df.sort_values(by=['feature_importance'], ascending=False,inplace=True,)\n", "\n", " #Define size of bar plot\n", " plt.figure(figsize=(11,11))\n", " #Plot Searborn bar chart\n", " sns.barplot(x=fi_df['feature_importance'], y=fi_df['feature_names'])\n", " #Add chart labels\n", " plt.title('FEATURE IMPORTANCE')\n", " plt.xlabel('FEATURE IMPORTANCE')\n", " plt.ylabel('FEATURES')\n", " plt.savefig('featimpicu.pdf', transparent=True,dphi=300)\n", " files.download(\"featimpicu.pdf\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 676 }, "id": "2baUV6GVCKqM", "outputId": "09c7c5fa-2395-4636-b09e-1bfd4c11bf9e" }, "outputs": [ { "data": { "application/javascript": [ "\n", " async function download(id, filename, size) {\n", " if (!google.colab.kernel.accessAllowed) {\n", " return;\n", " }\n", " const div = document.createElement('div');\n", " const label = document.createElement('label');\n", " label.textContent = `Downloading \"${filename}\": `;\n", " div.appendChild(label);\n", " const progress = document.createElement('progress');\n", " progress.max = size;\n", " div.appendChild(progress);\n", " document.body.appendChild(div);\n", "\n", " const buffers = [];\n", " let downloaded = 0;\n", "\n", " const channel = await google.colab.kernel.comms.open(id);\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", "\n", " for await (const message of channel.messages) {\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", " if (message.buffers) {\n", " for (const buffer of message.buffers) {\n", " buffers.push(buffer);\n", " downloaded += buffer.byteLength;\n", " progress.value = downloaded;\n", " }\n", " }\n", " }\n", " const blob = new Blob(buffers, {type: 'application/binary'});\n", " const a = document.createElement('a');\n", " a.href = window.URL.createObjectURL(blob);\n", " a.download = filename;\n", " div.appendChild(a);\n", " a.click();\n", " div.remove();\n", " }\n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/javascript": [ "download(\"download_1182c1d9-b9fe-4d1b-bb44-9e7e000b34ee\", \"featimpicu.pdf\", 20683)" ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "plot_feature_importance(clf.feature_importances_, X.columns,'TabNet')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 475 }, "id": "vKVWHmhnQdcK", "outputId": "6bcc53d5-0e3d-48d5-eba3-3433a11a3d3e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No handles with labels found to put in legend.\n" ] }, { "data": { "application/javascript": [ "\n", " async function download(id, filename, size) {\n", " if (!google.colab.kernel.accessAllowed) {\n", " return;\n", " }\n", " const div = document.createElement('div');\n", " const label = document.createElement('label');\n", " label.textContent = `Downloading \"${filename}\": `;\n", " div.appendChild(label);\n", " const progress = document.createElement('progress');\n", " progress.max = size;\n", " div.appendChild(progress);\n", " document.body.appendChild(div);\n", "\n", " const buffers = [];\n", " let downloaded = 0;\n", "\n", " const channel = await google.colab.kernel.comms.open(id);\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", "\n", " for await (const message of channel.messages) {\n", " // Send a message to notify the kernel that we're ready.\n", " channel.send({})\n", " if (message.buffers) {\n", " for (const buffer of message.buffers) {\n", " buffers.push(buffer);\n", " downloaded += buffer.byteLength;\n", " progress.value = downloaded;\n", " }\n", " }\n", " }\n", " const blob = new Blob(buffers, {type: 'application/binary'});\n", " const a = document.createElement('a');\n", " a.href = window.URL.createObjectURL(blob);\n", " a.download = filename;\n", " div.appendChild(a);\n", " a.click();\n", " div.remove();\n", " }\n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "application/javascript": [ "download(\"download_35498819-6642-466f-beb1-7fabbcd50853\", \"iculoss.pdf\", 11913)" ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "# plot losses\n", "import matplotlib.pyplot as plt\n", "plt.plot(clf.history['loss'])\n", "plt.rcParams[\"figure.figsize\"]=(10, 10)\n", "plt.title('Model Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.legend();\n", "plt.savefig('iculoss.pdf', dpi=300)\n", "files.download(\"iculoss.pdf\")\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "IFHkqDLGSTar" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.metrics import precision_recall_curve\n", "from sklearn.metrics import auc\n", "import seaborn as sns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clf_probs = clf.predict_proba(X_test)\n", "clf_probs = clf_probs[:, 1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "ASjv-DaKg4_D" }, "outputs": [], "source": [ "clf_precision, clf_recall, _ = precision_recall_curve(y_test, clf_probs)\n", "clf_f1, clf_auc = f1_score(y_test, tab_pred), auc(clf_recall, clf_precision)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "28nML_luhHng" }, "outputs": [], "source": [ "print('TabNet: f1=%.3f auc=%.3f' % (clf_f1, clf_auc))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "asjPHeRZhVrx" }, "outputs": [], "source": [ "plt.plot(clf_recall, clf_precision, marker='.',)\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "plt.title('Precision-Recall Curve for Proposed Model')\n", "plt.legend()\n", "plt.savefig('precallicu.pdf', transparent=True,dphi=300)\n", "files.download(\"precallicu.pdf\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CTBJWK-c6eto" }, "outputs": [], "source": [ "explain_matrix, masks = clf.explain(X_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g-V0uFfXZK4P" }, "outputs": [], "source": [ "from google.colab import files" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "743Uml8oFI6N" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "fig, axis = plt.subplots(2,1, figsize=(15,15))\n", "for i in range(2):\n", " axis[i].imshow(masks[i][:20])\n", " axis[i].set_title(f'masks{i}')\n", " fig.suptitle('Feature Importance Masks')\n", "plt.savefig('masksicu.pdf', transparent=True,dphi=300)\n", "files.download(\"masksicu.pdf\") " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DsQ7ioIPo_Bn" }, "outputs": [], "source": [ "cm = confusion_matrix(y_test, tab_pred)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "55Dzi5snqpYi" }, "outputs": [], "source": [ "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n", "\n", "cmd = ConfusionMatrixDisplay(cm, display_labels=['ICU','No ICU'])\n", "cmd.plot()\n", "plt.savefig('confmatrixicu.pdf', transparent=True,dphi=300)\n", "files.download(\"confmatrixicu.pdf\") " ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "ICU.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 1 }