{ "cells": [ { "cell_type": "markdown", "id": "c6d265a8", "metadata": {}, "source": [ "# ***Heart Disease Prediction Using Machine Learning***\n", "\n", "***Using Python's Data Science and Machine Learning Libraries, building a machine learning model capable to classify whether or not a person is having a heart disease based upon his/her medical attributes.***\n", "\n", "***Our approach will be :*** \n", " * ***1. Defining the problem***\n", " * ***2. Data***\n", " * ***3. Evaluation***\n", " * ***4. Features***\n", " * ***5. Modelling***\n", " * ***6. Experimentation***" ] }, { "cell_type": "markdown", "id": "ba0b70fc", "metadata": {}, "source": [ "## ***What is classification?***\n", "\n", "Classification involves deciding whether a sample is part of one class or another (**single-class classification**). If there are multiple class options, it's referred to as **multi-class classification**.\n", "\n", "\n", "## ***What we'll end up with***\n", "\n", "Since we already have a dataset, we'll approach the problem with the following machine learning modelling framework.\n", "\n", "| | \n", "|:--:| \n", "| 6 Step Machine Learning Modelling Framework |\n", "\n", "***More specifically, we'll look at the following topics.***\n", "\n", "* **Exploratory data analysis (EDA)** - the process of going through a dataset and finding out more about it.\n", "* **Model training** - create model(s) to learn to predict a target variable based on other variables.\n", "* **Model evaluation** - evaluating a models predictions using problem-specific evaluation metrics. \n", "* **Model comparison** - comparing several different models to find the best one.\n", "* **Model fine-tuning** - once we've found a good model, how can we improve it?\n", "* **Feature importance** - since we're predicting the presence of heart disease, are there some things which are more important for prediction?\n", "* **Cross-validation** - if we do build a good model, can we be sure it will work on unseen data?\n", "* **Reporting what we've found** - if we had to present our work, what would we show someone?\n", "\n", "To work through these topics, we'll use pandas, Matplotlib and NumPy for data anaylsis, as well as, Scikit-Learn for machine learning and modelling tasks.\n", "\n", "| | \n", "|:--:| \n", "| Tools which can be used for each step of the machine learning modelling process. |\n", "\n", "We'll work through each step and by the end of the notebook, we'll have a handful of models, all which can predict whether or not a person has heart disease based on a number of different parameters at a considerable accuracy. \n", "\n", "We'll also be able to describe which parameters are more indicative than others, for example, sex may be more important than age." ] }, { "cell_type": "markdown", "id": "3442b9ce", "metadata": {}, "source": [ "### ***1. Defining the problem 🤔***\n", "\n", "***Problem Statement : Using given medical attributes about a patient, predicting whether or not a person have heart disease.***" ] }, { "cell_type": "markdown", "id": "641c8df5", "metadata": {}, "source": [ "### ***2. Data***\n", "\n", "***Originally the data is take from Clevland Data from the UCI Machine Learning Repository.***\n", "***https://archive.ics.uci.edu/ml/datasets/Heart+Disease***\n", "\n", "***It is also available on Kaggle.***\n", "***https://www.kaggle.com/ronitf/heart-disease-uci***\n" ] }, { "cell_type": "markdown", "id": "5d8666cb", "metadata": {}, "source": [ "### ***3. Evaluation***\n", "\n", "***We'll carry out this project only if 95% accuracy is achievable in classifying whether or not a person has heart disease.***" ] }, { "cell_type": "markdown", "id": "f69fc96f", "metadata": {}, "source": [ "### ***4. Features***\n", "\n", "***Data Atrributes (Information about each section of data)***\n", "\n", "***1. age - age in years***\n", "\n", "***2. sex - (1 = male; 0 = female)***\n", "\n", "***3. cp - chest pain type***\n", " * ***0: Typical angina: chest pain related decrease blood supply to the heart***\n", " * ***1: Atypical angina: chest pain not related to heart***\n", " * ***2: Non-anginal pain: typically esophageal spasms (non heart related)***\n", " * ***3: Asymptomatic: chest pain not showing signs of disease***\n", " \n", "***4. trestbps - resting blood pressure (in mm Hg on admission to the hospital)***\n", " * ***anything above 130-140 is typically cause for concern***\n", " \n", "***5. chol - serum cholestoral in mg/dl*** \n", " * ***serum = LDL + HDL + .2 * triglycerides***\n", " * ***above 200 is cause for concern***\n", " \n", "***6. fbs - (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)*** \n", " * ***'>126' mg/dL signals diabetes***\n", " \n", "***7. restecg - resting electrocardiographic results***\n", " * ***0: Nothing to note***\n", " * ***1: ST-T Wave abnormality***\n", " * ***can range from mild symptoms to severe problems***\n", " * ***signals non-normal heart beat***\n", " * ***2: Possible or definite left ventricular hypertrophy***\n", " * ***Enlarged heart's main pumping chamber***\n", " \n", "***8. thalach - maximum heart rate achieved*** \n", "\n", "***9. exang - exercise induced angina (1 = yes; 0 = no)*** \n", "\n", "***10. oldpeak - ST depression induced by exercise relative to rest*** \n", " * ***looks at stress of heart during excercise***\n", " * ***unhealthy heart will stress more***\n", " \n", "***11. slope - the slope of the peak exercise ST segment***\n", " * ***0: Upsloping: better heart rate with excercise (uncommon)***\n", " * ***1: Flatsloping: minimal change (typical healthy heart)***\n", " * ***2: Downslopins: signs of unhealthy heart***\n", " \n", "***12. ca - number of major vessels (0-3) colored by flourosopy***\n", " * ***colored vessel means the doctor can see the blood passing through***\n", " * ***the more blood movement the better (no clots)***\n", " \n", "***13. thal - thalium stress result***\n", " * ***1,3: normal***\n", " * ***6: fixed defect: used to be defect but ok now***\n", " * ***7: reversable defect: no proper blood movement when excercising***\n", " \n", "***14. target - have disease or not (1=yes, 0=no) (= the predicted attribute)***\n", "\n", "### ***Getting the tools ready😎*** \n", "\n", "***We'll be working with Pandas, Matplotlib and NumPy for data analysis and manipulation.*** " ] }, { "cell_type": "markdown", "id": "7510550b", "metadata": {}, "source": [ "### ***Preparing the tools***\n", "\n", "***At the start of any project, it's custom to see the required libraries imported in a big chunk like you can see below.***\n", "\n", "***However, in practice, your projects may import libraries as you go. After you've spent a couple of hours working on your problem, you'll probably want to do some tidying up. This is where you may want to consolidate every library you've used at the top of your notebook (like the cell below).***\n", "\n", "***The libraries you use will differ from project to project. But there are a few which will you'll likely take advantage of during almost every structured data project.*** \n", "\n", "* ***[pandas](https://pandas.pydata.org/) for data analysis.***\n", "* ***[NumPy](https://numpy.org/) for numerical operations.***\n", "* ***[Matplotlib](https://matplotlib.org/)/[seaborn](https://seaborn.pydata.org/) for plotting or data visualization.***\n", "* ***[Scikit-Learn](https://scikit-learn.org/stable/) for machine learning modelling and evaluation.***" ] }, { "cell_type": "code", "execution_count": 196, "id": "26c2fc47", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 196, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# For setting theme💖\n", "from jupyterthemes import get_themes\n", "import jupyterthemes as jt\n", "from jupyterthemes.stylefx import set_nb_theme\n", "set_nb_theme('monokai')\n", "# Themes : chesterish, grade3, gruvboxd, gruvboxl, manokai, oceans16, onedork, solarizedd, solarizedl" ] }, { "cell_type": "code", "execution_count": 197, "id": "a547a3bf", "metadata": {}, "outputs": [], "source": [ "# Importing required tools.\n", "\n", "# Regular EDA (exploratory data analysis) and plotting libraries.\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# to make plots appear inside the notebook.\n", "%matplotlib inline \n", "\n", "# Getting models from Scikit-Learn \n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "# For evaluating the model\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import cross_val_score\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics import precision_score\n", "from sklearn.metrics import recall_score\n", "from sklearn.metrics import f1_score\n", "from sklearn.metrics import plot_roc_curve" ] }, { "cell_type": "markdown", "id": "c47fece2", "metadata": {}, "source": [ "### ***Importing Data 💽***\n", "\n", "* ***There are many different kinds of ways to store data. The typical way of storing \"tabular data\", data similar to what you'd see in an Excel file is in `.csv` format. `.csv` stands for comma seperated values.***\n", "\n", "* ***Pandas has a built-in function to read `.csv` files called `read_csv()` which takes the file pathname of your `.csv` file. You'll likely use this a lot.***" ] }, { "cell_type": "code", "execution_count": 198, "id": "aed89693", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
.............................................
29857001402410112310.21030
29945131102640113201.21030
30068101441931114103.41230
30157101301310111511.21130
30257011302360017400.01120
\n", "

303 rows × 14 columns

\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 \n", "1 37 1 2 130 250 0 1 187 0 3.5 \n", "2 41 0 1 130 204 0 0 172 0 1.4 \n", "3 56 1 1 120 236 0 1 178 0 0.8 \n", "4 57 0 0 120 354 0 1 163 1 0.6 \n", ".. ... ... .. ... ... ... ... ... ... ... \n", "298 57 0 0 140 241 0 1 123 1 0.2 \n", "299 45 1 3 110 264 0 1 132 0 1.2 \n", "300 68 1 0 144 193 1 1 141 0 3.4 \n", "301 57 1 0 130 131 0 1 115 1 1.2 \n", "302 57 0 1 130 236 0 0 174 0 0.0 \n", "\n", " slope ca thal target \n", "0 0 0 1 1 \n", "1 0 0 2 1 \n", "2 2 0 2 1 \n", "3 2 0 2 1 \n", "4 2 0 2 1 \n", ".. ... .. ... ... \n", "298 1 0 3 0 \n", "299 1 0 3 0 \n", "300 1 2 3 0 \n", "301 1 1 3 0 \n", "302 1 1 2 0 \n", "\n", "[303 rows x 14 columns]" ] }, "execution_count": 198, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('./dataset/heart-disease.csv')\n", "df" ] }, { "cell_type": "code", "execution_count": 199, "id": "bdead887", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(303, 14)" ] }, "execution_count": 199, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.shape # (rows, columns)" ] }, { "cell_type": "markdown", "id": "0aaf8c0c", "metadata": {}, "source": [ "### ***Data Exploration (exploratory data analysis or EDA)***\n", "\n", "***Extracting more information from data.***\n", "\n", "***We're focusing on :***\n", " * ***1. Question which we are trying to solve.***\n", " * ***2. Type of data we have and dealing with different types data (e.g numerical and non-numerical).***\n", " * ***3. Dealing with missing data.***\n", " * ***4. The outliers.***\n", " * ***5. Adding more features inorder to extract more information from our data.***\n" ] }, { "cell_type": "code", "execution_count": 200, "id": "98c38ab3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 0 \n", "1 37 1 2 130 250 0 1 187 0 3.5 0 \n", "2 41 0 1 130 204 0 0 172 0 1.4 2 \n", "3 56 1 1 120 236 0 1 178 0 0.8 2 \n", "4 57 0 0 120 354 0 1 163 1 0.6 2 \n", "\n", " ca thal target \n", "0 0 1 1 \n", "1 0 2 1 \n", "2 0 2 1 \n", "3 0 2 1 \n", "4 0 2 1 " ] }, "execution_count": 200, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 201, "id": "01f7f0b2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
29857001402410112310.21030
29945131102640113201.21030
30068101441931114103.41230
30157101301310111511.21130
30257011302360017400.01120
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", "298 57 0 0 140 241 0 1 123 1 0.2 \n", "299 45 1 3 110 264 0 1 132 0 1.2 \n", "300 68 1 0 144 193 1 1 141 0 3.4 \n", "301 57 1 0 130 131 0 1 115 1 1.2 \n", "302 57 0 1 130 236 0 0 174 0 0.0 \n", "\n", " slope ca thal target \n", "298 1 0 3 0 \n", "299 1 0 3 0 \n", "300 1 2 3 0 \n", "301 1 1 3 0 \n", "302 1 1 2 0 " ] }, "execution_count": 201, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.tail()" ] }, { "cell_type": "code", "execution_count": 202, "id": "94e595f6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1 165\n", "0 138\n", "Name: target, dtype: int64" ] }, "execution_count": 202, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Number of classes present of each\n", "df['target'].value_counts()" ] }, { "cell_type": "code", "execution_count": 275, "id": "100d8ec4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0, 1]), [Text(0, 0, '1'), Text(1, 0, '0')])" ] }, "execution_count": 275, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEACAYAAAC+gnFaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPzUlEQVR4nO3dW2yTdfzH8U9b2IC24ijLDBHGbAwJEMh0zCiHgQQGEhMPCTHiDBvImCCKTAlquCIhHIegZCFhEPVCQkiMEpGDLsThjWxccLEtke2/LQaSOQ+sDMbK+r/wz/7UDt3hoV33fb8SL/Y8z8q32xPfbX/tM1ckEokIAGCWO9EDAAASixAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMC4EYkeYCD++OOGurv5+IMTAgGf2tpCiR4DiMG56Ry326W0NO999ydlCLq7I4TAQfwsMVRxbsYHLw0BgHGEAACMIwQAYBwhAADjCAEAGEcIAMA4QgAAxiXl5wiSgTfg1Rh3cnQ2Pd2f6BH+U0d3t2603Uj0GMCwRAgekDFut1yJHmIYibjdIgPAg5EcD1kBAA8MIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGNfvENTW1mratGm6du1a1PZFixZpypQpMf/9/vvvPcdcvnxZBQUFys7O1pw5c7R37151dXUN/l4AAAasXxeda2hoUHFxscLhcNT2GzduqKWlRZs2bVJubm7UvoceekiS1NTUpJUrVyo7O1v79u3TlStXVFZWplAopK1btw7ybgAABqpPIQiHwzp27Jj27NmjkSNHxuyvr69XJBLRwoULFQwGe72NQ4cOye/36+DBg0pJSVFeXp5GjRqlbdu2qbi4WBkZGYO7JwCAAenTS0PV1dXavXu3ioqKVFpaGrO/trZWqampmjx58n1v48KFC1qwYIFSUlJ6ti1ZskR37txRVVVV/ycHADiiTyEIBoM6d+6c1q9fL4/HE7O/vr5eDz/8sN59913l5OQoOztbGzduVGtrqyTp5s2bunr1qrKysqK+b9y4cfL5fGpsbHTgrgAABqJPIRg/frwCgcB999fV1em3337T448/rvLycm3ZskU///yzXn/9dd26dUvt7e2SJJ/PF/O9Xq9XoVBogOMDAAbLkb9Q9tFHHykSiWjmzJmSpJycHAWDQb366qv6+uuvlZeXJ0lyuWL/ZlckEpG7n3/SMRCIDQqGv2T4k5pwFr/z+HAkBDNmzIjZ9uSTT8rv96uurk7Lli2TpF4f+Xd0dMjv798vu60tpO7uyMCGjRNOYOe1trYnegTEUXq6n9+5Q9xu178+gB70B8o6Ojp04sQJ1dXVRW2PRCLq6upSWlqavF6vMjIy1NTUFHVMW1ubQqFQzNoBACB+Bh2C1NRU7dixQ5988knU9u+//163bt3q+VzB7NmzVVlZqdu3b/ccc/r0aXk8npjPHgAA4mfQIfB4PCopKdHZs2e1bds2/fTTTzp69Kg2b96shQsX6qmnnpIkrV69Wq2trVqzZo0qKyt15MgRbd++XcuXL9eECRMGfUcAAAPjyBpBYWGhfD6fPvvsMx0/flxjx47VK6+8orfeeqvnmGAwqIqKCu3cuVMbNmxQWlqaCgsLo44BAMSfKxKJDO1V114ky2Jx7HukMFARsVhsDYvFznngi8UAgORGCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjHPlkMYDkkeYdqRFjRiV6jD5Jhqv4hjtu6Y8bXYkeY1AIAWDMiDGjdCV9bqLHGDaCrT9KSR4CXhoCAOMIAQAYRwgAwDhCAADGEQIAMI4QAIBxhAAAjCMEAGAcIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMA4QgAAxhECADCOEACAcYQAAIwjBABgHCEAAOMIAQAYRwgAwDhCAADGEQIAMI4QAIBxhAAAjCMEAGAcIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYFy/Q1BbW6tp06bp2rVrUdurqqr08ssva+bMmXr22WdVUVER872XL19WQUGBsrOzNWfOHO3du1ddXV0Dnx4AMGj9CkFDQ4OKi4sVDoejttfU1Gjt2rV67LHHdODAAT3//PPauXOnDh8+3HNMU1OTVq5cqdTUVO3bt09FRUU6cuSItm/f7sw9AQAMyIi+HBQOh3Xs2DHt2bNHI0eOjNm/f/9+TZ06Vbt27ZIkzZs3T+FwWOXl5SooKFBKSooOHTokv9+vgwcPKiUlRXl5eRo1apS2bdum4uJiZWRkOHvPAAB90qdnBNXV1dq9e7eKiopUWloata+zs1MXL17U4sWLo7bn5+fr+vXrqqmpkSRduHBBCxYsUEpKSs8xS5Ys0Z07d1RVVTXY+wEAGKA+hSAYDOrcuXNav369PB5P1L6WlhZ1dXUpKysrantmZqYkqbGxUTdv3tTVq1djjhk3bpx8Pp8aGxsHcx8AAIPQp5eGxo8ff9997e3tkiSfzxe13ev1SpJCodB9j7l7XCgU6tu0AADH9SkE/yYSiUiSXC5Xr/vdbve/HhOJROR29+/NS4FAbFAw/KWn+xM9AtCrZD83Bx0Cv//vH8A/H9Xf/drv9/c8E+jtkX9HR0fPbfRVW1tI3d2RgYwbN8l+YgxFra3tiR5hWODcdN5QPzfdbte/PoAe9AfKJk2aJI/Ho+bm5qjtd7/OysqS1+tVRkaGmpqaoo5pa2tTKBSKWTsAAMTPoEOQmpqqnJwcnTlzpuclIEk6ffq0/H6/pk+fLkmaPXu2Kisrdfv27ahjPB6PcnNzBzsGAGCAHLnERElJiWpqarRx40adP39e+/bt0+HDh1VcXKzRo0dLklavXq3W1latWbNGlZWVPR8mW758uSZMmODEGACAAXAkBE8//bQOHDigK1euaN26dfrmm2/0/vvv64033ug5JhgMqqKiQh0dHdqwYYOOHDmiwsJCffjhh06MAAAYIFfk3tdzkkSyLBb3/j4qDEREQ39BLlmkp/t1JX1uoscYNoKtPw75c/OBLxYDAJIbIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMA4QgAAxhECADCOEACAcYQAAIwjBABgHCEAAOMIAQAYRwgAwDhCAADGEQIAMI4QAIBxhAAAjCMEAGAcIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMA4QgAAxhECADCOEACAcYQAAIwjBABgHCEAAOMIAQAYRwgAwDhCAADGEQIAMI4QAIBxhAAAjCMEAGAcIQAA40Y4dUPhcFhPPPGEOjs7o7aPGTNGly5dkiRVVVWprKxMv/zyiwKBgF577TUVFRU5NQIAYAAcC0FjY6M6Ozu1Y8cOTZ48uWe72/33k46amhqtXbtWS5cu1dtvv63q6mrt3LlTkUhEq1atcmoMAEA/ORaCuro6ud1u5efna/To0TH79+/fr6lTp2rXrl2SpHnz5ikcDqu8vFwFBQVKSUlxahQAQD84tkZQW1urSZMm9RqBzs5OXbx4UYsXL47anp+fr+vXr6umpsapMQAA/eRYCOrr65WSkqJVq1YpOztbs2bN0tatWxUKhdTS0qKuri5lZWVFfU9mZqakv19WAgAkhmMhqKurU3Nzs/Ly8nTo0CG9+eabOnnypEpKStTe3i5J8vl8Ud/j9XolSaFQyKkxAAD95NgaQVlZmcaOHaspU6ZIkmbNmqVAIKD33ntPFy5ckCS5XK5ev/fugnJfBQK+/z4Iw056uj/RIwC9SvZz07EQ5ObmxmybP39+1Nf/fOR/92u/v38/xLa2kLq7I/0bMM6S/cQYilpb2xM9wrDAuem8oX5uut2uf30A7chLQ21tbTp+/LhaWlqitt+6dUuSFAgE5PF41NzcHLX/7tf/XDsAAMSPIyFwuVzaunWrvvjii6jt3377rTwej5555hnl5OTozJkzikT+/5H86dOn5ff7NX36dCfGAAAMgCMvDY0bN04rVqzQ559/Lp/Pp5ycHFVXV6u8vFwrVqxQZmamSkpKVFhYqI0bN+rFF1/UpUuXdPjwYW3atKnXt5wCAOLDFbn3IfogdHV16ejRozpx4oR+/fVXZWRkaPny5Vq9enXPYvDZs2e1f/9+NTY2KiMjQytWrBjQJSaSZY2g96VxDEREQ/912GSRnu7XlfS5iR5j2Ai2/jjkz83/WiNwLATxRAjsIQTOIQTOGg4h4OqjAGAcIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMA4QgAAxhECADCOEACAcYQAAIwjBABgHCEAAOMIAQAYRwgAwDhCAADGEQIAMI4QAIBxhAAAjCMEAGAcIQAA4wgBABhHCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMA4QgAAxhECADCOEACAcYQAAIwjBABgHCEAAOMIAQAYRwgAwDhCAADGEQIAMI4QAIBxhAAAjCMEAGAcIQAA4+IegpMnT2rZsmWaMWOGli5dqq+++ireIwAA7hHXEJw6dUqlpaWaPXu2Pv30U+Xm5mrz5s367rvv4jkGAOAeI+L5j+3du1dLly7VBx98IEmaO3eu/vrrL3388cdasmRJPEcBAPyfuD0jaGlpUXNzsxYvXhy1PT8/Xw0NDWppaYnXKACAe8TtGUFDQ4MkKSsrK2p7ZmamJKmxsVETJ07s02253S5nh3tAMhM9wDCTLL/3ZDBi4iOJHmFYGern5n/NF7cQtLe3S5J8Pl/Udq/XK0kKhUJ9vq20NK9zgz1A/5PoAYaZQMD33wehTzJrjid6hGEl2c/NuL00FIlEJEkul6vX7W4372QFgESI2/99/X6/pNhH/jdu3IjaDwCIr7iF4O7aQHNzc9T2pqamqP0AgPiKWwgyMzP16KOPxnxm4MyZM5o8ebImTJgQr1EAAPeI6+cI1q1bpy1btmjs2LGaP3++fvjhB506dUplZWXxHAMAcA9X5O5qbZx8+eWXqqio0NWrVzVx4kStWbNGL7zwQjxHAADcI+4hAAAMLbxnEwCMIwQAYBwhMK62tlbTpk3TtWvXEj0KIIlL1ScCITCsoaFBxcXFCofDiR4FkMSl6hOFxWKDwuGwjh07pj179mjkyJH6888/df78eT3yCBciQ2ItWrRI06dPj3pL+TvvvKP6+nqdOnUqgZMNbzwjMKi6ulq7d+9WUVGRSktLEz0OIIlL1ScSITAoGAzq3LlzWr9+vTweT6LHAST17VL1eDDi+sliDA3jx49P9AhADCcvVY/+4RkBgCGBS9UnDj9ZAEMCl6pPHEIAYEjgUvWJQwgADAlcqj5xWCwGMGRwqfrEIAQAhoyXXnpJt2/fVkVFhY4fP66JEydqx44deu655xI92rDGJ4sBwDjWCADAOEIAAMYRAgAwjhAAgHGEAACMIwQAYBwhAADjCAEAGEcIAMC4/wW7Li5U3ULKTAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df['target'].value_counts().plot(kind='bar', color = ['cyan', 'crimson'], grid = True);\n", "plt.xticks(rotation = 0)" ] }, { "cell_type": "code", "execution_count": 204, "id": "ea1b378c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 303 entries, 0 to 302\n", "Data columns (total 14 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 age 303 non-null int64 \n", " 1 sex 303 non-null int64 \n", " 2 cp 303 non-null int64 \n", " 3 trestbps 303 non-null int64 \n", " 4 chol 303 non-null int64 \n", " 5 fbs 303 non-null int64 \n", " 6 restecg 303 non-null int64 \n", " 7 thalach 303 non-null int64 \n", " 8 exang 303 non-null int64 \n", " 9 oldpeak 303 non-null float64\n", " 10 slope 303 non-null int64 \n", " 11 ca 303 non-null int64 \n", " 12 thal 303 non-null int64 \n", " 13 target 303 non-null int64 \n", "dtypes: float64(1), int64(13)\n", "memory usage: 33.3 KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "code", "execution_count": 205, "id": "e7e603d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "age 0\n", "sex 0\n", "cp 0\n", "trestbps 0\n", "chol 0\n", "fbs 0\n", "restecg 0\n", "thalach 0\n", "exang 0\n", "oldpeak 0\n", "slope 0\n", "ca 0\n", "thal 0\n", "target 0\n", "dtype: int64" ] }, "execution_count": 205, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# No missing values 😁\n", "df.isna().sum()" ] }, { "cell_type": "code", "execution_count": 206, "id": "3f6a8cc0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
count303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000303.000000
mean54.3663370.6831680.966997131.623762246.2640260.1485150.528053149.6468650.3267331.0396041.3993400.7293732.3135310.544554
std9.0821010.4660111.03205217.53814351.8307510.3561980.52586022.9051610.4697941.1610750.6162261.0226060.6122770.498835
min29.0000000.0000000.00000094.000000126.0000000.0000000.00000071.0000000.0000000.0000000.0000000.0000000.0000000.000000
25%47.5000000.0000000.000000120.000000211.0000000.0000000.000000133.5000000.0000000.0000001.0000000.0000002.0000000.000000
50%55.0000001.0000001.000000130.000000240.0000000.0000001.000000153.0000000.0000000.8000001.0000000.0000002.0000001.000000
75%61.0000001.0000002.000000140.000000274.5000000.0000001.000000166.0000001.0000001.6000002.0000001.0000003.0000001.000000
max77.0000001.0000003.000000200.000000564.0000001.0000002.000000202.0000001.0000006.2000002.0000004.0000003.0000001.000000
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs \\\n", "count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \n", "mean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 \n", "std 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 \n", "min 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 \n", "25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 \n", "50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 \n", "75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 \n", "max 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 \n", "\n", " restecg thalach exang oldpeak slope ca \\\n", "count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \n", "mean 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 \n", "std 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 \n", "min 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 \n", "25% 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 \n", "50% 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 \n", "75% 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 \n", "max 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 \n", "\n", " thal target \n", "count 303.000000 303.000000 \n", "mean 2.313531 0.544554 \n", "std 0.612277 0.498835 \n", "min 0.000000 0.000000 \n", "25% 2.000000 0.000000 \n", "50% 2.000000 1.000000 \n", "75% 3.000000 1.000000 \n", "max 3.000000 1.000000 " ] }, "execution_count": 206, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe()" ] }, { "cell_type": "markdown", "id": "dda51419", "metadata": {}, "source": [ "### ***Heart Disease frequency according to Sex***\n", "\n", "* If you want to compare two columns to each other, you can use the function `pd.crosstab(column_1, column_2)`. \n", "\n", "* This is helpful if you want to start gaining an intuition about how your independent variables interact with your dependent variables.\n", "\n", "* Let's compare our target column with the sex column. \n", "\n", "* Remember from our data dictionary, for the target column, 1 = heart disease present, 0 = no heart disease. And for sex, 1 = male, 0 = female." ] }, { "cell_type": "code", "execution_count": 207, "id": "593a5bd2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1 207\n", "0 96\n", "Name: sex, dtype: int64" ] }, "execution_count": 207, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.sex.value_counts()" ] }, { "cell_type": "markdown", "id": "0a3c6244", "metadata": {}, "source": [ "#### *There are 207 males and 96 females in our study.*" ] }, { "cell_type": "code", "execution_count": 208, "id": "5cb9e69f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sex01
target
024114
17293
\n", "
" ], "text/plain": [ "sex 0 1\n", "target \n", "0 24 114\n", "1 72 93" ] }, "execution_count": 208, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compare target column with sex column\n", "pd.crosstab(df.target,df.sex)" ] }, { "cell_type": "markdown", "id": "6ee3e9e9", "metadata": {}, "source": [ "### ***What can we infer from this?***\n", "\n", "***Let's make a simple heuristic.\n", "Since there are about 100 women and 72 of them have a postive value of heart disease being present, we might infer, based on this one variable if the participant is a woman, there's a 75% chance she has heart disease.\n", "As for males, there's about 200 total with around half indicating a presence of heart disease. So we might predict, if the participant is male, 50% of the time he will have heart disease.\n", "Averaging these two values, we can assume, based on no other parameters, if there's a person, there's a 62.5% chance they have heart disease.***" ] }, { "cell_type": "code", "execution_count": 209, "id": "44a7851a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "303" ] }, "execution_count": 209, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(df)" ] }, { "cell_type": "markdown", "id": "e8e30633", "metadata": {}, "source": [ "### ***Creating a plot of crosstab***\n", "\n", "* You can plot the crosstab by using the `plot()` function and passing it a few parameters such as, `kind` (the type of plot you want), `figsize=(length, width)` (how big you want it to be) and `color=[colour_1, colour_2]` (the different colours you'd like to use).\n", "\n", "\n", "* Different metrics are represented best with different kinds of plots. In our case, a bar graph is great. We'll see examples of more later. And with a bit of practice, you'll gain an intuition of which plot to use with different variables.\n", "\n", "\n", "* We'll create the plot with `crosstab()` and `plot()`, then add some helpful labels to it with `plt.title()`, `plt.xlabel()` and more.\n", "\n", "\n", "* To add the attributes, you call them on `plt` within the same cell as where you make create the graph." ] }, { "cell_type": "code", "execution_count": 273, "id": "619b573b", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pd.crosstab(df.target,df.sex).plot(kind='bar',\n", " figsize=(10,6),\n", " color=['cyan','crimson']);\n", "\n", "plt.title('Heart Disease Frequency for Sex')\n", "plt.xlabel('0 = No Disease, 1 = Disease')\n", "plt.ylabel('Amount')\n", "plt.legend(['Female','Male'])\n", "plt.grid()\n", "plt.xticks(rotation = 0);" ] }, { "cell_type": "markdown", "id": "60332099", "metadata": {}, "source": [ "### ***Age Vs Max Heart Rate for Headrt Disease (thalach)***\n", "\n", "* Let's try combining a couple of independent variables, such as, `age` and `thalach` (maximum heart rate) and then comparing them to our target variable `heart disease`.\n", "\n", "\n", "* Because there are so many different values for `age` and `thalach`, we'll use a scatter plot." ] }, { "cell_type": "code", "execution_count": 272, "id": "73d397c7", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Create another figure\n", "plt.figure(figsize=(10,6))\n", "\n", "# Scatter with positive examples\n", "plt.scatter(df.age[df.target==1],\n", " df.thalach[df.target==1],\n", " color='crimson')\n", "\n", "\n", "# Scatter with negative examples\n", "plt.scatter(df.age[df.target==0],\n", " df.thalach[df.target==0],\n", " color='darkcyan')\n", "\n", "# Adding some helpful info\n", "plt.title('Heart Disease in function of Age and Max Heart Rate')\n", "plt.xlabel('Age')\n", "plt.ylabel('Max Heart Rate (thalach)')\n", "plt.legend([\"Disease\", \"No Disease\"]);\n", "plt.grid()" ] }, { "cell_type": "markdown", "id": "d472bf03", "metadata": {}, "source": [ "### ***What can we infer from this?***\n", "\n", "* It seems the younger someone is, the higher their max heart rate (dots are higher on the left of the graph) and the older someone is, the more green dots there are. But this may be because there are more dots all together on the right side of the graph (older participants).\n", "\n", "\n", "* Both of these are observational of course, but this is what we're trying to do, build an understanding of the data.\n", "\n", "\n", "* Let's check the age **distribution**." ] }, { "cell_type": "code", "execution_count": 271, "id": "2818bba0", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Checking the distribution of the age column with a histogram\n", "df.age.plot.hist(color = 'crimson', grid =True);" ] }, { "cell_type": "markdown", "id": "45b550e6", "metadata": {}, "source": [ "***We can see it's a NORMAL DISTRIBUTION.***\n", "\n", "***(https://en.wikipedia.org/wiki/Normal_distribution) but slightly swaying to the right, which reflects in the scatter plot above.***" ] }, { "cell_type": "markdown", "id": "4aabd3de", "metadata": {}, "source": [ "### ***Heart Disease frequency per Chest Pain type***\n", "\n", "* Let's try another independent variable. This time, `cp` (chest pain).\n", "\n", "* We'll use the same process as we did before with `sex`." ] }, { "cell_type": "markdown", "id": "047243e8", "metadata": {}, "source": [ "***cp - chest pain type***\n", " * ***0: Typical angina: chest pain related decrease blood supply to the heart***\n", " * ***1: Atypical angina: chest pain not related to heart***\n", " * ***2: Non-anginal pain: typically esophageal spasms (non heart related)***\n", " * ***3: Asymptomatic: chest pain not showing signs of disease*** " ] }, { "cell_type": "code", "execution_count": 213, "id": "4a7be9a9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
target01
cp
010439
1941
21869
3716
\n", "
" ], "text/plain": [ "target 0 1\n", "cp \n", "0 104 39\n", "1 9 41\n", "2 18 69\n", "3 7 16" ] }, "execution_count": 213, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.crosstab(df.cp,df.target)" ] }, { "cell_type": "code", "execution_count": 270, "id": "1a4437a7", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Making the above data more visual\n", "pd.crosstab(df.cp,df.target).plot(kind='bar',\n", " figsize=(10,6),\n", " color=['cyan','crimson'])\n", "\n", "# Labelling\n", "plt.title('Heart Disease Frequency Per Chest Pain Type')\n", "plt.xlabel(\"Chest Pain Type\")\n", "plt.ylabel(\"Amount\")\n", "plt.legend([\"No Disease\", \"Disease\"])\n", "plt.grid()\n", "plt.xticks(rotation=0);" ] }, { "cell_type": "markdown", "id": "9f91effa", "metadata": {}, "source": [ "### ***What can we infer from this?***\n", "\n", "Remember from our data dictionary what the different levels of chest pain are.\n", "\n", "3. cp - chest pain type \n", " * 0: Typical angina: chest pain related decrease blood supply to the heart\n", " * 1: Atypical angina: chest pain not related to heart\n", " * 2: Non-anginal pain: typically esophageal spasms (non heart related)\n", " * 3: Asymptomatic: chest pain not showing signs of disease\n", " \n", "It's interesting the atypical agina (value 1) states it's not related to the heart but seems to have a higher ratio of participants with heart disease than not.\n", "\n", "\n", "What does atypical agina even mean?\n", "\n", "At this point, it's important to remember, if your data dictionary doesn't supply you enough information, you may want to do further research on your values. This research may come in the form of asking a **subject matter expert** (such as a cardiologist or the person who gave you the data) or Googling to find out more.\n", "\n", "According to PubMed, it seems [even some medical professionals are confused by the term](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2763472/).\n", "\n", "> Today, 23 years later, “atypical chest pain” is still popular in medical circles. Its meaning, however, remains unclear. A few articles have the term in their title, but do not define or discuss it in their text. In other articles, the term refers to noncardiac causes of chest pain.\n", "\n", "Although not conclusive, this graph above is a hint at the confusion of defintions being represented in data." ] }, { "cell_type": "markdown", "id": "f41ad260", "metadata": {}, "source": [ "### ***Make a correlation matrix***\n", "\n", "***Finally, we'll compare all of the independent variables in one hit.***\n", "\n", "***Why?***\n", "\n", "* Because this may give an idea of which independent variables may or may not have an impact on our target variable.\n", "\n", "\n", "* We can do this using `df.corr()` which will create a [**correlation matrix**](https://www.statisticshowto.datasciencecentral.com/correlation-matrix/) for us, in other words, a big table of numbers telling us how related each variable is the other." ] }, { "cell_type": "code", "execution_count": 215, "id": "51235612", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
age1.000000-0.098447-0.0686530.2793510.2136780.121308-0.116211-0.3985220.0968010.210013-0.1688140.2763260.068001-0.225439
sex-0.0984471.000000-0.049353-0.056769-0.1979120.045032-0.058196-0.0440200.1416640.096093-0.0307110.1182610.210041-0.280937
cp-0.068653-0.0493531.0000000.047608-0.0769040.0944440.0444210.295762-0.394280-0.1492300.119717-0.181053-0.1617360.433798
trestbps0.279351-0.0567690.0476081.0000000.1231740.177531-0.114103-0.0466980.0676160.193216-0.1214750.1013890.062210-0.144931
chol0.213678-0.197912-0.0769040.1231741.0000000.013294-0.151040-0.0099400.0670230.053952-0.0040380.0705110.098803-0.085239
fbs0.1213080.0450320.0944440.1775310.0132941.000000-0.084189-0.0085670.0256650.005747-0.0598940.137979-0.032019-0.028046
restecg-0.116211-0.0581960.044421-0.114103-0.151040-0.0841891.0000000.044123-0.070733-0.0587700.093045-0.072042-0.0119810.137230
thalach-0.398522-0.0440200.295762-0.046698-0.009940-0.0085670.0441231.000000-0.378812-0.3441870.386784-0.213177-0.0964390.421741
exang0.0968010.141664-0.3942800.0676160.0670230.025665-0.070733-0.3788121.0000000.288223-0.2577480.1157390.206754-0.436757
oldpeak0.2100130.096093-0.1492300.1932160.0539520.005747-0.058770-0.3441870.2882231.000000-0.5775370.2226820.210244-0.430696
slope-0.168814-0.0307110.119717-0.121475-0.004038-0.0598940.0930450.386784-0.257748-0.5775371.000000-0.080155-0.1047640.345877
ca0.2763260.118261-0.1810530.1013890.0705110.137979-0.072042-0.2131770.1157390.222682-0.0801551.0000000.151832-0.391724
thal0.0680010.210041-0.1617360.0622100.098803-0.032019-0.011981-0.0964390.2067540.210244-0.1047640.1518321.000000-0.344029
target-0.225439-0.2809370.433798-0.144931-0.085239-0.0280460.1372300.421741-0.436757-0.4306960.345877-0.391724-0.3440291.000000
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs \\\n", "age 1.000000 -0.098447 -0.068653 0.279351 0.213678 0.121308 \n", "sex -0.098447 1.000000 -0.049353 -0.056769 -0.197912 0.045032 \n", "cp -0.068653 -0.049353 1.000000 0.047608 -0.076904 0.094444 \n", "trestbps 0.279351 -0.056769 0.047608 1.000000 0.123174 0.177531 \n", "chol 0.213678 -0.197912 -0.076904 0.123174 1.000000 0.013294 \n", "fbs 0.121308 0.045032 0.094444 0.177531 0.013294 1.000000 \n", "restecg -0.116211 -0.058196 0.044421 -0.114103 -0.151040 -0.084189 \n", "thalach -0.398522 -0.044020 0.295762 -0.046698 -0.009940 -0.008567 \n", "exang 0.096801 0.141664 -0.394280 0.067616 0.067023 0.025665 \n", "oldpeak 0.210013 0.096093 -0.149230 0.193216 0.053952 0.005747 \n", "slope -0.168814 -0.030711 0.119717 -0.121475 -0.004038 -0.059894 \n", "ca 0.276326 0.118261 -0.181053 0.101389 0.070511 0.137979 \n", "thal 0.068001 0.210041 -0.161736 0.062210 0.098803 -0.032019 \n", "target -0.225439 -0.280937 0.433798 -0.144931 -0.085239 -0.028046 \n", "\n", " restecg thalach exang oldpeak slope ca \\\n", "age -0.116211 -0.398522 0.096801 0.210013 -0.168814 0.276326 \n", "sex -0.058196 -0.044020 0.141664 0.096093 -0.030711 0.118261 \n", "cp 0.044421 0.295762 -0.394280 -0.149230 0.119717 -0.181053 \n", "trestbps -0.114103 -0.046698 0.067616 0.193216 -0.121475 0.101389 \n", "chol -0.151040 -0.009940 0.067023 0.053952 -0.004038 0.070511 \n", "fbs -0.084189 -0.008567 0.025665 0.005747 -0.059894 0.137979 \n", "restecg 1.000000 0.044123 -0.070733 -0.058770 0.093045 -0.072042 \n", "thalach 0.044123 1.000000 -0.378812 -0.344187 0.386784 -0.213177 \n", "exang -0.070733 -0.378812 1.000000 0.288223 -0.257748 0.115739 \n", "oldpeak -0.058770 -0.344187 0.288223 1.000000 -0.577537 0.222682 \n", "slope 0.093045 0.386784 -0.257748 -0.577537 1.000000 -0.080155 \n", "ca -0.072042 -0.213177 0.115739 0.222682 -0.080155 1.000000 \n", "thal -0.011981 -0.096439 0.206754 0.210244 -0.104764 0.151832 \n", "target 0.137230 0.421741 -0.436757 -0.430696 0.345877 -0.391724 \n", "\n", " thal target \n", "age 0.068001 -0.225439 \n", "sex 0.210041 -0.280937 \n", "cp -0.161736 0.433798 \n", "trestbps 0.062210 -0.144931 \n", "chol 0.098803 -0.085239 \n", "fbs -0.032019 -0.028046 \n", "restecg -0.011981 0.137230 \n", "thalach -0.096439 0.421741 \n", "exang 0.206754 -0.436757 \n", "oldpeak 0.210244 -0.430696 \n", "slope -0.104764 0.345877 \n", "ca 0.151832 -0.391724 \n", "thal 1.000000 -0.344029 \n", "target -0.344029 1.000000 " ] }, "execution_count": 215, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.corr()" ] }, { "cell_type": "code", "execution_count": 278, "id": "5d8f8608", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Making the correlation matrix more visual\n", "corr_matrix = df.corr()\n", "fig, ax = plt.subplots(figsize = (15, 10))\n", "ax = sns.heatmap(corr_matrix,\n", " annot = True,\n", " linewidths = 0.5,\n", " fmt = \".2f\",\n", " cmap =None )\n", "\n", "bottom, top = ax.get_ylim()\n", "plt.yticks(rotation = 0)\n", "ax.set_ylim(bottom + 0.5, top - 0.5);" ] }, { "cell_type": "markdown", "id": "1e83c4b0", "metadata": {}, "source": [ "#### ***Note :***\n", "\n", "***A higher positive value means a potential positive correlation (increase) and a higher negative value means a potential negative correlation (decrease).***" ] }, { "cell_type": "markdown", "id": "68eb0db3", "metadata": {}, "source": [ "### ***Before we model***\n", "\n", "Remember, we do exploratory data analysis (EDA) to start building an intuitition of the dataset.\n", "\n", "What have we learned so far? Aside from our basline estimate using `sex`, the rest of the data seems to be pretty distributed.\n", "\n", "So what we'll do next is **model driven EDA**, meaning, we'll use machine learning models to drive our next questions.\n", "\n", "**A few extra things to remember:***\n", "\n", "* Not every EDA will look the same, what we've seen here is an example of what you could do for structured, tabular dataset.\n", "* You don't necessarily have to do the same plots as we've done here, there are many more ways to visualize data, I encourage you to look at more.\n", "* We want to quickly find:\n", " * Distributions (`df.column.hist()`)\n", " * Missing values (`df.info()`)\n", " * Outliers" ] }, { "cell_type": "markdown", "id": "a235b2c7", "metadata": {}, "source": [ "## ***5. Modelling***\n", "\n", "***We've explored the data, now we'll try to use machine learning to predict our target variable based on the 13 independent variables.***\n", "\n", "***Remember our problem?***\n", "\n", "* >Given clinical parameters about a patient, can we predict whether or not they have heart disease?\n", "\n", "That's what we'll be trying to answer.\n", "\n", "***And remember our evaluation metric?***\n", "\n", "* >If we can reach 95% accuracy at predicting whether or not a patient has heart disease during the proof of concept, we'll pursure this project.\n", "\n", "But before we build a model, we have to get our dataset ready." ] }, { "cell_type": "code", "execution_count": 217, "id": "03066a8c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 0 \n", "1 37 1 2 130 250 0 1 187 0 3.5 0 \n", "2 41 0 1 130 204 0 0 172 0 1.4 2 \n", "3 56 1 1 120 236 0 1 178 0 0.8 2 \n", "4 57 0 0 120 354 0 1 163 1 0.6 2 \n", "\n", " ca thal target \n", "0 0 1 1 \n", "1 0 2 1 \n", "2 0 2 1 \n", "3 0 2 1 \n", "4 0 2 1 " ] }, "execution_count": 217, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 218, "id": "7b755fe5", "metadata": {}, "outputs": [], "source": [ "# Split data into X and y\n", "X = df.drop('target',axis = 1)\n", "y = df[\"target\"]" ] }, { "cell_type": "code", "execution_count": 219, "id": "acf7b5d1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathal
063131452331015002.3001
137121302500118703.5002
241011302040017201.4202
356111202360117800.8202
457001203540116310.6202
..........................................
29857001402410112310.2103
29945131102640113201.2103
30068101441931114103.4123
30157101301310111511.2113
30257011302360017400.0112
\n", "

303 rows × 13 columns

\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 \n", "1 37 1 2 130 250 0 1 187 0 3.5 \n", "2 41 0 1 130 204 0 0 172 0 1.4 \n", "3 56 1 1 120 236 0 1 178 0 0.8 \n", "4 57 0 0 120 354 0 1 163 1 0.6 \n", ".. ... ... .. ... ... ... ... ... ... ... \n", "298 57 0 0 140 241 0 1 123 1 0.2 \n", "299 45 1 3 110 264 0 1 132 0 1.2 \n", "300 68 1 0 144 193 1 1 141 0 3.4 \n", "301 57 1 0 130 131 0 1 115 1 1.2 \n", "302 57 0 1 130 236 0 0 174 0 0.0 \n", "\n", " slope ca thal \n", "0 0 0 1 \n", "1 0 0 2 \n", "2 2 0 2 \n", "3 2 0 2 \n", "4 2 0 2 \n", ".. ... .. ... \n", "298 1 0 3 \n", "299 1 0 3 \n", "300 1 2 3 \n", "301 1 1 3 \n", "302 1 1 2 \n", "\n", "[303 rows x 13 columns]" ] }, "execution_count": 219, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "code", "execution_count": 220, "id": "8b630421", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 1\n", "1 1\n", "2 1\n", "3 1\n", "4 1\n", " ..\n", "298 0\n", "299 0\n", "300 0\n", "301 0\n", "302 0\n", "Name: target, Length: 303, dtype: int64" ] }, "execution_count": 220, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "markdown", "id": "da7110d3", "metadata": {}, "source": [ "### ***Training and test data split***\n", "\n", "Now comes one of the most important concepts in machine learning, the **training/test split**.\n", "\n", "This is where you'll split your data into a **training set** and a **test set**.\n", "\n", "You use your training set to train your model and your test set to test it.\n", "\n", "The test set must remain separate from your training set.\n", "\n", "#### ***Why not use all the data to train a model?***\n", "\n", "Let's say you wanted to take your model into the hospital and start using it on patients. How would you know how well your model goes on a new patient not included in the original full dataset you had?\n", "\n", "This is where the test set comes in. It's used to mimic taking your model to a real environment as much as possible.\n", "\n", "And it's why it's important to never let your model learn from the test set, it should only be evaluated on it.\n", "\n", "To split our data into a training and test set, we can use Scikit-Learn's [`train_test_split()`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html) and feed it our independent and dependent variables (`X` & `y`)." ] }, { "cell_type": "code", "execution_count": 221, "id": "2d42dd9a", "metadata": {}, "outputs": [], "source": [ "# Split data into train and test set\n", "np.random.seed(42)\n", "\n", "# Split into train and test set\n", "X_train, X_test, y_train, y_test = train_test_split(X,\n", " y,\n", " test_size = 0.2)" ] }, { "cell_type": "markdown", "id": "cd30f051", "metadata": {}, "source": [ "> The `test_size` parameter is used to tell the `train_test_split()` function how much of our data we want in the test set.\n", "\n", "> A rule of thumb is to use 80% of your data to train on and the other 20% to test on. \n", "\n", "> For our problem, a train and test set are enough. But for other problems, you could also use a validation (train/validation/test) set or cross-validation (we'll see this in a second).\n", "\n", "> But again, each problem will differ. The post, [How (and why) to create a good validation set](https://www.fast.ai/2017/11/13/validation-sets/) by Rachel Thomas is a good place to go to learn more.\n", "\n", "***Let's look at our training data***." ] }, { "cell_type": "code", "execution_count": 222, "id": "d058e900", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathal
13242111202950116200.0202
20258101502700011110.8203
19646121502310114703.6102
7555011352500016101.4102
17660101172301116011.4223
..........................................
18850121402330116300.6113
715112942270115410.0213
10669131602341013100.1112
27046101202490014400.8203
10263011401950117900.0222
\n", "

242 rows × 13 columns

\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", "132 42 1 1 120 295 0 1 162 0 0.0 \n", "202 58 1 0 150 270 0 0 111 1 0.8 \n", "196 46 1 2 150 231 0 1 147 0 3.6 \n", "75 55 0 1 135 250 0 0 161 0 1.4 \n", "176 60 1 0 117 230 1 1 160 1 1.4 \n", ".. ... ... .. ... ... ... ... ... ... ... \n", "188 50 1 2 140 233 0 1 163 0 0.6 \n", "71 51 1 2 94 227 0 1 154 1 0.0 \n", "106 69 1 3 160 234 1 0 131 0 0.1 \n", "270 46 1 0 120 249 0 0 144 0 0.8 \n", "102 63 0 1 140 195 0 1 179 0 0.0 \n", "\n", " slope ca thal \n", "132 2 0 2 \n", "202 2 0 3 \n", "196 1 0 2 \n", "75 1 0 2 \n", "176 2 2 3 \n", ".. ... .. ... \n", "188 1 1 3 \n", "71 2 1 3 \n", "106 1 1 2 \n", "270 2 0 3 \n", "102 2 2 2 \n", "\n", "[242 rows x 13 columns]" ] }, "execution_count": 222, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train" ] }, { "cell_type": "code", "execution_count": 223, "id": "6c4a9e10", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "132 1\n", "202 0\n", "196 0\n", "75 1\n", "176 0\n", " ..\n", "188 0\n", "71 1\n", "106 1\n", "270 0\n", "102 1\n", "Name: target, Length: 242, dtype: int64" ] }, "execution_count": 223, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "code", "execution_count": 224, "id": "38d6450b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "242" ] }, "execution_count": 224, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(y_train)" ] }, { "cell_type": "markdown", "id": "27ba569a", "metadata": {}, "source": [ "#### ***After splitting the data into training and test sets, we need to now build a machine learning model.***\n", "#### ***We'll train the data (find the patterns) on the training set.***\n", "#### ***And then we'll use the patterns on the test set.***\n", "#### ***We're going to implement three machine learning model :***\n", "#### ***1. Logistic Regression***\n", "#### ***2. K-Nearest Neighbours Classifier***\n", "#### ***3. Random Forest Classifier***" ] }, { "cell_type": "markdown", "id": "5aa93655", "metadata": {}, "source": [ "### ***Why these?***\n", "\n", "If we look at the [Scikit-Learn algorithm cheat sheet](https://scikit-learn.org/stable/tutorial/machine_learning_map/index.html), we can see we're working on a classification problem and these are the algorithms it suggests (plus a few more).\n", "\n", "| | \n", "|:--:| \n", "| An example path we can take using the Scikit-Learn Machine Learning Map |\n", "\n", "\"Wait, I don't see Logistic Regression and why not use LinearSVC?\"\n", "\n", "Good questions. \n", "\n", "I was confused too when I didn't see Logistic Regression listed as well because when you read the Scikit-Learn documentation on it, you can see it's [a model for classification](https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression).\n", "\n", "And as for LinearSVC, let's pretend we've tried it, and it doesn't work, so we're following other options in the map.\n", "\n", "For now, knowing each of these algorithms inside and out is not essential.\n", "\n", "Machine learning and data science is an iterative practice. These algorithms are tools in your toolbox.\n", "\n", "In the beginning, on your way to becoming a practioner, it's more important to understand your problem (such as, classification versus regression) and then knowing what tools you can use to solve it.\n", "\n", "Since our dataset is relatively small, we can experiment to find algorithm performs best.\n", "\n", "All of the algorithms in the Scikit-Learn library use the same functions, for training a model, `model.fit(X_train, y_train)` and for scoring a model `model.score(X_test, y_test)`. `score()` returns the ratio of correct predictions (1.0 = 100% correct).\n", "\n", "Since the algorithms we've chosen implement the same methods for fitting them to the data as well as evaluating them, let's put them in a dictionary and create a which fits and scores them." ] }, { "cell_type": "code", "execution_count": 225, "id": "68622ac2", "metadata": {}, "outputs": [], "source": [ "# Creating a model dictionary\n", "models = {\"Logistic Regression\" : LogisticRegression(),\n", " \"KNN\" : KNeighborsClassifier(),\n", " \"Random Forest\" : RandomForestClassifier()}\n", "\n", "# Using a function to fit and evaluate the score of models\n", "def fit_and_score(models, X_train, X_test, y_train, y_test):\n", " \"\"\"\n", " Fitting and evaluating the give machine learning models.\n", " models : a dictionary of different Scitkit-Learn machine learning models.\n", " X_train : training data (no labels)\n", " X_test : testing data (no labels)\n", " y_train : training labels\n", " y_test : \"testing labels\"\n", " \n", " \"\"\"\n", " # Set random seed\n", " np.random.seed(42)\n", " # Make a dictionary to keep model scores\n", " model_scores = {}\n", " # Loop through models\n", " for name, model in models.items():\n", " # Fit the model to the data\n", " model.fit(X_train, y_train)\n", " # Evaluate the model and append its score to model_scores\n", " model_scores[name] = model.score(X_test, y_test)\n", " return model_scores\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 226, "id": "4fa0affc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\vedant\\coding stuff\\project\\heart-disease-prediction\\env\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n" ] }, { "data": { "text/plain": [ "{'Logistic Regression': 0.8852459016393442,\n", " 'KNN': 0.6885245901639344,\n", " 'Random Forest': 0.8360655737704918}" ] }, "execution_count": 226, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_scores = fit_and_score(models = models,\n", " X_train = X_train,\n", " X_test = X_test,\n", " y_train = y_train,\n", " y_test = y_test)\n", "\n", "model_scores" ] }, { "cell_type": "markdown", "id": "a2f2294d", "metadata": {}, "source": [ "### ***Model comparison***" ] }, { "cell_type": "code", "execution_count": 268, "id": "b1ce7cf5", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_compare = pd.DataFrame(model_scores, index = [\"accuracy\"])\n", "model_compare.T.plot.bar(color = 'crimson', figsize = (7, 5));\n", "plt.xticks(rotation = 0)\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "832870a4", "metadata": {}, "source": [ "#### ****These things should be given more attention while working with a classification problem.****\n", "\n", "#### ***1. Hyperparameter tuning***\n", "#### ***2. Feature importance***\n", "#### ***3. Confusion matrix***\n", "#### ***4. Cross-validation***\n", "#### ***5. Precision***\n", "#### ***6. Recall***\n", "#### ***7. F1 score***\n", "#### ***8. Classification report***\n", "#### ***9. ROC curve***\n", "#### ***10. Area under the curve (AUC)***" ] }, { "cell_type": "markdown", "id": "4472d9bb", "metadata": {}, "source": [ "* **Hyperparameter tuning** - Each model you use has a series of dials you can turn to dictate how they perform. Changing these values may increase or decrease model performance.\n", "\n", "\n", "* **Feature importance** - If there are a large amount of features we're using to make predictions, do some have more importance than others? For example, for predicting heart disease, which is more important, sex or age?\n", "\n", "\n", "* [**Confusion matrix**](https://www.dataschool.io/simple-guide-to-confusion-matrix-terminology/) - Compares the predicted values with the true values in a tabular way, if 100% correct, all values in the matrix will be top left to bottom right (diagnol line).\n", "\n", "\n", "* [**Cross-validation**](https://scikit-learn.org/stable/modules/cross_validation.html) - Splits your dataset into multiple parts and train and tests your model on each part and evaluates performance as an average. \n", "\n", "\n", "* [**Precision**](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) - Proportion of true positives over total number of samples. Higher precision leads to less false positives.\n", "\n", "\n", "* [**Recall**](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) - Proportion of true positives over total number of true positives and false negatives. Higher recall leads to less false negatives.\n", "\n", "\n", "* [**F1 score**](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) - Combines precision and recall into one metric. 1 is best, 0 is worst.\n", "\n", "\n", "* [**Classification report**](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html) - Sklearn has a built-in function called `classification_report()` which returns some of the main classification metrics such as precision, recall and f1-score.\n", "\n", "\n", "* [**ROC Curve**](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_score.html) - [Receiver Operating Characterisitc](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) is a plot of true positive rate versus false positive rate.\n", "\n", "\n", "* [**Area Under Curve (AUC)**](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html) - The area underneath the ROC curve. A perfect model achieves a score of 1.0." ] }, { "cell_type": "markdown", "id": "fb74b8e9", "metadata": {}, "source": [ "## ***Hyperparameter Tuning (by hand)***" ] }, { "cell_type": "code", "execution_count": 228, "id": "7b5a48c3", "metadata": {}, "outputs": [], "source": [ "# Tuning KNN model\n", "\n", "train_scores = []\n", "test_scores = []\n", "\n", "# Create a list of different values for n_neighbors\n", "neighbors = range(1,21)\n", "\n", "# Setup KNN instance\n", "knn = KNeighborsClassifier()\n", "\n", "# Loop through different n_neighbors\n", "for i in neighbors:\n", " knn.set_params(n_neighbors = i)\n", " \n", " # Fit the algorithm\n", " knn.fit(X_train, y_train)\n", " \n", " # Update the training scores list\n", " train_scores.append(knn.score(X_train, y_train))\n", " \n", " # Update the test scores list\n", " test_scores.append(knn.score(X_test, y_test))" ] }, { "cell_type": "code", "execution_count": 229, "id": "21738305", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0,\n", " 0.8099173553719008,\n", " 0.7727272727272727,\n", " 0.743801652892562,\n", " 0.7603305785123967,\n", " 0.7520661157024794,\n", " 0.743801652892562,\n", " 0.7231404958677686,\n", " 0.71900826446281,\n", " 0.6942148760330579,\n", " 0.7272727272727273,\n", " 0.6983471074380165,\n", " 0.6900826446280992,\n", " 0.6942148760330579,\n", " 0.6859504132231405,\n", " 0.6735537190082644,\n", " 0.6859504132231405,\n", " 0.6652892561983471,\n", " 0.6818181818181818,\n", " 0.6694214876033058]" ] }, "execution_count": 229, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_scores" ] }, { "cell_type": "code", "execution_count": 230, "id": "b08d8be1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.6229508196721312,\n", " 0.639344262295082,\n", " 0.6557377049180327,\n", " 0.6721311475409836,\n", " 0.6885245901639344,\n", " 0.7213114754098361,\n", " 0.7049180327868853,\n", " 0.6885245901639344,\n", " 0.6885245901639344,\n", " 0.7049180327868853,\n", " 0.7540983606557377,\n", " 0.7377049180327869,\n", " 0.7377049180327869,\n", " 0.7377049180327869,\n", " 0.6885245901639344,\n", " 0.7213114754098361,\n", " 0.6885245901639344,\n", " 0.6885245901639344,\n", " 0.7049180327868853,\n", " 0.6557377049180327]" ] }, "execution_count": 230, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_scores" ] }, { "cell_type": "code", "execution_count": 267, "id": "4235ff1d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Maximum KNN score on the test data : 75.41%\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(neighbors, train_scores, label = \"Train score\")\n", "plt.plot(neighbors, test_scores, label = \"Test score\")\n", "plt.xticks(np.arange(1,21,2))\n", "plt.xlabel(\"Number of neighbors\")\n", "plt.ylabel(\"Model score\")\n", "plt.legend()\n", "plt.grid()\n", "\n", "print(f\"Maximum KNN score on the test data : {max(test_scores)*100:.2f}%\")" ] }, { "cell_type": "markdown", "id": "9c737a52", "metadata": {}, "source": [ "### ***Hyperparameter tuning with RandomizedSearchCV***\n", "\n", "***We're about to tune :***\n", "* ***1. LogisticRegression()***\n", "* ***2. RandomForestClassifier()***" ] }, { "cell_type": "code", "execution_count": 232, "id": "38fe2646", "metadata": {}, "outputs": [], "source": [ "# Create a hyperparameter grid for LogisticRegression\n", "log_reg_grid = {\"C\" : np.logspace(-4, 4, 20),\n", " \"solver\" : [\"liblinear\"]}\n", "\n", "# Create a hyperparameter grid for RandomForestClassifier\n", "rf_grid = {\"n_estimators\" : np.arange(10, 1000, 50),\n", " \"max_depth\" : [None, 3, 5, 10],\n", " \"min_samples_split\" : np.arange(2,20,2),\n", " \"min_samples_leaf\" : np.arange(1, 20, 2)}" ] }, { "cell_type": "markdown", "id": "f85515b8", "metadata": {}, "source": [ "***We have created hyperparameter grid setup for each of our models, now we'll tune them using RandomizedSearchCV***" ] }, { "cell_type": "code", "execution_count": 233, "id": "f1ee2007", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 20 candidates, totalling 100 fits\n" ] }, { "data": { "text/plain": [ "RandomizedSearchCV(cv=5, estimator=LogisticRegression(), n_iter=20,\n", " param_distributions={'C': array([1.00000000e-04, 2.63665090e-04, 6.95192796e-04, 1.83298071e-03,\n", " 4.83293024e-03, 1.27427499e-02, 3.35981829e-02, 8.85866790e-02,\n", " 2.33572147e-01, 6.15848211e-01, 1.62377674e+00, 4.28133240e+00,\n", " 1.12883789e+01, 2.97635144e+01, 7.84759970e+01, 2.06913808e+02,\n", " 5.45559478e+02, 1.43844989e+03, 3.79269019e+03, 1.00000000e+04]),\n", " 'solver': ['liblinear']},\n", " verbose=True)" ] }, "execution_count": 233, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Tuning LogisticRegression\n", "\n", "np.random.seed(42)\n", "\n", "# Setup random hyperparameter search for LogisticRegression\n", "rs_log_reg = RandomizedSearchCV(LogisticRegression(),\n", " param_distributions = log_reg_grid,\n", " cv = 5,\n", " n_iter = 20,\n", " verbose = True)\n", "\n", "# Fitting random hyperparameter search model for LogisticRegression\n", "rs_log_reg.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 234, "id": "11abe1ec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'solver': 'liblinear', 'C': 0.23357214690901212}" ] }, "execution_count": 234, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rs_log_reg.best_params_" ] }, { "cell_type": "code", "execution_count": 235, "id": "747f85c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8852459016393442" ] }, "execution_count": 235, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rs_log_reg.score(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 236, "id": "1f164631", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 20 candidates, totalling 100 fits\n" ] }, { "data": { "text/plain": [ "RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(), n_iter=20,\n", " param_distributions={'max_depth': [None, 3, 5, 10],\n", " 'min_samples_leaf': array([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]),\n", " 'min_samples_split': array([ 2, 4, 6, 8, 10, 12, 14, 16, 18]),\n", " 'n_estimators': array([ 10, 60, 110, 160, 210, 260, 310, 360, 410, 460, 510, 560, 610,\n", " 660, 710, 760, 810, 860, 910, 960])},\n", " verbose=True)" ] }, "execution_count": 236, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Tuning RandomForestClassifier\n", "\n", "np.random.seed(42)\n", "\n", "# Setup random hyperparameter search for RandomForestClassifier\n", "rs_rf = RandomizedSearchCV(RandomForestClassifier(),\n", " param_distributions = rf_grid,\n", " cv = 5,\n", " n_iter = 20,\n", " verbose = True)\n", "\n", "# Fit random hyperparameter search model for RandomForestClassifier()\n", "rs_rf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 237, "id": "74231ab3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'n_estimators': 210,\n", " 'min_samples_split': 4,\n", " 'min_samples_leaf': 19,\n", " 'max_depth': 3}" ] }, "execution_count": 237, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Find the best hyperparameters\n", "rs_rf.best_params_" ] }, { "cell_type": "code", "execution_count": 238, "id": "38b1e06d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8688524590163934" ] }, "execution_count": 238, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the randomized search RandomForestClassifier model\n", "rs_rf.score(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 239, "id": "9d091892", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Logistic Regression': 0.8852459016393442,\n", " 'KNN': 0.6885245901639344,\n", " 'Random Forest': 0.8360655737704918}" ] }, "execution_count": 239, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_scores" ] }, { "cell_type": "markdown", "id": "762fab3f", "metadata": {}, "source": [ "### ***Three ways to tune the model :***\n", "#### ***1. By hand***\n", "#### ***2. RandomizedSearchCV***\n", "#### ***3. GridSearchCV***" ] }, { "cell_type": "markdown", "id": "bee37f6c", "metadata": {}, "source": [ "### ***Hyperparameter tuning with GridSearchCV***\n", "\n", "***Trying to improve LogisticRegression model with GridSearchCV***" ] }, { "cell_type": "code", "execution_count": 240, "id": "5f49d817", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 30 candidates, totalling 150 fits\n" ] } ], "source": [ "# Different hyperparameters for our LogisticRegression model\n", "log_reg_grid = {\"C\" : np.logspace(-4, 4, 30),\n", " \"solver\" : [\"liblinear\"]}\n", "\n", "# Setup grid hyperparameter search for LogisticRegression\n", "gs_log_reg = GridSearchCV(LogisticRegression(),\n", " param_grid = log_reg_grid,\n", " cv = 5,\n", " verbose = True)\n", "\n", "# Fit grid parameter search model\n", "gs_log_reg.fit(X_train, y_train);" ] }, { "cell_type": "code", "execution_count": 241, "id": "378cabaf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 0.20433597178569418, 'solver': 'liblinear'}" ] }, "execution_count": 241, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check the best hyperparameters\n", "gs_log_reg.best_params_" ] }, { "cell_type": "code", "execution_count": 242, "id": "36a44d35", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8852459016393442" ] }, "execution_count": 242, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluate the grid search LogisticRegression model\n", "gs_log_reg.score(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 243, "id": "e2193fd5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Logistic Regression': 0.8852459016393442,\n", " 'KNN': 0.6885245901639344,\n", " 'Random Forest': 0.8360655737704918}" ] }, "execution_count": 243, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_scores" ] }, { "cell_type": "markdown", "id": "76937b8f", "metadata": {}, "source": [ "### ***Evaluating our tunened machine learning classifier, beyond accuracy (using cross-validation)***\n", "\n", "* ROC curve and AUC score\n", "* Confusion matrix\n", "* Classification report\n", "* Precision\n", "* Recall\n", "* F1-score\n", "\n", "***To make comparisons and evaluate our trained model, first we need to make predictions***" ] }, { "cell_type": "code", "execution_count": 244, "id": "c6764ca3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,\n", " 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], dtype=int64)" ] }, "execution_count": 244, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make predictions with tune models\n", "y_preds = gs_log_reg.predict(X_test)\n", "y_preds" ] }, { "cell_type": "code", "execution_count": 245, "id": "98b3d249", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "179 0\n", "228 0\n", "111 1\n", "246 0\n", "60 1\n", " ..\n", "249 0\n", "104 1\n", "300 0\n", "193 0\n", "184 0\n", "Name: target, Length: 61, dtype: int64" ] }, "execution_count": 245, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_test" ] }, { "cell_type": "markdown", "id": "9ec9e6b5", "metadata": {}, "source": [ "### ***ROC Curve and AUC Scores***\n", "\n", "***What's a ROC curve?***\n", "\n", "* It's a way of understanding how your model is performing by comparing the true positive rate to the false positive rate.\n", "\n", "In our case...\n", "\n", "> To get an appropriate example in a real-world problem, consider a diagnostic test that seeks to determine whether a person has a certain disease. A false positive in this case occurs when the person tests positive, but does not actually have the disease. A false negative, on the other hand, occurs when the person tests negative, suggesting they are healthy, when they actually do have the disease.\n", "\n", "Scikit-Learn implements a function `plot_roc_curve` which can help us create a ROC curve as well as calculate the area under the curve (AUC) metric.\n", "\n", "\n", "Reading the documentation on the [`plot_roc_curve`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_roc_curve.html) function we can see it takes `(estimator, X, y)` as inputs. Where `estiamator` is a fitted machine learning model and `X` and `y` are the data you'd like to test it on.\n", "\n", "\n", "In our case, we'll use the GridSearchCV version of our `LogisticRegression` estimator, `gs_log_reg` as well as the test data, `X_test` and `y_test`." ] }, { "cell_type": "code", "execution_count": 266, "id": "596f5650", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\vedant\\coding stuff\\project\\heart-disease-prediction\\env\\lib\\site-packages\\sklearn\\utils\\deprecation.py:87: FutureWarning: Function plot_roc_curve is deprecated; Function `plot_roc_curve` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: RocCurveDisplay.from_predictions or RocCurveDisplay.from_estimator.\n", " warnings.warn(msg, category=FutureWarning)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plotting ROC curve and calculate AUC metric\n", "plot_roc_curve(gs_log_reg, X_test, y_test, color = 'crimson')\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2c8c7402", "metadata": {}, "source": [ "### ***Confusion Matrix***\n", "> A confusion matrix is a visual way to show where your model made the right predictions and where it made the wrong predictions (or in other words, got confused).\n", "Scikit-Learn allows us to create a confusion matrix using [`confusion_matrix()`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html) and passing it the true labels and predicted labels." ] }, { "cell_type": "code", "execution_count": 247, "id": "be8aef14", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[25 4]\n", " [ 3 29]]\n" ] } ], "source": [ "print(confusion_matrix(y_test, y_preds))" ] }, { "cell_type": "code", "execution_count": 248, "id": "bea3b954", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Making our confusion matrix more visual\n", "sns.set_theme(font_scale = 1.5)\n", "\n", "def plot_conf_mat(y_test, y_preds):\n", " \"\"\"\n", " Plots a confusion matrix using Seaborn's heatmap().\n", " \"\"\"\n", " fig, ax = plt.subplots(figsize=(3, 3))\n", " ax = sns.heatmap(confusion_matrix(y_test, y_preds),\n", " annot=True, # Annotate the boxes\n", " cbar=False)\n", " plt.xlabel(\"Predicted label\") # predictions go on the x-axis\n", " plt.ylabel(\"True label\") # true labels go on the y-axis \n", " \n", "plot_conf_mat(y_test, y_preds)" ] }, { "cell_type": "markdown", "id": "9740236f", "metadata": {}, "source": [ "***We can see the model gets confused (predicts the wrong label) relatively the same across both classes. In essence, there are 4 occasaions where the model predicted 0 when it should've been 1 (false negative) and 3 occasions where the model predicted 1 instead of 0 (false positive).***" ] }, { "cell_type": "markdown", "id": "d291030e", "metadata": {}, "source": [ "***After getting a ROC curve, an AUC metric and a confusion matrix, now we should get a classification report as well as cross-validated precision, recall and f1-score.***" ] }, { "cell_type": "markdown", "id": "82b01b30", "metadata": {}, "source": [ "### ***Classification report***\n", "\n", "> We can make a classification report using [`classification_report()`](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html) and passing it the true labels as well as our models predicted labels. \n", "\n", "> A classification report will also give us information of the precision and recall of our model for each class." ] }, { "cell_type": "code", "execution_count": 249, "id": "f75f5995", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.89 0.86 0.88 29\n", " 1 0.88 0.91 0.89 32\n", "\n", " accuracy 0.89 61\n", " macro avg 0.89 0.88 0.88 61\n", "weighted avg 0.89 0.89 0.89 61\n", "\n" ] } ], "source": [ "print(classification_report(y_test, y_preds))" ] }, { "cell_type": "markdown", "id": "08a55444", "metadata": {}, "source": [ "### ***Calculating evaluation metrics using cross-validation***\n", "\n", "> ***We'll evaluate precision, recall and f1-score of our model using cross-validation and to do so we'll be using `cross_val_score()`.***" ] }, { "cell_type": "code", "execution_count": 250, "id": "56be68a5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 0.20433597178569418, 'solver': 'liblinear'}" ] }, "execution_count": 250, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Checking best hyperparameters\n", "gs_log_reg.best_params_" ] }, { "cell_type": "code", "execution_count": 251, "id": "e53c1140", "metadata": {}, "outputs": [], "source": [ "# Creating a new classifier with best hyperparameters\n", "clf = LogisticRegression(C = 0.20433597178569418,\n", " solver = \"liblinear\")" ] }, { "cell_type": "markdown", "id": "5ac78e70", "metadata": {}, "source": [ "#### ***1. Cross-validated : accuracy***" ] }, { "cell_type": "code", "execution_count": 252, "id": "a6945bd8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.81967213, 0.90163934, 0.86885246, 0.88333333, 0.75 ])" ] }, "execution_count": 252, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_acc = cross_val_score(clf,\n", " X,\n", " y,\n", " cv = 5,\n", " scoring = \"accuracy\")\n", "cv_acc" ] }, { "cell_type": "code", "execution_count": 253, "id": "d28eed93", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8446994535519124" ] }, "execution_count": 253, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(cv_acc)" ] }, { "cell_type": "markdown", "id": "f585b290", "metadata": {}, "source": [ "#### ***2. Cross-validated : precision***" ] }, { "cell_type": "code", "execution_count": 254, "id": "b0e63e3f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8207936507936507" ] }, "execution_count": 254, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_precision = cross_val_score(clf,\n", " X,\n", " y,\n", " cv = 5,\n", " scoring = \"precision\")\n", "\n", "cv_precision = np.mean(cv_precision)\n", "cv_precision" ] }, { "cell_type": "markdown", "id": "a09b312d", "metadata": {}, "source": [ "#### ***3. Cross-validated : recall***" ] }, { "cell_type": "code", "execution_count": 255, "id": "017424ac", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9212121212121213" ] }, "execution_count": 255, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_recall = cross_val_score(clf,\n", " X,\n", " y,\n", " cv = 5,\n", " scoring = \"recall\")\n", "\n", "cv_recall = np.mean(cv_recall)\n", "cv_recall" ] }, { "cell_type": "markdown", "id": "79c7d790", "metadata": {}, "source": [ "#### ***4. Cross-validated : f1-score***" ] }, { "cell_type": "code", "execution_count": 256, "id": "d4acd737", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8673007976269721" ] }, "execution_count": 256, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_f1 = cross_val_score(clf, \n", " X, \n", " y, \n", " cv = 5, \n", " scoring = \"f1\")\n", "\n", "cv_f1 = np.mean(cv_f1)\n", "cv_f1" ] }, { "cell_type": "code", "execution_count": 257, "id": "c28a6b5a", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEUCAYAAADHgubDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyg0lEQVR4nO3dfVyN9/8H8NeJbqbjNsxSQik3lZJqpCikpOFoyu3czk2G0Rc129xObLEiWrNYcpOstRZhIXOzud3YxsaUTshtSuUm1fX7w6Pzc5yjc+hguV7Px6PHw/lcn+u63ten49V1fc51zpEIgiCAiIhEQe9VF0BERC8PQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEof8aKS4uRlxcHGQyGZycnODg4ICAgAAkJiaioqLiVZente+//x42NjZYu3Ztlf2mTp0KW1tb3L59W6vtenl5YcSIEYrHI0aMgJeXl8b1tO2nTnFxMfLz859rXXXmzJkDGxubKvskJyfDxsYGNjY22L1791P7LVq0CDY2Ns99bKWlpbh27ZrGfpX1HDly5Ln2Q7rF0H9NZGVlYdCgQVi+fDlsbGwwY8YMTJ06FYaGhvjkk08wa9Ys1JS3ZPTu3RtGRkbYuXPnU/uUlJRg//79cHd3R8OGDZ9rPxMnTkRYWNjzlqnRn3/+CV9fX5w/f/6F7UOTPXv2PHXZ3r17n3u7ly9fhr+/Pw4dOqSxr7OzM5YtWwZLS8vn3h/pTu1XXQBV34MHDzB58mQUFBRg27ZtaNu2rWLZmDFjMH/+fGzatAn29vYYOXLkK6xUO1KpFF5eXtixYwcuX76M5s2bq/TZs2cP7t+/j3feeee59+Pm5ladMjU6d+4crl+//kL3URUzMzNkZmaivLwctWrVUlr2119/4fLly2jUqNFzbfvSpUu4ePGiVn3Nzc1hbm7+XPsh3eOZ/mtg06ZNyM7ORmhoqFLgV5o9ezbq16+PLVu2vILqnk9lmO/atUvt8h07dij+OJB6PXv2REFBAU6cOKGy7KeffkKLFi1gZWX1CiqjV4mh/xrYvn076tSpAz8/P7XLjYyMsHXrVqSkpCjavLy8MHfuXISFhcHOzg4eHh6Kuefjx49j1KhRcHR0hKOjI0aOHIljx44pbbOwsBBz5sxBjx49YGtri169eiEiIgIPHjxQ9CktLcXixYvRs2dP2Nraonv37pg/fz4KCws1HlO3bt3QsGFDtaF/584dHDx4EH369IGhoSEEQcDmzZsREBAAR0dH2NnZwcfHB7GxsVVOaambqz98+DCCgoLg4OCAXr16YceOHWrX3blzJ4YPHw4nJyfY2trCy8sLy5YtQ2lpKQBg5cqVCA0NBQCMHDlSaT9Xr17FrFmz8Pbbb8POzg4DBgxAamqqyj7+/PNPjBkzBo6OjnB3d0d8fLzGcXucm5sb3njjDbXTOBkZGejdu7fa9TTVl5ycrLhiDA0NVbzGsHLlStjZ2eGnn36Cm5sbHB0dkZSUpHZOv7S0FCtXroS3tzfs7e3Rp08fxMbGory8XNFn165dGDRoEBwdHeHk5ITRo0er/QNGz4bTOzWcIAg4e/YsOnXqBH19/af2a9mypUrb9u3b0apVK3z00Ue4efMmGjVqhD179mDKlClo0aIFJk2aBABISkrCqFGjEBUVhZ49ewIApk+fjjNnzmDkyJFo2rQpfvvtN8TGxqKgoAALFy4EACxYsABpaWkYOXIkzM3Ncf78eWzcuBE5OTmIi4ur8rj09fXh6+uLzZs3Iy8vD2+99ZZi2e7du/Hw4UPF1cCXX36JmJgYDBw4EIMHD0ZJSQlSUlIQERGBJk2aYODAgVqN5eHDhzF+/Hi0bNkS06dPR35+Pj766CNIJBI0aNBA0S8pKQlz586Fl5cXQkJC8PDhQ/z000/45ptvUKdOHUyZMgW9e/fGjRs3kJiYiIkTJ8LOzg4AcO3aNbz77rsQBAEjRoxA/fr1sWfPHvzvf//D9evXMW7cOADA+fPnMWLECNSrVw+TJ0/Gw4cPER0drRSKmhgZGcHNzQ179uzBnDlzFO0XL17E+fPnsXDhQvzxxx9K62hTn7OzMyZOnIiYmBgEBgbCyclJsX5ZWRnmzp2LsWPHorS0FE5OTvj9999VagsODsbPP/8Mf39/jB49GqdPn0ZERARu3bqF0NBQHD16FB9++CE8PDzw7rvv4t69e0hISMDo0aOxfft2ThdVh0A12q1btwRra2vhww8/fKb1PD09hbZt2wo5OTmKtocPHwoeHh5C9+7dhaKiIkV7YWGh4O7uLri7uwulpaXCzZs3BWtra2Ht2rVK25wzZ47w3nvvKR7b29sL8+fPV+qzYsUKQSaTCcXFxRprPHHihGBtbS2sW7dOqX3MmDGCh4eHUF5eLpSWlgqdOnVSOf6ioiLB1tZWmDBhgtIxDx8+XPF4+PDhgqenp+LxwIEDVY79l19+EaytrZX6+fj4CIGBgUJFRYWirXLs+vXrp2j77rvvBGtra+HXX39VtM2ePVtwcXERrl27plTvjBkzBFtbW+HmzZuCIAjCBx98IDg4OAhXrlxR9Pn3338FW1tbwdrauspxe3y/ycnJgrW1tXDu3DnF8tjYWMHNzU2oqKhQGQNt6/v1118Fa2tr4bvvvlP0iYqKEqytrYWoqKin1iMIgpCZmSlYW1sLa9asUeo3c+ZMoUOHDkJBQYHw6aefCo6Ojkpj/Pfffwve3t5Cenp6lcdPVeP0Tg2np/foV/gsZ4CVWrRogRYtWigenzlzBlevXsWwYcMglUoV7fXq1cPw4cNx7do1/Pnnn6hbty7q1KmDTZs2YdeuXbh79y4AYMmSJVi/fr1ivWbNmmHHjh1ITk7GnTt3ADy6Qvjuu+9gbGyssb5OnTrB3NxcaYonPz8fv/76K/r16wc9PT3o6+vj8OHDWLBggdK6t2/fhlQqVdSmya1bt/DXX3/Bz89P6djffvttlVskU1NTERsbC4lEorR+vXr1qtxfRUUFMjIy0LlzZ9SuXRv5+fmKH29vb5SWluLQoUOoqKjAgQMH0L17d6UrHEtLS3Tr1k2r46nUo0cP1KpVS+kunoyMDPTq1Uup/mepTxNNNWZmZkJPTw/Dhw9Xap89ezZ++OEHSKVSNGvWDCUlJVi0aBEuXLgAALCxscGuXbvg4+Oj7eGTGpzeqeHq168PfX3957oX3MTEROnxpUuXAACtWrVS6du6dWsAwJUrV+Do6IgFCxbg448/xtSpU2FgYAAXFxd4e3tjwIABMDQ0BADMmzcP06dPR2hoKD7++GM4ODigd+/eGDRoEOrWrYvy8nKVuvX19ZWmUvr164eYmBhcu3YNb775Jnbu3ImysjKlu3b09fWRmZmJPXv2IDs7Gzk5OYrXDQQtb1O9fPkyACj9EXz82E+fPq20v2PHjiEtLQ1ZWVmQy+W4desWAKi906jS7du3UVRUhIyMDGRkZKjtk5eXh4KCAty9e/eptTzLrZYNGzaEk5MT9u7di4kTJ+L69es4deoUpk2b9tz1afLk8+pJly9fhomJidIfVwBo0qQJmjRpAgAYPnw4Dh48iISEBCQkJMDMzAyenp4ICAhQe7MCaY+hX8NJJBI4Ojrizz//RFlZGWrXVv8rXbFiBXJzcxEaGqr4j/XkbXxVBWTlssrXDfz9/eHu7o6MjAzs378fhw8fxsGDB7Fp0yYkJSXBwMAAXbp0wb59+xQ/hw4dUlwNJCcn4+7du4rXCCq5uLhgw4YNisf+/v5Ys2YNdu/ejREjRiA9PV3xxqPKuv73v/8hLS0NTk5OcHR0RGBgIJydnfHee+890zgCUHohutKTb2yLiIhAbGws2rdvDwcHB/Tv3x+Ojo5YuHBhlaFYeTXWp08fBAUFqe3z+Fy1NrVoo2fPnggPD8f169eRkZGBevXqwcXFpdr1PU3l1efTlJeXq1xlPEkqlSIhIQG///47MjIy8PPPP2PDhg3YuHEjli1bBn9/f411kHoM/ddA7969cfToUezYsUPtfev379/Htm3bUF5ernQW/aTKs9SsrCyVZdnZ2QCguOw+e/Ys2rRpg4CAAAQEBKC0tBSff/454uPjcfDgQXTr1g1nz55Fs2bN4OfnBz8/P1RUVGDdunVYtmwZtm/fjsGDB2PdunVK+6lXr57SY0tLS3To0AG7d+9G3759cfz4ccycOVOx/Pjx40hLS8PkyZOVzl7LyspQUFCg9Qt+zZs3h0QiUXvveeUVEPDoLDU2Nhb9+/fHsmXLlPrdvHmzyn00atQIb7zxBsrKytC1a1elZVeuXMGZM2fwxhtvoGHDhpBKpRpr0VavXr2wZMkSxdWQp6en2pMDbeurLlNTUxw+fBglJSVK03x//fUX4uLiMGnSJNSqVQtFRUVwcHCAg4MDQkJC8O+//2LYsGFYt24dQ78aOKf/GggMDETz5s2xdOlSnDt3TmlZeXk55s2bh5s3b2L8+PFV3uHToUMHNGnSBJs3b0ZxcbGivbi4GJs2bUKTJk1ga2uL8+fPY9iwYdi2bZuij4GBAdq3bw/g0RVEQUEBAgMD8dVXXyn66OnpKe5i0dPTg6GhIbp27ar0Y2trq1LXO++8g5MnT+LHH38EAKX/8AUFBQCgcr/51q1bce/ePZSVlVU5dpUaNWoEZ2dnpKamKoX3b7/9hr/++kvxuHLa6Mn97d+/HxcvXlTaX+UZb+XZee3ateHh4YH9+/fj77//Vlo/PDwcwcHBuH37NiQSCXr37o0DBw4o/T4vXbqEzMxMrY7ncWZmZmjbti3S0tJw5MiRp96qqW19wP9fJT7PlUf37t1RUVGBpKQkpfbNmzcjPT0djRs3xqJFizB58mSUlJQolrdu3Rr16tXTeCVBVeOZ/mvA0NAQq1atwpgxYxAQEAB/f3/Y2dmhoKAAO3fuxNmzZ+Hj44PRo0dXuR19fX18/PHHmD59OgYNGoSAgAAAwLZt23D9+nVERUVBT08PHTt2ROfOnbFixQrk5eXBxsYGeXl5SEhIQOvWrdGlSxcYGBjA398fmzZtwr179+Do6IiCggIkJCSgcePG8PX11fr4/Pz8sGzZMkRHR8PFxQVvvvmmYpmjoyOkUimWLFmCK1euoF69ejhy5Ah27NgBQ0NDpdDQZPbs2Rg2bBgGDx6MYcOG4d69e1i/fr3SxzxYWVnB1NQUMTExePDgAZo1a4bTp0/j+++/V9lf5btdN2/ejJs3b8Lf3x8hISE4cuQIhg0bhmHDhsHU1BSZmZnYt28fAgMD0aZNGwDAtGnTkJmZiREjRmDUqFGoVasWNmzYAGNjY8V7AZ5Fr169sGrVKtSpU6fKF1q1ra9yTFJTUyEIgta3xQKP3iPi5uaG8PBwnD9/HnZ2dvjtt9+QkpKC4OBgNGjQAKNHj8b48eMxbNgwxetEGRkZkMvlWLp06TMfPz3mld47RDp19epVITw8XOjbt6/g4OAgdOzYURg8eLCwbds2pVvfBEH19sXHHT58WBg+fLjQsWNHwcnJSRgzZoxw7NgxpT63b98WFi5cKHh5eQm2traCm5ub8NFHHwnXr19X9Ll3754QGRkpeHt7C3Z2doKLi4swbdo04eLFi898bKNHjxasra2Fbdu2qSw7fvy4EBQUJDg4OAguLi5CYGCgsH37dmHJkiVCu3bthBs3bqg95idvVxQEQTh16pQwcuRIwcHBQfDw8BDWrVsnzJw5U6nfuXPnhDFjxgidO3cWnJychIEDBwobN24Uvv32W8Ha2lr4448/BEEQhNLSUmHatGmCvb294OzsLNy/f18QBEG4ePGiMGPGDMHV1VWws7MT+vbtK6xbt04oKytTqiUrK0uYOHGi0KlTJ+Htt98WIiIihIiIiGe6ZbPSmTNnBGtra+GDDz5Q6qtuDLStb+HChYKjo6Pg4OAg5OTkKG7ZzM3N1VjPvXv3hIiICKFHjx6Cra2t0LdvXyEhIUEoLy9X9Nm3b58QFBQkODs7C/b29sKgQYOEtLS0Ko+dNJMIQg35FC4iIqo2To4REYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiETkP//mrNu3S1BR8d++q9TERIpbt4o1dyStcDx1i+OpOzVhLPX0JGjY8OmfYvufD/2KCuE/H/oAakSNNQnHU7c4nrpT08eS0ztERCLC0CciEhGGPhGRiDD0iYhEhKFPRCQiDH0iIhFh6BMRich//j59IjFpaKyP2nWMdL7dJk3q6nR7ZXfv43bJQ51uk14Ohj7Rf0jtOka40MT9VZehkeWNAwBDv0bi9A4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEH8NARK8lfo6Regx9Inot8XOM1OP0DhGRiDD0iYhEhKFPRCQiDH0iIhFh6BMRiQhDn4hIRBj6REQiwtAnIhIRrUM/LS0Nfn5+sLe3h6+vL1JSUqrsn5+fj9DQUHTr1g0uLi6YMGECLl68WM1yiYioOrR6R256ejpCQkIwcuRIuLu7IyMjA7Nnz4aRkRF8fHxU+guCgODgYMjlcvzvf/9DgwYNEBUVhZEjR+LHH39E/fr1dX4g2uJbs3WL40lUs2gV+suXL4evry/CwsIAAO7u7igsLERkZKTa0L948SJOnjyJpUuXYsCAAQAAS0tL9OrVC3v37sXAgQN1dwTPiG/N1i2OJ1HNonF6Jzc3F3K5HN7e3krtffr0QVZWFnJzc1XWefDgAQDA2NhY0VZ5dl9QUFCdeomIqBo0hn5WVhYAoFWrVkrtFhYWAIDs7GyVddq2bQtXV1dER0fjwoULyM/Px6JFi1CnTh306tVLF3UTEdFz0Di9U1RUBACQSqVK7ZVn8cXFxWrXmzdvHsaNG4e+ffsCAAwMDBAdHQ1zc/NnKtDERKq502tK1/PaYsfx1C2Op+68zLHUGPqCIAAAJBKJ2nY9PdWLhQsXLiAoKAgtWrRAWFgYjIyMsHXrVkydOhVr165F586dtS7w1q1iVFQIWvfXpCY9UW/cKHrVJWjE8dQtjqfuiHUs9fQkVZ4sawz9unUfDdyTZ/QlJSVKyx+3fv16AEBcXJxiLt/NzQ1Dhw7FZ599huTkZO2qJyIindI4p185ly+Xy5Xac3JylJY/7sqVK7C0tFS6NVMikcDJyQn//vtvtQomIqLnpzH0LSwsYGZmhp07dyq17969Gy1btoSpqanKOq1atcL58+dRWFio1H7q1Ck0b968miUTEdHz0uo+/eDgYISGhqJ+/fro0aMH9u7di/T0dKxYsQLAo3ffyuVyWFlZQSqVYtSoUUhNTcXYsWPx/vvvw8jICD/88AOOHj2qWIeIiF4+rUJfJpOhtLQUcXFxSEpKgrm5OZYuXaq4MyczMxOhoaGIj4+Hq6srzMzMsHnzZnz++eeYM2cO9PT0YG1tjXXr1qFr164v9ICIiOjptP5i9KCgIAQFBaldJpPJIJPJlNosLS0RExNTveqIiEin+CmbREQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEoU9EJCIMfSIiEWHoExGJCEOfiEhEGPpERCLC0CciEhGGPhGRiDD0iYhEhKFPRCQiDH0iIhFh6BMRiQhDn4hIRBj6REQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISES0Dv20tDT4+fnB3t4evr6+SElJqbJ/RUUF1qxZg549e8Le3h7+/v7Yvn17deslIqJqqK1Np/T0dISEhGDkyJFwd3dHRkYGZs+eDSMjI/j4+Khd57PPPkNiYiJmzJiBtm3bYvv27Zg5cyakUim6d++u04MgIiLtaBX6y5cvh6+vL8LCwgAA7u7uKCwsRGRkpNrQl8vl2LhxIxYsWIB3330XANClSxdcvHgRBw4cYOgTEb0iGkM/NzcXcrkcM2bMUGrv06cP0tPTkZubC3Nzc6VlGRkZMDIywoABA5TaExISql8xERE9N41z+llZWQCAVq1aKbVbWFgAALKzs1XW+eeff9CqVSscPnwY77zzDtq3bw9vb2/s2LFDFzUTEdFz0nimX1RUBACQSqVK7cbGxgCA4uJilXXy8/ORl5eHsLAwTJs2DWZmZkhKSsKHH36IRo0a4e2339a6QBMTqeZOr6kmTeq+6hJeKxxP3eJ46s7LHEuNoS8IAgBAIpGobdfTU71YePjwIfLz8xETEwNPT08Aj+b0s7KysGrVqmcK/Vu3ilFRIWjdX5Oa9ES9caPoVZegEcdTtzieuiPWsdTTk1R5sqxxeqdu3UcD9+QZfUlJidLyxxkbG6NWrVpwc3NTtEkkEnTt2hX//POPdpUTEZHOaQz9yrl8uVyu1J6Tk6O0/HEWFhaoqKhAWVmZUvvDhw9VrhiIiOjl0Rj6FhYWMDMzw86dO5Xad+/ejZYtW8LU1FRlHXd3dwiCgPT0dEVbWVkZDhw4ACcnJx2UTUREz0Or+/SDg4MRGhqK+vXro0ePHti7dy/S09OxYsUKAI9euJXL5bCysoJUKkWXLl3QvXt3LFq0CHfv3kXLli2xadMmXL58GRERES/0gIiI6Om0Cn2ZTIbS0lLExcUhKSkJ5ubmWLp0Kfr27QsAyMzMRGhoKOLj4+Hq6goAiIqKQmRkJGJjY1FYWIj27dsjLi4Otra2L+5oiIioShKh8jac/6gXcffOhSbuOtvei2J548B//u4IgOOpaxxP3RHrWFb77h0iInp9MPSJiESEoU9EJCIMfSIiEWHoExGJCEOfiEhEGPpERCLC0CciEhGGPhGRiDD0iYhEhKFPRCQiDH0iIhFh6BMRiQhDn4hIRBj6REQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEoU9EJCIMfSIiEWHoExGJCEOfiEhEGPpERCLC0CciEhGtQz8tLQ1+fn6wt7eHr68vUlJStN5JXl4enJycsHr16uepkYiIdESr0E9PT0dISAjc3NwQHR0NFxcXzJ49Gzt37tS4riAICAsLQ3FxcbWLJSKi6qmtTafly5fD19cXYWFhAAB3d3cUFhYiMjISPj4+Va67adMmZGVlVb9SIiKqNo1n+rm5uZDL5fD29lZq79OnD7KyspCbm1vlul988QUWLlxY/UqJiKjaNIZ+5Vl6q1atlNotLCwAANnZ2WrXq6iowJw5c+Dr6wsPD4/q1klERDqgcXqnqKgIACCVSpXajY2NAeCpc/XffvstcnNzERMTU60CTUykmju9ppo0qfuqS3itcDx1i+OpOy9zLDWGviAIAACJRKK2XU9P9WIhKysLX375JaKiolC3bvUO5tatYlRUCNXaxuNq0hP1xo2iV12CRhxP3eJ46o5Yx1JPT1LlybLG6Z3K0H7yjL6kpERpeaXy8nLMmTMHPj4+cHNzQ1lZGcrKygA8mvKp/DcREb18GkO/ci5fLpcrtefk5Cgtr5SXl4dTp04hJSUFHTp0UPwAwMqVKxX/JiKil0/j9I6FhQXMzMywc+dO9O7dW9G+e/dutGzZEqampkr9mzZtim3btqlsJyAgAEOGDMGgQYN0UDYRET0Pre7TDw4ORmhoKOrXr48ePXpg7969SE9Px4oVKwAA+fn5kMvlsLKyglQqhZ2dndrtNG3a9KnLiIjoxdPqHbkymQzz58/HwYMHERwcjKNHj2Lp0qXo27cvACAzMxOBgYH466+/XmixRERUPVqd6QNAUFAQgoKC1C6TyWSQyWRVrv/PP/88W2VERKRz/JRNIiIRYegTEYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEoU9EJCIMfSIiEWHoExGJCEOfiEhEGPpERCLC0CciEhGGPhGRiDD0iYhEhKFPRCQiDH0iIhFh6BMRiQhDn4hIRBj6REQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEoU9EJCJah35aWhr8/Pxgb28PX19fpKSkVNn/xo0bmDt3Ljw9PeHo6AiZTIb09PTq1ktERNVQW5tO6enpCAkJwciRI+Hu7o6MjAzMnj0bRkZG8PHxUelfWlqKcePGoaioCFOnTkXTpk2xa9cuTJ8+HeXl5ejXr5/OD4SIiDTTKvSXL18OX19fhIWFAQDc3d1RWFiIyMhItaH/888/4++//0ZSUhLs7e0BAG5ubrhy5Qq+/vprhj4R0SuicXonNzcXcrkc3t7eSu19+vRBVlYWcnNzVdYxNjZGYGAg7OzslNpbt24NuVxezZKJiOh5aTzTz8rKAgC0atVKqd3CwgIAkJ2dDXNzc6VlXbp0QZcuXZTaHj58iP3796NNmzbVKpiIiJ6fxtAvKioCAEilUqV2Y2NjAEBxcbFWO/riiy9w8eJFREdHP1OBJiZSzZ1eU02a1H3VJbxWOJ66xfHUnZc5lhpDXxAEAIBEIlHbrqdX9QyRIAj4/PPPsX79eowdOxa9evV6pgJv3SpGRYXwTOtUpSY9UW/cKHrVJWjE8dQtjqfuiHUs9fQkVZ4sawz9unUfDdyTZ/QlJSVKy9UpLS3FnDlzsH37dowdOxazZs3SqmgiInoxNIZ+5Vy+XC6HjY2Noj0nJ0dp+ZOKi4sxYcIEnDx5EmFhYXjvvfd0US8REVWDxrt3LCwsYGZmhp07dyq17969Gy1btoSpqanKOuXl5Zg0aRJOnTqF5cuXM/CJiP4jtLpPPzg4GKGhoahfvz569OiBvXv3Ij09HStWrAAA5OfnQy6Xw8rKClKpFFu2bMHRo0cRGBiIt956C7///rtiWxKJBB07dnwhB0NERFXTKvRlMhlKS0sRFxeHpKQkmJubY+nSpejbty8AIDMzE6GhoYiPj4erqyt27doFAEhMTERiYqLStmrVqoUzZ87o+DCIiEgbWoU+AAQFBSEoKEjtMplMBplMpngcHx9f/cqIiEjn+CmbREQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISEQY+kREIsLQJyISEYY+EZGIMPSJiESEoU9EJCIMfSIiEWHoExGJCEOfiEhEGPpERCLC0CciEhGGPhGRiDD0iYhEhKFPRCQiDH0iIhFh6BMRiQhDn4hIRBj6REQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIRYegTEYkIQ5+ISES0Dv20tDT4+fnB3t4evr6+SElJqbJ/SUkJ5s+fDzc3Nzg6OmL8+PG4ePFiNcslIqLq0Cr009PTERISAjc3N0RHR8PFxQWzZ8/Gzp07n7rOhx9+iJ07dyIkJARLly7FtWvXMHLkSBQVFemseCIieja1tem0fPly+Pr6IiwsDADg7u6OwsJCREZGwsfHR6X/8ePHsX//fnz99dfw8PAAAHTu3Bk9e/bE5s2b8f777+vwEIiISFsaz/Rzc3Mhl8vh7e2t1N6nTx9kZWUhNzdXZZ1Dhw7B2NgYbm5uirZGjRrB2dkZP//8sw7KJiKi56HxTD8rKwsA0KpVK6V2CwsLAEB2djbMzc1V1rGwsECtWrWU2lu0aIH09PRnKlBPT/JM/bVR27yZzrf5IryIY38ROJ66xfHUHTGOpaZtaQz9yjl4qVSq1G5sbAwAKC4uVlmnuLhYpX/lOur6V6VhQ+Nn6q8Ni5NJOt/mi2BiojqG/0UcT93ieOoOx1KVxukdQRAAABKJRG27np7qJiqXqd2hmv5ERPRyaEzgunXrAlA9oy8pKVFa/jipVKpY/uQ66q4AiIjo5dAY+pVz+XK5XKk9JydHafmT6+Tm5qqc8efk5KjtT0REL4fG0LewsICZmZnKPfm7d+9Gy5YtYWpqqrJOt27dcOfOHRw+fFjRlp+fj+PHj6Nr1646KJuIiJ6HVvfpBwcHIzQ0FPXr10ePHj2wd+9epKenY8WKFQAeBbpcLoeVlRWkUimcnZ3h4uKCGTNmICQkBA0aNMDKlStRt25dDBky5IUeEBERPZ1EqOpV18ds2bIFcXFxyMvLg7m5Od5//30MGDAAAJCcnIzQ0FDEx8fD1dUVAFBYWIjw8HBkZGSgoqICTk5OmDNnDlq3bv3CDoaIiKqmdegTEVHNx/sniYhEhKFPRCQir13of/DBB7CxsUFiYuKrLqVGGjFiBGxsbJR+bG1t0bNnT4SHh+P+/fsvpYZRo0a9sP7/VerGvm3btujUqRNkMhl++OGHl15TcnIybGxscPXqVUWNr8NYP82cOXNUfgeP/zx+RyIAnD17Fh06dFCMT02g1d07NUV+fj727dsHa2trJCYmIjAw8FWXVCPZ2dlh7ty5iscPHjzAsWPHEB0djWvXrinu2npRPv30U5V3gOuy/3/Zk2NfUVGBq1ev4ttvv8WsWbPQoEEDdO/e/RVW+Ppr1qwZIiMj1S6zsrJS/DsrKwsTJkxAWVnZyypNJ16r0P/xxx9haGiIkJAQvP/++/jjjz9gZ2f3qsuqcaRSKRwcHJTaXF1dcfXqVWzbtg2hoaFo2rTpC9v/4/+xXkT//zJ1Yw8AHh4e6NKlC5KTkxn6L5iBgYHa30GlsrIyJCYmIiIiAvr6+i+vMB15raZ3kpOT4ebmBnd3dzRt2lRlikcQBKxfvx4+Pj6wt7dHnz59sGHDBqU++/fvR1BQEBwcHODu7o5FixYpPlJi5cqVaN++vcp+bWxssHr1agDAkSNHFNNLPXr0QLdu3XD8+HEAQGJiImQyGRwcHGBvb4+BAwdi165dStvKyspCcHCw4r0OkydPVrwbWiaTYfjw4Sr7Hzx4MKZOnfqco6a99u3bQxAE5OXlwcvLC+Hh4RgxYgQ6deqEJUuWAABu376NuXPnokuXLrC3t8eQIUNw4sQJpe2Ulpbiyy+/hJeXFzp27Ah/f3/s2LFDsfzJKYRDhw5h8ODBcHR0hLOzMyZPnowLFy48tf/9+/cRGRmJPn36wM7ODn379lV5Lnh5eWHVqlUIDw9H165d0bFjR4wdO1bxTvP/GgMDA+jr6yuuaCoqKhATE4NevXrB1tYWPj4+SEpS/XCxlJQUDBgwAB07doSXlxeioqJQXl6uWL5r1y4MGTIEjo6OsLW1ha+vLzZt2vTSjqsmOnHiBL744guMGTMGISEhr7qcZ/bahP7Zs2fx999/o3///tDT00P//v2xfft2pc8MWrZsGZYtWwZvb2/ExMTA398fixcvxsaNGwEA+/btw4QJE9C0aVNERkZi+vTpSE1NVXx5zLNYsWIFwsLCMHPmTNjb2yM+Ph7z58+Ht7c3vvrqK3zxxReoXbs2Zs6ciWvXrgEArl27hsDAQOTm5mLBggUIDw/HpUuXMGrUKNy9exeDBg3C8ePHcenSJcV+srOzcerUKchksmqOoGaVX3dZ+VHaGzZsgK2tLSIjI+Hn54cHDx5g1KhRyMzMxIwZMxAVFYX69etj1KhROH36tGI7ISEhWL9+PYKCghATEwNnZ2fMmDED+/btU9lnbm4uJk+eDFtbW6xZswaLFi1SXFaru9tYEASMHz8e3377LYYMGYI1a9aga9eu+PTTTxEdHa3Ud/369cjOzsaSJUuwcOFC/PnnnwgNDdXhiD07QRBQVlam+Hnw4AEuXLiA0NBQlJSUoH///gCAefPmYdWqVRg4cCBiYmLg6emJjz/+WOkkZuPGjZg9ezbs7e0RHR2NUaNG4euvv0ZERAQAYM+ePZg6dSrs7e2xevVqrFy5EmZmZpg/f77S70uMHv8dVP5UPt8sLS2RkZGBKVOmqHx8fE3w2kzvfPfddzAxMVFc+spkMnz99ddITU3F0KFDcefOHcTHx2PUqFGYMWMGAKBr1664evUqjh07hmHDhmHlypWwtbVFVFSUYruCICAuLu6ZPxJ62LBhSl88c+nSJYwbNw4TJ05UtDVv3hwymQwnT56Er68v1q9fj7KyMqxfvx6NGjUC8OhzjMaMGYMzZ87A398fS5cuxY8//ohJkyYBeHQm16RJE7i7uz/fwKlRGTyVbt++jZ9//hlbtmyBj4+PorZmzZph1qxZirPPrVu34p9//kFSUpJiWs3DwwMBAQFYsWIF1q1bh3PnzmHXrl345JNPMGzYMABAly5dIJfLceTIEXh6eirVcvr0ady/fx8TJkzAm2++CQB46623sGfPHrUf4Ld//34cPXpU6VvdunXrhrKyMsTExGDo0KFo2LAhAKBBgwZYvXq14j+uXC7HypUrUVRUpPaDBF+GX3/9FR06dFBqk0gksLGxQWRkJDw9PZGdnY2tW7di1qxZGDNmDIBHx1heXo7IyEgEBATA0NAQ0dHR8PHxwYIFCxR97ty5g0OHDkEQBFy4cAEymUzpD52joyNcXV1x9OhR2Nvbv7wD/w+Ry+UqvwPg0R/aIUOGoHHjxq+gKt15LUK/tLQUP/74I3x8fHD37l0AQOPGjdGhQwckJiZi6NCh+P3331FWVobevXsrrbto0SIAj6YEzpw5gw8//FBpeUBAAAICAp65Jmtra6XHlVcLd+7cQVZWFnJycnDkyBEAwMOHDwE8umzs1KmTIlSBR6H/+Blw7969kZqaikmTJkEQBKSmpsLf31+nZxzqgqdWrVro1asX5s2bp2hr06aN0guov/zyC9588020a9dO6Y+Gp6cnvvrqK5SWliqmep78Paxdu1ZtLR07doShoSECAgLg4+MDDw8PuLq6PjWQjh07Bn19fZVvevP398fmzZtx6tQp9OjRQ7Htx8etWbNHX7hx9+7dVxb69vb2+OSTTwA8uvKLjIxEWVkZVqxYoXg3+6+//gpBEODp6ak0zl5eXvj2229x+vRpNG7cGLdu3VIZ5ylTpmDKlCkAoPja0pKSEmRnZ0Mul+OPP/4A8P/PSTFq1qwZVq1apdLevHnzV1CN7r0Wob93714UFBRgy5Yt2LJli8ryU6dOoaCgAABgYmKidhuFhYUQBEEpcKvjyf3I5XJ88skn+OWXX6Cvr4/WrVujbdu2AP7/+wcKCgoU30j2NDKZDGlpafjjjz9QUlKCK1eu6Hxq5/HgkUgkMDIyQvPmzfHGG28o9XvyGAsKCnD16lW1Z0nAoysGTb+HJ5mZmSEhIQGxsbHYtm0b4uPjUa9ePQwdOhTTp09XuWunsLAQJiYmKt/bUHl2VvmlQABgZGSk1KdynVf5JnVjY2PFVZKdnR0cHBzwzjvvYOzYsfjuu+/QqFEjxRiq+35qALh+/Tpq1370X7uqcc7Pz8enn36KjIwMSCQSWFhYwMnJCcCrHYNXzcDA4LW+AeS1CP3k5GS0bNlScRlbqaysDBMnTsSWLVsUZ375+flo0aKFok9ubi7y8vIUL9Devn1baRvFxcX47bff4ODgAIlEgoqKCqXl6r434EkVFRV4//33YWhoiG3btqFdu3aoXbs2/v33X6V7r6VSKfLz81XWP3jwICwtLfHWW2+hS5cuaN68OdLT01FSUgI7Ozu0adNGYw3P4vHgeRZ169aFpaUlli5dqnZ5w4YNFWfQ+fn5aNKkiWLZuXPncO/ePXTs2FFlPXt7e6xatUpxpZCYmIiYmBi0b98effr0Uepbr1493Lp1CxUVFUrBf+PGDUUNNUnjxo3xySefYNq0aVi8eDEiIiIUY5iQkKDyhwt49Iey8niffD7dvHkT58+fR6dOnRASEoLs7GysX78ejo6OMDAwwL1799S+IEyvjxr/Qu7169dx8OBB+Pn5wdXVVenHzc0Nnp6eSE9PR8eOHaGvr6/yYuGaNWsQFhYGqVSKtm3bYu/evUrLMzIyMG7cOBQVFUEqlUIQBKU3Yjx5Z4o6t2/fRnZ2NgYPHgw7OzvFWVjll8RXnlU5OTnh5MmTijM5ALh8+TLGjRunmArS09PDgAED8NNPP2Hfvn0YOHDgsw/aC+Ls7IwrV66gadOmsLOzU/zs2bMHGzZsgL6+vuJM8snfw+LFi7F8+XKVbW7YsAFeXl4oLS2FgYEBunTpgoULFwIA8vLyVPq7uLjg4cOH2L17t1J7Wloa9PX1a+Q8tY+PD9zd3ZGWloajR4+ic+fOAB5d1Tw+znl5eYiKisK9e/fQunVrNGjQQOX5nJiYiMmTJwN49Nz18fGBq6srDAwMAPz/c/LJkxt6fdT4M/2UlBSUl5fDz89P7fIBAwZg165d2LFjB4YPH45vvvkGtWvXRufOnXHixAl8//33ihCZOnUqgoODERISgv79++Pq1auIiIjAgAEDYGpqiu7du2PJkiX46KOPMG7cOFy5cgXR0dGK7wt+GhMTEzRv3hzx8fFo2rQppFIpDhw4gPj4eABQvA4xevRo/PDDDxg3bhwmTJgAiUSCVatWoXXr1kpz1AMHDsTq1auhr6+Pfv366WIYdUImkyEhIQGjR49WvPCamZmJdevWYcqUKZBIJGjXrh28vb2xZMkS3L17FzY2NsjIyMDRo0fxzTffqGzz7bffxrJlyxAcHIzhw4ejVq1a2LJlCwwNDVVe9AUevXDs7OyMjz76CFevXkWbNm2wf/9+bNmyBZMmTUK9evVexlDoXFhYGN555x0sWrQI33//Pfr164ewsDDk5uaiXbt2+Pfff7F8+XJ06NBB8R0XU6ZMweLFi9GwYUN4eXnh3LlziI2NxdixY2FoaAh7e3ukpqaiXbt2ePPNN3Hy5EnExsZCIpHg3r17r/iI6UWp8aH//fffo23btrC0tFS73MPDA40aNUJiYiJ++OEHNGrUCFu3bkVsbCwsLCzw2WefKc6We/bsidWrV2PVqlWYPHkyTExMMHjwYAQHBwN49KLq0qVLsWbNGowfPx6WlpZYuHCh4o9GVVavXo3Fixdj1qxZMDAwgJWVFdasWYPPPvsMJ06cwNChQ2FqaoqNGzfi888/x6xZs2BoaIiuXbti1qxZqFOnjmJb5ubmsLKygpWVFerXr6+DUdQNY2NjbNy4EREREQgPD0dJSQnMzc3x8ccfK72/ICIiApGRkYiLi0NhYSEsLS0Vt1Y+qU2bNvjqq6+wcuVKzJgxA+Xl5bC1tUVcXJza1z/09PTw1Vdf4csvv8TatWtRWFiIli1bYt68eQgKCnqhx/8itW7dGiNGjEBcXBw2b96M8PBwxMTEICEhAdeuXUPjxo0REBCg9H6NESNG4I033kBcXBy2bNkCU1NTTJ06FaNHjwYAhIeHY+HChYpp0ZYtW2L+/PlITU3V6gqWaiZ+tHINdOnSJfTu3Rtr166Fm5vbqy6HiGoQhn4NcubMGcW3ltWuXRspKSmvzWfOENHLUeNfyBWTBw8eIC4uDhKJBF988QUDn4ieGc/0iYhEhGf6REQiwtAnIhIRhj4RkYgw9ImIRIShT0QkIgx9IiIR+T9jauydcQr7YgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Creating visualisation for cross-validated metrics\n", "cv_metrics = pd.DataFrame({\"Accuracy\": cv_acc,\n", " \"Precision\": cv_precision,\n", " \"Recall\": cv_recall,\n", " \"F1\": cv_f1})\n", "\n", "cv_metrics[:1].T.plot.bar(title = \"Cross-Validated Metrics\", legend = False, color ='crimson');\n", "plt.xticks(rotation = 0)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2b470859", "metadata": {}, "source": [ "### ***Feature Importance***\n", "\n", "***Feature importance is used to find which features contributed most to the outcomes of the model and how did they contribute?***\n", "\n", "***Remember : Finding feature importance is different for each machine learning model.***\n", "\n", "> Feature importance is another way of asking, \"which features contributing most to the outcomes of the model?\"\n", "\n", "> Or for our problem, trying to predict heart disease using a patient's medical characterisitcs, which charateristics contribute most to a model predicting whether someone has heart disease or not?\n", "\n", "> Unlike some of the other functions we've seen, because how each model finds patterns in data is slightly different, how a model judges how important those patterns are is different as well. This means for each model, there's a slightly different way of finding which features were most important.\n", "\n", "> You can usually find an example via the Scikit-Learn documentation or via searching for something like \"[MODEL TYPE] feature importance\", such as, \"random forest feature importance\"." ] }, { "cell_type": "code", "execution_count": 258, "id": "202fe34d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 0 \n", "1 37 1 2 130 250 0 1 187 0 3.5 0 \n", "2 41 0 1 130 204 0 0 172 0 1.4 2 \n", "3 56 1 1 120 236 0 1 178 0 0.8 2 \n", "4 57 0 0 120 354 0 1 163 1 0.6 2 \n", "\n", " ca thal target \n", "0 0 1 1 \n", "1 0 2 1 \n", "2 0 2 1 \n", "3 0 2 1 \n", "4 0 2 1 " ] }, "execution_count": 258, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 259, "id": "86c6efb7", "metadata": {}, "outputs": [], "source": [ "# Finding feature importance for LogisticRegression model\n", "\n", "# Fit an instance of LogisticRegression\n", "clf = LogisticRegression(C = 0.20433597178569418,\n", " solver = \"liblinear\")\n", "\n", "clf.fit(X_train, y_train);" ] }, { "cell_type": "code", "execution_count": 260, "id": "5380729c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.00316728, -0.86044652, 0.6606704 , -0.01156993, -0.00166375,\n", " 0.04386107, 0.31275848, 0.02459362, -0.60413081, -0.56862803,\n", " 0.45051628, -0.63609898, -0.67663373]])" ] }, "execution_count": 260, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check coef_\n", "clf.coef_" ] }, { "cell_type": "code", "execution_count": 261, "id": "d1afb63e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 0 \n", "1 37 1 2 130 250 0 1 187 0 3.5 0 \n", "2 41 0 1 130 204 0 0 172 0 1.4 2 \n", "3 56 1 1 120 236 0 1 178 0 0.8 2 \n", "4 57 0 0 120 354 0 1 163 1 0.6 2 \n", "\n", " ca thal target \n", "0 0 1 1 \n", "1 0 2 1 \n", "2 0 2 1 \n", "3 0 2 1 \n", "4 0 2 1 " ] }, "execution_count": 261, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 262, "id": "5f867cbc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'age': 0.0031672806268220445,\n", " 'sex': -0.8604465226286001,\n", " 'cp': 0.6606703996492814,\n", " 'trestbps': -0.011569930743501303,\n", " 'chol': -0.001663745833540806,\n", " 'fbs': 0.043861067871676124,\n", " 'restecg': 0.3127584791782968,\n", " 'thalach': 0.02459361509185037,\n", " 'exang': -0.6041308102637141,\n", " 'oldpeak': -0.5686280255489925,\n", " 'slope': 0.4505162810238786,\n", " 'ca': -0.6360989756865822,\n", " 'thal': -0.67663372723561}" ] }, "execution_count": 262, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Match coef's of features to columns\n", "feature_dict = dict(zip(df.columns, list(clf.coef_[0])))\n", "feature_dict" ] }, { "cell_type": "code", "execution_count": 263, "id": "9a0d3121", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize feature importance\n", "feature_df = pd.DataFrame(feature_dict, index = [0])\n", "feature_df.T.plot.bar(title = \"Feature Importance\", legend = False, grid = True, color = 'crimson', figsize = (8, 4));" ] }, { "cell_type": "code", "execution_count": 264, "id": "297b16d9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
target01
sex
02472
111493
\n", "
" ], "text/plain": [ "target 0 1\n", "sex \n", "0 24 72\n", "1 114 93" ] }, "execution_count": 264, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.crosstab(df[\"sex\"], df[\"target\"])" ] }, { "cell_type": "code", "execution_count": 265, "id": "63d5675b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
target01
slope
0129
19149
235107
\n", "
" ], "text/plain": [ "target 0 1\n", "slope \n", "0 12 9\n", "1 91 49\n", "2 35 107" ] }, "execution_count": 265, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.crosstab(df[\"slope\"], df[\"target\"])" ] }, { "cell_type": "markdown", "id": "98e38ac7", "metadata": {}, "source": [ "***slope - the slope of the peak exercise ST segment***\n", " * ***0: Upsloping: better heart rate with excercise (uncommon)***\n", " * ***1: Flatsloping: minimal change (typical healthy heart)***\n", " * ***2: Downslopins: signs of unhealthy heart***" ] }, { "cell_type": "markdown", "id": "9ab1070b", "metadata": {}, "source": [ "### ***6. Experimentation***\n", "\n", "> Well we've completed all the metrics.\n", "We are able to put together a great report containing a confusion matrix, a handful of cross-valdated metrics such as precision, recall and F1 as well as which features contribute most to the model making a decision." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }