# mypy: allow-untyped-defs # This module provides a FAST (on GPU) content addressable store for storages # (and tensors on top of them) with VERY WEAK portability guarantees (e.g., # don't expect CPU/CUDA to address to the same hash, don't expect it to be # portable across devices) that is NOT cryptographically secure. In return, # we are able to hash 40G of tensor data on GPU in less than a second, # compared to running SHA-1 in CPU which would a minute or so. The primary # use case is for efficiently snapshotting intermediate tensor data for # offline debugging, but it's been put in this module in case you think of # another use case for it. The hash function could be replaced with a # straight reimplementation of SHA-1, which would give us much stronger # portability guarantees. # # WARNING: THERE IS NO BC/FC GUARANTEE FOR THIS FORMAT! If you need to format # shift the result, consider packing it into a single torch.save object # with traditional view sharing. # # Because of the weak portability guarantees, you can only write to the # content store from a single process; we don't provide any capability # of "reopening" a content store to add more things to it. But we don't # assume that you can keep all of the tensors you want to add to the store # in memory at once, because you probably can't! Nor do we assume that # you know a priori whether or not two storages can be deduplicated or not. # # Note: only storages are content-addressed; tensors are name addressed # # Note: our padding strategy means that [1, 0] and [1] int16 tensors would # map to the same (padded) storage. We think this will be immaterial for most # users. import ctypes import functools import hashlib import os.path import struct from collections import defaultdict from typing import Dict, Optional, Set import torch import torch._prims as prims import torch._utils import torch.nn.functional as F from torch._C import default_generator from torch.multiprocessing.reductions import StorageWeakRef def lazy_compile(**compile_kwargs): """Lazily wrap a function with torch.compile on the first call This avoids eagerly importing dynamo. """ def decorate_fn(fn): @functools.wraps(fn) def compile_hook(*args, **kwargs): compiled_fn = torch.compile(fn, **compile_kwargs) globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) return compiled_fn(*args, **kwargs) return compile_hook return decorate_fn # Use of torch.compile is mandatory for (1) good memory usage # and (2) xor_sum implementation. This is our first instance of # using PT2 to implement a kernel in PyTorch; if we get AOT capabilities # it would be good to apply it here. @lazy_compile(dynamic=True) def hash_storage_kernel(x): # The randint calls are carefully written to hit things we # have lowerings for in inductor. Lack of unsigned 32-bit integer # is a pain. a = torch.randint( -(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32 ).abs() a = ((a % (2**31 - 1)) + 1).long() b = ( torch.randint(-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32) .abs() .long() ) # This is a standard shift-multiply universal hash family # plus xor sum hash, using Philox to generate random numbers. # Our Philox RNG is not deterministic across devices so # don't use this for stable hashing. # # This assumes fixed length so you're also obligated to bucket # by the length of tensor as well return prims.xor_sum((a * x + b).int(), [0]) # Returns a hex digest of the data in the storage. Guaranteed to be # SHA-1 if stable_hash=True, otherwise it will consistent for a single # process run but not necessarily across processes. def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) -> str: import torch._dynamo from torch._dynamo.utils import is_compile_supported device_type = storage.device.type if stable_hash or not is_compile_supported(device_type): cpu_storage = storage.cpu() # TODO: make storage support buffer protocol so this isn't # necessary buf = (ctypes.c_byte * cpu_storage.nbytes()).from_address( cpu_storage.data_ptr() ) sha1 = hashlib.sha1() sha1.update(buf) return sha1.hexdigest() # TODO: factor this into a random utility if device_type == "cpu": generator = default_generator elif device_type == "cuda": import torch.cuda generator = torch.cuda.default_generators[storage.device.index] else: raise AssertionError(f"unhandled device type {device_type}") state = generator.get_state() try: generator.manual_seed(0) x = torch.empty(0, dtype=torch.uint8, device=storage.device).set_(storage) # type: ignore[call-overload] # The dtype-casting view cannot be compiled, and so the # padding/reshaping also needs to be done externally even # though it could be profitably fused pad = -x.numel() % 4 if pad > 0: x = F.pad(x, (0, pad), "constant", 0) x = x.view(torch.int32) # We run the 32-bit hash five times with differing parameters to # reduce chance of collision ITER = 5 cs = [hash_storage_kernel(x).item() for _ in range(ITER)] return struct.pack(">" + "i" * ITER, *cs).hex() finally: generator.set_state(state) class ContentStoreWriter: # Structure: # storages/ # 00/ # 0000..00 # tensors/ # name def __init__(self, loc: str, stable_hash: bool = False) -> None: self.loc: str = loc self.seen_storage_hashes: Set[str] = set() self.stable_hash = stable_hash # TODO: offer some sort of non-blocking API to speed things up def write_storage(self, storage: torch.UntypedStorage) -> str: h = hash_storage(storage, stable_hash=self.stable_hash) if h in self.seen_storage_hashes: return h # TODO: consider not using torch.save for this; we don't actually # need any metadata for the storage subfolder = os.path.join(self.loc, "storages") os.makedirs(subfolder, exist_ok=True) target = os.path.join(subfolder, h) if os.path.exists(target): return h torch.save(storage, target) self.seen_storage_hashes.add(h) return h def compute_tensor_metadata(self, t: torch.Tensor, h=None): if h is None: h = hash_storage(t.untyped_storage(), stable_hash=self.stable_hash) return ( t.dtype, h, t.storage_offset(), tuple(t.shape), t.stride(), torch._utils.get_tensor_metadata(t), ) def write_tensor(self, name: str, t: torch.Tensor) -> None: storage = t.untyped_storage() h = self.write_storage(storage) # TODO: Support more advanced snapshotting of requires_grad/grad/etc d, f = os.path.split(name) payload = self.compute_tensor_metadata(t, h=h) subfolder = os.path.join(self.loc, "tensors", d) os.makedirs(subfolder, exist_ok=True) torch.save(payload, os.path.join(subfolder, f)) class ContentStoreReader: def __init__(self, loc: str, *, cache=True) -> None: self.loc = loc self.storage_cache: Optional[ Dict[Optional[torch.device], Dict[str, StorageWeakRef]] ] = None if cache: self.storage_cache = defaultdict(dict) def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage: if device is not None: device = torch.device(device) ws = ( self.storage_cache[device].get(h) if self.storage_cache is not None else None ) s: Optional[torch.UntypedStorage] if ws is not None: s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata) if s is not None: return s s = torch.load( os.path.join(self.loc, "storages", h), weights_only=True, map_location=device, )._untyped_storage assert s is not None if self.storage_cache is not None: self.storage_cache[device][h] = StorageWeakRef(s) return s def read_tensor_metadata(self, name: str): fn = os.path.join(self.loc, "tensors", name) if not os.path.exists(fn): raise FileNotFoundError(fn) return torch.load(fn, weights_only=True) def read_tensor(self, name: str, *, device=None) -> torch.Tensor: dtype, h, storage_offset, size, stride, metadata = self.read_tensor_metadata( name ) storage = self.read_storage(h, device=device) t = torch.tensor([], dtype=dtype, device=storage.device) t.set_(storage, storage_offset, size, stride) torch._utils.set_tensor_metadata(t, metadata) return t