from __future__ import annotations from typing import Dict, List, TYPE_CHECKING from .quantizer import QuantizationAnnotation, Quantizer if TYPE_CHECKING: import torch from torch.fx import Node __all__ = [ "ComposableQuantizer", ] class ComposableQuantizer(Quantizer): """ ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. This allows users to quantize a model with multiple quantizers. E.g., embedding quantization maybe supported by one quantizer while linear layers and other ops might be supported by another quantizer. ComposableQuantizer is initialized with a list of `Quantizer` instances. The order of the composition matters since that is the order in which the quantizers will be applies. Example: ``` embedding_quantizer = EmbeddingQuantizer() linear_quantizer = MyLinearQuantizer() xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer]) prepared_m = prepare_pt2e(model, composed_quantizer) ``` """ def __init__(self, quantizers: List[Quantizer]): super().__init__() self.quantizers = quantizers self._graph_annotations: Dict[Node, QuantizationAnnotation] = {} def _record_and_validate_annotations( self, gm: torch.fx.GraphModule, quantizer: Quantizer ) -> None: for n in gm.graph.nodes: if "quantization_annotation" in n.meta: # check if the annotation has been changed by # comparing QuantizationAnnotation object id if n in self._graph_annotations and ( id(self._graph_annotations[n]) != id(n.meta["quantization_annotation"]) ): raise RuntimeError( f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" ) else: self._graph_annotations[n] = n.meta["quantization_annotation"] else: if n in self._graph_annotations: raise RuntimeError( f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" ) def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" for quantizer in self.quantizers: quantizer.annotate(model) self._record_and_validate_annotations(model, quantizer) return model def transform_for_annotation( self, model: torch.fx.GraphModule ) -> torch.fx.GraphModule: for quantizer in self.quantizers: model = quantizer.transform_for_annotation(model) return model def validate(self, model: torch.fx.GraphModule) -> None: pass