# mypy: allow-untyped-defs from dataclasses import dataclass from functools import reduce from typing import List, Optional, Union from torch.distributed.remote_device import _remote_device @dataclass class ShardMetadata: """ Represents a shard of the overall Tensor including its offsets, lengths and device placement. Args: shard_offsets(List[int]): Offsets in the original tensor indicating the start offsets for this shard. Should have the same rank as the original tensor. shard_sizes(List[int]): Integers indicating the size of each dimension for this shard. Should have the same rank as the original tensor. placement(:class:`torch.distributed._remote_device`): Specifies the placement of this shard. """ __slots__ = ["shard_offsets", "shard_sizes", "placement"] shard_offsets: List[int] shard_sizes: List[int] placement: Optional[_remote_device] def __init__( self, shard_offsets: List[int], shard_sizes: List[int], placement: Optional[Union[str, _remote_device]] = None, ): self.shard_offsets = shard_offsets self.shard_sizes = shard_sizes if isinstance(placement, str): self.placement = _remote_device(placement) else: self.placement = placement if len(self.shard_offsets) != len(self.shard_sizes): raise ValueError( f"shard_offsets and shard_sizes should have " f"the same number of elements, found {len(self.shard_offsets)} " f"and {self.shard_sizes} respectively" ) for i in range(len(self.shard_offsets)): if self.shard_offsets[i] < 0: raise ValueError("shard_offsets should be >=0") if self.shard_sizes[i] < 0: raise ValueError("shard_sizes should be >= 0") def __hash__(self): def _hash_reduce(a, b): return (a << 8) + hash(b) res = reduce(_hash_reduce, self.shard_offsets, 37) res = reduce(_hash_reduce, self.shard_sizes, res) res = _hash_reduce(res, self.placement) return res