""" We will recreate all the RNN modules as we require the modules to be decomposed into its building blocks to be able to observe. """ # mypy: allow-untyped-defs import numbers import warnings from typing import Optional, Tuple import torch from torch import Tensor __all__ = ["LSTMCell", "LSTM"] class LSTMCell(torch.nn.Module): r"""A quantizable long short-term memory (LSTM) cell. For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTMCell(10, 20) >>> input = torch.randn(6, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) """ _FLOAT_MODULE = torch.nn.LSTMCell def __init__( self, input_dim: int, hidden_dim: int, bias: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.input_size = input_dim self.hidden_size = hidden_dim self.bias = bias self.igates = torch.nn.Linear( input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs ) self.hgates = torch.nn.Linear( hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs ) self.gates = torch.ao.nn.quantized.FloatFunctional() self.input_gate = torch.nn.Sigmoid() self.forget_gate = torch.nn.Sigmoid() self.cell_gate = torch.nn.Tanh() self.output_gate = torch.nn.Sigmoid() self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0) self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0) self.hidden_state_dtype: torch.dtype = torch.quint8 self.cell_state_dtype: torch.dtype = torch.quint8 def forward( self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None ) -> Tuple[Tensor, Tensor]: if hidden is None or hidden[0] is None or hidden[1] is None: hidden = self.initialize_hidden(x.shape[0], x.is_quantized) hx, cx = hidden igates = self.igates(x) hgates = self.hgates(hx) gates = self.gates.add(igates, hgates) input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) input_gate = self.input_gate(input_gate) forget_gate = self.forget_gate(forget_gate) cell_gate = self.cell_gate(cell_gate) out_gate = self.output_gate(out_gate) fgate_cx = self.fgate_cx.mul(forget_gate, cx) igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) cy = fgate_cx_igate_cgate # TODO: make this tanh a member of the module so its qparams can be configured tanh_cy = torch.tanh(cy) hy = self.ogate_cy.mul(out_gate, tanh_cy) return hy, cy def initialize_hidden( self, batch_size: int, is_quantized: bool = False ) -> Tuple[Tensor, Tensor]: h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros( (batch_size, self.hidden_size) ) if is_quantized: (h_scale, h_zp) = self.initial_hidden_state_qparams (c_scale, c_zp) = self.initial_cell_state_qparams h = torch.quantize_per_tensor( h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype ) c = torch.quantize_per_tensor( c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype ) return h, c def _get_name(self): return "QuantizableLSTMCell" @classmethod def from_params(cls, wi, wh, bi=None, bh=None): """Uses the weights and biases to create a new LSTM cell. Args: wi, wh: Weights for the input and hidden layers bi, bh: Biases for the input and hidden layers """ assert (bi is None) == (bh is None) # Either both None or both have values input_size = wi.shape[1] hidden_size = wh.shape[1] cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None)) cell.igates.weight = torch.nn.Parameter(wi) if bi is not None: cell.igates.bias = torch.nn.Parameter(bi) cell.hgates.weight = torch.nn.Parameter(wh) if bh is not None: cell.hgates.bias = torch.nn.Parameter(bh) return cell @classmethod def from_float(cls, other, use_precomputed_fake_quant=False): assert type(other) == cls._FLOAT_MODULE assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" observed = cls.from_params( other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh ) observed.qconfig = other.qconfig observed.igates.qconfig = other.qconfig observed.hgates.qconfig = other.qconfig return observed class _LSTMSingleLayer(torch.nn.Module): r"""A single one-directional LSTM layer. The difference between a layer and a cell is that the layer can process a sequence, while the cell only expects an instantaneous value. """ def __init__( self, input_dim: int, hidden_dim: int, bias: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): result = [] seq_len = x.shape[0] for i in range(seq_len): hidden = self.cell(x[i], hidden) result.append(hidden[0]) # type: ignore[index] result_tensor = torch.stack(result, 0) return result_tensor, hidden @classmethod def from_params(cls, *args, **kwargs): cell = LSTMCell.from_params(*args, **kwargs) layer = cls(cell.input_size, cell.hidden_size, cell.bias) layer.cell = cell return layer class _LSTMLayer(torch.nn.Module): r"""A single bi-directional LSTM layer.""" def __init__( self, input_dim: int, hidden_dim: int, bias: bool = True, batch_first: bool = False, bidirectional: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.batch_first = batch_first self.bidirectional = bidirectional self.layer_fw = _LSTMSingleLayer( input_dim, hidden_dim, bias=bias, **factory_kwargs ) if self.bidirectional: self.layer_bw = _LSTMSingleLayer( input_dim, hidden_dim, bias=bias, **factory_kwargs ) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): if self.batch_first: x = x.transpose(0, 1) if hidden is None: hx_fw, cx_fw = (None, None) else: hx_fw, cx_fw = hidden hidden_bw: Optional[Tuple[Tensor, Tensor]] = None if self.bidirectional: if hx_fw is None: hx_bw = None else: hx_bw = hx_fw[1] hx_fw = hx_fw[0] if cx_fw is None: cx_bw = None else: cx_bw = cx_fw[1] cx_fw = cx_fw[0] if hx_bw is not None and cx_bw is not None: hidden_bw = hx_bw, cx_bw if hx_fw is None and cx_fw is None: hidden_fw = None else: hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional( cx_fw ) result_fw, hidden_fw = self.layer_fw(x, hidden_fw) if hasattr(self, "layer_bw") and self.bidirectional: x_reversed = x.flip(0) result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) result_bw = result_bw.flip(0) result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) if hidden_fw is None and hidden_bw is None: h = None c = None elif hidden_fw is None: (h, c) = torch.jit._unwrap_optional(hidden_bw) elif hidden_bw is None: (h, c) = torch.jit._unwrap_optional(hidden_fw) else: h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] else: result = result_fw h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] if self.batch_first: result.transpose_(0, 1) return result, (h, c) @classmethod def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): r""" There is no FP equivalent of this class. This function is here just to mimic the behavior of the `prepare` within the `torch.ao.quantization` flow. """ assert hasattr(other, "qconfig") or (qconfig is not None) input_size = kwargs.get("input_size", other.input_size) hidden_size = kwargs.get("hidden_size", other.hidden_size) bias = kwargs.get("bias", other.bias) batch_first = kwargs.get("batch_first", other.batch_first) bidirectional = kwargs.get("bidirectional", other.bidirectional) layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) layer.qconfig = getattr(other, "qconfig", qconfig) wi = getattr(other, f"weight_ih_l{layer_idx}") wh = getattr(other, f"weight_hh_l{layer_idx}") bi = getattr(other, f"bias_ih_l{layer_idx}", None) bh = getattr(other, f"bias_hh_l{layer_idx}", None) layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) if other.bidirectional: wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) return layer class LSTM(torch.nn.Module): r"""A quantizable long short-term memory (LSTM). For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` Attributes: layers : instances of the `_LSTMLayer` .. note:: To access the weights and biases, you need to access them per layer. See examples below. Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> # To get the weights: >>> # xdoctest: +SKIP >>> print(rnn.layers[0].weight_ih) tensor([[...]]) >>> print(rnn.layers[0].weight_hh) AssertionError: There is no reverse path in the non-bidirectional layer """ _FLOAT_MODULE = torch.nn.LSTM def __init__( self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, batch_first: bool = False, dropout: float = 0.0, bidirectional: bool = False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.dropout = float(dropout) self.bidirectional = bidirectional self.training = False # Default to eval mode. If we want to train, we will explicitly set to training. num_directions = 2 if bidirectional else 1 if ( not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or isinstance(dropout, bool) ): raise ValueError( "dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed" ) if dropout > 0: warnings.warn( "dropout option for quantizable LSTM is ignored. " "If you are training, please, use nn.LSTM version " "followed by `prepare` step." ) if num_layers == 1: warnings.warn( "dropout option adds dropout after all but last " "recurrent layer, so non-zero dropout expects " f"num_layers greater than 1, but got dropout={dropout} " f"and num_layers={num_layers}" ) layers = [ _LSTMLayer( self.input_size, self.hidden_size, self.bias, batch_first=False, bidirectional=self.bidirectional, **factory_kwargs, ) ] for layer in range(1, num_layers): layers.append( _LSTMLayer( self.hidden_size, self.hidden_size, self.bias, batch_first=False, bidirectional=self.bidirectional, **factory_kwargs, ) ) self.layers = torch.nn.ModuleList(layers) def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): if self.batch_first: x = x.transpose(0, 1) max_batch_size = x.size(1) num_directions = 2 if self.bidirectional else 1 if hidden is None: zeros = torch.zeros( num_directions, max_batch_size, self.hidden_size, dtype=torch.float, device=x.device, ) zeros.squeeze_(0) if x.is_quantized: zeros = torch.quantize_per_tensor( zeros, scale=1.0, zero_point=0, dtype=x.dtype ) hxcx = [(zeros, zeros) for _ in range(self.num_layers)] else: hidden_non_opt = torch.jit._unwrap_optional(hidden) if isinstance(hidden_non_opt[0], Tensor): hx = hidden_non_opt[0].reshape( self.num_layers, num_directions, max_batch_size, self.hidden_size ) cx = hidden_non_opt[1].reshape( self.num_layers, num_directions, max_batch_size, self.hidden_size ) hxcx = [ (hx[idx].squeeze(0), cx[idx].squeeze(0)) for idx in range(self.num_layers) ] else: hxcx = hidden_non_opt hx_list = [] cx_list = [] for idx, layer in enumerate(self.layers): x, (h, c) = layer(x, hxcx[idx]) hx_list.append(torch.jit._unwrap_optional(h)) cx_list.append(torch.jit._unwrap_optional(c)) hx_tensor = torch.stack(hx_list) cx_tensor = torch.stack(cx_list) # We are creating another dimension for bidirectional case # need to collapse it hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) if self.batch_first: x = x.transpose(0, 1) return x, (hx_tensor, cx_tensor) def _get_name(self): return "QuantizableLSTM" @classmethod def from_float(cls, other, qconfig=None): assert isinstance(other, cls._FLOAT_MODULE) assert hasattr(other, "qconfig") or qconfig observed = cls( other.input_size, other.hidden_size, other.num_layers, other.bias, other.batch_first, other.dropout, other.bidirectional, ) observed.qconfig = getattr(other, "qconfig", qconfig) for idx in range(other.num_layers): observed.layers[idx] = _LSTMLayer.from_float( other, idx, qconfig, batch_first=False ) # Prepare the model if other.training: observed.train() observed = torch.ao.quantization.prepare_qat(observed, inplace=True) else: observed.eval() observed = torch.ao.quantization.prepare(observed, inplace=True) return observed @classmethod def from_observed(cls, other): # The whole flow is float -> observed -> quantized # This class does float -> observed only raise NotImplementedError( "It looks like you are trying to convert a " "non-quantizable LSTM module. Please, see " "the examples on quantizable LSTMs." )