# mypy: allow-untyped-defs from typing import List, Optional import torch import torch.distributed.distributed_c10d as c10d """ This file contains the op impls for the legacy (c10d_functional) functional collectives. These impls simply call into the native (_c10d_functional) functional collectives. """ def _broadcast(input, src, tag, ranks, group_size): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.broadcast( input, src, group_name, ) def _all_reduce(input, reduce_op, tag, ranks, group_size): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.all_reduce( input, reduce_op, group_name, ) def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.all_reduce_coalesced( inputs, reduce_op, group_name, ) def _all_gather_into_tensor(input, tag, ranks, group_size): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.all_gather_into_tensor( input, group_size, group_name, ) def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.all_gather_into_tensor_coalesced( input, group_size, group_name, ) def _reduce_scatter_tensor( input: torch.Tensor, reduce_op: str, tag: str, ranks: List[int], group_size: int, ): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.reduce_scatter_tensor( input, reduce_op, group_size, group_name, ) def _reduce_scatter_tensor_coalesced( inputs: List[torch.Tensor], reduce_op: str, tag: str, ranks: List[int], group_size: int, ): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( inputs, reduce_op, group_size, group_name, ) def _all_to_all_single( input: torch.Tensor, output_split_sizes: Optional[List[int]], input_split_sizes: Optional[List[int]], tag: str, ranks: List[int], group_size: int, ): if output_split_sizes is None or input_split_sizes is None: assert output_split_sizes is None and input_split_sizes is None, ( "output_split_sizes and input_split_sizes must either be " "specified together or both set to None" ) output_split_sizes = [input.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) return torch.ops._c10d_functional.all_to_all_single( input, output_split_sizes, input_split_sizes, group_name, ) def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: return torch.ops._c10d_functional.wait_tensor(tensor)