Add documentation for classifier training
This commit is contained in:
parent
b3abe1977b
commit
d7dd335971
@ -1,5 +1,29 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "93e87d64",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Table of contents\n",
|
||||
"1. [Introduction](#introduction)\n",
|
||||
"2. [Specify data transforms](#transforms)\n",
|
||||
"3. [Load the dataset](#load)\n",
|
||||
"4. [Perform classifications](#train)\n",
|
||||
"5. [Evaluate with Grad-CAM](#eval)\n",
|
||||
"6. [Evaluate train metrics](#trainmetrics)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "42857e8d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Introduction <a name=\"introduction\"></a>\n",
|
||||
"\n",
|
||||
"This notebook loads the plant dataset consisting of 452 healthy and 452 stressed images from disk and trains a resnet50 classifier on it. Training metrics from the original run are loaded from disk and plots to visualize its performance are created."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
@ -7,8 +31,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import print_function, division\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
@ -16,13 +38,11 @@
|
||||
"import torch.backends.cudnn as cudnn\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import seaborn as sns\n",
|
||||
"import torchvision\n",
|
||||
"from sklearn.model_selection import KFold\n",
|
||||
"import fiftyone as fo\n",
|
||||
"import fiftyone.brain as fob\n",
|
||||
"from torchvision import datasets, models, transforms\n",
|
||||
"from torchvision import datasets, transforms\n",
|
||||
"from torchvision.models import resnet50, ResNet50_Weights\n",
|
||||
"from torch.utils.data import Dataset, DataLoader,TensorDataset,random_split,SubsetRandomSampler\n",
|
||||
"from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import time\n",
|
||||
"import os\n",
|
||||
@ -30,6 +50,18 @@
|
||||
"import random"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "455d2362",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Specify data transforms <a name=\"transforms\"></a>\n",
|
||||
"\n",
|
||||
"Before we load the dataset, we define two types of data transforms. One is for the training and one for the validation set. The training set is augmented with a random resized crop to the image size the resnet50 model expects and is also sometimes flipped horizontally. The image is then transformed into a tensor and normalized with the ImageNet mean and standard deviation for each channel.\n",
|
||||
"\n",
|
||||
"The validation set should not contain horizontally flipped or randomly cropped images and so the images are only resized and cropped to fit the model's exptected inputs. Normalization still has to happen."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
@ -57,23 +89,13 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a779e636",
|
||||
"cell_type": "markdown",
|
||||
"id": "1e8b5b7d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def imshow(inp, title=None):\n",
|
||||
" \"\"\"Imshow for Tensor.\"\"\"\n",
|
||||
" inp = inp.numpy().transpose((1, 2, 0))\n",
|
||||
" mean = np.array([0.485, 0.456, 0.406])\n",
|
||||
" std = np.array([0.229, 0.224, 0.225])\n",
|
||||
" inp = std * inp + mean\n",
|
||||
" inp = np.clip(inp, 0, 1)\n",
|
||||
" plt.imshow(inp)\n",
|
||||
" if title is not None:\n",
|
||||
" plt.title(title)\n",
|
||||
" plt.pause(0.001) # pause a bit so that plots are updated"
|
||||
"## Load the dataset <a name=\"load\"></a>\n",
|
||||
"\n",
|
||||
"The directory `plantsdata` contains two folders named `healthy` and `wilted`. This folder is loaded into an `ImageFolder` class and then split with `random_split()` into a 90/10 train/val split. The respective transforms are added to the two datasets and loaders are created."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -90,7 +112,7 @@
|
||||
"data_dir = 'plantsdata'\n",
|
||||
"dataset = datasets.ImageFolder(os.path.join(data_dir))\n",
|
||||
"\n",
|
||||
"# 80/20 split\n",
|
||||
"# 90/10 split\n",
|
||||
"train_dataset, val_dataset = random_split(dataset, [0.9, 0.1])\n",
|
||||
"\n",
|
||||
"dataset_size = len(dataset)\n",
|
||||
@ -113,6 +135,18 @@
|
||||
"batch_size = 4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "41937228",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Perform classifications <a name=\"train\"></a>\n",
|
||||
"\n",
|
||||
"This function takes a model, a loss function (`criterion`), an optimizer (SGD in this case) and a scheduler which lowers the learning rate every 7 epochs to aid in finding minima. All of these parameters can be customized to use a different loss function, optimizer, and so on. Training commences for `num_epochs` and the epoch with the best validation accuracy is always saved. As soon as training is finished, the model with the best validation accuracy is returned.\n",
|
||||
"\n",
|
||||
"Metrics during training (`train_acc`, `train_loss`, `val_acc`, `val_loss`) are saved to a dictionary called `history`. This dictionary has to be initialized beforehand with empty lists for each of the four metrics."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
@ -193,237 +227,26 @@
|
||||
" return model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1a1e438a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here we create the empty history dictionary, load the pre-trained resnet from `pytorch.models` and add a linear layer with the amount of classes we have (2). The loss function is set to `CrossEntropyLoss()` and the optimizer is SGD with a learning rate of 0.001. The scheduler decays the learning rate every 7 steps by 0.1."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be15db0c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.5582 Acc: 0.6978\n",
|
||||
"val Loss: 0.5307 Acc: 0.7667\n",
|
||||
"\n",
|
||||
"Epoch 1/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.3797 Acc: 0.8268\n",
|
||||
"val Loss: 0.4450 Acc: 0.8222\n",
|
||||
"\n",
|
||||
"Epoch 2/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.2861 Acc: 0.8845\n",
|
||||
"val Loss: 0.3655 Acc: 0.8333\n",
|
||||
"\n",
|
||||
"Epoch 3/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.2418 Acc: 0.8894\n",
|
||||
"val Loss: 0.3731 Acc: 0.8333\n",
|
||||
"\n",
|
||||
"Epoch 4/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.1718 Acc: 0.9361\n",
|
||||
"val Loss: 0.3951 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 5/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.1248 Acc: 0.9595\n",
|
||||
"val Loss: 0.3524 Acc: 0.8333\n",
|
||||
"\n",
|
||||
"Epoch 6/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0979 Acc: 0.9668\n",
|
||||
"val Loss: 0.2959 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 7/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0772 Acc: 0.9828\n",
|
||||
"val Loss: 0.3463 Acc: 0.8444\n",
|
||||
"\n",
|
||||
"Epoch 8/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0639 Acc: 0.9791\n",
|
||||
"val Loss: 0.2962 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 9/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0633 Acc: 0.9791\n",
|
||||
"val Loss: 0.3204 Acc: 0.8444\n",
|
||||
"\n",
|
||||
"Epoch 10/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0621 Acc: 0.9791\n",
|
||||
"val Loss: 0.2782 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 11/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0621 Acc: 0.9803\n",
|
||||
"val Loss: 0.3503 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 12/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0585 Acc: 0.9828\n",
|
||||
"val Loss: 0.2996 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 13/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0574 Acc: 0.9828\n",
|
||||
"val Loss: 0.2762 Acc: 0.8778\n",
|
||||
"\n",
|
||||
"Epoch 14/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0370 Acc: 0.9939\n",
|
||||
"val Loss: 0.2943 Acc: 0.8778\n",
|
||||
"\n",
|
||||
"Epoch 15/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0406 Acc: 0.9914\n",
|
||||
"val Loss: 0.3017 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 16/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0477 Acc: 0.9840\n",
|
||||
"val Loss: 0.3728 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 17/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0431 Acc: 0.9877\n",
|
||||
"val Loss: 0.2876 Acc: 0.8778\n",
|
||||
"\n",
|
||||
"Epoch 18/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0478 Acc: 0.9889\n",
|
||||
"val Loss: 0.2877 Acc: 0.8778\n",
|
||||
"\n",
|
||||
"Epoch 19/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0452 Acc: 0.9914\n",
|
||||
"val Loss: 0.3009 Acc: 0.9000\n",
|
||||
"\n",
|
||||
"Epoch 20/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0417 Acc: 0.9914\n",
|
||||
"val Loss: 0.3057 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 21/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0389 Acc: 0.9926\n",
|
||||
"val Loss: 0.3701 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 22/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0469 Acc: 0.9840\n",
|
||||
"val Loss: 0.2676 Acc: 0.9111\n",
|
||||
"\n",
|
||||
"Epoch 23/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0496 Acc: 0.9889\n",
|
||||
"val Loss: 0.3054 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 24/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0511 Acc: 0.9902\n",
|
||||
"val Loss: 0.3788 Acc: 0.8778\n",
|
||||
"\n",
|
||||
"Epoch 25/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0408 Acc: 0.9914\n",
|
||||
"val Loss: 0.3294 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 26/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0433 Acc: 0.9926\n",
|
||||
"val Loss: 0.3117 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 27/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0404 Acc: 0.9926\n",
|
||||
"val Loss: 0.3222 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 28/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0465 Acc: 0.9877\n",
|
||||
"val Loss: 0.3083 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 29/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0523 Acc: 0.9853\n",
|
||||
"val Loss: 0.2602 Acc: 0.8889\n",
|
||||
"\n",
|
||||
"Epoch 30/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0347 Acc: 0.9951\n",
|
||||
"val Loss: 0.3026 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 31/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0554 Acc: 0.9828\n",
|
||||
"val Loss: 0.3376 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 32/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0522 Acc: 0.9877\n",
|
||||
"val Loss: 0.3147 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 33/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0508 Acc: 0.9914\n",
|
||||
"val Loss: 0.2795 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 34/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0362 Acc: 0.9926\n",
|
||||
"val Loss: 0.2940 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 35/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0421 Acc: 0.9877\n",
|
||||
"val Loss: 0.3091 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 36/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0389 Acc: 0.9975\n",
|
||||
"val Loss: 0.2861 Acc: 0.8667\n",
|
||||
"\n",
|
||||
"Epoch 37/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0392 Acc: 0.9939\n",
|
||||
"val Loss: 0.3131 Acc: 0.8889\n",
|
||||
"\n",
|
||||
"Epoch 38/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0420 Acc: 0.9877\n",
|
||||
"val Loss: 0.2981 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 39/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0440 Acc: 0.9902\n",
|
||||
"val Loss: 0.3115 Acc: 0.8889\n",
|
||||
"\n",
|
||||
"Epoch 40/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0436 Acc: 0.9877\n",
|
||||
"val Loss: 0.2774 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 41/99\n",
|
||||
"----------\n",
|
||||
"train Loss: 0.0440 Acc: 0.9889\n",
|
||||
"val Loss: 0.3586 Acc: 0.8556\n",
|
||||
"\n",
|
||||
"Epoch 42/99\n",
|
||||
"----------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history = {'train_loss': [], 'val_loss': [], 'train_acc':[], 'val_acc':[]}\n",
|
||||
"model_ft = resnet50(weights=ResNet50_Weights.DEFAULT)\n",
|
||||
"num_ftrs = model_ft.fc.in_features\n",
|
||||
"# Here the size of each output sample is set to 2.\n",
|
||||
"# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).\n",
|
||||
"\n",
|
||||
"# Add linear layer with number of classes\n",
|
||||
"model_ft.fc = nn.Linear(num_ftrs, 2)\n",
|
||||
"\n",
|
||||
"model_ft = model_ft.to(device)\n",
|
||||
@ -437,7 +260,15 @@
|
||||
"exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)\n",
|
||||
"\n",
|
||||
"model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,\n",
|
||||
" num_epochs=100)"
|
||||
" num_epochs=50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9af07909",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Some metrics in the `history` dict are still tensors and associated with the GPU. This will create a second dictionary `history_new` which only contains the scalar values."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -458,456 +289,25 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"execution_count": null,
|
||||
"id": "f4177ed1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>train_loss</th>\n",
|
||||
" <th>val_loss</th>\n",
|
||||
" <th>train_acc</th>\n",
|
||||
" <th>val_acc</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.576820</td>\n",
|
||||
" <td>0.430790</td>\n",
|
||||
" <td>0.685306</td>\n",
|
||||
" <td>0.859259</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.390554</td>\n",
|
||||
" <td>0.330808</td>\n",
|
||||
" <td>0.834850</td>\n",
|
||||
" <td>0.851852</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.276992</td>\n",
|
||||
" <td>0.311592</td>\n",
|
||||
" <td>0.892068</td>\n",
|
||||
" <td>0.888889</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.199594</td>\n",
|
||||
" <td>0.279608</td>\n",
|
||||
" <td>0.915475</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.152077</td>\n",
|
||||
" <td>0.272182</td>\n",
|
||||
" <td>0.941482</td>\n",
|
||||
" <td>0.888889</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>0.139206</td>\n",
|
||||
" <td>0.324749</td>\n",
|
||||
" <td>0.945384</td>\n",
|
||||
" <td>0.881481</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>0.089517</td>\n",
|
||||
" <td>0.260757</td>\n",
|
||||
" <td>0.976593</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>0.062301</td>\n",
|
||||
" <td>0.231712</td>\n",
|
||||
" <td>0.984395</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>0.065700</td>\n",
|
||||
" <td>0.212535</td>\n",
|
||||
" <td>0.985696</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>0.059114</td>\n",
|
||||
" <td>0.253683</td>\n",
|
||||
" <td>0.985696</td>\n",
|
||||
" <td>0.874074</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>0.055415</td>\n",
|
||||
" <td>0.253353</td>\n",
|
||||
" <td>0.986996</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>0.045581</td>\n",
|
||||
" <td>0.250287</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>0.043951</td>\n",
|
||||
" <td>0.223316</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.918519</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>0.057628</td>\n",
|
||||
" <td>0.252697</td>\n",
|
||||
" <td>0.986996</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>0.041168</td>\n",
|
||||
" <td>0.260735</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.881481</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15</th>\n",
|
||||
" <td>0.038926</td>\n",
|
||||
" <td>0.248857</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>16</th>\n",
|
||||
" <td>0.047686</td>\n",
|
||||
" <td>0.236497</td>\n",
|
||||
" <td>0.988296</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>17</th>\n",
|
||||
" <td>0.044452</td>\n",
|
||||
" <td>0.246685</td>\n",
|
||||
" <td>0.985696</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>18</th>\n",
|
||||
" <td>0.042139</td>\n",
|
||||
" <td>0.239525</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>19</th>\n",
|
||||
" <td>0.034043</td>\n",
|
||||
" <td>0.258453</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>20</th>\n",
|
||||
" <td>0.055712</td>\n",
|
||||
" <td>0.233888</td>\n",
|
||||
" <td>0.984395</td>\n",
|
||||
" <td>0.925926</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>21</th>\n",
|
||||
" <td>0.036809</td>\n",
|
||||
" <td>0.245363</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.888889</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>22</th>\n",
|
||||
" <td>0.036963</td>\n",
|
||||
" <td>0.255232</td>\n",
|
||||
" <td>0.994798</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>23</th>\n",
|
||||
" <td>0.039264</td>\n",
|
||||
" <td>0.248907</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>24</th>\n",
|
||||
" <td>0.040695</td>\n",
|
||||
" <td>0.263689</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.881481</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>25</th>\n",
|
||||
" <td>0.046714</td>\n",
|
||||
" <td>0.277939</td>\n",
|
||||
" <td>0.988296</td>\n",
|
||||
" <td>0.881481</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>26</th>\n",
|
||||
" <td>0.044117</td>\n",
|
||||
" <td>0.304999</td>\n",
|
||||
" <td>0.988296</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>27</th>\n",
|
||||
" <td>0.042990</td>\n",
|
||||
" <td>0.236786</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.940741</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>28</th>\n",
|
||||
" <td>0.037008</td>\n",
|
||||
" <td>0.293796</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.888889</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>29</th>\n",
|
||||
" <td>0.033433</td>\n",
|
||||
" <td>0.242197</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.918519</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>30</th>\n",
|
||||
" <td>0.038702</td>\n",
|
||||
" <td>0.231981</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>31</th>\n",
|
||||
" <td>0.044133</td>\n",
|
||||
" <td>0.225656</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>32</th>\n",
|
||||
" <td>0.045103</td>\n",
|
||||
" <td>0.235305</td>\n",
|
||||
" <td>0.988296</td>\n",
|
||||
" <td>0.925926</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>33</th>\n",
|
||||
" <td>0.048956</td>\n",
|
||||
" <td>0.257209</td>\n",
|
||||
" <td>0.988296</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>34</th>\n",
|
||||
" <td>0.048990</td>\n",
|
||||
" <td>0.254565</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>35</th>\n",
|
||||
" <td>0.050165</td>\n",
|
||||
" <td>0.271521</td>\n",
|
||||
" <td>0.986996</td>\n",
|
||||
" <td>0.881481</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>36</th>\n",
|
||||
" <td>0.048851</td>\n",
|
||||
" <td>0.223436</td>\n",
|
||||
" <td>0.984395</td>\n",
|
||||
" <td>0.918519</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>37</th>\n",
|
||||
" <td>0.037578</td>\n",
|
||||
" <td>0.233252</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>38</th>\n",
|
||||
" <td>0.056870</td>\n",
|
||||
" <td>0.238889</td>\n",
|
||||
" <td>0.980494</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>39</th>\n",
|
||||
" <td>0.030565</td>\n",
|
||||
" <td>0.292963</td>\n",
|
||||
" <td>0.993498</td>\n",
|
||||
" <td>0.896296</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>40</th>\n",
|
||||
" <td>0.040223</td>\n",
|
||||
" <td>0.228972</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>41</th>\n",
|
||||
" <td>0.041459</td>\n",
|
||||
" <td>0.246075</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>42</th>\n",
|
||||
" <td>0.037677</td>\n",
|
||||
" <td>0.263450</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>43</th>\n",
|
||||
" <td>0.047407</td>\n",
|
||||
" <td>0.273632</td>\n",
|
||||
" <td>0.992198</td>\n",
|
||||
" <td>0.888889</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>44</th>\n",
|
||||
" <td>0.047997</td>\n",
|
||||
" <td>0.241980</td>\n",
|
||||
" <td>0.985696</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>45</th>\n",
|
||||
" <td>0.043608</td>\n",
|
||||
" <td>0.230462</td>\n",
|
||||
" <td>0.989597</td>\n",
|
||||
" <td>0.911111</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>46</th>\n",
|
||||
" <td>0.048454</td>\n",
|
||||
" <td>0.312999</td>\n",
|
||||
" <td>0.985696</td>\n",
|
||||
" <td>0.866667</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>47</th>\n",
|
||||
" <td>0.057154</td>\n",
|
||||
" <td>0.274626</td>\n",
|
||||
" <td>0.980494</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>48</th>\n",
|
||||
" <td>0.050744</td>\n",
|
||||
" <td>0.231532</td>\n",
|
||||
" <td>0.988296</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>49</th>\n",
|
||||
" <td>0.049543</td>\n",
|
||||
" <td>0.229090</td>\n",
|
||||
" <td>0.990897</td>\n",
|
||||
" <td>0.903704</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" train_loss val_loss train_acc val_acc\n",
|
||||
"0 0.576820 0.430790 0.685306 0.859259\n",
|
||||
"1 0.390554 0.330808 0.834850 0.851852\n",
|
||||
"2 0.276992 0.311592 0.892068 0.888889\n",
|
||||
"3 0.199594 0.279608 0.915475 0.903704\n",
|
||||
"4 0.152077 0.272182 0.941482 0.888889\n",
|
||||
"5 0.139206 0.324749 0.945384 0.881481\n",
|
||||
"6 0.089517 0.260757 0.976593 0.911111\n",
|
||||
"7 0.062301 0.231712 0.984395 0.911111\n",
|
||||
"8 0.065700 0.212535 0.985696 0.903704\n",
|
||||
"9 0.059114 0.253683 0.985696 0.874074\n",
|
||||
"10 0.055415 0.253353 0.986996 0.911111\n",
|
||||
"11 0.045581 0.250287 0.990897 0.896296\n",
|
||||
"12 0.043951 0.223316 0.993498 0.918519\n",
|
||||
"13 0.057628 0.252697 0.986996 0.911111\n",
|
||||
"14 0.041168 0.260735 0.993498 0.881481\n",
|
||||
"15 0.038926 0.248857 0.990897 0.903704\n",
|
||||
"16 0.047686 0.236497 0.988296 0.896296\n",
|
||||
"17 0.044452 0.246685 0.985696 0.911111\n",
|
||||
"18 0.042139 0.239525 0.992198 0.911111\n",
|
||||
"19 0.034043 0.258453 0.993498 0.903704\n",
|
||||
"20 0.055712 0.233888 0.984395 0.925926\n",
|
||||
"21 0.036809 0.245363 0.993498 0.888889\n",
|
||||
"22 0.036963 0.255232 0.994798 0.896296\n",
|
||||
"23 0.039264 0.248907 0.992198 0.903704\n",
|
||||
"24 0.040695 0.263689 0.993498 0.881481\n",
|
||||
"25 0.046714 0.277939 0.988296 0.881481\n",
|
||||
"26 0.044117 0.304999 0.988296 0.896296\n",
|
||||
"27 0.042990 0.236786 0.990897 0.940741\n",
|
||||
"28 0.037008 0.293796 0.993498 0.888889\n",
|
||||
"29 0.033433 0.242197 0.993498 0.918519\n",
|
||||
"30 0.038702 0.231981 0.992198 0.896296\n",
|
||||
"31 0.044133 0.225656 0.992198 0.911111\n",
|
||||
"32 0.045103 0.235305 0.988296 0.925926\n",
|
||||
"33 0.048956 0.257209 0.988296 0.896296\n",
|
||||
"34 0.048990 0.254565 0.990897 0.896296\n",
|
||||
"35 0.050165 0.271521 0.986996 0.881481\n",
|
||||
"36 0.048851 0.223436 0.984395 0.918519\n",
|
||||
"37 0.037578 0.233252 0.992198 0.896296\n",
|
||||
"38 0.056870 0.238889 0.980494 0.911111\n",
|
||||
"39 0.030565 0.292963 0.993498 0.896296\n",
|
||||
"40 0.040223 0.228972 0.992198 0.903704\n",
|
||||
"41 0.041459 0.246075 0.990897 0.903704\n",
|
||||
"42 0.037677 0.263450 0.990897 0.903704\n",
|
||||
"43 0.047407 0.273632 0.992198 0.888889\n",
|
||||
"44 0.047997 0.241980 0.985696 0.911111\n",
|
||||
"45 0.043608 0.230462 0.989597 0.911111\n",
|
||||
"46 0.048454 0.312999 0.985696 0.866667\n",
|
||||
"47 0.057154 0.274626 0.980494 0.903704\n",
|
||||
"48 0.050744 0.231532 0.988296 0.903704\n",
|
||||
"49 0.049543 0.229090 0.990897 0.903704"
|
||||
]
|
||||
},
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.DataFrame(history_new)\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4df67cc5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate with Grad-CAM <a name=\"eval\"></a>\n",
|
||||
"\n",
|
||||
"In this section we load the originally trained model (the one which has been trained earlier and tested on the Jetson Nano) and do inference on a test image which contains a healthy and stressed plant not seen before."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
@ -931,6 +331,14 @@
|
||||
"import PIL"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "33956afd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We load the image, transform it with the same normalization as before, do inference and run the result through a softmax layer to get probabilities for each class. Our test image is classified as 30% healthy and 70% stressed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
@ -973,6 +381,14 @@
|
||||
"torch.nn.functional.softmax(out, dim=1)[0] * 100\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "96c117d1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can then use the image and the prediction with Grad-CAM to evaluate which regions contributed most to either class. The CAM object is created and the target layers for resnet is the last convolutional layer. In `targets` we specify which class we want to generate CAMs for; 0 means healthy and 1 means stressed. The visualization image is then stored in a variable so that we can plot it later."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
@ -995,21 +411,20 @@
|
||||
"targets = [ClassifierOutputTarget(0)]\n",
|
||||
"\n",
|
||||
"# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.\n",
|
||||
"grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True)\n",
|
||||
"rgb_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True)\n",
|
||||
"\n",
|
||||
"# In this example grayscale_cam has only one image in the batch:\n",
|
||||
"grayscale_cam = grayscale_cam[0, :]\n",
|
||||
"visualization_healthy = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)\n",
|
||||
"# In this example rgb_cam has only one image in the batch:\n",
|
||||
"rgb_cam = rgb_cam[0, :]\n",
|
||||
"visualization_healthy = show_cam_on_image(rgb_img, rgb_cam, use_rgb=True)\n",
|
||||
"\n",
|
||||
"# Specify target for CAM (0 = healthy, 1 = stressed)\n",
|
||||
"targets = [ClassifierOutputTarget(1)]\n",
|
||||
"\n",
|
||||
"# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.\n",
|
||||
"grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True)\n",
|
||||
"rgb_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True)\n",
|
||||
"\n",
|
||||
"# In this example grayscale_cam has only one image in the batch:\n",
|
||||
"grayscale_cam = grayscale_cam[0, :]\n",
|
||||
"visualization_stressed = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)"
|
||||
"rgb_cam = rgb_cam[0, :]\n",
|
||||
"visualization_stressed = show_cam_on_image(rgb_img, rgb_cam, use_rgb=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1066,6 +481,7 @@
|
||||
"ax[0][1].imshow(visualization_healthy)\n",
|
||||
"ax[1][0].imshow(rgb_img)\n",
|
||||
"ax[1][1].imshow(visualization_stressed)\n",
|
||||
"# Omit pixel values for the axis\n",
|
||||
"ax[0][0].axis('off')\n",
|
||||
"ax[0][1].axis('off') \n",
|
||||
"ax[1][0].axis('off') \n",
|
||||
@ -1074,6 +490,16 @@
|
||||
"fig.savefig(fig_save_dir + 'classifier-cam.pdf', format='pdf', bbox_inches='tight')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c521849",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate train metrics <a name=\"trainmetrics\"></a>\n",
|
||||
"\n",
|
||||
"We define the style of the plots with a grid this time and specify the directory to save images to. Metrics from the original run are loaded from the csv file and accuracy during training and validation is plotted. The second plot visualizes the loss during training for the train and validation set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
@ -1125,35 +551,6 @@
|
||||
"fig.tight_layout()\n",
|
||||
"fig.savefig(fig_save_dir + 'classifier-metrics.pdf', format='pdf', bbox_inches='tight')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"id": "61bb2edd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.84376"
|
||||
]
|
||||
},
|
||||
"execution_count": 75,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"results.iloc[20:35, 4].mean()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "61551dcb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user