Source code for deephyper.nas._nx_search_space
import abc
import traceback
from collections.abc import Iterable
import networkx as nx
from deephyper.core.exceptions.nas.space import (
NodeAlreadyAdded,
StructureHasACycle,
WrongSequenceToSetOperations,
)
from deephyper.nas.node import MimeNode, Node, VariableNode
[docs]class NxSearchSpace(abc.ABC):
"""A NxSearchSpace is an search_space based on a networkx graph."""
def __init__(self, seed=None, **kwargs):
self.graph = nx.DiGraph()
self.seed = seed
self.output_node = None
def plot(self, path):
with open(path, "w") as f:
try:
nx.nx_agraph.write_dot(self.graph, f)
except Exception:
print("Error: can't create graphviz file...")
traceback.print_exc()
def __len__(self):
"""Number of VariableNodes in the current search_space.
Returns:
int: number of variable nodes in the current search_space.
"""
return len(self.nodes)
@property
def nodes(self):
"""Nodes of the current KSearchSpace.
Returns:
iterator: nodes of the current KSearchSpace.
"""
return list(self.graph.nodes)
[docs] def add_node(self, node):
"""Add a new node to the search_space.
Args:
node (Node): node to add to the search_space.
Raises:
TypeError: if 'node' is not an instance of Node.
NodeAlreadyAdded: if 'node' has already been added to the search_space.
"""
if not isinstance(node, Node):
raise TypeError("'node' argument should be an instance of Node!")
if node in self.nodes:
raise NodeAlreadyAdded(node)
self.graph.add_node(node)
[docs] def connect(self, node1, node2):
"""Create a new connection in the KSearchSpace graph.
The edge created corresponds to : node1 -> node2.
Args:
node1 (Node)
node2 (Node)
Raise:
StructureHasACycle: if the new edge is creating a cycle.
"""
assert isinstance(node1, Node)
assert isinstance(node2, Node)
self.graph.add_edge(node1, node2)
if not (nx.is_directed_acyclic_graph(self.graph)):
raise StructureHasACycle(
f"the connection between {node1} -> {node2} is creating a cycle in the search_space's graph."
)
@property
def size(self):
"""Size of the search space define by the search_space"""
s = 0
for n in filter(lambda n: isinstance(n, VariableNode), self.nodes):
if n.num_ops != 0:
if s == 0:
s = n.num_ops
else:
s *= n.num_ops
return s
@property
def max_num_ops(self):
"""Returns the maximum number of operations accross all VariableNodes of the struct.
Returns:
int: maximum number of Operations for a VariableNode in the current Structure.
"""
return max(map(lambda n: n.num_ops, self.variable_nodes))
@property
def num_nodes(self):
"""Returns the number of VariableNodes in the current Structure.
Returns:
int: number of VariableNodes in the current Structure.
"""
return len(list(self.variable_nodes))
@property
def variable_nodes(self):
"""Iterator of VariableNodes of the search_space.
Returns:
(Iterator(VariableNode)): generator of VariablesNodes of the search_space.
"""
return filter(lambda n: isinstance(n, VariableNode), self.nodes)
@property
def mime_nodes(self):
"""Iterator of MimeNodes of the search_space.
Returns:
(Iterator(MimeNode)): iterator of MimeNodes of the search_space.
"""
return filter(lambda n: isinstance(n, MimeNode), self.nodes)
[docs] def denormalize(self, indexes):
"""Denormalize a sequence of normalized indexes to get a sequence of absolute indexes. Useful when you want to compare the number of different search_spaces.
Args:
indexes (Iterable): a sequence of normalized indexes.
Returns:
list: A list of absolute indexes corresponding to operations choosen with relative indexes of `indexes`.
"""
assert isinstance(
indexes, Iterable
), 'Wrong argument, "indexes" should be of Iterable.'
if len(indexes) != self.num_nodes:
raise WrongSequenceToSetOperations(indexes, list(self.variable_nodes))
return [
vnode.denormalize(op_i) for op_i, vnode in zip(indexes, self.variable_nodes)
]
def set_output_node(self):
"""Set the output node of the search_space.
:meta private:
"""
if self.output_node is None:
nodes = list(self.graph.nodes())
self.output_node = []
for n in nodes:
if len(list(self.graph.successors(n))) == 0:
self.output_node.append(n)
if len(self.output_node) == 1:
self.output_node = self.output_node[0]
def create_tensor_aux(self, g, n, train=None):
"""Recursive function to create the tensors from the graph.
:meta private:
Args:
g (nx.DiGraph): a graph
n (nx.Node): a node
train (bool): True if the network is built for training, False if the network is built for validation/testing (for example False will deactivate Dropout).
Return:
the tensor represented by n.
"""
try:
if n._tensor is not None:
output_tensor = n._tensor
else:
pred = list(g.predecessors(n))
if len(pred) == 0:
output_tensor = n.create_tensor(train=train, seed=self.seed)
else:
tensor_list = list()
for s_i in pred:
tmp = self.create_tensor_aux(g, s_i, train=train)
if type(tmp) is list:
tensor_list.extend(tmp)
else:
tensor_list.append(tmp)
output_tensor = n.create_tensor(
tensor_list, train=train, seed=self.seed
)
return output_tensor
except TypeError:
raise RuntimeError(f"Failed to build tensors from :{n}")
[docs] @abc.abstractmethod
def choices(self):
"""Gives the possible choices for each decision variable of the search space.
Returns:
list: A list of tuple where each element corresponds to a discrete variable represented by ``(low, high)``.
"""
[docs] @abc.abstractmethod
def sample(self, choice=None):
"""Sample a ``tf.keras.Model`` from the search space.
Args:
choice (list, optional): A list of decision for the operations of this search space. Defaults to None, will generate a random sample.
Returns:
tf.keras.Model: A Tensorflow Keras model.
"""
[docs] @abc.abstractmethod
def build(self):
"""Build the current graph search space."""