{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b88ce481", "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function, division\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.optim import lr_scheduler\n", "import torch.backends.cudnn as cudnn\n", "import numpy as np\n", "import pandas as pd\n", "import torchvision\n", "from sklearn.model_selection import KFold\n", "import fiftyone as fo\n", "import fiftyone.brain as fob\n", "from torchvision import datasets, models, transforms\n", "from torchvision.models import resnet50, ResNet50_Weights\n", "from torch.utils.data import Dataset, DataLoader,TensorDataset,random_split,SubsetRandomSampler\n", "import matplotlib.pyplot as plt\n", "import time\n", "import os\n", "import copy\n", "import random" ] }, { "cell_type": "code", "execution_count": 2, "id": "d0cc05b1", "metadata": {}, "outputs": [], "source": [ "cudnn.benchmark = True\n", "plt.ion() # interactive mode\n", "\n", "data_transforms = {\n", " 'train': transforms.Compose([\n", " transforms.RandomResizedCrop(224),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " ]),\n", " 'val': transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " ]),\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "id": "a779e636", "metadata": {}, "outputs": [], "source": [ "def imshow(inp, title=None):\n", " \"\"\"Imshow for Tensor.\"\"\"\n", " inp = inp.numpy().transpose((1, 2, 0))\n", " mean = np.array([0.485, 0.456, 0.406])\n", " std = np.array([0.229, 0.224, 0.225])\n", " inp = std * inp + mean\n", " inp = np.clip(inp, 0, 1)\n", " plt.imshow(inp)\n", " if title is not None:\n", " plt.title(title)\n", " plt.pause(0.001) # pause a bit so that plots are updated" ] }, { "cell_type": "code", "execution_count": 4, "id": "32d3d5c6", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "torch.manual_seed(42)\n", "np.random.seed(42)\n", "\n", "data_dir = 'plantsdata'\n", "dataset = datasets.ImageFolder(os.path.join(data_dir))\n", "\n", "# 80/20 split\n", "train_dataset, val_dataset = random_split(dataset, [0.9, 0.1])\n", "\n", "dataset_size = len(dataset)\n", "\n", "train_dataset.dataset.transform = data_transforms['train']\n", "val_dataset.dataset.transform = data_transforms['val']\n", "\n", "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4,\n", " shuffle=True, num_workers=4)\n", "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4,\n", " shuffle=True, num_workers=4)\n", "\n", "dataloaders = {'train': train_loader, 'val': val_loader}\n", "\n", "dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}\n", "\n", "class_names = dataset.classes\n", "\n", "num_epochs = 50\n", "batch_size = 4" ] }, { "cell_type": "code", "execution_count": 5, "id": "4b0f5b61", "metadata": {}, "outputs": [], "source": [ "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n", " since = time.time()\n", "\n", " best_model_wts = copy.deepcopy(model.state_dict())\n", " best_acc = 0.0\n", "\n", " for epoch in range(num_epochs):\n", " print(f'Epoch {epoch}/{num_epochs - 1}')\n", " print('-' * 10)\n", "\n", " # Each epoch has a training and validation phase\n", " for phase in ['train', 'val']:\n", " if phase == 'train':\n", " model.train() # Set model to training mode\n", " else:\n", " model.eval() # Set model to evaluate mode\n", "\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " # Iterate over data.\n", " for inputs, labels in dataloaders[phase]:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward\n", " # track history if only in train\n", " with torch.set_grad_enabled(phase == 'train'):\n", " outputs = model(inputs)\n", " _, preds = torch.max(outputs, 1)\n", " loss = criterion(outputs, labels)\n", "\n", " # backward + optimize only if in training phase\n", " if phase == 'train':\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # statistics\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += torch.sum(preds == labels.data)\n", " if phase == 'train':\n", " scheduler.step()\n", "\n", " epoch_loss = running_loss / dataset_sizes[phase]\n", " epoch_acc = running_corrects.double() / dataset_sizes[phase]\n", "\n", " if phase == 'train':\n", " history['train_loss'].append(epoch_loss)\n", " history['train_acc'].append(epoch_acc)\n", " else:\n", " history['val_loss'].append(epoch_loss)\n", " history['val_acc'].append(epoch_acc)\n", " print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')\n", "\n", " # deep copy the model\n", " if phase == 'val' and epoch_acc > best_acc:\n", " best_acc = epoch_acc\n", " best_model_wts = copy.deepcopy(model.state_dict())\n", "\n", " print()\n", "\n", " time_elapsed = time.time() - since\n", " print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')\n", " print(f'Best val Acc: {best_acc:4f}')\n", "\n", " # load best model weights\n", " model.load_state_dict(best_model_wts)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "id": "be15db0c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0/99\n", "----------\n", "train Loss: 0.5582 Acc: 0.6978\n", "val Loss: 0.5307 Acc: 0.7667\n", "\n", "Epoch 1/99\n", "----------\n", "train Loss: 0.3797 Acc: 0.8268\n", "val Loss: 0.4450 Acc: 0.8222\n", "\n", "Epoch 2/99\n", "----------\n", "train Loss: 0.2861 Acc: 0.8845\n", "val Loss: 0.3655 Acc: 0.8333\n", "\n", "Epoch 3/99\n", "----------\n", "train Loss: 0.2418 Acc: 0.8894\n", "val Loss: 0.3731 Acc: 0.8333\n", "\n", "Epoch 4/99\n", "----------\n", "train Loss: 0.1718 Acc: 0.9361\n", "val Loss: 0.3951 Acc: 0.8667\n", "\n", "Epoch 5/99\n", "----------\n", "train Loss: 0.1248 Acc: 0.9595\n", "val Loss: 0.3524 Acc: 0.8333\n", "\n", "Epoch 6/99\n", "----------\n", "train Loss: 0.0979 Acc: 0.9668\n", "val Loss: 0.2959 Acc: 0.8556\n", "\n", "Epoch 7/99\n", "----------\n", "train Loss: 0.0772 Acc: 0.9828\n", "val Loss: 0.3463 Acc: 0.8444\n", "\n", "Epoch 8/99\n", "----------\n", "train Loss: 0.0639 Acc: 0.9791\n", "val Loss: 0.2962 Acc: 0.8667\n", "\n", "Epoch 9/99\n", "----------\n", "train Loss: 0.0633 Acc: 0.9791\n", "val Loss: 0.3204 Acc: 0.8444\n", "\n", "Epoch 10/99\n", "----------\n", "train Loss: 0.0621 Acc: 0.9791\n", "val Loss: 0.2782 Acc: 0.8667\n", "\n", "Epoch 11/99\n", "----------\n", "train Loss: 0.0621 Acc: 0.9803\n", "val Loss: 0.3503 Acc: 0.8667\n", "\n", "Epoch 12/99\n", "----------\n", "train Loss: 0.0585 Acc: 0.9828\n", "val Loss: 0.2996 Acc: 0.8667\n", "\n", "Epoch 13/99\n", "----------\n", "train Loss: 0.0574 Acc: 0.9828\n", "val Loss: 0.2762 Acc: 0.8778\n", "\n", "Epoch 14/99\n", "----------\n", "train Loss: 0.0370 Acc: 0.9939\n", "val Loss: 0.2943 Acc: 0.8778\n", "\n", "Epoch 15/99\n", "----------\n", "train Loss: 0.0406 Acc: 0.9914\n", "val Loss: 0.3017 Acc: 0.8556\n", "\n", "Epoch 16/99\n", "----------\n", "train Loss: 0.0477 Acc: 0.9840\n", "val Loss: 0.3728 Acc: 0.8556\n", "\n", "Epoch 17/99\n", "----------\n", "train Loss: 0.0431 Acc: 0.9877\n", "val Loss: 0.2876 Acc: 0.8778\n", "\n", "Epoch 18/99\n", "----------\n", "train Loss: 0.0478 Acc: 0.9889\n", "val Loss: 0.2877 Acc: 0.8778\n", "\n", "Epoch 19/99\n", "----------\n", "train Loss: 0.0452 Acc: 0.9914\n", "val Loss: 0.3009 Acc: 0.9000\n", "\n", "Epoch 20/99\n", "----------\n", "train Loss: 0.0417 Acc: 0.9914\n", "val Loss: 0.3057 Acc: 0.8667\n", "\n", "Epoch 21/99\n", "----------\n", "train Loss: 0.0389 Acc: 0.9926\n", "val Loss: 0.3701 Acc: 0.8667\n", "\n", "Epoch 22/99\n", "----------\n", "train Loss: 0.0469 Acc: 0.9840\n", "val Loss: 0.2676 Acc: 0.9111\n", "\n", "Epoch 23/99\n", "----------\n", "train Loss: 0.0496 Acc: 0.9889\n", "val Loss: 0.3054 Acc: 0.8667\n", "\n", "Epoch 24/99\n", "----------\n", "train Loss: 0.0511 Acc: 0.9902\n", "val Loss: 0.3788 Acc: 0.8778\n", "\n", "Epoch 25/99\n", "----------\n", "train Loss: 0.0408 Acc: 0.9914\n", "val Loss: 0.3294 Acc: 0.8556\n", "\n", "Epoch 26/99\n", "----------\n", "train Loss: 0.0433 Acc: 0.9926\n", "val Loss: 0.3117 Acc: 0.8556\n", "\n", "Epoch 27/99\n", "----------\n", "train Loss: 0.0404 Acc: 0.9926\n", "val Loss: 0.3222 Acc: 0.8667\n", "\n", "Epoch 28/99\n", "----------\n", "train Loss: 0.0465 Acc: 0.9877\n", "val Loss: 0.3083 Acc: 0.8556\n", "\n", "Epoch 29/99\n", "----------\n", "train Loss: 0.0523 Acc: 0.9853\n", "val Loss: 0.2602 Acc: 0.8889\n", "\n", "Epoch 30/99\n", "----------\n", "train Loss: 0.0347 Acc: 0.9951\n", "val Loss: 0.3026 Acc: 0.8667\n", "\n", "Epoch 31/99\n", "----------\n", "train Loss: 0.0554 Acc: 0.9828\n", "val Loss: 0.3376 Acc: 0.8556\n", "\n", "Epoch 32/99\n", "----------\n", "train Loss: 0.0522 Acc: 0.9877\n", "val Loss: 0.3147 Acc: 0.8667\n", "\n", "Epoch 33/99\n", "----------\n", "train Loss: 0.0508 Acc: 0.9914\n", "val Loss: 0.2795 Acc: 0.8667\n", "\n", "Epoch 34/99\n", "----------\n", "train Loss: 0.0362 Acc: 0.9926\n", "val Loss: 0.2940 Acc: 0.8667\n", "\n", "Epoch 35/99\n", "----------\n", "train Loss: 0.0421 Acc: 0.9877\n", "val Loss: 0.3091 Acc: 0.8667\n", "\n", "Epoch 36/99\n", "----------\n", "train Loss: 0.0389 Acc: 0.9975\n", "val Loss: 0.2861 Acc: 0.8667\n", "\n", "Epoch 37/99\n", "----------\n", "train Loss: 0.0392 Acc: 0.9939\n", "val Loss: 0.3131 Acc: 0.8889\n", "\n", "Epoch 38/99\n", "----------\n", "train Loss: 0.0420 Acc: 0.9877\n", "val Loss: 0.2981 Acc: 0.8556\n", "\n", "Epoch 39/99\n", "----------\n", "train Loss: 0.0440 Acc: 0.9902\n", "val Loss: 0.3115 Acc: 0.8889\n", "\n", "Epoch 40/99\n", "----------\n", "train Loss: 0.0436 Acc: 0.9877\n", "val Loss: 0.2774 Acc: 0.8556\n", "\n", "Epoch 41/99\n", "----------\n", "train Loss: 0.0440 Acc: 0.9889\n", "val Loss: 0.3586 Acc: 0.8556\n", "\n", "Epoch 42/99\n", "----------\n" ] } ], "source": [ "history = {'train_loss': [], 'val_loss': [], 'train_acc':[], 'val_acc':[]}\n", "model_ft = resnet50(weights=ResNet50_Weights.DEFAULT)\n", "num_ftrs = model_ft.fc.in_features\n", "# Here the size of each output sample is set to 2.\n", "# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).\n", "model_ft.fc = nn.Linear(num_ftrs, 2)\n", "\n", "model_ft = model_ft.to(device)\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "# Observe that all parameters are being optimized\n", "optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n", "\n", "# Decay LR by a factor of 0.1 every 7 epochs\n", "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)\n", "\n", "model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,\n", " num_epochs=100)" ] }, { "cell_type": "code", "execution_count": 63, "id": "0b425016", "metadata": {}, "outputs": [], "source": [ "history_new = {'train_loss': history['train_loss'].copy(),\n", " 'val_loss': history['val_loss'].copy(),\n", " 'train_acc': [],\n", " 'val_acc': []}\n", "for key in ['train_acc', 'val_acc']:\n", " for elem in history[key]:\n", " history_new[key].append(elem.detach().cpu().item())" ] }, { "cell_type": "code", "execution_count": 67, "id": "f4177ed1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | train_loss | \n", "val_loss | \n", "train_acc | \n", "val_acc | \n", "
|---|---|---|---|---|
| 0 | \n", "0.576820 | \n", "0.430790 | \n", "0.685306 | \n", "0.859259 | \n", "
| 1 | \n", "0.390554 | \n", "0.330808 | \n", "0.834850 | \n", "0.851852 | \n", "
| 2 | \n", "0.276992 | \n", "0.311592 | \n", "0.892068 | \n", "0.888889 | \n", "
| 3 | \n", "0.199594 | \n", "0.279608 | \n", "0.915475 | \n", "0.903704 | \n", "
| 4 | \n", "0.152077 | \n", "0.272182 | \n", "0.941482 | \n", "0.888889 | \n", "
| 5 | \n", "0.139206 | \n", "0.324749 | \n", "0.945384 | \n", "0.881481 | \n", "
| 6 | \n", "0.089517 | \n", "0.260757 | \n", "0.976593 | \n", "0.911111 | \n", "
| 7 | \n", "0.062301 | \n", "0.231712 | \n", "0.984395 | \n", "0.911111 | \n", "
| 8 | \n", "0.065700 | \n", "0.212535 | \n", "0.985696 | \n", "0.903704 | \n", "
| 9 | \n", "0.059114 | \n", "0.253683 | \n", "0.985696 | \n", "0.874074 | \n", "
| 10 | \n", "0.055415 | \n", "0.253353 | \n", "0.986996 | \n", "0.911111 | \n", "
| 11 | \n", "0.045581 | \n", "0.250287 | \n", "0.990897 | \n", "0.896296 | \n", "
| 12 | \n", "0.043951 | \n", "0.223316 | \n", "0.993498 | \n", "0.918519 | \n", "
| 13 | \n", "0.057628 | \n", "0.252697 | \n", "0.986996 | \n", "0.911111 | \n", "
| 14 | \n", "0.041168 | \n", "0.260735 | \n", "0.993498 | \n", "0.881481 | \n", "
| 15 | \n", "0.038926 | \n", "0.248857 | \n", "0.990897 | \n", "0.903704 | \n", "
| 16 | \n", "0.047686 | \n", "0.236497 | \n", "0.988296 | \n", "0.896296 | \n", "
| 17 | \n", "0.044452 | \n", "0.246685 | \n", "0.985696 | \n", "0.911111 | \n", "
| 18 | \n", "0.042139 | \n", "0.239525 | \n", "0.992198 | \n", "0.911111 | \n", "
| 19 | \n", "0.034043 | \n", "0.258453 | \n", "0.993498 | \n", "0.903704 | \n", "
| 20 | \n", "0.055712 | \n", "0.233888 | \n", "0.984395 | \n", "0.925926 | \n", "
| 21 | \n", "0.036809 | \n", "0.245363 | \n", "0.993498 | \n", "0.888889 | \n", "
| 22 | \n", "0.036963 | \n", "0.255232 | \n", "0.994798 | \n", "0.896296 | \n", "
| 23 | \n", "0.039264 | \n", "0.248907 | \n", "0.992198 | \n", "0.903704 | \n", "
| 24 | \n", "0.040695 | \n", "0.263689 | \n", "0.993498 | \n", "0.881481 | \n", "
| 25 | \n", "0.046714 | \n", "0.277939 | \n", "0.988296 | \n", "0.881481 | \n", "
| 26 | \n", "0.044117 | \n", "0.304999 | \n", "0.988296 | \n", "0.896296 | \n", "
| 27 | \n", "0.042990 | \n", "0.236786 | \n", "0.990897 | \n", "0.940741 | \n", "
| 28 | \n", "0.037008 | \n", "0.293796 | \n", "0.993498 | \n", "0.888889 | \n", "
| 29 | \n", "0.033433 | \n", "0.242197 | \n", "0.993498 | \n", "0.918519 | \n", "
| 30 | \n", "0.038702 | \n", "0.231981 | \n", "0.992198 | \n", "0.896296 | \n", "
| 31 | \n", "0.044133 | \n", "0.225656 | \n", "0.992198 | \n", "0.911111 | \n", "
| 32 | \n", "0.045103 | \n", "0.235305 | \n", "0.988296 | \n", "0.925926 | \n", "
| 33 | \n", "0.048956 | \n", "0.257209 | \n", "0.988296 | \n", "0.896296 | \n", "
| 34 | \n", "0.048990 | \n", "0.254565 | \n", "0.990897 | \n", "0.896296 | \n", "
| 35 | \n", "0.050165 | \n", "0.271521 | \n", "0.986996 | \n", "0.881481 | \n", "
| 36 | \n", "0.048851 | \n", "0.223436 | \n", "0.984395 | \n", "0.918519 | \n", "
| 37 | \n", "0.037578 | \n", "0.233252 | \n", "0.992198 | \n", "0.896296 | \n", "
| 38 | \n", "0.056870 | \n", "0.238889 | \n", "0.980494 | \n", "0.911111 | \n", "
| 39 | \n", "0.030565 | \n", "0.292963 | \n", "0.993498 | \n", "0.896296 | \n", "
| 40 | \n", "0.040223 | \n", "0.228972 | \n", "0.992198 | \n", "0.903704 | \n", "
| 41 | \n", "0.041459 | \n", "0.246075 | \n", "0.990897 | \n", "0.903704 | \n", "
| 42 | \n", "0.037677 | \n", "0.263450 | \n", "0.990897 | \n", "0.903704 | \n", "
| 43 | \n", "0.047407 | \n", "0.273632 | \n", "0.992198 | \n", "0.888889 | \n", "
| 44 | \n", "0.047997 | \n", "0.241980 | \n", "0.985696 | \n", "0.911111 | \n", "
| 45 | \n", "0.043608 | \n", "0.230462 | \n", "0.989597 | \n", "0.911111 | \n", "
| 46 | \n", "0.048454 | \n", "0.312999 | \n", "0.985696 | \n", "0.866667 | \n", "
| 47 | \n", "0.057154 | \n", "0.274626 | \n", "0.980494 | \n", "0.903704 | \n", "
| 48 | \n", "0.050744 | \n", "0.231532 | \n", "0.988296 | \n", "0.903704 | \n", "
| 49 | \n", "0.049543 | \n", "0.229090 | \n", "0.990897 | \n", "0.903704 | \n", "