{ "cells": [ { "cell_type": "markdown", "id": "f7ef2e27", "metadata": { "id": "f7ef2e27" }, "source": [ "# Multi-Fidelity Hyperparameter Optimization with Keras\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deephyper/tutorials/blob/main/tutorials/colab/HPS_basic_classification_with_tabular_data/notebook.ipynb)\n", "\n", "In this tutorial we present how to use hyperparameter optimization on a basic example from the Keras documentation. We follow the previous tutorial based on the same example and add multi-fidelity to it. The purpose of multi-fidelity is to dynamically manage the budget allocated (also called fidelity) to evaluate an hyperparameter configuration. For example, when training a deep neural network the number of epochs can be continued or stopped based on currently observed performance and some policy.\n", "\n", "In DeepHyper, the multi-fidelity agent is designed separately from the hyperparameter search agent. Of course, both can communicate but from an API perspective they are different objects. The multi-fidelity agents are called `Stopper` in DeepHyper and their documentation can be found at [deephyper.stopper](https://deephyper.readthedocs.io/en/latest/_autosummary/deephyper.stopper.html). \n", "\n", "In this notebook, we will demonstrate how to use multi-fidelity inside sequential Bayesian optimization. When moving to a distributed setting, it is important to use a shared database accessible by all workers otherwise the multi-fidelity scheme may not work properly. An example, of database instanciation for parallel computing is explained in: [Introduction to Distributed Bayesian Optimization (DBO) with MPI (Communication) and Redis (Storage)](https://deephyper.readthedocs.io/en/latest/tutorials/tutorials/scripts/02_Intro_to_DBO/README.html).\n", "\n", "**Reference**:\n", " This tutorial is based on materials from the Keras Documentation: [Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/)\n", "\n", "Let us start with installing DeepHyper!\n", " \n", "
`max_evals (int)`
: Defines the maximum number of evaluations that we want to perform. Default to -1
for an infinite number.`timeout (int)`
: Defines a time budget (in seconds) before stopping the search. Default to None
for an infinite time budget.\n", " | p:activation | \n", "p:batch_size | \n", "p:dropout_rate | \n", "p:learning_rate | \n", "p:units | \n", "objective | \n", "job_id | \n", "m:timestamp_submit | \n", "m:timestamp_gather | \n", "m:timestamp_start | \n", "m:timestamp_end | \n", "m:loss | \n", "m:val_loss | \n", "m:accuracy | \n", "m:val_accuracy | \n", "m:budget | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "relu | \n", "32 | \n", "0.500000 | \n", "0.001000 | \n", "32 | \n", "0.803279 | \n", "0 | \n", "3.554752 | \n", "7.302983 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6227220296859741, 0.5570479035377502, 0.538... | \n", "[0.5288254618644714, 0.4989699125289917, 0.475... | \n", "[0.6487603187561035, 0.7438016533851624, 0.756... | \n", "[0.7868852615356445, 0.8032786846160889, 0.819... | \n", "35 | \n", "
1 | \n", "linear | \n", "9 | \n", "0.147090 | \n", "0.001889 | \n", "49 | \n", "0.803279 | \n", "1 | \n", "7.347784 | \n", "10.729823 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.49330684542655945, 0.3671776056289673, 0.30... | \n", "[0.3687109649181366, 0.3877825140953064, 0.386... | \n", "[0.7479338645935059, 0.8264462947845459, 0.851... | \n", "[0.8032786846160889, 0.868852436542511, 0.8524... | \n", "27 | \n", "
2 | \n", "softsign | \n", "12 | \n", "0.499104 | \n", "0.000029 | \n", "100 | \n", "0.786885 | \n", "2 | \n", "10.749194 | \n", "14.626382 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.710996150970459, 0.7085673213005066, 0.7007... | \n", "[0.6798619031906128, 0.6712697148323059, 0.663... | \n", "[0.5289255976676941, 0.5495867729187012, 0.561... | \n", "[0.5409836173057556, 0.5409836173057556, 0.557... | \n", "59 | \n", "
3 | \n", "softsign | \n", "27 | \n", "0.582597 | \n", "0.000215 | \n", "97 | \n", "0.868852 | \n", "3 | \n", "14.645453 | \n", "18.470026 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6582441329956055, 0.6144227385520935, 0.614... | \n", "[0.6063965559005737, 0.5792577862739563, 0.555... | \n", "[0.6157024502754211, 0.6528925895690918, 0.648... | \n", "[0.7213114500045776, 0.7540983557701111, 0.737... | \n", "64 | \n", "
4 | \n", "selu | \n", "28 | \n", "0.469018 | \n", "0.000618 | \n", "110 | \n", "0.786885 | \n", "4 | \n", "18.490789 | \n", "21.121490 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6250562071800232, 0.5189632177352905, 0.483... | \n", "[0.49002861976623535, 0.423252135515213, 0.386... | \n", "[0.6818181872367859, 0.7231404781341553, 0.731... | \n", "[0.7704917788505554, 0.8196721076965332, 0.786... | \n", "4 | \n", "
5 | \n", "linear | \n", "15 | \n", "0.289261 | \n", "0.000777 | \n", "68 | \n", "0.819672 | \n", "5 | \n", "21.140275 | \n", "23.720370 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6295762062072754, 0.4781874716281891, 0.421... | \n", "[0.5124895572662354, 0.433324933052063, 0.4010... | \n", "[0.6735537052154541, 0.7603305578231812, 0.801... | \n", "[0.8032786846160889, 0.7868852615356445, 0.803... | \n", "4 | \n", "
6 | \n", "elu | \n", "133 | \n", "0.007165 | \n", "0.000024 | \n", "88 | \n", "0.688525 | \n", "6 | \n", "23.738940 | \n", "26.728599 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6447474956512451, 0.6408497095108032, 0.641... | \n", "[0.597851037979126, 0.5967557430267334, 0.5956... | \n", "[0.6280992031097412, 0.6322314143180847, 0.640... | \n", "[0.688524603843689, 0.688524603843689, 0.68852... | \n", "4 | \n", "
7 | \n", "elu | \n", "184 | \n", "0.365660 | \n", "0.000276 | \n", "126 | \n", "0.704918 | \n", "7 | \n", "26.748085 | \n", "29.293730 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6677700281143188, 0.6242449879646301, 0.632... | \n", "[0.6279993057250977, 0.612653374671936, 0.5987... | \n", "[0.6239669322967529, 0.6652892827987671, 0.673... | \n", "[0.7049180269241333, 0.7213114500045776, 0.721... | \n", "4 | \n", "
8 | \n", "linear | \n", "27 | \n", "0.410357 | \n", "0.003974 | \n", "10 | \n", "0.819672 | \n", "8 | \n", "29.312796 | \n", "32.318316 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6348783373832703, 0.4910481572151184, 0.433... | \n", "[0.5030055046081543, 0.4379017949104309, 0.406... | \n", "[0.7066115736961365, 0.7438016533851624, 0.801... | \n", "[0.7540983557701111, 0.8032786846160889, 0.803... | \n", "31 | \n", "
9 | \n", "gelu | \n", "37 | \n", "0.586930 | \n", "0.001828 | \n", "94 | \n", "0.786885 | \n", "9 | \n", "32.583841 | \n", "35.622722 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.646159291267395, 0.5308106541633606, 0.4458... | \n", "[0.5199227333068848, 0.44885024428367615, 0.40... | \n", "[0.6363636255264282, 0.7479338645935059, 0.809... | \n", "[0.8196721076965332, 0.7868852615356445, 0.786... | \n", "35 | \n", "
10 | \n", "softsign | \n", "15 | \n", "0.580355 | \n", "0.000195 | \n", "88 | \n", "0.426230 | \n", "10 | \n", "35.721770 | \n", "37.368073 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7734954953193665] | \n", "[0.7477120161056519] | \n", "[0.41735535860061646] | \n", "[0.4262295067310333] | \n", "1 | \n", "
11 | \n", "swish | \n", "27 | \n", "0.599008 | \n", "0.000148 | \n", "80 | \n", "0.245902 | \n", "11 | \n", "37.469103 | \n", "38.974784 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.828150749206543] | \n", "[0.810135006904602] | \n", "[0.3636363744735718] | \n", "[0.24590164422988892] | \n", "1 | \n", "
12 | \n", "softplus | \n", "11 | \n", "0.590109 | \n", "0.000299 | \n", "97 | \n", "0.770492 | \n", "12 | \n", "39.081071 | \n", "41.960308 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.8678354024887085, 0.7065883874893188, 0.700... | \n", "[0.6125381588935852, 0.5351727604866028, 0.495... | \n", "[0.5206611752510071, 0.5909090638160706, 0.669... | \n", "[0.7049180269241333, 0.7704917788505554, 0.770... | \n", "4 | \n", "
13 | \n", "softsign | \n", "24 | \n", "0.568922 | \n", "0.000112 | \n", "122 | \n", "0.803279 | \n", "13 | \n", "42.066191 | \n", "45.404559 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.724600076675415, 0.7176037430763245, 0.6875... | \n", "[0.6987729072570801, 0.674211859703064, 0.6525... | \n", "[0.4958677589893341, 0.5454545617103577, 0.607... | \n", "[0.5573770403862, 0.6065573692321777, 0.655737... | \n", "40 | \n", "
14 | \n", "softsign | \n", "29 | \n", "0.578839 | \n", "0.000159 | \n", "105 | \n", "0.803279 | \n", "14 | \n", "45.507679 | \n", "48.744906 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7335056662559509, 0.7066584229469299, 0.688... | \n", "[0.7332508563995361, 0.7050181031227112, 0.679... | \n", "[0.46694216132164, 0.5289255976676941, 0.52066... | \n", "[0.44262295961380005, 0.49180328845977783, 0.5... | \n", "51 | \n", "
15 | \n", "linear | \n", "16 | \n", "0.287004 | \n", "0.000726 | \n", "69 | \n", "0.786885 | \n", "15 | \n", "48.853081 | \n", "52.410750 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6595832109451294, 0.5204204320907593, 0.485... | \n", "[0.48184749484062195, 0.4187839925289154, 0.38... | \n", "[0.64462810754776, 0.7148760557174683, 0.74793... | \n", "[0.8032786846160889, 0.8032786846160889, 0.819... | \n", "37 | \n", "
16 | \n", "tanh | \n", "27 | \n", "0.433537 | \n", "0.005760 | \n", "24 | \n", "0.786885 | \n", "16 | \n", "52.515180 | \n", "55.492876 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.4885702431201935, 0.4026663899421692, 0.338... | \n", "[0.3990103304386139, 0.3842548429965973, 0.395... | \n", "[0.7644628286361694, 0.8016529083251953, 0.847... | \n", "[0.8196721076965332, 0.8032786846160889, 0.852... | \n", "29 | \n", "
17 | \n", "softsign | \n", "32 | \n", "0.597483 | \n", "0.000243 | \n", "8 | \n", "0.803279 | \n", "17 | \n", "55.600845 | \n", "59.004269 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.8090429306030273, 0.7903764247894287, 0.808... | \n", "[0.7881447076797485, 0.7788082361221313, 0.769... | \n", "[0.4586776793003082, 0.44628098607063293, 0.45... | \n", "[0.31147539615631104, 0.31147539615631104, 0.3... | \n", "63 | \n", "
18 | \n", "swish | \n", "24 | \n", "0.584958 | \n", "0.000213 | \n", "111 | \n", "0.852459 | \n", "18 | \n", "59.111170 | \n", "62.916621 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6998273134231567, 0.6901867985725403, 0.633... | \n", "[0.6759735941886902, 0.6463225483894348, 0.621... | \n", "[0.5289255976676941, 0.56611567735672, 0.65289... | \n", "[0.6065573692321777, 0.6721311211585999, 0.672... | \n", "36 | \n", "
19 | \n", "swish | \n", "25 | \n", "0.598112 | \n", "0.000129 | \n", "124 | \n", "0.409836 | \n", "19 | \n", "63.025823 | \n", "65.628007 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7799810171127319, 0.7728349566459656, 0.743... | \n", "[0.7869507670402527, 0.7617835998535156, 0.738... | \n", "[0.39256197214126587, 0.42148759961128235, 0.4... | \n", "[0.26229506731033325, 0.3442623019218445, 0.37... | \n", "4 | \n", "
20 | \n", "swish | \n", "24 | \n", "0.503962 | \n", "0.000188 | \n", "108 | \n", "0.852459 | \n", "20 | \n", "65.737362 | \n", "69.012517 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7826627492904663, 0.7401067018508911, 0.716... | \n", "[0.7268974184989929, 0.6923832893371582, 0.661... | \n", "[0.40495866537094116, 0.4710743725299835, 0.51... | \n", "[0.4098360538482666, 0.5409836173057556, 0.672... | \n", "52 | \n", "
21 | \n", "swish | \n", "23 | \n", "0.508027 | \n", "0.000195 | \n", "121 | \n", "0.786885 | \n", "21 | \n", "69.122440 | \n", "72.220513 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6466206312179565, 0.6104539036750793, 0.596... | \n", "[0.6054739952087402, 0.5772570967674255, 0.553... | \n", "[0.6363636255264282, 0.7355371713638306, 0.739... | \n", "[0.688524603843689, 0.7704917788505554, 0.7704... | \n", "16 | \n", "
22 | \n", "swish | \n", "21 | \n", "0.599638 | \n", "0.000220 | \n", "115 | \n", "0.803279 | \n", "22 | \n", "72.331612 | \n", "75.464572 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7254757285118103, 0.6856257915496826, 0.656... | \n", "[0.7080947756767273, 0.6708484292030334, 0.636... | \n", "[0.4834710657596588, 0.5289255976676941, 0.628... | \n", "[0.5081967115402222, 0.6393442749977112, 0.737... | \n", "36 | \n", "
23 | \n", "swish | \n", "24 | \n", "0.559085 | \n", "0.000214 | \n", "88 | \n", "0.721311 | \n", "23 | \n", "75.574948 | \n", "78.101354 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7135123610496521, 0.6844885349273682, 0.662... | \n", "[0.6889106631278992, 0.6591018438339233, 0.631... | \n", "[0.4834710657596588, 0.5619834661483765, 0.595... | \n", "[0.5081967115402222, 0.6229507923126221, 0.704... | \n", "4 | \n", "
24 | \n", "softplus | \n", "26 | \n", "0.576747 | \n", "0.000199 | \n", "102 | \n", "0.803279 | \n", "24 | \n", "78.210997 | \n", "81.422945 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.8818848729133606, 0.8496137857437134, 0.756... | \n", "[0.7038170695304871, 0.6505393385887146, 0.610... | \n", "[0.4793388545513153, 0.4752066135406494, 0.537... | \n", "[0.4098360538482666, 0.8360655903816223, 0.819... | \n", "27 | \n", "
25 | \n", "swish | \n", "24 | \n", "0.567181 | \n", "0.000209 | \n", "107 | \n", "0.754098 | \n", "25 | \n", "81.533879 | \n", "84.056534 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7189084887504578, 0.7079837918281555, 0.677... | \n", "[0.6937562227249146, 0.6694052219390869, 0.646... | \n", "[0.5371900796890259, 0.557851254940033, 0.5785... | \n", "[0.5901639461517334, 0.6557376980781555, 0.721... | \n", "4 | \n", "
26 | \n", "tanh | \n", "24 | \n", "0.503059 | \n", "0.000118 | \n", "100 | \n", "0.754098 | \n", "26 | \n", "84.166216 | \n", "86.679132 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6868227124214172, 0.6927969455718994, 0.650... | \n", "[0.6453925967216492, 0.6199279427528381, 0.596... | \n", "[0.5826446413993835, 0.5785123705863953, 0.648... | \n", "[0.6721311211585999, 0.688524603843689, 0.7377... | \n", "4 | \n", "
27 | \n", "softsign | \n", "27 | \n", "0.552740 | \n", "0.000288 | \n", "101 | \n", "0.836066 | \n", "27 | \n", "86.789922 | \n", "90.000142 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6240310072898865, 0.6154924035072327, 0.591... | \n", "[0.5968618392944336, 0.5632079243659973, 0.535... | \n", "[0.6652892827987671, 0.6363636255264282, 0.739... | \n", "[0.8032786846160889, 0.8360655903816223, 0.819... | \n", "27 | \n", "
28 | \n", "softsign | \n", "26 | \n", "0.585004 | \n", "0.000215 | \n", "104 | \n", "0.836066 | \n", "28 | \n", "90.113099 | \n", "93.519698 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.6899656057357788, 0.6779984831809998, 0.646... | \n", "[0.677489161491394, 0.6455464363098145, 0.6153... | \n", "[0.5619834661483765, 0.5743801593780518, 0.644... | \n", "[0.5573770403862, 0.6065573692321777, 0.639344... | \n", "64 | \n", "
29 | \n", "softsign | \n", "19 | \n", "0.582967 | \n", "0.000434 | \n", "96 | \n", "0.803279 | \n", "29 | \n", "93.630995 | \n", "96.207705 | \n", "1.692627e+09 | \n", "1.692627e+09 | \n", "[0.7093287110328674, 0.5999704003334045, 0.584... | \n", "[0.6647931933403015, 0.5841431021690369, 0.525... | \n", "[0.5289255976676941, 0.6983470916748047, 0.706... | \n", "[0.5573770403862, 0.7540983557701111, 0.803278... | \n", "4 | \n", "