{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Model evaluation & cross-validation\n", "\n", "````{margin}\n", "```{warning}\n", "These pages are currently under construction and will be updated continuously.\n", "Please visit these pages again in the next few weeks for further information.\n", "````\n", "\n", "---------------" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Aim(s) of this section 🎯\n", "\n", "As mention in the previous section, it is not sufficient to apply these methods to learn somthing about the nature of our data. It is always necessary to assess the quality of the implemented model. The goal of these section is to look at ways to estimate the generalization accuracy of a model on future (e.g.,unseen, out-of-sample) data.\n", "\n", "In other words, at the end of these sections you should know:\n", "- 1) different techniques to evaluate a given model\n", "- 2) understand the basic idea of cross-validation and different kinds of the same\n", "- 3) get an idea how to assess the significance (e.g., via permutation tests)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Outline for this section 📝\n", "\n", "1. Model diagnostics\n", "\n", "2. Cross-validation" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Prepare data for model\n", "\n", "Lets bring back our example data set (you know the song ...)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are 155 samples and 2016 features\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "# get the data set\n", "data = np.load('MAIN2019_BASC064_subsamp_features.npz')['a']\n", "\n", "# get the labels\n", "info = pd.read_csv('participants.csv')\n", "\n", "\n", "print('There are %s samples and %s features' % (data.shape[0], data.shape[1]))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Now let's look at the labels" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
participant_idAgeAgeGroupChild_AdultGenderHandedness
0sub-pixar12327.06AdultadultFR
1sub-pixar12433.44AdultadultMR
2sub-pixar12531.00AdultadultMR
3sub-pixar12619.00AdultadultFR
4sub-pixar12723.00AdultadultFR
\n", "
" ], "text/plain": [ " participant_id Age AgeGroup Child_Adult Gender Handedness\n", "0 sub-pixar123 27.06 Adult adult F R\n", "1 sub-pixar124 33.44 Adult adult M R\n", "2 sub-pixar125 31.00 Adult adult M R\n", "3 sub-pixar126 19.00 Adult adult F R\n", "4 sub-pixar127 23.00 Adult adult F R" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "info.head(n=5)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "We'll set `Age` as target\n", "- i.e., well look at these from the `regression` perspective" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "count 155.000000\n", "mean 10.555189\n", "std 8.071957\n", "min 3.518138\n", "25% 5.300000\n", "50% 7.680000\n", "75% 10.975000\n", "max 39.000000\n", "Name: Age, dtype: float64" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# set age as target\n", "Y_con = info['Age']\n", "Y_con.describe()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Next:\n", "- we need to divide our input data `X` into `training` and `test` sets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# import necessary python modules\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.svm import SVC\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score\n", "\n", "# split the data\n", "X_train, X_test, y_train, y_test = train_test_split(data, Y_con, random_state=0)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Now lets look at the size of the data sets" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "N used for training: 116 | N used for testing: 39\n" ] } ], "source": [ "# print the size of our training and test groups\n", "print('N used for training:', len(X_train),\n", " ' | N used for testing:', len(X_test))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "**Question:** Is that a good distribution? Does it look ok?\n", "\n", "- Why might this be problematic (hint: what do you know about groups (e.g., `Child_Adult`) in the data." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAATsUlEQVR4nO3dfbBcBXmA8eclCUkEOgnhTgpJaOLHIBTbUCLlazoWayfaD6JNQcbadAYbHLCDY7Wi/aPa1hmsVgWnilSpsVUBRQpSxAIGrQNCLxjkKxJQMBeBhGgQOhMl8PaPPRcuMbm5CXv23Xv3+c3cye45+/HeM+Th5Ozu2chMJEm9t0/1AJI0qAywJBUxwJJUxABLUhEDLElFplcPMBHLly/Pa665pnoMSdpTMd7KSbEH/Nhjj1WPIEldNykCLElTkQGWpCIGWJKKTIoX4SRNTk899RQjIyNs27atepRWzZo1i4ULFzJjxow9up8BltSakZERDjjgABYvXkzEuG8ImLQyky1btjAyMsKSJUv26L4egpDUmm3btjFv3rwpG1+AiGDevHl7tZdvgCW1airHd9Te/o4GWJKKGGBJPbNg0aFERNd+Fiw6dNzn27p1K5/4xCf2eM7Xve51bN26dS9/y4nzRThJPfPjkY2c+qkbu/Z4l5xx/LjrRwN85plnPm/59u3bmT591/m7+uqruzLf7hhgSVPWOeecw/3338/SpUuZMWMGs2bNYu7cuaxfv557772XFStWsHHjRrZt28bZZ5/N6tWrAVi8eDHDw8M8+eSTvPa1r+XEE0/kxhtvZMGCBVxxxRXMnj27K/N5CELSlHXuuefykpe8hHXr1vGhD32I2267jfPOO497770XgIsuuohbb72V4eFhzj//fLZs2fJLj7FhwwbOOuss7rrrLubMmcNll13WtfncA5Y0MI455pjnvVf3/PPP5/LLLwdg48aNbNiwgXnz5j3vPkuWLGHp0qUAHH300TzwwANdm2dK7wF3+4D/RA76S+pf++2337OXb7jhBq677jpuuukmbr/9do466qidvpd35syZz16eNm0a27dv79o8U3oPuNsH/GH3B/0l9Y8DDjiAJ554YqfrHn/8cebOncuLXvQi1q9fz3e+850eTzfFAyypvxyycFFXd2IOWbho3PXz5s3jhBNO4Mgjj2T27NnMnz//2XXLly/nggsu4PDDD+ewww7j2GOP7dpcE2WAJfXMQxt/1PPn/MIXvrDT5TNnzuRrX/vaTteNHuc96KCDuPPOO59d/s53vrOrs03pY8CS1M8MsCQVMcCSWpWZ1SO0bm9/RwMsqTWzZs1iy5YtUzrCo+cDnjVr1h7f1xfhJLVm4cKFjIyMsHnz5upRWjX6jRh7ygBLas2MGTP2+FsiBomHICSpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkoq0HuCImBYR342Iq5rrSyLi5oi4LyIuiYh9255BkvpRL/aAzwbuGXP9g8BHM/OlwE+B03swgyT1nVYDHBELgT8APt1cD+Ak4MvNTdYAK9qcQZL6Vdt7wB8D/gZ4prk+D9iamdub6yPAgpZnkKS+1FqAI+IPgU2Zeete3n91RAxHxPDmzZu7PJ0k1WtzD/gE4I8j4gHgYjqHHs4D5kTE9OY2C4GHdnbnzLwwM5dl5rKhoaEWx5SkGq0FODPfk5kLM3Mx8EbgG5n5JmAtsLK52SrgirZmkKR+VvE+4HcD74iI++gcE/5MwQySVG767m/ywmXmDcANzeUfAMf04nklqZ/5SThJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAZakIgZYkooYYEkqYoAlqYgBlqQiBliSirQW4IiYFRG3RMTtEXFXRLy/Wb4kIm6OiPsi4pKI2LetGSSpn7W5B/xz4KTM/E1gKbA8Io4FPgh8NDNfCvwUOL3FGSSpb7UW4Ox4srk6o/lJ4CTgy83yNcCKtmaQpH7W6jHgiJgWEeuATcC1wP3A1szc3txkBFjQ5gyS1K9aDXBmPp2ZS4GFwDHAyyd634hYHRHDETG8efPmtkaUpDI9eRdEZm4F1gLHAXMiYnqzaiHw0C7uc2FmLsvMZUNDQ70YU5J6qs13QQxFxJzm8mzgNcA9dEK8srnZKuCKtmaQpH42ffc32WsHA2siYhqd0F+amVdFxN3AxRHxj8B3gc+0OIMk9a3WApyZ3wOO2snyH9A5HixJA81PwklSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklRkQgGOiBMmskySNHET3QP++ASXSZImaPp4KyPiOOB4YCgi3jFm1a8A09ocTJKmunEDDOwL7N/c7oAxy38GrGxrKEkaBOMGODO/CXwzIj6bmQ/2aCZJGgi72wMeNTMiLgQWj71PZp7UxlCSNAgmGuAvARcAnwaebm8cSRocEw3w9sz8ZKuTSNKAmejb0L4aEWdGxMERceDoT6uTSdIUN9E94FXNn+8asyyBF3d3HEkaHBMKcGYuaXsQSRo0EwpwRPz5zpZn5ue6O44kDY6JHoJ45ZjLs4BXA7cBBliS9tJED0H81djrETEHuLiNgSRpUOzt6Sj/D/C4sCS9ABM9BvxVOu96gM5JeA4HLm1rKEkaBBM9BvzhMZe3Aw9m5kgL80jSwJjQIYjmpDzr6ZwRbS7wizaHkqRBMNFvxDgFuAX4U+AU4OaI8HSUkvQCTPQQxN8Cr8zMTQARMQRcB3y5rcEkaaqb6Lsg9hmNb2PLHtxXkrQTE90DviYivg58sbl+KnB1OyNJ0mDY3XfCvRSYn5nviog3ACc2q24CPt/2cJI0le1uD/hjwHsAMvMrwFcAIuIVzbo/anE2SZrSdnccd35m3rHjwmbZ4lYmkqQBsbsAzxln3ewuziFJA2d3AR6OiL/ccWFEvAW4tZ2RJGkw7O4Y8NuByyPiTTwX3GXAvsDrW5xLkqa8cQOcmY8Cx0fE7wJHNov/KzO/0fpkkjTFTfR8wGuBtS3PMjnsM52I6OpDHrJwEQ9t/FFXH1NS/5voBzE06pntnPqpG7v6kJeccXxXH0/S5ODHiSWpSGsBjohFEbE2Iu6OiLsi4uxm+YERcW1EbGj+nNvWDJLUz9rcA94O/HVmHgEcC5wVEUcA5wDXZ+bLgOub65I0cFoLcGY+nJm3NZefAO4BFgAnA2uam60BVrQ1gyT1s54cA46IxcBRwM10Pt78cLPqEWB+L2aQpH7TeoAjYn/gMuDtmfmzsesyM3nuyz53vN/qiBiOiOHNmze3PaYk9VyrAY6IGXTi+/nmbGoAj0bEwc36g4FNO7tvZl6Ymcsyc9nQ0FCbY0pSiTbfBRHAZ4B7MvMjY1ZdCaxqLq8CrmhrBknqZ21+EOME4M3AHRGxrln2XuBc4NKIOB14kM6XfErSwGktwJn5bWBXn9l9dVvPK0mThZ+Ek6QiBliSihhgSSpigCWpiAGWpCIGWJKKGGBJKmKAJamIAe4HzffMdfNnwaJDq38rSbvhd8L1A79nThpI7gFLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBVpLcARcVFEbIqIO8csOzAiro2IDc2fc9t6fknqd23uAX8WWL7DsnOA6zPzZcD1zXVJGkitBTgzvwX8ZIfFJwNrmstrgBVtPb8k9bteHwOen5kPN5cfAebv6oYRsToihiNiePPmzb2ZTpJ6qOxFuMxMIMdZf2FmLsvMZUNDQz2cTJJ6o9cBfjQiDgZo/tzU4+eXpL7R6wBfCaxqLq8Crujx80tS32jzbWhfBG4CDouIkYg4HTgXeE1EbAB+r7kuSQNpelsPnJmn7WLVq9t6TkmaTPwknCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcBT1T7TiYiu/ixYdGj1b1VqwaJD3abqqtZOR6liz2zn1E/d2NWHvOSM47v6eJPNj0c2uk3VVe4BS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSEQMsSUUMsCQVMcCSVMQAS1IRAyxJRQywJBUxwJJUxABLUhEDLElFDLCkCfE78brP74STNCF+J173uQcsSUUMsCQVMcCSVMQAS1IRX4TTlLRg0aH8eGRj9Ri7t890IqKrDzltxkyefurnXX3M1rTw+x+ycBEPbfxRVx+zLQZYU9KkecX+me2tzNntxxx93K5r6fefLDwEIUlFDLAkFTHAklTEAEtSEQMsSUUMsKSppXlr22Q4aZBvQ5M0tUyit7a5ByxJRQywJBUxwJJUxABLUhEDLElFDLAkFTHAklTEAEtSkZIAR8TyiPh+RNwXEedUzCBJ1Xoe4IiYBvwL8FrgCOC0iDii13NIUrWKPeBjgPsy8weZ+QvgYuDkgjkkqVRkZm+fMGIlsDwz39JcfzPw25n5th1utxpY3Vw9DPh+j0Y8CHisR8/VLc7cO5NxbmfujZ3N/FhmLt/VHfr2ZDyZeSFwYa+fNyKGM3NZr5/3hXDm3pmMcztzb+zNzBWHIB4CFo25vrBZJkkDpSLA/wu8LCKWRMS+wBuBKwvmkKRSPT8EkZnbI+JtwNeBacBFmXlXr+cYR88Pe3SBM/fOZJzbmXtjj2fu+YtwkqQOPwknSUUMsCQVMcBjRMQDEXFHRKyLiOHqeXYmIi6KiE0RceeYZQdGxLURsaH5c27ljDvaxczvi4iHmm29LiJeVznjjiJiUUSsjYi7I+KuiDi7Wd6323qcmft9W8+KiFsi4vZm7vc3y5dExM3NKQsuaV607wvjzPzZiPjhmG29dNzH8RjwcyLiAWBZZvbtG8Aj4neAJ4HPZeaRzbJ/An6Smec259aYm5nvrpxzrF3M/D7gycz8cOVsuxIRBwMHZ+ZtEXEAcCuwAvgL+nRbjzPzKfT3tg5gv8x8MiJmAN8GzgbeAXwlMy+OiAuA2zPzk5Wzjhpn5rcCV2XmlyfyOO4BTzKZ+S3gJzssPhlY01xeQ+cvXd/Yxcx9LTMfzszbmstPAPcAC+jjbT3OzH0tO55srs5ofhI4CRgNWb9t613NvEcM8PMl8N8RcWvzUejJYn5mPtxcfgSYXznMHnhbRHyvOUTRN/+U31FELAaOAm5mkmzrHWaGPt/WETEtItYBm4BrgfuBrZm5vbnJCH32P5MdZ87M0W39gWZbfzQiZo73GAb4+U7MzN+ic6a2s5p/Ok8q2TmmNBmOK30SeAmwFHgY+OfSaXYhIvYHLgPenpk/G7uuX7f1Tmbu+22dmU9n5lI6n4w9Bnh57US7t+PMEXEk8B46s78SOBAY9/CUAR4jMx9q/twEXE7nP4TJ4NHm+N/occBNxfPsVmY+2vwH/Azwr/Thtm6O7V0GfD4zv9Is7uttvbOZJ8O2HpWZW4G1wHHAnIgY/bBY356yYMzMy5vDQJmZPwf+jd1sawPciIj9mhcuiIj9gN8H7hz/Xn3jSmBVc3kVcEXhLBMyGrHG6+mzbd28yPIZ4J7M/MiYVX27rXc18yTY1kMRMae5PBt4DZ3j12uBlc3N+m1b72zm9WP+5xx0jlmPu619F0QjIl5MZ68XOh/R/kJmfqBwpJ2KiC8Cr6Jz6rtHgb8D/hO4FDgUeBA4JTP75kWvXcz8Kjr/JE7gAeCMMcdWy0XEicD/AHcAzzSL30vnmGpfbutxZj6N/t7Wv0HnRbZpdHYKL83Mv2/+Tl5M55/y3wX+rNmzLDfOzN8AhoAA1gFvHfNi3S8/jgGWpBoegpCkIgZYkooYYEkqYoAlqYgBlqQiBlhTXkSsiIiMiL7/dJUGiwHWIDiNztmqTqseRBrLAGtKa86LcCJwOp0vgCUi9omIT0TE+uacvldHxMpm3dER8c3mhExf3+FTZFJXGWBNdScD12TmvcCWiDgaeAOwGDgCeDOd8w6Mnkfh48DKzDwauAjou09Dauro+bciSz12GnBec/ni5vp04EvNyWkeiYi1zfrDgCOBazsf5WcanbOHSa0wwJqyIuJAOif1fkVEJJ2gJs+d8+OX7gLclZnH9WhEDTgPQWgqWwn8e2b+WmYuzsxFwA/pfDvHnzTHgufTOTEQwPeBoYh49pBERPx6xeAaDAZYU9lp/PLe7mXAr9L5hoW7gf8AbgMez8xf0In2ByPidjpnszq+Z9Nq4Hg2NA2kiNi/+ULFecAtwAmZ+Uj1XBosHgPWoLqqOaH2vsA/GF9VcA9Ykop4DFiSihhgSSpigCWpiAGWpCIGWJKK/D8WGBmJv4D0egAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFgCAYAAACFYaNMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAVvUlEQVR4nO3df7RdZX3n8fc3JBgxGWKSCzW/JtgKowMt4KUj4LQFSo3WBdKhIIJDOzhhxh+jowuEUerorK5F13QwXe1IyWCKLSwEEVctWEyKqdiFhV5ikPDDEqvVy69cYFCKAxj4zh9np73e3iQ3h3P299x736+1zrrn7L3Pfr73yc0nT56793MiM5EktW9OdQGSNFsZwJJUxACWpCIGsCQVMYAlqcjc6gKmYs2aNXnLLbdUlyFJ3Yjd7ZgWI+DHH3+8ugRJ6rlpEcCSNBMZwJJUxACWpCLT4pdwkmaOH//4x4yOjvLss89Wl9JT8+fPZ8WKFcybN2/K7zGAJbVqdHSUhQsXsnr1aiJ2e4HAtJKZPPHEE4yOjnLIIYdM+X1OQUhq1bPPPsuSJUtmTPgCRARLlizZ51G9ASypdTMpfHfp5nsygCWpiAEsqdTylauIiJ49lq9ctdc2n3rqKT71qU91Ve+6dev40Y9+1NV7J/KXcJJKPTz6fc684vaene+684/b6zG7Avjd7373Pp9/3bp1nHPOORxwwAHdlPcT+hbAEbEBeCuwIzMPH7f9fcB7gBeAmzPzwn7VIEmTueiii/j2t7/NkUceycknn8xBBx3E9ddfz3PPPcdpp53Gxz/+cZ555hnOOOMMRkdHeeGFF7jkkkt47LHHePjhhznhhBNYunQpmzdvfkl19HMEfBXwB8Af79oQEScApwI/l5nPRcRBfWxfkiZ16aWXsm3bNrZu3crGjRu54YYbuPPOO8lMTjnlFG677TbGxsZYtmwZN998MwA/+MEPOPDAA7nsssvYvHkzS5cufcl19G0OODNvA56csPk/A5dm5nPNMTv61b4kTcXGjRvZuHEjRx11FEcffTQPPPAADz74IEcccQSbNm3iwx/+MF/72tc48MADe95227+EOxT4txFxR0R8NSKO2d2BEbE2IkYiYmRsbKyrxiom9yVNL5nJxRdfzNatW9m6dSvbt2/nvPPO49BDD2XLli0cccQRfPSjH+UTn/hEz9tu+5dwc4HFwBuAY4DrI+LVOclHM2fmemA9wPDwcFcf3VwxuS9p8C1cuJCnn34agDe96U1ccsklnH322SxYsICHHnqIefPmsXPnThYvXsw555zDokWLuPLKK3/ivb2Ygmg7gEeBG5vAvTMiXgSWAt0NcSVNe8tWrOzp4GbZipV7PWbJkiUcf/zxHH744bz5zW/mHe94B8ceeywACxYs4Oqrr2b79u1ccMEFzJkzh3nz5nH55ZcDsHbtWtasWcOyZcte8i/hYpLBZ89ExGrgpl1XQUTEfwKWZeZvRcShwK3AqslGwOMNDw/nyMhIN+33fATcz/6SZoP777+f1772tdVl9MVuvrfd3iLXz8vQrgV+CVgaEaPAx4ANwIaI2AY8D5y7t/CVpJmqbwGcmWftZtc5/WpTkqYTb0WW1LqZ+B/fbr4nA1hSq+bPn88TTzwxo0J413rA8+fP36f3uRaEpFatWLGC0dFRur2+f1Dt+kSMfWEAS2rVvHnz9ulTI2YypyAkqYgBLElFDGBJKmIAS1IRA1iSihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAJamIASxJRQxgSSpiAEtSEQNYkooYwJJUxACWpCJ9C+CI2BAROyJi2yT7PhQRGRFL+9W+JA26fo6ArwLWTNwYESuBXwG+18e2JWng9S2AM/M24MlJdn0SuBDIfrUtSdNBq3PAEXEq8FBm3t1mu5I0iOa21VBEHAD8NzrTD1M5fi2wFmDVqlV9rEySarQ5Av5p4BDg7oj4LrAC2BIRPzXZwZm5PjOHM3N4aGioxTIlqR2tjYAz8x7goF2vmxAezszH26pBkgZJPy9Duxb4OnBYRIxGxHn9akuSpqO+jYAz86y97F/dr7YlaTrwTjhJKmIAS1IRA1iSihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAJamIASxJRQxgSSpiAEtSEQNYkooYwJJUxACWpCIGsCQVMYAlqYgBLElFDGBJKmIAS1IRA1iSihjAklSkbwEcERsiYkdEbBu37X9GxAMR8c2I+EJELOpX+5I06Po5Ar4KWDNh2ybg8Mz8WeBvgYv72L4kDbS+BXBm3gY8OWHbxszc2bz8a2BFv9qXpEFXOQf8H4A/393OiFgbESMRMTI2NtZiWZLUjpIAjoiPADuBa3Z3TGauz8zhzBweGhpqrzhJasncthuMiN8A3gqclJnZdvuSNChaDeCIWANcCPxiZv6ozbYladD08zK0a4GvA4dFxGhEnAf8AbAQ2BQRWyPiD/vVviQNur6NgDPzrEk2f7pf7UnSdOOdcJJUxACWpCIGsCQVMYAlqYgBLElFDGBJKmIAS1IRA1iSihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAJamIASxJRQxgSSpiAEtSEQNYkooYwJJUxACWpCIGsCQVMYAlqUjfAjgiNkTEjojYNm7b4ojYFBEPNl9f2a/2JWnQ9XMEfBWwZsK2i4BbM/M1wK3Na0malfoWwJl5G/DkhM2nAp9pnn8GeFu/2pekQdf2HPDBmflI8/xR4ODdHRgRayNiJCJGxsbG2qlOklpU9ku4zEwg97B/fWYOZ+bw0NBQi5VJUjvaDuDHIuJVAM3XHS23L0kDo+0A/iJwbvP8XOBPW25fkgZGPy9Duxb4OnBYRIxGxHnApcDJEfEg8MvNa0maleb268SZedZudp3UrzYlaTrxTjhJKmIAS1IRA1iSihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKTCmAI+L4qWyTJE3dVEfAvz/FbTPbnLlERE8ey1euqv5uJBXb42poEXEscBwwFBEfHLfrXwD79bOwgfTiTs684vaenOq684/ryXkkTV97W45yf2BBc9zCcdt/CJzer6IkaTbYYwBn5leBr0bEVZn59y3VJEmzwlQXZH9ZRKwHVo9/T2ae2I+iJGk2mGoAfw74Q+BK4IX+lSNJs8dUA3hnZl7e10okaZaZ6mVofxYR746IV0XE4l2PvlYmSTPcVEfAuz5K/oJx2xJ4dW/LkaTZY0oBnJmH9LsQSZptphTAEfHvJ9uemX/c23IkafaY6hTEMeOezwdOArYABrAkdWmqUxDvG/86IhYBn+1HQZI0W3S7HOUzgPPCkvQSTHUO+M/oXPUAnUV4Xgtc36+iJGk2mOoc8O+Oe74T+PvMHO220Yj4r8C76IT6PcBvZuaz3Z5PkqajKU1BNIvyPEBnRbRXAs9322BELAf+CzCcmYfTGVG/vdvzSdJ0NdVPxDgDuBP4deAM4I6IeCnLUc4FXh4Rc4EDgIdfwrkkaVqa6hTER4BjMnMHQEQMAX8B3LCvDWbmQxHxu8D3gP8HbMzMjROPi4i1wFqAVav89AhJM89Ur4KYsyt8G0/sw3t/QkS8EjiVzlUUy4BXRMQ5E4/LzPWZOZyZw0NDQ900JUkDbaoj4Fsi4svAtc3rM4EvddnmLwPfycwxgIi4kc7HHl3d5fkkaVra22fC/QxwcGZeEBG/Bryx2fV14Jou2/we8IaIOIDOFMRJwEiX55KkaWtv0wjr6Hz+G5l5Y2Z+MDM/CHyh2bfPMvMOOnPHW+hcgjYHWN/NuSRpOtvbFMTBmXnPxI2ZeU9ErO620cz8GPCxbt8vSTPB3kbAi/aw7+U9rEOSZp29BfBIRPzHiRsj4l3AXf0pSZJmh71NQXwA+EJEnM0/Be4wsD9wWh/rkqQZb48BnJmPAcdFxAnA4c3mmzPzK32vTJJmuKmuB7wZ2NznWiRpVul2PWBJ0ktkAEtSEQNYkooYwJJUxACWpCIGsCQVMYAlqYgBLElFDGBJKmIAS1IRA1iSihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAJalISQBHxKKIuCEiHoiI+yPi2Io6JKnS3KJ2fw+4JTNPj4j9gQOK6pCkMq0HcEQcCPwC8BsAmfk88HzbdUhStYopiEOAMeCPIuIbEXFlRLxi4kERsTYiRiJiZGxsrP0qJanPKgJ4LnA0cHlmHgU8A1w08aDMXJ+Zw5k5PDQ01HaNktR3FQE8Coxm5h3N6xvoBLIkzSqtB3BmPgp8PyIOazadBNzXdh2SVK3qKoj3Adc0V0D8HfCbRXVIUpmSAM7MrcBwRduSNCi8E06SihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAq8yZS0T07LF85arq70jSPqpajlIv7uTMK27v2emuO/+4np1LUjscAUtSEQNYkooYwJJUxACWpCIGsCQVMYAlqYgBLElFDGBJKmIAS1IRA1iSihjAklTEAJakIgawJBUpC+CI2C8ivhERN1XVIEmVKkfA7wfuL2xfkkqVBHBErAB+Fbiyon1JGgRVI+B1wIXAi0XtS1K51gM4It4K7MjMu/Zy3NqIGImIkbGxsZaqk6T2VIyAjwdOiYjvAp8FToyIqycelJnrM3M4M4eHhobarlGS+q71AM7MizNzRWauBt4OfCUzz2m7Dkmq5nXAklSk9FORM/Mvgb+srEGSqjgClqQiBrAkFTGAJamIASxJRQxgSSpiAEtSEQNYkooYwJJUxACWpCIGsCQVMYAlqYgBLElFDGBJKmIAa9pZvnIVEdGzx/KVq6q/Jc1SpctRSt14ePT7nHnF7T0733XnH9ezc0n7whGwJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAJamIASxJRQxgSSpiAEtSkdYDOCJWRsTmiLgvIu6NiPe3XYMkDYKK5Sh3Ah/KzC0RsRC4KyI2ZeZ9BbVIUpnWR8CZ+UhmbmmePw3cDyxvuw5JqlY6BxwRq4GjgDsm2bc2IkYiYmRsbKz12iSp38oCOCIWAJ8HPpCZP5y4PzPXZ+ZwZg4PDQ21X6Ak9VlJAEfEPDrhe01m3lhRgyRVq7gKIoBPA/dn5mVtty9Jg6JiBHw88E7gxIjY2jzeUlCHJJVq/TK0zPwrINpuV5IGjXfCSVIRA1iSihjAklTEAJakIgawJBUxgCWpiAEsSUUMYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAZ4o5c4mInj2Wr1zV0/KWr1zVs9p6rsd9N3f/+QP9Z9FLvfxzHfSfu37UV/GpyOqHF3dy5hW39+x0151/XM/OBfDw6Pd7Vl+va+tH3w3yn0Uv9fLPFQb75w56X58jYEkqYgBLUhEDWJKKGMCSVMQAlqQiBrAkFTGAJamIASxJRQxgSSpiAEtSEQNYkooYwJJUxACWpCIlARwRayLiWxGxPSIuqqhBkqq1HsARsR/wv4E3A68DzoqI17VdhyRVqxgB/zywPTP/LjOfBz4LnFpQhySVisxst8GI04E1mfmu5vU7gX+Tme+dcNxaYG3z8jDgW30ubSnweJ/b6Ja1dcfaumNt3dldbY9n5prJ3jCwn4iRmeuB9W21FxEjmTncVnv7wtq6Y23dsbbudFNbxRTEQ8DKca9XNNskaVapCOC/AV4TEYdExP7A24EvFtQhSaVan4LIzJ0R8V7gy8B+wIbMvLftOibR2nRHF6ytO9bWHWvrzj7X1vov4SRJHd4JJ0lFDGBJKmIAAxHx3Yi4JyK2RsRIcS0bImJHRGwbt21xRGyKiAebr68coNr+e0Q81PTd1oh4S1FtKyNic0TcFxH3RsT7m+3lfbeH2sr7LiLmR8SdEXF3U9vHm+2HRMQdzXIB1zW/MB+U2q6KiO+M67cj265tXI37RcQ3IuKm5vU+9ZsB/E9OyMwjB+Aaw6uAiRdtXwTcmpmvAW5tXle4in9eG8Anm747MjO/1HJNu+wEPpSZrwPeALynucV9EPpud7VBfd89B5yYmT8HHAmsiYg3AL/T1PYzwP8Fzhug2gAuGNdvWwtq2+X9wP3jXu9TvxnAAyYzbwOenLD5VOAzzfPPAG9rs6ZddlPbQMjMRzJzS/P8aTp/KZYzAH23h9rKZcc/NC/nNY8ETgRuaLZX9dvuahsIEbEC+FXgyuZ1sI/9ZgB3JLAxIu5qboEeNAdn5iPN80eBgyuLmcR7I+KbzRRFyfTIeBGxGjgKuIMB67sJtcEA9F3z3+itwA5gE/Bt4KnM3NkcMkrRPxgTa8vMXf32202/fTIiXlZRG7AOuBB4sXm9hH3sNwO4442ZeTSdFdreExG/UF3Q7mTnusGBGQUAlwM/Tee/iI8A/6uymIhYAHwe+EBm/nD8vuq+m6S2gei7zHwhM4+kc1fqzwP/qqKOyUysLSIOBy6mU+MxwGLgw23XFRFvBXZk5l0v5TwGMJCZDzVfdwBfoPNDOEgei4hXATRfdxTX848y87HmL8mLwP+hsO8iYh6dgLsmM29sNg9E301W2yD1XVPPU8Bm4FhgUUTsulGrfLmAcbWtaaZ0MjOfA/6Imn47HjglIr5LZ0XHE4HfYx/7bdYHcES8IiIW7noO/Aqwbc/vat0XgXOb5+cCf1pYy0/YFW6N0yjqu2b+7dPA/Zl52bhd5X23u9oGoe8iYigiFjXPXw6cTGeOejNwenNYVb9NVtsD4/5BDTpzrK33W2ZenJkrMnM1neUUvpKZZ7Ov/ZaZs/oBvBq4u3ncC3ykuJ5r6fx39Md05pDOozO3dCvwIPAXwOIBqu1PgHuAb9IJu1cV1fZGOtML3wS2No+3DELf7aG28r4Dfhb4RlPDNuC3mu2vBu4EtgOfA142QLV9pem3bcDVwIKKn7lxdf4ScFM3/eatyJJUZNZPQUhSFQNYkooYwJJUxACWpCIGsCQVMYA1K0TE2yIiI2Jg7vKSDGDNFmcBf9V8lQaCAawZr1mD4Y10bhx5e7NtTkR8KiIeaNYJ/lJEnN7se31EfLVZnOnLE+5Yk3rGANZscCpwS2b+LfBERLwe+DVgNfA64J101j/YtWbD7wOnZ+brgQ3Ab1cUrZmv9U9FlgqcRWehFOgsnHIWnZ/9z2VnIZxHI2Jzs/8w4HBgU2epAfajc/u11HMGsGa0iFhMZ6WqIyIi6QRq0ln1btK3APdm5rEtlahZzCkIzXSnA3+Smf8yM1dn5krgO3Q+2ePfNXPBB9NZUAXgW8BQRPzjlERE/OuKwjXzGcCa6c7in492Pw/8FJ0V3e6js6LWFuAHmfk8ndD+nYi4m87KZce1Vq1mFVdD06wVEQsy8x8iYgmdJQSPz8xHq+vS7OEcsGazm5oFv/cH/ofhq7Y5ApakIs4BS1IRA1iSihjAklTEAJakIgawJBX5/0R2OfbsYuj3AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "sns.displot(y_train,label='train')\n", "plt.legend()\n", "sns.displot(y_test,label='test')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Model fit\n", "\n", "Now lets go ahead and fit the model\n", "- we will use a fairly standard regression model called a Support Vector Regressor (SVR)\n", " - similar to the one we used in the previous section" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "SVR(kernel='linear')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.svm import SVR\n", "\n", "# define the model\n", "lin_svr = SVR(kernel='linear')\n", "\n", "# fit the model\n", "lin_svr.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Model diagnostics" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Now let's look at how the model performs in predicting the data\n", "- we can use the `score` method to calculate the coefficient of determination (or [R-squared](https://en.wikipedia.org/wiki/Coefficient_of_determination)) of the prediction.\n", " - for this we compare the observed data to the predicted data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# predict the training data based on the model\n", "y_pred = lin_svr.predict(X_train) \n", "\n", "# caluclate the model accuracy\n", "acc = lin_svr.score(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy (R2) 0.9998468951478122\n" ] } ], "source": [ "# print results\n", "print('accuracy (R2)', acc)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Now lets plot the predicted values" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'Predicted Age')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.regplot(y=y_pred, x=y_train, scatter_kws=dict(color='k'))\n", "plt.xlabel('Predicted Age')" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Now thats really cool, eh? **Almost a perfect fit**" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "... which means something is wrong\n", "- what are we missing here?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- **recall**: We are still using the test data sets." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "
\"logo\"
\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Train/test stratification\n", "\n", "Now lets do this again but we'll add some constraints to the predriction\n", "- Well keey the 75/25 ratio between test and train data sets\n", "- But now we will try to keep the characteristics of the data set consistent accross training and test datasets\n", "- For this we will use something called [stratification](https://en.wikipedia.org/wiki/Stratified_sampling)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# use `AgeGroup` for stratification\n", "age_class2 = info.loc[y_train.index,'AgeGroup']\n", "\n", "# split the data\n", "X_train2, X_test, y_train2, y_test = train_test_split(\n", " X_train, # x\n", " y_train, # y\n", " test_size = 0.25, # 75%/25% split \n", " shuffle = True, # shuffle dataset before splitting\n", " stratify = age_class2, # keep distribution of age class consistent\n", " # betw. train & test sets.\n", " random_state = 0 # same shuffle each time\n", ")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Let's re-fit the model on the newly computed (and stratified) train data and evaluate it' performace on an (also stratified) test data\n", "- We'll compute again the model accuracy (R-squared) to evalueate the models performance,\n", "- but we'll also have a look at the [mean-absolute-error](https://en.wikipedia.org/wiki/Mean_absolute_error) (MAE), it is measured as the average sum of the absolute diffrences between predictions and actual observations. Unlike other measures, MAE is more robust to outliers, since it doesn't square the deviations (cf. [mean-squared-error](https://en.wikipedia.org/wiki/Mean_squared_error))\n", " - it provides a way to asses \"how far off\" are our predictions from our actual data, while staying on it's referential space" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "from sklearn.metrics import mean_absolute_error\n", "\n", "# fit model just to training data\n", "lin_svr.fit(X_train2, y_train2)\n", "\n", "# predict the *test* data based on the model trained on X_train2\n", "y_pred = lin_svr.predict(X_test) \n", "\n", "# calculate the model accuracy\n", "acc = lin_svr.score(X_test, y_test) \n", "mae = mean_absolute_error(y_true=y_test,y_pred=y_pred)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Lets check the results" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy (R2) = 0.6593855081680796\n", "MAE = 3.2059201603105882\n" ] } ], "source": [ "# print results\n", "print('accuracy (R2) = ', acc)\n", "print('MAE = ', mae)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'Predicted Age')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot results\n", "sns.regplot(x=y_pred,y=y_test, scatter_kws=dict(color='k'))\n", "plt.xlabel('Predicted Age')" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### [Cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics))\n", "\n", "Not perfect, but its not bad, as far as predicting with unseen data goes. Especially with a training sample of \"only\" 69 subjects.\n", "\n", "- But, can we do better?\n", "- On thing we could do is increase the size our training set while simultaneously reducing bias by instead using 10-fold **cross-validation**" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "
\"logo\"
\n", " \n", "
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Cross-validation is a technique used to protect against biases in a predictive model\n", "- particularly useful in cases where the amount of data may be limited. \n", "- basic idea: you partition the data in a fixed number of folds, run the analysis on each fold, and then average out the overall error estimate" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Let's look at the models performance across 10 folds" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# import modules needed for cross-validation\n", "from sklearn.model_selection import cross_val_predict, cross_val_score\n", "\n", "# predict\n", "y_pred = cross_val_predict(lin_svr, X_train, y_train, cv=10)\n", "# scores\n", "acc = cross_val_score(lin_svr, X_train, y_train, cv=10)\n", "mae = cross_val_score(lin_svr, X_train, y_train, cv=10, \n", " scoring='neg_mean_absolute_error')\n", "# negative MAE is simply the negative of the \n", "# MAE (by definition a positive quantity), \n", "# since MAE is an error metric, i.e. the lower the better, \n", "# negative MAE is the opposite" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true, "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fold 0 -- Acc = 0.419, MAE = 5.271\n", "Fold 1 -- Acc = 0.11, MAE = 3.069\n", "Fold 2 -- Acc = 0.79, MAE = 2.275\n", "Fold 3 -- Acc = 0.809, MAE = 3.061\n", "Fold 4 -- Acc = 0.641, MAE = 3.906\n", "Fold 5 -- Acc = 0.195, MAE = 4.732\n", "Fold 6 -- Acc = 0.684, MAE = 3.974\n", "Fold 7 -- Acc = 0.815, MAE = 2.693\n", "Fold 8 -- Acc = 0.058, MAE = 5.525\n", "Fold 9 -- Acc = 0.698, MAE = 2.571\n" ] } ], "source": [ "# print the results for each fold\n", "for i in range(10):\n", " print(\n", " 'Fold {} -- Acc = {}, MAE = {}'.format(i, np.round(acc[i], 3), np.round(-mae[i], 3))\n", " )" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "For the visually oriented among us" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Metric score [i.e., R-squared 0 to 1]')" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(8, 6))\n", "plt.plot(acc, label = 'R-squared')\n", "plt.legend()\n", "plt.plot(-mae, label = 'MAE')\n", "plt.legend(prop={'size': 12}, loc=9)\n", "plt.xlabel('Folds [1 to 10]')\n", "plt.ylabel('Metric score [i.e., R-squared 0 to 1]')" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "We can also look at the **overall accuracy** of the model" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R2: 0.5544107024745264\n", "MAE: 3.7082530497804926\n" ] } ], "source": [ "from sklearn.metrics import r2_score\n", "\n", "overall_acc = r2_score(y_train, y_pred)\n", "overall_mae = mean_absolute_error(y_train, y_pred)\n", "\n", "print('R2:', overall_acc)\n", "print('MAE:', overall_mae)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Now, let's look at the final overall model prediction" ] }, { "cell_type": "code", "execution_count": 86, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Predicted Age')" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sns.regplot(x=y_train, y=y_pred, scatter_kws=dict(color='k'))\n", "plt.ylabel('Predicted Age')" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Summary\n", "\n", "Not bad, not bad at all.\n", "\n", "But **most importantly**\n", "- this is a more **accurate estimation** of our model's predictive efficacy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Peer Herholz (he/him)](https://peerherholz.github.io/) \n", "Research affiliate - [NeuroDataScience lab](https://neurodatascience.github.io/) at [MNI](https://www.mcgill.ca/neuro/)/[McGill](https://www.mcgill.ca/) (& [UNIQUE](https://sites.google.com/view/unique-neuro-ai)) & [Senseable Intelligence Group](https://sensein.group/) at [McGovern Institute for Brain Research](https://mcgovern.mit.edu/)/[MIT](https://www.mit.edu/) \n", "Member - [BIDS](https://bids-specification.readthedocs.io/en/stable/), [ReproNim](https://www.repronim.org/), [Brainhack](https://brainhack.org/), [Neuromod](https://www.cneuromod.ca/), [OHBM SEA-SIG](https://ohbm-environment.org/) \n", "\n", "\"logo\" \"logo\"   @peerherholz \n", "\n", "\"logo\"" ] } ], "metadata": { "anaconda-cloud": {}, "celltoolbar": "Slideshow", "kernelspec": { "display_name": "neuro_ai", "language": "python", "name": "neuro_ai" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }