2024-03-14 18:30:11 +01:00

1026 lines
146 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "747ddcf2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33me1527193\u001b[0m (\u001b[33mflower-classification\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import os\n",
"import time\n",
"import random\n",
"import wandb\n",
"import torch\n",
"wandb.login()\n",
"\n",
"def set_size(width, fraction=1, subplots=(1, 1)):\n",
" \"\"\"Set figure dimensions to avoid scaling in LaTeX.\n",
"\n",
" Parameters\n",
" ----------\n",
" width: float\n",
" Document textwidth or columnwidth in pts\n",
" fraction: float, optional\n",
" Fraction of the width which you wish the figure to occupy\n",
"\n",
" Returns\n",
" -------\n",
" fig_dim: tuple\n",
" Dimensions of figure in inches\n",
" \"\"\"\n",
" # Width of figure (in pts)\n",
" fig_width_pt = width * fraction\n",
"\n",
" # Convert from pt to inches\n",
" inches_per_pt = 1 / 72.27\n",
"\n",
" # Golden ratio to set aesthetic figure height\n",
" # https://disq.us/p/2940ij3\n",
" golden_ratio = (5**.5 - 1) / 2\n",
"\n",
" # Figure width in inches\n",
" fig_width_in = fig_width_pt * inches_per_pt\n",
" # Figure height in inches\n",
" fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])\n",
"\n",
" fig_dim = (fig_width_in, fig_height_in)\n",
"\n",
" return fig_dim\n",
"\n",
"torch.manual_seed(42)\n",
"np.random.seed(42)"
]
},
{
"cell_type": "markdown",
"id": "4d29e56f-ee81-4f43-96ff-99bb22c52f6a",
"metadata": {},
"source": [
"# Download Metrics from WandB"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "76cc2ca7",
"metadata": {},
"outputs": [],
"source": [
"api = wandb.Api()\n",
"\n",
"# Project is specified by <entity/project-name>\n",
"runs = api.runs(\"flower-classification/pytorch-sweeps-demo\")\n",
"\n",
"summary_list, config_list, name_list = [], [], []\n",
"for run in runs: \n",
" # .summary contains the output keys/values for metrics like accuracy.\n",
" # We call ._json_dict to omit large files \n",
" summary_list.append(run.summary._json_dict)\n",
"\n",
" # .config contains the hyperparameters.\n",
" # We remove special values that start with _.\n",
" config_list.append(\n",
" {k: v for k,v in run.config.items()\n",
" if not k.startswith('_')})\n",
"\n",
" # .name is the human-readable name of the run.\n",
" name_list.append(run.name)\n",
"\n",
"runs_df = pd.DataFrame({\n",
" \"summary\": summary_list,\n",
" \"config\": config_list,\n",
" \"name\": name_list\n",
" })\n",
"\n",
"runs_df.to_csv(\"hyp-metrics.csv\")"
]
},
{
"cell_type": "markdown",
"id": "821aeb7d-784b-4e38-b4f8-c49245ee25ce",
"metadata": {},
"source": [
"# Transform Metrics\n",
"\n",
"The column `summary` contains most of the metrics we are interested in (`test/precision`,…) but all of the metrics are in a dictionary in this column."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "353f9082",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Unnamed: 0</th>\n",
" <th>name</th>\n",
" <th>_step</th>\n",
" <th>_timestamp</th>\n",
" <th>test/recall</th>\n",
" <th>test/f1-score</th>\n",
" <th>test/epoch_acc</th>\n",
" <th>test/epoch_loss</th>\n",
" <th>train/epoch_loss</th>\n",
" <th>epoch</th>\n",
" <th>...</th>\n",
" <th>test/batch_loss</th>\n",
" <th>eps</th>\n",
" <th>gamma</th>\n",
" <th>epochs</th>\n",
" <th>beta_one</th>\n",
" <th>beta_two</th>\n",
" <th>optimizer</th>\n",
" <th>step_size</th>\n",
" <th>batch_size</th>\n",
" <th>learning_rate</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>fiery-sweep-26</td>\n",
" <td>2059</td>\n",
" <td>1.680693e+09</td>\n",
" <td>0.617021</td>\n",
" <td>0.707317</td>\n",
" <td>0.733333</td>\n",
" <td>0.566462</td>\n",
" <td>0.424106</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>1.000000e-01</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>0.99</td>\n",
" <td>0.900</td>\n",
" <td>adam</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>0.0003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>radiant-sweep-25</td>\n",
" <td>1039</td>\n",
" <td>1.680693e+09</td>\n",
" <td>0.822222</td>\n",
" <td>0.747475</td>\n",
" <td>0.722222</td>\n",
" <td>0.645458</td>\n",
" <td>0.64979</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>1.000000e+00</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>0.99</td>\n",
" <td>0.900</td>\n",
" <td>adam</td>\n",
" <td>2</td>\n",
" <td>8</td>\n",
" <td>0.0003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>blooming-sweep-24</td>\n",
" <td>1039</td>\n",
" <td>1.680692e+09</td>\n",
" <td>0.783784</td>\n",
" <td>0.852941</td>\n",
" <td>0.888889</td>\n",
" <td>0.348129</td>\n",
" <td>0.016143</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>1.000000e-08</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>0.90</td>\n",
" <td>0.999</td>\n",
" <td>sgd</td>\n",
" <td>5</td>\n",
" <td>8</td>\n",
" <td>0.0030</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3</td>\n",
" <td>visionary-sweep-23</td>\n",
" <td>529</td>\n",
" <td>1.680692e+09</td>\n",
" <td>0.833333</td>\n",
" <td>0.795455</td>\n",
" <td>0.800000</td>\n",
" <td>0.555318</td>\n",
" <td>0.532423</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>1.000000e+00</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>0.90</td>\n",
" <td>0.900</td>\n",
" <td>sgd</td>\n",
" <td>2</td>\n",
" <td>16</td>\n",
" <td>0.0003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>ancient-sweep-22</td>\n",
" <td>410</td>\n",
" <td>1.680692e+09</td>\n",
" <td>0.884615</td>\n",
" <td>0.707692</td>\n",
" <td>0.577778</td>\n",
" <td>1.560271</td>\n",
" <td>0.75081</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>NaN</td>\n",
" <td>1.000000e-08</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>0.90</td>\n",
" <td>0.990</td>\n",
" <td>adam</td>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>0.0100</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>133</td>\n",
" <td>different-sweep-5</td>\n",
" <td>1159</td>\n",
" <td>1.678732e+09</td>\n",
" <td>0.714286</td>\n",
" <td>0.813953</td>\n",
" <td>0.822222</td>\n",
" <td>0.493642</td>\n",
" <td>0.518635</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>0.506896</td>\n",
" <td>NaN</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>sgd</td>\n",
" <td>3</td>\n",
" <td>8</td>\n",
" <td>0.0001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>134</th>\n",
" <td>134</td>\n",
" <td>wise-sweep-4</td>\n",
" <td>1159</td>\n",
" <td>1.678731e+09</td>\n",
" <td>0.846154</td>\n",
" <td>0.835443</td>\n",
" <td>0.855556</td>\n",
" <td>0.548264</td>\n",
" <td>0.54292</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>0.515937</td>\n",
" <td>NaN</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>sgd</td>\n",
" <td>2</td>\n",
" <td>8</td>\n",
" <td>0.0001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>135</th>\n",
" <td>135</td>\n",
" <td>misty-sweep-3</td>\n",
" <td>2289</td>\n",
" <td>1.678731e+09</td>\n",
" <td>0.775000</td>\n",
" <td>0.849315</td>\n",
" <td>0.877778</td>\n",
" <td>0.241948</td>\n",
" <td>0.020604</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>1.758836</td>\n",
" <td>NaN</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>sgd</td>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>0.0030</td>\n",
" </tr>\n",
" <tr>\n",
" <th>136</th>\n",
" <td>136</td>\n",
" <td>unique-sweep-2</td>\n",
" <td>1159</td>\n",
" <td>1.678730e+09</td>\n",
" <td>0.684211</td>\n",
" <td>0.753623</td>\n",
" <td>0.811111</td>\n",
" <td>0.479234</td>\n",
" <td>0.42905</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>0.455120</td>\n",
" <td>NaN</td>\n",
" <td>0.1</td>\n",
" <td>10</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>sgd</td>\n",
" <td>3</td>\n",
" <td>8</td>\n",
" <td>0.0003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>137</th>\n",
" <td>137</td>\n",
" <td>polar-sweep-1</td>\n",
" <td>2289</td>\n",
" <td>1.678730e+09</td>\n",
" <td>0.863636</td>\n",
" <td>0.883721</td>\n",
" <td>0.888889</td>\n",
" <td>0.544247</td>\n",
" <td>0.024021</td>\n",
" <td>9</td>\n",
" <td>...</td>\n",
" <td>2.532007</td>\n",
" <td>NaN</td>\n",
" <td>0.5</td>\n",
" <td>10</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>sgd</td>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>0.0030</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>138 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" Unnamed: 0 name _step _timestamp test/recall \\\n",
"0 0 fiery-sweep-26 2059 1.680693e+09 0.617021 \n",
"1 1 radiant-sweep-25 1039 1.680693e+09 0.822222 \n",
"2 2 blooming-sweep-24 1039 1.680692e+09 0.783784 \n",
"3 3 visionary-sweep-23 529 1.680692e+09 0.833333 \n",
"4 4 ancient-sweep-22 410 1.680692e+09 0.884615 \n",
".. ... ... ... ... ... \n",
"133 133 different-sweep-5 1159 1.678732e+09 0.714286 \n",
"134 134 wise-sweep-4 1159 1.678731e+09 0.846154 \n",
"135 135 misty-sweep-3 2289 1.678731e+09 0.775000 \n",
"136 136 unique-sweep-2 1159 1.678730e+09 0.684211 \n",
"137 137 polar-sweep-1 2289 1.678730e+09 0.863636 \n",
"\n",
" test/f1-score test/epoch_acc test/epoch_loss train/epoch_loss epoch \\\n",
"0 0.707317 0.733333 0.566462 0.424106 9 \n",
"1 0.747475 0.722222 0.645458 0.64979 9 \n",
"2 0.852941 0.888889 0.348129 0.016143 9 \n",
"3 0.795455 0.800000 0.555318 0.532423 9 \n",
"4 0.707692 0.577778 1.560271 0.75081 1 \n",
".. ... ... ... ... ... \n",
"133 0.813953 0.822222 0.493642 0.518635 9 \n",
"134 0.835443 0.855556 0.548264 0.54292 9 \n",
"135 0.849315 0.877778 0.241948 0.020604 9 \n",
"136 0.753623 0.811111 0.479234 0.42905 9 \n",
"137 0.883721 0.888889 0.544247 0.024021 9 \n",
"\n",
" ... test/batch_loss eps gamma epochs beta_one beta_two \\\n",
"0 ... NaN 1.000000e-01 0.1 10 0.99 0.900 \n",
"1 ... NaN 1.000000e+00 0.5 10 0.99 0.900 \n",
"2 ... NaN 1.000000e-08 0.5 10 0.90 0.999 \n",
"3 ... NaN 1.000000e+00 0.1 10 0.90 0.900 \n",
"4 ... NaN 1.000000e-08 0.5 10 0.90 0.990 \n",
".. ... ... ... ... ... ... ... \n",
"133 ... 0.506896 NaN 0.5 10 NaN NaN \n",
"134 ... 0.515937 NaN 0.5 10 NaN NaN \n",
"135 ... 1.758836 NaN 0.5 10 NaN NaN \n",
"136 ... 0.455120 NaN 0.1 10 NaN NaN \n",
"137 ... 2.532007 NaN 0.5 10 NaN NaN \n",
"\n",
" optimizer step_size batch_size learning_rate \n",
"0 adam 3 4 0.0003 \n",
"1 adam 2 8 0.0003 \n",
"2 sgd 5 8 0.0030 \n",
"3 sgd 2 16 0.0003 \n",
"4 adam 7 4 0.0100 \n",
".. ... ... ... ... \n",
"133 sgd 3 8 0.0001 \n",
"134 sgd 2 8 0.0001 \n",
"135 sgd 3 4 0.0030 \n",
"136 sgd 3 8 0.0003 \n",
"137 sgd 7 4 0.0030 \n",
"\n",
"[138 rows x 25 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('hyp-metrics.csv',\n",
" delimiter=',')\n",
"df['summary'] = df['summary'].map(eval)\n",
"df['config'] = df['config'].map(eval)\n",
"df = df.join(pd.json_normalize(df['summary'])).drop('summary', axis='columns')\n",
"df = df.join(pd.json_normalize(df['config'])).drop('config', axis='columns')\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4679b2f8",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/run/user/1000/ipykernel_39845/2346208349.py:1: FutureWarning: In a future version of pandas all arguments of Series.sort_values will be keyword-only.\n",
" df['learning_rate'].value_counts().sort_values(0)\n"
]
},
{
"data": {
"text/plain": [
"0.0100 21\n",
"0.1000 21\n",
"0.0003 23\n",
"0.0010 23\n",
"0.0001 23\n",
"0.0030 27\n",
"Name: learning_rate, dtype: int64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['learning_rate'].value_counts().sort_values(0)\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1b1a54fc",
"metadata": {},
"outputs": [],
"source": [
"# Style the plots (with grid this time)\n",
"width = 418\n",
"sns.set_theme(style='whitegrid',\n",
" rc={'text.usetex': True, 'font.family': 'serif', 'axes.labelsize': 16,\n",
" 'font.size': 16, 'legend.fontsize': 11,\n",
" 'xtick.labelsize': 12, 'ytick.labelsize': 12})\n",
"\n",
"fig_save_dir = '../../thesis/graphics/'"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "00efa25b",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 578.387x357.463 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df_prepared = df.rename(columns={'learning_rate': 'learning rate', 'batch_size': 'batch size'})\n",
"fig, ax = plt.subplots(1, 1, figsize=set_size(width, subplots=(1,1)))\n",
"sns.scatterplot(x=\"learning rate\", y=\"test/f1-score\", hue=\"batch size\",\n",
" palette=sns.cubehelix_palette(5, light=0.8, gamma=1.2),\n",
" sizes=(5, 30), linewidth=0, s=50,\n",
" data=df_prepared, ax=ax)\n",
"ax.set_xscale('log')\n",
"ax.set_xticks([0.0001, 0.0003, 0.001, 0.003, 0.01, 0.1])\n",
"ax.set_xticklabels(labels = ['0.0001', '0.0003', '0.001', '0.003', '0.01', '0.1'])\n",
"fig.tight_layout()\n",
"fig.savefig(fig_save_dir + 'classifier-hyp-metrics.pdf', format='pdf', bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "44e275ab",
"metadata": {},
"outputs": [],
"source": [
"parameters_dict = {\n",
" 'optimizer': {\n",
" 'values': ['adam', 'sgd']\n",
" },\n",
"}\n",
"\n",
"parameters_dict.update({\n",
" 'batch_size': {\n",
" 'values': [4, 8, 16, 32, 64]},\n",
" 'learning_rate': {\n",
" 'values': [0.0001, 0.0003, 0.001, 0.003, 0.01, 0.1]},\n",
" 'step_size': {\n",
" 'values': [2, 3, 5, 7]},\n",
" 'gamma': {\n",
" 'values': [0.1, 0.5]},\n",
" 'beta_one': {\n",
" 'values': [0.9, 0.99]},\n",
" 'beta_two': {\n",
" 'values': [0.5, 0.9, 0.99, 0.999]},\n",
" 'eps': {\n",
" 'values': [1e-08, 0.1, 1]}\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7d3c2860",
"metadata": {},
"outputs": [],
"source": [
"params = pd.DataFrame.from_dict(parameters_dict)\n",
"params = params.transpose()\n",
"params['values_string'] = [', '.join(map(str, l)) for l in params['values']]\n",
"params['values'] = params['values_string']\n",
"params = params.drop(['values_string'], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "acc3a77e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>values</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>optimizer</th>\n",
" <td>adam, sgd</td>\n",
" </tr>\n",
" <tr>\n",
" <th>batch_size</th>\n",
" <td>4, 8, 16, 32, 64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>learning_rate</th>\n",
" <td>0.0001, 0.0003, 0.001, 0.003, 0.01, 0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>step_size</th>\n",
" <td>2, 3, 5, 7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gamma</th>\n",
" <td>0.1, 0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>beta_one</th>\n",
" <td>0.9, 0.99</td>\n",
" </tr>\n",
" <tr>\n",
" <th>beta_two</th>\n",
" <td>0.5, 0.9, 0.99, 0.999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>eps</th>\n",
" <td>1e-08, 0.1, 1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" values\n",
"optimizer adam, sgd\n",
"batch_size 4, 8, 16, 32, 64\n",
"learning_rate 0.0001, 0.0003, 0.001, 0.003, 0.01, 0.1\n",
"step_size 2, 3, 5, 7\n",
"gamma 0.1, 0.5\n",
"beta_one 0.9, 0.99\n",
"beta_two 0.5, 0.9, 0.99, 0.999\n",
"eps 1e-08, 0.1, 1"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "73a26951",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>optimizer</th>\n",
" <th>batch_size</th>\n",
" <th>learning_rate</th>\n",
" <th>step_size</th>\n",
" <th>gamma</th>\n",
" <th>beta_one</th>\n",
" <th>beta_two</th>\n",
" <th>eps</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>adam</td>\n",
" <td>4</td>\n",
" <td>0.0001</td>\n",
" <td>2</td>\n",
" <td>0.1</td>\n",
" <td>0.9</td>\n",
" <td>0.5</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>adam</td>\n",
" <td>4</td>\n",
" <td>0.0001</td>\n",
" <td>2</td>\n",
" <td>0.1</td>\n",
" <td>0.9</td>\n",
" <td>0.5</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>adam</td>\n",
" <td>4</td>\n",
" <td>0.0001</td>\n",
" <td>2</td>\n",
" <td>0.1</td>\n",
" <td>0.9</td>\n",
" <td>0.5</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>adam</td>\n",
" <td>4</td>\n",
" <td>0.0001</td>\n",
" <td>2</td>\n",
" <td>0.1</td>\n",
" <td>0.9</td>\n",
" <td>0.9</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>adam</td>\n",
" <td>4</td>\n",
" <td>0.0001</td>\n",
" <td>2</td>\n",
" <td>0.1</td>\n",
" <td>0.9</td>\n",
" <td>0.9</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>sgd</td>\n",
" <td>64</td>\n",
" <td>0.1</td>\n",
" <td>7</td>\n",
" <td>0.5</td>\n",
" <td>0.99</td>\n",
" <td>0.99</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>sgd</td>\n",
" <td>64</td>\n",
" <td>0.1</td>\n",
" <td>7</td>\n",
" <td>0.5</td>\n",
" <td>0.99</td>\n",
" <td>0.99</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>sgd</td>\n",
" <td>64</td>\n",
" <td>0.1</td>\n",
" <td>7</td>\n",
" <td>0.5</td>\n",
" <td>0.99</td>\n",
" <td>0.999</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>sgd</td>\n",
" <td>64</td>\n",
" <td>0.1</td>\n",
" <td>7</td>\n",
" <td>0.5</td>\n",
" <td>0.99</td>\n",
" <td>0.999</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>values</th>\n",
" <td>sgd</td>\n",
" <td>64</td>\n",
" <td>0.1</td>\n",
" <td>7</td>\n",
" <td>0.5</td>\n",
" <td>0.99</td>\n",
" <td>0.999</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>11520 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" optimizer batch_size learning_rate step_size gamma beta_one beta_two \\\n",
"values adam 4 0.0001 2 0.1 0.9 0.5 \n",
"values adam 4 0.0001 2 0.1 0.9 0.5 \n",
"values adam 4 0.0001 2 0.1 0.9 0.5 \n",
"values adam 4 0.0001 2 0.1 0.9 0.9 \n",
"values adam 4 0.0001 2 0.1 0.9 0.9 \n",
"... ... ... ... ... ... ... ... \n",
"values sgd 64 0.1 7 0.5 0.99 0.99 \n",
"values sgd 64 0.1 7 0.5 0.99 0.99 \n",
"values sgd 64 0.1 7 0.5 0.99 0.999 \n",
"values sgd 64 0.1 7 0.5 0.99 0.999 \n",
"values sgd 64 0.1 7 0.5 0.99 0.999 \n",
"\n",
" eps \n",
"values 0.0 \n",
"values 0.1 \n",
"values 1 \n",
"values 0.0 \n",
"values 0.1 \n",
"... ... \n",
"values 0.1 \n",
"values 1 \n",
"values 0.0 \n",
"values 0.1 \n",
"values 1 \n",
"\n",
"[11520 rows x 8 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame.from_dict(parameters_dict).explode('optimizer').explode('batch_size').explode('learning_rate').explode('step_size').explode('gamma').explode('beta_one').explode('beta_two').explode('eps')"
]
},
{
"cell_type": "markdown",
"id": "0d01bf18",
"metadata": {},
"source": [
"# F1-score stratified 10-fold cross validation"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "bb567230",
"metadata": {},
"outputs": [],
"source": [
"f_scores_test = pd.read_csv('f1-scores-folds.csv', delimiter=',')\n",
"f_scores_test['epoch'] = np.resize(np.arange(25), 10*25)\n",
"f_scores_test['fold'] = np.repeat(np.arange(10), 25)\n",
"f_scores_test = pd.melt(f_scores_test[['epoch', 'fold', 'StratifiedKFold-ROC - test/f1-score']], ['epoch', 'fold'])\n",
"\n",
"f_scores_train = pd.read_csv('f1-scores-folds-train.csv', delimiter=',')\n",
"f_scores_train['epoch'] = np.resize(np.arange(25), 10*25)\n",
"f_scores_train['fold'] = np.repeat(np.arange(10), 25)\n",
"f_scores_train = pd.melt(f_scores_train[['epoch', 'fold', 'StratifiedKFold-ROC - train/f1-score']], ['epoch', 'fold'])"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "493e415e",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 578.387x714.925 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(2, 1, figsize=set_size(width, subplots=(2,1)), sharex=True)\n",
"sns.lineplot(x=\"epoch\", y=\"value\",\n",
" hue='fold',\n",
" palette=sns.cubehelix_palette(10, light=0.8, gamma=1.2),\n",
" linewidth=1,\n",
" data=f_scores_train, ax=ax[0])\n",
"\n",
"sns.lineplot(x=\"epoch\", y=\"value\",\n",
" hue='fold',\n",
" palette=sns.cubehelix_palette(10, light=0.8, gamma=1.2),\n",
" linewidth=1,\n",
" data=f_scores_test, ax=ax[1])\n",
"ax[0].set_ylabel('train/f1-score')\n",
"ax[1].set_ylabel('test/f1-score')\n",
"fig.tight_layout()\n",
"fig.savefig(fig_save_dir + 'classifier-hyp-folds-f1.pdf', format='pdf', bbox_inches='tight')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c642d40",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}