Source code for deephyper.core.utils._timeout

import logging
import multiprocessing
import multiprocessing.pool

from deephyper.core.exceptions import TimeoutReached


[docs] def terminate_on_timeout(timeout, func, *args, **kwargs): """High order function to wrap the call of a function in a thread to monitor its execution time. >>> import functools >>> f_timeout = functools.partial(terminate_on_timeout, 10, f) >>> f_timeout(1, b=2) Args: timeout (int): timeout in seconds. func (function): function to call. *args: positional arguments to pass to the function. **kwargs: keyword arguments to pass to the function. """ pool = multiprocessing.pool.ThreadPool(processes=1) results = pool.apply_async(func, args, kwargs) pool.close() try: return results.get(timeout) except multiprocessing.TimeoutError: msg = f"Search timeout expired after {timeout} sec." logging.warning(msg) raise TimeoutReached(msg) finally: pool.terminate()