{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cQlg2FWWk4ZQ", "outputId": "4c861533-b430-4fd5-9328-ee3d7efafa12" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting pytorch-tabnet\n", " Downloading https://files.pythonhosted.org/packages/94/e5/2a808d611a5d44e3c997c0d07362c04a56c70002208e00aec9eee3d923b5/pytorch_tabnet-3.1.1-py3-none-any.whl\n", "Requirement already satisfied: torch<2.0,>=1.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (1.9.0+cu102)\n", "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (1.19.5)\n", "Requirement already satisfied: tqdm<5.0,>=4.36 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (4.41.1)\n", "Requirement already satisfied: scipy>1.4 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (1.4.1)\n", "Requirement already satisfied: scikit_learn>0.21 in /usr/local/lib/python3.7/dist-packages (from pytorch-tabnet) (0.22.2.post1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch<2.0,>=1.2->pytorch-tabnet) (3.7.4.3)\n", "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit_learn>0.21->pytorch-tabnet) (1.0.1)\n", "Installing collected packages: pytorch-tabnet\n", "Successfully installed pytorch-tabnet-3.1.1\n" ] } ], "source": [ "pip install pytorch-tabnet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Jj1kswhSk7C_", "outputId": "ccbfbf40-898c-477a-9a20-acecaf5626d9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: imbalanced-learn in /usr/local/lib/python3.7/dist-packages (0.4.3)\n", "Requirement already satisfied: scikit-learn>=0.20 in /usr/local/lib/python3.7/dist-packages (from imbalanced-learn) (0.22.2.post1)\n", "Requirement already satisfied: numpy>=1.8.2 in /usr/local/lib/python3.7/dist-packages (from imbalanced-learn) (1.19.5)\n", "Requirement already satisfied: scipy>=0.13.3 in /usr/local/lib/python3.7/dist-packages (from imbalanced-learn) (1.4.1)\n", "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20->imbalanced-learn) (1.0.1)\n" ] } ], "source": [ " pip install imbalanced-learn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xImT45fqk7o1", "outputId": "2e05e48f-ac0f-47a4-cb7e-28019776a6b1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/sklearn/externals/six.py:31: FutureWarning: The module is deprecated in version 0.21 and will be removed in version 0.23 since we've dropped support for Python 2.7. Please rely on the official version of six (https://pypi.org/project/six/).\n", " \"(https://pypi.org/project/six/).\", FutureWarning)\n", "/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.neighbors.base module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.neighbors. Anything that cannot be imported from sklearn.neighbors is now part of the private API.\n", " warnings.warn(message, FutureWarning)\n" ] } ], "source": [ "import pandas as pd\n", "import imblearn\n", "from imblearn.over_sampling import ADASYN, SMOTE\n", "from collections import Counter\n", "#import umap\n", "from sklearn.manifold import TSNE\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.decomposition import FastICA, PCA, FactorAnalysis\n", "from sklearn.metrics import roc_auc_score, recall_score, f1_score, accuracy_score,confusion_matrix, precision_score, log_loss, roc_curve\n", "pd.set_option('display.max_columns', None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 226 }, "id": "bpzovsyWk_A_", "outputId": "bd928cce-cd3c-40c7-f332-258682ac2504" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGenderethnicityraceFeverCoughSOBFatigueSputumMyalgiaDiarrheaNausea.or.vomitingSore.throatRunny.nose.nasal.congestionLoss.of.smellLoss.of.tasteHeadahceChest.discomfort..Chest.painAsymptomaticsmoking_historyhypertensionhxdiabeteshxasthmahxcoronaryheartdiseasehxcopdhxheartfailurehxcarcinomahximmunosuppressionhxckdhxALTCRPD.dimerFerritinHRLDHLymphocytesSpO2ProcalcitoninRRSystolic.BPTemperatureTroponindeath
0230371110101100010100000000000231.741529.110913030.7970.062010237.20.010
1500110100011100000000110000000215.6150153.21081309.2940.081612737.30.010
2290171110010000000100000000000101.262542.48813315.1970.071910336.90.010
35712111100000000000000000000003412.2206373.96713315.4980.051815337.40.010
42602110000000000000000000000005114.3273107.710014718.3960.431412037.00.010
\n", "
" ], "text/plain": [ " Age Gender ethnicity race Fever Cough SOB Fatigue Sputum Myalgia \\\n", "0 23 0 3 7 1 1 1 0 1 0 \n", "1 50 0 1 1 0 1 0 0 0 1 \n", "2 29 0 1 7 1 1 1 0 0 1 \n", "3 57 1 2 1 1 1 1 0 0 0 \n", "4 26 0 2 1 1 0 0 0 0 0 \n", "\n", " Diarrhea Nausea.or.vomiting Sore.throat Runny.nose.nasal.congestion \\\n", "0 1 1 0 0 \n", "1 1 1 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Loss.of.smell Loss.of.taste Headahce Chest.discomfort..Chest.pain \\\n", "0 0 1 0 1 \n", "1 0 0 0 0 \n", "2 0 0 0 1 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Asymptomatic smoking_history hypertensionhx diabeteshx asthmahx \\\n", "0 0 0 0 0 0 \n", "1 0 0 1 1 0 \n", "2 0 0 0 0 0 \n", "3 0 0 0 0 0 \n", "4 0 0 0 0 0 \n", "\n", " coronaryheartdiseasehx copdhx heartfailurehx carcinomahx \\\n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " immunosuppressionhx ckdhx ALT CRP D.dimer Ferritin HR LDH \\\n", "0 0 0 23 1.7 415 29.1 109 130 \n", "1 0 0 21 5.6 150 153.2 108 130 \n", "2 0 0 10 1.2 625 42.4 88 133 \n", "3 0 0 34 12.2 206 373.9 67 133 \n", "4 0 0 51 14.3 273 107.7 100 147 \n", "\n", " Lymphocytes SpO2 Procalcitonin RR Systolic.BP Temperature Troponin \\\n", "0 30.7 97 0.06 20 102 37.2 0.01 \n", "1 9.2 94 0.08 16 127 37.3 0.01 \n", "2 15.1 97 0.07 19 103 36.9 0.01 \n", "3 15.4 98 0.05 18 153 37.4 0.01 \n", "4 18.3 96 0.43 14 120 37.0 0.01 \n", "\n", " death \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 " ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "data_death = pd.read_csv('DeadMICE.csv')\n", "data_death.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 226 }, "id": "6PzetCkOlCtd", "outputId": "159d3c18-7d2d-4e9c-f542-b14a9e865151" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGenderethnicityraceFeverCoughSOBFatigueSputumMyalgiaDiarrheaNausea_or_vomitingSore_throatRunny_nose_nasal_congestionLoss_of_smellLoss_of_tasteHeadahceChest_discomfort__Chest_painAsymptomaticsmoking_historyhypertensionhxdiabeteshxasthmahxcoronaryheartdiseasehxcopdhxheartfailurehxcarcinomahximmunosuppressionhxckdhxALTCRPD_dimerFerritinHRLDHLymphocytesSpO2ProcalcitoninRRSystolic_BPTemperatureTroponindeath
0230371110101100010100000000000231.741529.110913030.7970.062010237.20.010
1500110100011100000000110000000215.6150153.21081309.2940.081612737.30.010
2290171110010000000100000000000101.262542.48813315.1970.071910336.90.010
35712111100000000000000000000003412.2206373.96713315.4980.051815337.40.010
42602110000000000000000000000005114.3273107.710014718.3960.431412037.00.010
\n", "
" ], "text/plain": [ " Age Gender ethnicity race Fever Cough SOB Fatigue Sputum Myalgia \\\n", "0 23 0 3 7 1 1 1 0 1 0 \n", "1 50 0 1 1 0 1 0 0 0 1 \n", "2 29 0 1 7 1 1 1 0 0 1 \n", "3 57 1 2 1 1 1 1 0 0 0 \n", "4 26 0 2 1 1 0 0 0 0 0 \n", "\n", " Diarrhea Nausea_or_vomiting Sore_throat Runny_nose_nasal_congestion \\\n", "0 1 1 0 0 \n", "1 1 1 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Loss_of_smell Loss_of_taste Headahce Chest_discomfort__Chest_pain \\\n", "0 0 1 0 1 \n", "1 0 0 0 0 \n", "2 0 0 0 1 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Asymptomatic smoking_history hypertensionhx diabeteshx asthmahx \\\n", "0 0 0 0 0 0 \n", "1 0 0 1 1 0 \n", "2 0 0 0 0 0 \n", "3 0 0 0 0 0 \n", "4 0 0 0 0 0 \n", "\n", " coronaryheartdiseasehx copdhx heartfailurehx carcinomahx \\\n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " immunosuppressionhx ckdhx ALT CRP D_dimer Ferritin HR LDH \\\n", "0 0 0 23 1.7 415 29.1 109 130 \n", "1 0 0 21 5.6 150 153.2 108 130 \n", "2 0 0 10 1.2 625 42.4 88 133 \n", "3 0 0 34 12.2 206 373.9 67 133 \n", "4 0 0 51 14.3 273 107.7 100 147 \n", "\n", " Lymphocytes SpO2 Procalcitonin RR Systolic_BP Temperature Troponin \\\n", "0 30.7 97 0.06 20 102 37.2 0.01 \n", "1 9.2 94 0.08 16 127 37.3 0.01 \n", "2 15.1 97 0.07 19 103 36.9 0.01 \n", "3 15.4 98 0.05 18 153 37.4 0.01 \n", "4 18.3 96 0.43 14 120 37.0 0.01 \n", "\n", " death \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 " ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "data_death.columns = data_death.columns.str.replace('.','_')\n", "data_death = data_death.rename(columns={'Gender__female_0__male1_':'Gender'})\n", "data_death.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "At3ph4eklEFt" }, "outputs": [], "source": [ "X = data_death.drop(columns='death')\n", "y= data_death['death']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fSCSgMGllJp1", "outputId": "59942903-4c4c-4a2c-d846-579f8e85baa4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Resampled dataset shape Counter({0: 878, 1: 864})\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function safe_indexing is deprecated; safe_indexing is deprecated in version 0.22 and will be removed in version 0.24.\n", " warnings.warn(msg, category=FutureWarning)\n" ] } ], "source": [ "ada = ADASYN(random_state=0, n_neighbors=20)\n", "X_res, y_res = ada.fit_resample(X, y)\n", "print('Resampled dataset shape {}'.format(Counter(y_res)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "onAK4SoqN77t", "outputId": "19a68958-ff47-43f7-b1aa-4e4c2c5b0d0b" }, "outputs": [ { "data": { "text/plain": [ "(1742, 42)" ] }, "execution_count": 8, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "#reducer = umap.UMAP()\n", "transformer = FastICA(random_state=0)\n", "X_transformed = transformer.fit_transform(X_res)\n", "X_transformed.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2_OnuupTlL3Y" }, "outputs": [], "source": [ "X_train,X_test,y_train,y_test=train_test_split(X_transformed,y_res, test_size=0.1, random_state=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-JEFbrBLiXn9", "outputId": "da6616fe-6857-45aa-a9cf-9ca7e1381e0b" }, "outputs": [ { "data": { "text/plain": [ "96" ] }, "execution_count": 10, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "len(y_test[y_test == 0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 226 }, "id": "BTN0tftIFuGc", "outputId": "a64df591-2329-4cb9-f1ff-0b86642aa8ec" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeGenderethnicityraceFeverCoughSOBFatigueSputumMyalgiaDiarrheaNausea_or_vomitingSore_throatRunny_nose_nasal_congestionLoss_of_smellLoss_of_tasteHeadahceChest_discomfort__Chest_painAsymptomaticsmoking_historyhypertensionhxdiabeteshxasthmahxcoronaryheartdiseasehxcopdhxheartfailurehxcarcinomahximmunosuppressionhxckdhxALTCRPD_dimerFerritinHRLDHLymphocytesSpO2ProcalcitoninRRSystolic_BPTemperatureTroponindeath
0230371110101100010100000000000231.741529.110913030.7970.062010237.20.010
1500110100011100000000110000000215.6150153.21081309.2940.081612737.30.010
2290171110010000000100000000000101.262542.48813315.1970.071910336.90.010
35712111100000000000000000000003412.2206373.96713315.4980.051815337.40.010
42602110000000000000000000000005114.3273107.710014718.3960.431412037.00.010
\n", "
" ], "text/plain": [ " Age Gender ethnicity race Fever Cough SOB Fatigue Sputum Myalgia \\\n", "0 23 0 3 7 1 1 1 0 1 0 \n", "1 50 0 1 1 0 1 0 0 0 1 \n", "2 29 0 1 7 1 1 1 0 0 1 \n", "3 57 1 2 1 1 1 1 0 0 0 \n", "4 26 0 2 1 1 0 0 0 0 0 \n", "\n", " Diarrhea Nausea_or_vomiting Sore_throat Runny_nose_nasal_congestion \\\n", "0 1 1 0 0 \n", "1 1 1 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Loss_of_smell Loss_of_taste Headahce Chest_discomfort__Chest_pain \\\n", "0 0 1 0 1 \n", "1 0 0 0 0 \n", "2 0 0 0 1 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " Asymptomatic smoking_history hypertensionhx diabeteshx asthmahx \\\n", "0 0 0 0 0 0 \n", "1 0 0 1 1 0 \n", "2 0 0 0 0 0 \n", "3 0 0 0 0 0 \n", "4 0 0 0 0 0 \n", "\n", " coronaryheartdiseasehx copdhx heartfailurehx carcinomahx \\\n", "0 0 0 0 0 \n", "1 0 0 0 0 \n", "2 0 0 0 0 \n", "3 0 0 0 0 \n", "4 0 0 0 0 \n", "\n", " immunosuppressionhx ckdhx ALT CRP D_dimer Ferritin HR LDH \\\n", "0 0 0 23 1.7 415 29.1 109 130 \n", "1 0 0 21 5.6 150 153.2 108 130 \n", "2 0 0 10 1.2 625 42.4 88 133 \n", "3 0 0 34 12.2 206 373.9 67 133 \n", "4 0 0 51 14.3 273 107.7 100 147 \n", "\n", " Lymphocytes SpO2 Procalcitonin RR Systolic_BP Temperature Troponin \\\n", "0 30.7 97 0.06 20 102 37.2 0.01 \n", "1 9.2 94 0.08 16 127 37.3 0.01 \n", "2 15.1 97 0.07 19 103 36.9 0.01 \n", "3 15.4 98 0.05 18 153 37.4 0.01 \n", "4 18.3 96 0.43 14 120 37.0 0.01 \n", "\n", " death \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 " ] }, "execution_count": 11, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "data_death.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X_transformed, y_res, \n", " test_size=0.1, random_state=10)\n", "target = data_death['death']\n", "skf = StratifiedKFold(n_splits=5)\n", "fold_no = 1\n", "for train_index, test_index in skf.split(data_death, target):\n", " train = data_death.loc[train_index,:]\n", " test = data_death.loc[test_index,:]\n", " print('Fold',str(fold_no),'Class Ratio:',sum(test['death'])/len(test['death']))\n", " fold_no += 1\n", "def train_model(train, test, fold_no):\n", " \n", " \n", " clf = TabNetClassifier(n_d=64,n_shared=2,mask_type='sparsemax', momentum=0.3,n_steps=3,n_independent=2,lambda_sparse=0.003,gamma=2.8)\n", " clf.fit(\n", " X_train=X_train, y_train=y_train,\n", " eval_set=[(X_train, y_train), (X_test, y_test)],\n", " eval_name=['train', 'test'],\n", " eval_metric=['auc'], max_epochs=150, batch_size=128, patience=60\n", " \n", " )\n", " predictions = clf.predict(X_test)\n", " #print('Fold',str(fold_no),'Accuracy:',recall_score(y_test,predictions))\n", " print('Roc score is',np.mean(roc_auc_score(y_test,predictions)))\n", " print('f1 score is'np.mean(f1_score(y_test,predictions)))\n", " print('accuracy score is'np.mean(accuracy_score(y_test,predictions)))\n", " print('recall score is'np.mean(recall_score(y_test,predictions)))\n", " print('precision score is'np.mean(precision_score(y_test,predictions))) \n", "import numpy as np\n", "\n", "fold_no = 1\n", "for train_index, test_index in skf.split(data_death, target):\n", " train = data_death.loc[train_index,:]\n", " test = data_death.loc[test_index,:]\n", " train_model(train,test,fold_no)\n", " fold_no += 1\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tab_pred = clf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GPmQo0pGIowW" }, "outputs": [], "source": [ "# plot auc\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from google.colab import files\n", "\n", "plt.plot(clf.history['train_auc'], label='Training accuracy')\n", "plt.plot(clf.history['test_auc'],label='Validation accuracy')\n", "plt.rcParams[\"figure.figsize\"]=(10, 10)\n", "plt.title('Accuracy score for training and testing')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "plt.savefig('accdead.pdf', transparent=True,dphi=300)\n", "files.download(\"accdead.pdf\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-aeYbmb5IozP" }, "outputs": [], "source": [ "# plot losses\n", "import matplotlib.pyplot as plt\n", "plt.plot(clf.history['loss'])\n", "plt.rcParams[\"figure.figsize\"]=(10, 10)\n", "plt.title('Model Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "plt.savefig('lossdead.pdf', transparent=True,dphi=300)\n", "files.download(\"lossdead.pdf\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R8ia2m_AIo14" }, "outputs": [], "source": [ "import numpy as np\n", "Rf = RandomForestClassifier()\n", "Rf.fit(X_train, y_train)\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from google.colab import files\n", "\n", "def plot_feature_importance(importance,names,model_type):\n", "\n", "\n", " #Create arrays from feature importance and feature names\n", " feature_importance = np.array(importance)\n", " feature_names = np.array(names)\n", "\n", " #Create a DataFrame using a Dictionary\n", " data_icu={'feature_names':feature_names,'feature_importance':feature_importance}\n", " fi_df = pd.DataFrame(data_icu)\n", "\n", " #Sort the DataFrame in order decreasing feature importance\n", " fi_df.sort_values(by=['feature_importance'], ascending=False,inplace=True,)\n", "\n", " #Define size of bar plot\n", " plt.figure(figsize=(11,11))\n", " #Plot Searborn bar chart\n", " sns.barplot(x=fi_df['feature_importance'], y=fi_df['feature_names'])\n", " #Add chart labels\n", " plt.title('FEATURE IMPORTANCE')\n", " plt.xlabel('FEATURE IMPORTANCE')\n", " plt.ylabel('FEATURES')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_feature_importance(clf.feature_importances_, X.columns,'Tabnet')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rPOycy7wwqew" }, "outputs": [], "source": [ "clf_probs = clf.predict_proba(X_test)\n", "clf_probs = clf_probs[:, 1]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4zA5NNjRxqwA" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.metrics import precision_recall_curve\n", "from sklearn.metrics import auc" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kwgewoc2wtkR" }, "outputs": [], "source": [ "clf_precision, clf_recall, _ = precision_recall_curve(y_test, clf_probs)\n", "clf_f1, clf_auc = f1_score(y_test, tab_pred), auc(clf_recall, clf_precision)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "13qjBp4twvCw" }, "outputs": [], "source": [ "Adasyn_icu = len(y_test[y_test==1]) / len(y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N66BO_3EwwiX" }, "outputs": [], "source": [ "plt.plot(clf_recall, clf_precision, marker='.',)\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "plt.title('Precision-Recall Curve for Proposed Model')\n", "# show the legend\n", "plt.legend()\n", "# show the plot\n", "\n", "plt.savefig('precalldead.pdf', transparent=True,dphi=300)\n", "files.download(\"precalldead.pdf\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Et7E10nh9pF" }, "outputs": [], "source": [ "explain_matrix, masks = clf.explain(X_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wi4bohTJh9s9" }, "outputs": [], "source": [ "from google.colab import files" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G5n8c7a7h9vN" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "fig, axis = plt.subplots(2,1, figsize=(15,15))\n", "for i in range(2):\n", " axis[i].imshow(masks[i][:20])\n", " axis[i].set_title(f'masks{i}')\n", " fig.suptitle('Feature Importance Masks')\n", "plt.savefig('masksdead.pdf', transparent=True,dphi=300)\n", "files.download(\"masksdead.pdf\") " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BmaHziAKtmzg" }, "outputs": [], "source": [ "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n", "cm = confusion_matrix(y_test, tab_pred)\n", "cmd = ConfusionMatrixDisplay(cm, display_labels=['Death','No Death'])\n", "cmd.plot()\n", "plt.savefig('cmdead.pdf', transparent=True,dphi=300)\n", "files.download(\"cmdead.pdf\") " ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "deaths.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 1 }