{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter search for text classification (Pytorch)\n", "\n", "In this tutorial we present how to use hyperparameter optimization on a text classification analysis example from the Pytorch documentation.\n", "\n", "**Reference**:\n", " This tutorial is based on materials from the Pytorch Documentation: [Text classification with the torchtext library](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: deephyper in /Users/romainegele/Documents/Argonne/deephyper (0.3.4)\n", "Requirement already satisfied: ConfigSpace>=0.4.20 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (0.5.0)\n", "Requirement already satisfied: dm-tree in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (0.1.7)\n", "Requirement already satisfied: Jinja2<3.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (3.0.3)\n", "Requirement already satisfied: numpy in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.22.4)\n", "Requirement already satisfied: pandas>=0.24.2 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.4.2)\n", "Requirement already satisfied: packaging in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (21.3)\n", "Requirement already satisfied: scikit-learn>=0.23.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.1.1)\n", "Requirement already satisfied: scipy>=0.19.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (1.8.1)\n", "Requirement already satisfied: tqdm>=4.64.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (4.64.0)\n", "Requirement already satisfied: pyyaml in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from deephyper) (6.0)\n", "Requirement already satisfied: cython in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ConfigSpace>=0.4.20->deephyper) (0.29.30)\n", "Requirement already satisfied: pyparsing in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ConfigSpace>=0.4.20->deephyper) (3.0.9)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from Jinja2<3.1->deephyper) (2.1.1)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from pandas>=0.24.2->deephyper) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from pandas>=0.24.2->deephyper) (2022.1)\n", "Requirement already satisfied: joblib>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from scikit-learn>=0.23.1->deephyper) (1.1.0)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from scikit-learn>=0.23.1->deephyper) (3.1.0)\n", "Requirement already satisfied: six>=1.5 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas>=0.24.2->deephyper) (1.15.0)\n", "Requirement already satisfied: ray in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (1.12.1)\n", "Requirement already satisfied: grpcio<=1.43.0,>=1.28.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.42.0)\n", "Requirement already satisfied: requests in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (2.27.1)\n", "Requirement already satisfied: click>=7.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (8.1.3)\n", "Requirement already satisfied: virtualenv in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (20.14.1)\n", "Requirement already satisfied: frozenlist in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.3.0)\n", "Requirement already satisfied: pyyaml in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (6.0)\n", "Requirement already satisfied: jsonschema in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (4.6.0)\n", "Requirement already satisfied: protobuf>=3.15.3 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (3.18.1)\n", "Requirement already satisfied: aiosignal in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.2.0)\n", "Requirement already satisfied: attrs in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (21.4.0)\n", "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.0.4)\n", "Requirement already satisfied: numpy>=1.19.3 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (1.22.4)\n", "Requirement already satisfied: filelock in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from ray) (3.7.1)\n", "Requirement already satisfied: six>=1.5.2 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from grpcio<=1.43.0,>=1.28.1->ray) (1.15.0)\n", "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from jsonschema->ray) (0.18.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (3.3)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (1.26.9)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (2022.5.18.1)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->ray) (2.0.12)\n", "Requirement already satisfied: distlib<1,>=0.3.1 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from virtualenv->ray) (0.3.4)\n", "Requirement already satisfied: platformdirs<3,>=2 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from virtualenv->ray) (2.5.2)\n", "Requirement already satisfied: torch in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (1.11.0)\n", "Requirement already satisfied: torchtext in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (0.12.0)\n", "Requirement already satisfied: torchdata in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (0.3.0)\n", "Requirement already satisfied: typing-extensions in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torch) (4.2.0)\n", "Requirement already satisfied: numpy in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchtext) (1.22.4)\n", "Requirement already satisfied: tqdm in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchtext) (4.64.0)\n", "Requirement already satisfied: requests in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchtext) (2.27.1)\n", "Requirement already satisfied: urllib3>=1.25 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from torchdata) (1.26.9)\n", "Requirement already satisfied: idna<4,>=2.5 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->torchtext) (3.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->torchtext) (2022.5.18.1)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/romainegele/miniforge3/envs/dh-env-test/lib/python3.9/site-packages (from requests->torchtext) (2.0.12)\n" ] } ], "source": [ "!pip3 install deephyper\n", "!pip3 install ray\n", "!pip3 install torch torchtext torchdata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import ray\n", "import json\n", "import pandas as pd\n", "from functools import partial\n", "\n", "import torch\n", "\n", "from torchtext.data.utils import get_tokenizer\n", "from torchtext.data.functional import to_map_style_dataset\n", "from torchtext.vocab import build_vocab_from_iterator\n", "\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import random_split\n", "\n", "from torch import nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
`max_evals (int)`
: Defines the maximum number of evaluations that we want to perform. Default to -1
for an infinite number.`timeout (int)`
: Defines a time budget (in seconds) before stopping the search. Default to None
for an infinite time budget.\n", " | batch_size | \n", "learning_rate | \n", "num_epochs | \n", "job_id | \n", "objective | \n", "timestamp_submit | \n", "timestamp_gather | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "13 | \n", "0.396163 | \n", "10 | \n", "1 | \n", "0.876310 | \n", "0.167126 | \n", "21.861172 | \n", "
1 | \n", "23 | \n", "0.176352 | \n", "13 | \n", "2 | \n", "0.836810 | \n", "21.893403 | \n", "46.585527 | \n", "
2 | \n", "85 | \n", "0.615317 | \n", "9 | \n", "3 | \n", "0.849250 | \n", "46.603550 | \n", "62.330578 | \n", "
3 | \n", "9 | \n", "4.647854 | \n", "15 | \n", "4 | \n", "0.888524 | \n", "62.348272 | \n", "96.357171 | \n", "
4 | \n", "22 | \n", "1.862895 | \n", "7 | \n", "5 | \n", "0.892131 | \n", "96.375062 | \n", "111.058422 | \n", "
5 | \n", "13 | \n", "8.390420 | \n", "7 | \n", "6 | \n", "0.892155 | \n", "111.077491 | \n", "127.247439 | \n", "
6 | \n", "56 | \n", "5.709700 | \n", "7 | \n", "7 | \n", "0.888726 | \n", "127.265297 | \n", "140.431804 | \n", "
7 | \n", "47 | \n", "1.073937 | \n", "17 | \n", "8 | \n", "0.891607 | \n", "140.449749 | \n", "168.683609 | \n", "
8 | \n", "116 | \n", "0.356077 | \n", "8 | \n", "9 | \n", "0.757940 | \n", "168.787206 | \n", "183.260858 | \n", "
9 | \n", "39 | \n", "3.127017 | \n", "7 | \n", "10 | \n", "0.891643 | \n", "183.278615 | \n", "196.925297 | \n", "
10 | \n", "512 | \n", "0.348708 | \n", "6 | \n", "11 | \n", "0.478179 | \n", "197.162700 | \n", "207.922961 | \n", "
11 | \n", "512 | \n", "3.726319 | \n", "17 | \n", "12 | \n", "0.879036 | \n", "208.166723 | \n", "232.997138 | \n", "
12 | \n", "512 | \n", "8.765370 | \n", "6 | \n", "13 | \n", "0.864667 | \n", "233.320723 | \n", "244.213859 | \n", "
13 | \n", "512 | \n", "1.910815 | \n", "5 | \n", "14 | \n", "0.737810 | \n", "244.466962 | \n", "254.037378 | \n", "
14 | \n", "512 | \n", "0.313972 | \n", "6 | \n", "15 | \n", "0.471179 | \n", "254.293386 | \n", "265.038085 | \n", "
15 | \n", "512 | \n", "0.390254 | \n", "17 | \n", "16 | \n", "0.659024 | \n", "265.296111 | \n", "291.446844 | \n", "
16 | \n", "512 | \n", "0.314117 | \n", "6 | \n", "17 | \n", "0.463964 | \n", "291.793700 | \n", "302.993824 | \n", "
17 | \n", "512 | \n", "0.406197 | \n", "6 | \n", "18 | \n", "0.495940 | \n", "303.264247 | \n", "314.411603 | \n", "
18 | \n", "512 | \n", "0.320737 | \n", "6 | \n", "19 | \n", "0.476452 | \n", "314.679307 | \n", "325.618635 | \n", "
19 | \n", "512 | \n", "0.310132 | \n", "6 | \n", "20 | \n", "0.455762 | \n", "325.959387 | \n", "336.946107 | \n", "
20 | \n", "512 | \n", "0.277178 | \n", "6 | \n", "21 | \n", "0.438250 | \n", "337.212328 | \n", "348.157166 | \n", "
21 | \n", "512 | \n", "0.276917 | \n", "6 | \n", "22 | \n", "0.449714 | \n", "348.422956 | \n", "359.220447 | \n", "
22 | \n", "512 | \n", "0.273978 | \n", "6 | \n", "23 | \n", "0.453071 | \n", "359.485135 | \n", "370.697273 | \n", "
23 | \n", "512 | \n", "0.298666 | \n", "6 | \n", "24 | \n", "0.450143 | \n", "371.048602 | \n", "382.304554 | \n", "
24 | \n", "512 | \n", "0.280337 | \n", "6 | \n", "25 | \n", "0.463952 | \n", "382.571858 | \n", "393.747633 | \n", "
25 | \n", "512 | \n", "0.292463 | \n", "6 | \n", "26 | \n", "0.469250 | \n", "394.017767 | \n", "405.152517 | \n", "
26 | \n", "512 | \n", "0.273828 | \n", "6 | \n", "27 | \n", "0.459488 | \n", "405.424290 | \n", "416.641872 | \n", "
27 | \n", "512 | \n", "0.306901 | \n", "6 | \n", "28 | \n", "0.461964 | \n", "416.986107 | \n", "428.063347 | \n", "
28 | \n", "512 | \n", "0.281089 | \n", "6 | \n", "29 | \n", "0.458131 | \n", "428.335923 | \n", "439.447637 | \n", "
29 | \n", "512 | \n", "0.278833 | \n", "6 | \n", "30 | \n", "0.436702 | \n", "439.720846 | \n", "450.642993 | \n", "