{ "cells": [ { "cell_type": "markdown", "id": "be0a4427", "metadata": {}, "source": [ "# Neural Architecture Search for Graph Neural Networks\n", "\n", "In this tutorial, we will design neural architecture search spaces for graph neural networks, specifically for message passing neural networks. \n", "\n", "For related papers, please check https://ieeexplore.ieee.org/abstract/document/9378060.\n", "\n", "The search space have multiple input tensors, including node features, edge features, edge pairs (source and target node indices) and node masks (number of nodes before zero-padding). The output of this search space is customizable. In this tutorial, we use an example from the QM7 dataset, which has a scalar output. The QM7 dataset is from the Deepchem library.\n", "\n", "There are two main variable nodes, namely `mpnn_cell` and `gather_cell`:\n", "* `mpnn_cell` is a message passing layer with a varierty of activation, aggregation, update functions, etc. \n", "* `gather_cell` is global graph gather layer with a variety of global pooling functions.\n", "\n", "We also adopted skip-connection in the search space; that is, the output of the *n-1* layer is the direct input to the *n+1* layer. Users can modify the number of `mpnn_cell` and maximum skip-connection distance to control the flexibility of skip-connection.\n", "\n", "We used random search and aging evolution (regularized evolution) to conduct architecture search. In the paper, we found aging evolution has good scalability. We also showed that the best architecture from the search outperforms the *moleculenet* benchmarks. Users are more than welcome to furthur modify the search space to boost the performance.\n", "\n", "## Install Deepchem and RDKit\n", "\n", "We need [DeepChem](https://deepchem.io) for the benchmark datasets. [RDKit](https://www.rdkit.org) is also required to convert molecule smile string to a graph representation." ] }, { "cell_type": "code", "execution_count": 1, "id": "d64db5f6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: deepchem in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (2.5.0)\n", "Requirement already satisfied: joblib in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deepchem) (1.1.0)\n", "Requirement already satisfied: scipy in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deepchem) (1.7.2)\n", "Requirement already satisfied: pandas in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deepchem) (1.3.4)\n", "Requirement already satisfied: numpy in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deepchem) (1.21.4)\n", "Requirement already satisfied: scikit-learn in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from deepchem) (1.0.1)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from pandas->deepchem) (2.8.2)\n", "Requirement already satisfied: pytz>=2017.3 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from pandas->deepchem) (2021.3)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from scikit-learn->deepchem) (3.0.0)\n", "Requirement already satisfied: six>=1.5 in /Users/romainegele/miniforge3/envs/dh-arm/lib/python3.9/site-packages (from python-dateutil>=2.7.3->pandas->deepchem) (1.15.0)\n", "Collecting package metadata (current_repodata.json): done\n", "Solving environment: done\n", "\n", "\n", "==> WARNING: A newer version of conda exists. <==\n", " current version: 4.10.3\n", " latest version: 4.11.0\n", "\n", "Please update conda by running\n", "\n", " $ conda update -n base conda\n", "\n", "\n", "\n", "# All requested packages already installed.\n", "\n" ] } ], "source": [ "!pip install deepchem\n", "!conda install -c rdkit rdkit -y" ] }, { "cell_type": "markdown", "id": "f26b8bd1", "metadata": {}, "source": [ "## Imports and GPU Detection \n", "\n", "
\n", " | arch_seq | \n", "id | \n", "objective | \n", "elapsed_sec | \n", "duration | \n", "
---|---|---|---|---|---|
0 | \n", "[13876, 0, 2160, 0, 0, 4762, 0, 1, 1, 1] | \n", "1 | \n", "-1.017522 | \n", "44.933679 | \n", "43.074681 | \n", "
1 | \n", "[5937, 0, 16471, 1, 0, 9397, 1, 1, 0, 2] | \n", "2 | \n", "-1.047807 | \n", "281.613137 | \n", "236.678758 | \n", "
2 | \n", "[6123, 1, 1777, 1, 0, 9158, 0, 1, 1, 7] | \n", "3 | \n", "-1.000114 | \n", "339.803210 | \n", "58.188634 | \n", "
3 | \n", "[5796, 1, 18368, 1, 1, 15679, 0, 1, 1, 6] | \n", "4 | \n", "-0.928806 | \n", "579.374068 | \n", "239.570209 | \n", "
4 | \n", "[7159, 1, 1494, 0, 0, 9234, 0, 1, 1, 4] | \n", "5 | \n", "-1.041694 | \n", "611.826204 | \n", "32.451197 | \n", "
5 | \n", "[4662, 0, 7258, 0, 0, 2419, 1, 1, 1, 5] | \n", "6 | \n", "-1.070579 | \n", "658.328865 | \n", "46.501993 | \n", "
6 | \n", "[5898, 0, 12026, 1, 0, 4872, 1, 0, 1, 7] | \n", "7 | \n", "-1.272499 | \n", "710.819948 | \n", "52.490447 | \n", "
7 | \n", "[6942, 1, 1192, 0, 0, 5577, 0, 1, 1, 5] | \n", "8 | \n", "-1.075484 | \n", "725.114781 | \n", "14.294140 | \n", "
8 | \n", "[10451, 1, 12804, 1, 0, 5329, 1, 0, 1, 10] | \n", "9 | \n", "-0.991375 | \n", "796.194318 | \n", "71.078856 | \n", "
9 | \n", "[6617, 1, 2963, 0, 0, 6786, 1, 1, 0, 3] | \n", "10 | \n", "-0.999313 | \n", "824.445577 | \n", "28.250628 | \n", "