import os from contextlib import contextmanager, ExitStack from typing import Generator from torch.distributed.elastic.multiprocessing.errors import record __all__ = [ "worker_main", ] TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" @contextmanager def _worker_server(socket_path: str) -> Generator[None, None, None]: from torch._C._distributed_c10d import _WorkerServer server = _WorkerServer(socket_path) try: yield finally: server.shutdown() @contextmanager @record def worker_main() -> Generator[None, None, None]: """ This is a context manager that wraps your main entry function. This combines the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that exposes handlers via a unix socket specified by ``Torch_WORKER_SERVER_SOCKET``. Example :: @worker_main() def main(): pass if __name__=="__main__": main() """ with ExitStack() as stack: socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) if socket_path is not None: stack.enter_context(_worker_server(socket_path)) yield