Source code for deephyper.evaluator.storage._mpi_win_mutable_mapping

import copy
import logging
import pickle
from collections.abc import MutableMapping
from typing import Hashable

import numpy as np

from deephyper.evaluator.mpi import MPI

# A good reference about one-sided communication with MPI
# https://enccs.github.io/intermediate-mpi/one-sided-concepts/


[docs] class MPIWinMutableMapping(MutableMapping): """Dict like object shared between MPI processes using one-sided communication. Args: default_value (dict): The default value of the mutable mapping at initialization. Defaults to ``None`` for empty dict. comm (MPI.Comm): An MPI communicator. size (int): The total size of the shared memory in bytes. Defaults to ``104857600`` for 100MB. root (int): The MPI rank where the shared memory window is hosted. """ HEADER_SIZE = 8 # Reserve 8 bytes for size header # Use to share state when pickling arguments of function COUNTER = 0 # Counter of created instances CACHE = {} def __init__( self, default_value: dict = None, comm: MPI.Comm = MPI.COMM_WORLD, size: int = 104857600, root: int = 0, ): logging.info("Creating MPIWinMutableMapping ...") self.comm = comm self.root = root self.locked = False self._session_is_started = False self._session_is_read_only = False # Allocate memory (works on multiple nodes) logging.info("Allocating MPI.Win ...") # If all processes are in the local shared comm then we # can use shared memory local_comm = self.comm.Split_type(MPI.COMM_TYPE_SHARED) if local_comm.Get_size() == self.comm.Get_size(): logging.info("Using MPI.Win.Allocate_shared") self.win = MPI.Win.Allocate_shared(size, 1, comm=comm) buf, itemsize = self.win.Shared_query(self.root) self.shared_memory = np.ndarray(buffer=buf, dtype=np.byte, shape=(size,)) else: logging.info("Using MPI.Win.Allocate") self.win = MPI.Win.Allocate(size, 1, comm=comm) self.shared_memory = np.empty((size,), dtype=np.byte) logging.info("MPI.Win allocated") if default_value is None: self.local_dict = {} else: self.local_dict = copy.deepcopy(default_value) self._cache_id = MPIWinMutableMapping.COUNTER MPIWinMutableMapping.COUNTER += 1 MPIWinMutableMapping.CACHE[self._cache_id] = self if self.comm.Get_rank() == self.root: self.lock() self._write_dict() self.unlock() self.comm.Barrier() # Synchronize processes logging.info("MPIWinMutableMapping created") def _lazy_read_dict(self): """Performs the read if not in a session.""" if not self._session_is_started: self._read_dict() def _read_dict(self): """Read the dictionnary state from the shared memory.""" # Deserialize the dictionary from shared memory try: self.win.Get(self.shared_memory, target_rank=self.root) self.win.Flush(self.root) size = int.from_bytes(self.shared_memory[: self.HEADER_SIZE], byteorder="big") if size > 0: raw_data = self.shared_memory[self.HEADER_SIZE : self.HEADER_SIZE + size].tobytes() self.local_dict = pickle.loads(raw_data) else: self.local_dict = {} except Exception as e: logging.error(f"Error reading shared memory: {e}") self.local_dict = {} def _lazy_write_dict(self): """Performs the write if not in a session.""" if not self._session_is_started: self._write_dict() def _write_dict(self): """Write the dictionnary state to the shared memory.""" # Serialize the dictionary to shared memory serialized = pickle.dumps(self.local_dict) size = len(serialized) if size + self.HEADER_SIZE > self.shared_memory.size: raise ValueError("Shared memory is too small for the dictionary.") self.shared_memory[: self.HEADER_SIZE] = np.frombuffer( size.to_bytes(self.HEADER_SIZE, byteorder="big"), dtype=np.byte ) self.shared_memory[self.HEADER_SIZE : self.HEADER_SIZE + size] = np.frombuffer( serialized, dtype=np.byte ) self.shared_memory[self.HEADER_SIZE + size :] = 0 self.win.Put(self.shared_memory, target_rank=self.root) self.win.Flush(self.root) def __getitem__(self, key): self.lock() self._lazy_read_dict() self.unlock() return self.local_dict[key] def __setitem__(self, key, value): self.lock() self._lazy_read_dict() self.local_dict[key] = value self._lazy_write_dict() self.unlock() def __delitem__(self, key): self.lock() self._lazy_read_dict() del self.local_dict[key] self._lazy_write_dict() self.unlock() def __iter__(self): self.lock() self._lazy_read_dict() self.unlock() return iter(self.local_dict) def __len__(self): self.lock() self._lazy_read_dict() self.unlock() return len(self.local_dict) def __repr__(self): self.lock() self._lazy_read_dict() self.unlock() return repr(self.local_dict)
[docs] def __call__(self, read_only: bool = False): self._session_is_read_only = read_only return self
def __enter__(self): self.session_start() return self def __exit__(self, type, value, traceback): self.session_finish()
[docs] def lock(self): """Acquire the lock. Blocking operation.""" if not self.locked: self.win.Lock(self.root) self.locked = True
[docs] def unlock(self): """Release the lock.""" if self.locked and not self._session_is_started: self.win.Unlock(self.root) self.locked = False
def session_start(self, read_only: bool = False): if self._session_is_started: raise RuntimeError("A session has already been started without being finished!") self._session_is_started = True self._session_is_read_only = read_only self.lock() self._read_dict() def session_finish(self): assert self.locked if not self._session_is_started: raise RuntimeError("No session has been started!") if not self._session_is_read_only: self._write_dict() self._session_is_started = False self._session_is_read_only = True self.unlock() # This can create a deadlock if not called by all processes! def __del__(self): self.win.Free() self.CACHE.pop(self._cache_id)
[docs] def incr(self, key: Hashable, amount=1): """Atomic operator that increments and returns the resulting value.""" keys = key.split(".") assert len(keys) > 0 # Case where the key is at the root of the mapping if len(keys) == 1: self.lock() self._lazy_read_dict() self.local_dict[key] += amount self._lazy_write_dict() self.unlock() return self.local_dict[key] # Case where the key is JSON path of type "key0.key1.key2" else: self.lock() self._lazy_read_dict() mapping = self.local_dict for key in keys[:-1]: mapping = mapping[key] key = keys[-1] mapping[key] += amount self._lazy_write_dict() self.unlock() return mapping[key]