Source code for deephyper.evaluator._mpi_comm
import asyncio
import functools
import logging
import traceback
from typing import Callable, Hashable
from deephyper.evaluator import Evaluator, Job, JobStatus
from deephyper.evaluator.mpi import MPI, MPICommExecutor
from deephyper.evaluator.storage import Storage
from deephyper.evaluator.storage._mpi_win_storage import MPIWinStorage
logger = logging.getLogger(__name__)
def catch_exception(run_func):
"""A wrapper function to execute the ``run_func`` passed by the user.
This is used to catch remote exception.
"""
try:
code = 0
result = run_func()
except Exception:
code = 1
result = traceback.format_exc()
print(f"{code=}, {result=}")
return code, result
[docs]
class MPICommEvaluator(Evaluator):
"""This evaluator uses the ``mpi4py`` library as backend.
This evaluator consider an already existing MPI-context (with running
processes), therefore it has less overhead than ``MPIPoolEvaluator``
which spawn processes dynamically.
Args:
run_function (callable):
Functions to be executed by the ``Evaluator``.
num_workers (int, optional):
Number of parallel Ray-workers used to compute the
``run_function``. Defaults to ``None`` which consider 1 rank as a
worker (minus the master rank).
callbacks (list, optional):
A list of callbacks to trigger custom actions at the creation or
completion of jobs. Defaults to ``None``.
run_function_kwargs (dict, optional):
Keyword-arguments to pass to the ``run_function``. Defaults to ``None``.
storage (Storage, optional):
Storage used by the evaluator. Defaults to ``SharedMemoryStorage``.
search_id (Hashable, optional):
The id of the search to use in the corresponding storage. If
``None`` it will create a new search identifier when initializing
the search.
comm (optional):
A MPI communicator, if ``None`` it will use ``MPI.COMM_WORLD``. Defaults to ``None``.
rank (int, optional):
The rank of the master process. Defaults to ``0``.
"""
def __init__(
self,
run_function: Callable,
num_workers: int = None,
callbacks=None,
run_function_kwargs=None,
storage: Storage = None,
search_id: Hashable = None,
comm=None,
root=0,
):
if not MPI.Is_initialized():
MPI.Init_thread()
self.comm = comm if comm else MPI.COMM_WORLD
self.root = root
if storage is None:
logging.info(
f"No storage was given to create {type(self).__name__} so using MPIWinStorage"
)
storage = MPIWinStorage(self.comm, root=self.root)
if isinstance(storage, MPIWinStorage):
if search_id is None:
logging.info(
"No search_id was given and an MPIWinStorage is used. Creating new search."
)
if self.comm.Get_rank() == self.root:
search_id = storage.create_new_search()
self.comm.Barrier()
super().__init__(
run_function=run_function,
num_workers=num_workers,
callbacks=callbacks,
run_function_kwargs=run_function_kwargs,
storage=storage,
search_id=search_id,
)
self.num_workers = self.comm.Get_size() - 1 # 1 rank is the master
self.sem = asyncio.Semaphore(self.num_workers)
logging.info(f"Creating MPICommExecutor with {self.num_workers} max_workers...")
if self.num_workers == 0 and self.comm.Get_size() <= 1:
raise RuntimeError(
"No workers was detected because there was only 1 MPI rank. The number of MPI "
"ranks must be greater than 1."
)
self._comm_executor = None
self._pool_executor = None
logging.info("Creation of MPICommExecutor done")
@property
def is_master(self):
return self.comm.Get_rank() == self.root
def __enter__(self):
self._comm_executor = MPICommExecutor(comm=self.comm, root=self.root)
self._pool_executor = self._comm_executor.__enter__()
return self
def __exit__(self, type, value, traceback):
if self.is_master:
if self.loop is not None and not self.loop.is_closed():
self.close()
self._pool_executor.__exit__(type, value, traceback)
self._pool_executor = None
[docs]
async def execute(self, job: Job) -> Job:
async with self.sem:
job.status = JobStatus.RUNNING
running_job = job.create_running_job(self._stopper)
run_function = functools.partial(
job.run_function, running_job, **self.run_function_kwargs
)
run_function_future = self.loop.run_in_executor(self._pool_executor, run_function)
if self.timeout is not None:
try:
output = await asyncio.wait_for(
asyncio.shield(run_function_future), timeout=self.time_left
)
except asyncio.TimeoutError:
job.status = JobStatus.CANCELLING
output = await run_function_future
job.status = JobStatus.CANCELLED
else:
output = await run_function_future
return self._update_job_when_done(job, output)