tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251106__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_assert_nodes.py +3 -1
  10. tico/passes/remove_redundant_expand.py +3 -1
  11. tico/quantization/__init__.py +6 -0
  12. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  29. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  31. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  32. tico/quantization/config/base.py +26 -0
  33. tico/quantization/config/gptq.py +29 -0
  34. tico/quantization/config/pt2e.py +25 -0
  35. tico/quantization/config/ptq.py +119 -0
  36. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  37. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
  38. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  39. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  40. tico/quantization/evaluation/metric.py +146 -0
  41. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  42. tico/quantization/passes/__init__.py +1 -0
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/__init__.py +1 -0
  47. tico/quantization/wrapq/dtypes.py +70 -0
  48. tico/quantization/wrapq/examples/__init__.py +1 -0
  49. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  50. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  51. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  52. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  53. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  54. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  55. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  56. tico/quantization/wrapq/mode.py +32 -0
  57. tico/quantization/wrapq/observers/__init__.py +1 -0
  58. tico/quantization/wrapq/observers/affine_base.py +128 -0
  59. tico/quantization/wrapq/observers/base.py +98 -0
  60. tico/quantization/wrapq/observers/ema.py +62 -0
  61. tico/quantization/wrapq/observers/identity.py +74 -0
  62. tico/quantization/wrapq/observers/minmax.py +39 -0
  63. tico/quantization/wrapq/observers/mx.py +60 -0
  64. tico/quantization/wrapq/qscheme.py +40 -0
  65. tico/quantization/wrapq/quantizer.py +179 -0
  66. tico/quantization/wrapq/utils/__init__.py +1 -0
  67. tico/quantization/wrapq/utils/introspection.py +167 -0
  68. tico/quantization/wrapq/utils/metrics.py +124 -0
  69. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  70. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  71. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  72. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  73. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  74. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  75. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  76. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  77. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  78. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  79. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  80. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  81. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  82. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  83. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  84. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  85. tico/quantization/wrapq/wrappers/nn/quant_silu.py +60 -0
  86. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  87. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  88. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  89. tico/quantization/wrapq/wrappers/registry.py +128 -0
  90. tico/serialize/circle_serializer.py +11 -4
  91. tico/serialize/operators/adapters/__init__.py +1 -0
  92. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  93. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  94. tico/serialize/operators/op_le.py +54 -0
  95. tico/serialize/operators/op_mm.py +15 -132
  96. tico/serialize/operators/op_rmsnorm.py +65 -0
  97. tico/utils/convert.py +20 -15
  98. tico/utils/dtype.py +22 -0
  99. tico/utils/register_custom_op.py +29 -4
  100. tico/utils/signature.py +247 -0
  101. tico/utils/utils.py +50 -53
  102. tico/utils/validate_args_kwargs.py +37 -0
  103. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/METADATA +49 -2
  104. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/RECORD +130 -73
  105. tico/experimental/quantization/__init__.py +0 -6
  106. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  107. tico/experimental/quantization/evaluation/metric.py +0 -109
  108. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  109. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  111. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  112. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  113. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  114. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  115. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  116. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  117. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  118. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  119. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  120. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  121. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  122. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  123. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  124. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  125. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  126. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  127. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  128. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  129. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  130. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251106.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,492 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ from typing import Dict, Iterable, List, Optional, Tuple
22
+
23
+ import torch
24
+ from torch import nn, Tensor
25
+
26
+ from tico.quantization.config.ptq import PTQConfig
27
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
28
+ QuantFairseqMultiheadAttention,
29
+ )
30
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
31
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
32
+ from tico.quantization.wrapq.wrappers.registry import try_register
33
+
34
+
35
+ @try_register("fairseq.modules.transformer_layer.TransformerDecoderLayerBase")
36
+ class QuantFairseqDecoderLayer(QuantModuleBase):
37
+ """
38
+ Quant-aware drop-in replacement for Fairseq TransformerDecoderLayerBase.
39
+
40
+ Design (inference-only):
41
+ - Keep LayerNorms and scalar head/residual scalers in FP.
42
+ - PTQ-wrap: self_attn, (optional) encoder_attn, fc1, fc2.
43
+ - Preserve Fairseq tensor contracts and incremental state handling.
44
+ - Remove training-time behaviors: dropout, activation-dropout, quant-noise, onnx_trace.
45
+
46
+ I/O:
47
+ - Input/Output use Fairseq shapes: [T, B, C].
48
+ - Forward returns: (x, attn, None) to match the original call sites in decoder.
49
+ * `attn` is from encoder-attention when requested (alignment).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ fp_layer: nn.Module,
55
+ *,
56
+ qcfg: Optional[PTQConfig] = None,
57
+ fp_name: Optional[str] = None,
58
+ ):
59
+ super().__init__(qcfg, fp_name=fp_name)
60
+
61
+ # --- read-only metadata copied from FP layer -----------------------
62
+ assert hasattr(fp_layer, "embed_dim")
63
+ assert hasattr(fp_layer, "normalize_before")
64
+ self.embed_dim: int = int(fp_layer.embed_dim) # type: ignore[arg-type]
65
+ self.normalize_before: bool = bool(fp_layer.normalize_before)
66
+
67
+ # Cross-self attention flag (when True, key/value can include encoder_out)
68
+ self.cross_self_attention: bool = bool(
69
+ getattr(fp_layer, "cross_self_attention", False)
70
+ )
71
+
72
+ # Generate prefix
73
+ def _safe_prefix(name: Optional[str]) -> str:
74
+ # Avoid "None.*" strings causing collisions
75
+ return (
76
+ name
77
+ if (name is not None and name != "None" and name != "")
78
+ else f"{self.__class__.__name__}_{id(self)}"
79
+ )
80
+
81
+ prefix = _safe_prefix(fp_name)
82
+ # Self-attn (PTQ) ---------------------------------------------------
83
+ # Use our MHA wrapper with identical API to the FP module.
84
+ attn_cfg = qcfg.child("self_attn") if qcfg else None
85
+ assert hasattr(fp_layer, "self_attn") and isinstance(
86
+ fp_layer.self_attn, nn.Module
87
+ )
88
+ self.self_attn = QuantFairseqMultiheadAttention(
89
+ fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{prefix}.self_attn"
90
+ )
91
+
92
+ # Optional attention LayerNorm applied to self-attn output (scale_attn)
93
+ # Kept in FP; reuse original instance for weight parity.
94
+ self.attn_ln = getattr(fp_layer, "attn_ln", None)
95
+
96
+ # Optional per-head scaling after self-attn output (scale_heads)
97
+ # Keep exact Parameter reference if present (shape: [num_heads])
98
+ self.c_attn = getattr(fp_layer, "c_attn", None)
99
+
100
+ # Cache head meta for c_attn path
101
+ self.nh = int(getattr(self.self_attn, "num_heads"))
102
+ self.head_dim = int(getattr(self.self_attn, "head_dim"))
103
+
104
+ # Encoder-attn (PTQ) ------------------------------------------------
105
+ # Only present if the original layer was constructed with encoder_attn.
106
+ enc_attn_mod = getattr(fp_layer, "encoder_attn", None)
107
+ assert enc_attn_mod is not None
108
+ enc_cfg = qcfg.child("encoder_attn") if qcfg else None
109
+ self.encoder_attn = QuantFairseqMultiheadAttention(
110
+ enc_attn_mod, qcfg=enc_cfg, fp_name=f"{prefix}.encoder_attn"
111
+ )
112
+
113
+ # Feed-forward (PTQ) ------------------------------------------------
114
+ fc1_cfg = qcfg.child("fc1") if qcfg else None
115
+ fc2_cfg = qcfg.child("fc2") if qcfg else None
116
+ assert hasattr(fp_layer, "fc1") and isinstance(fp_layer.fc1, nn.Module)
117
+ assert hasattr(fp_layer, "fc2") and isinstance(fp_layer.fc2, nn.Module)
118
+ self.fc1 = PTQWrapper(fp_layer.fc1, qcfg=fc1_cfg, fp_name=f"{fp_name}.fc1")
119
+ self.fc2 = PTQWrapper(fp_layer.fc2, qcfg=fc2_cfg, fp_name=f"{fp_name}.fc2")
120
+
121
+ # LayerNorms
122
+ enc_attn_ln_cfg = qcfg.child("encoder_attn_layer_norm") if qcfg else None
123
+ attn_ln_cfg = qcfg.child("self_attn_layer_norm") if qcfg else None
124
+ final_ln_cfg = qcfg.child("final_layer_norm") if qcfg else None
125
+ assert hasattr(fp_layer, "encoder_attn_layer_norm") and isinstance(
126
+ fp_layer.encoder_attn_layer_norm, nn.Module
127
+ )
128
+ assert hasattr(fp_layer, "self_attn_layer_norm") and isinstance(
129
+ fp_layer.self_attn_layer_norm, nn.Module
130
+ )
131
+ assert hasattr(fp_layer, "final_layer_norm") and isinstance(
132
+ fp_layer.final_layer_norm, nn.Module
133
+ )
134
+ self.encoder_attn_layer_norm = PTQWrapper(
135
+ fp_layer.encoder_attn_layer_norm,
136
+ qcfg=enc_attn_ln_cfg,
137
+ fp_name=f"{fp_name}.encoder_attn_layer_norm",
138
+ )
139
+ self.self_attn_layer_norm = PTQWrapper(
140
+ fp_layer.self_attn_layer_norm,
141
+ qcfg=attn_ln_cfg,
142
+ fp_name=f"{fp_name}.self_attn_layer_norm",
143
+ )
144
+ self.final_layer_norm = PTQWrapper(
145
+ fp_layer.final_layer_norm,
146
+ qcfg=final_ln_cfg,
147
+ fp_name=f"{fp_name}.final_layer_norm",
148
+ )
149
+
150
+ # Optional FFN intermediate LayerNorm (scale_fc), FP
151
+ self.ffn_layernorm = getattr(fp_layer, "ffn_layernorm", None)
152
+
153
+ # Optional residual scaling (scale_resids), keep Parameter reference
154
+ self.w_resid = getattr(fp_layer, "w_resid", None)
155
+
156
+ # Activation function
157
+ self.activation_fn = fp_layer.activation_fn # type: ignore[operator]
158
+ self.obs_activation_fn = self._make_obs("activation_fn")
159
+
160
+ # Alignment flag used by Fairseq (kept for API parity)
161
+ self.need_attn: bool = bool(getattr(fp_layer, "need_attn", True))
162
+
163
+ # No dropout / activation-dropout in inference wrapper
164
+ # (intentionally omitted)
165
+
166
+ # --- observers for external/self-attn KV cache inputs --------------
167
+ self.obs_prev_self_k_in = self._make_obs("prev_self_k_in")
168
+ self.obs_prev_self_v_in = self._make_obs("prev_self_v_in")
169
+
170
+ # ----------------------------------------------------------------------
171
+ def _maybe_apply_head_scale(self, x: Tensor) -> Tensor:
172
+ """
173
+ Optional per-head scaling (scale_heads) after self-attention.
174
+ x: [T, B, C]
175
+ """
176
+ if self.c_attn is None:
177
+ return x
178
+ T, B, _ = x.shape
179
+ x = x.view(T, B, self.nh, self.head_dim) # [T,B,H,Dh]
180
+ # einsum over head dim: scales each head independently
181
+ x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) # [T,B,H,Dh]
182
+ return x.reshape(T, B, self.nh * self.head_dim) # [T,B,C]
183
+
184
+ # ----------------------------------------------------------------------
185
+ def forward(
186
+ self,
187
+ x: Tensor, # [T,B,C]
188
+ encoder_out: Optional[Tensor] = None, # [S,B,Ce] or None
189
+ encoder_padding_mask: Optional[Tensor] = None, # [B,S] bool or additive float
190
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
191
+ prev_self_attn_state: Optional[List[Tensor]] = None,
192
+ prev_attn_state: Optional[List[Tensor]] = None,
193
+ self_attn_mask: Optional[Tensor] = None, # [T,T] or [B,T,T] or None
194
+ self_attn_padding_mask: Optional[Tensor] = None, # [B,T] or [B,T,T] or None
195
+ need_attn: bool = False,
196
+ need_head_weights: bool = False,
197
+ ) -> Tuple[Tensor, Optional[Tensor], None]:
198
+ """
199
+ Mirrors the original forward, minus training-only logic.
200
+ Returns:
201
+ x': [T,B,C], attn (from encoder-attn when requested), None
202
+ """
203
+ if need_head_weights:
204
+ need_attn = True
205
+
206
+ # ---- (1) Self-Attention block ------------------------------------
207
+ residual = x
208
+ if self.normalize_before:
209
+ x = self.self_attn_layer_norm(x)
210
+
211
+ # Load provided cached self-attn state (for incremental decoding)
212
+ if prev_self_attn_state is not None:
213
+ prev_key, prev_value = prev_self_attn_state[:2]
214
+ saved_state: Dict[str, Optional[Tensor]] = {
215
+ "prev_key": prev_key,
216
+ "prev_value": prev_value,
217
+ }
218
+ if len(prev_self_attn_state) >= 3:
219
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
220
+ assert incremental_state is not None
221
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
222
+
223
+ # Cross-self-attention: prepend encoder_out to K/V at the first step
224
+ y = x
225
+ if self.cross_self_attention:
226
+ _buf = self.self_attn._get_input_buffer(incremental_state)
227
+ no_cache_yet = not (
228
+ incremental_state is not None
229
+ and _buf is not None
230
+ and "prev_key" in _buf
231
+ )
232
+ if no_cache_yet:
233
+ if self_attn_mask is not None:
234
+ assert encoder_out is not None
235
+ # Grow attn mask to cover encoder timesteps (no autoregressive penalty for them)
236
+ self_attn_mask = torch.cat(
237
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask),
238
+ dim=1,
239
+ )
240
+ if self_attn_padding_mask is not None:
241
+ if encoder_padding_mask is None:
242
+ assert encoder_out is not None
243
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
244
+ encoder_out.size(1), encoder_out.size(0)
245
+ )
246
+ # Concatenate encoder pad-mask in front of target pad-mask
247
+ self_attn_padding_mask = torch.cat(
248
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
249
+ )
250
+ assert encoder_out is not None
251
+ y = torch.cat((encoder_out, x), dim=0) # [S+T, B, C]
252
+
253
+ # Self-attn; Fairseq never consumes self-attn weights for alignment here
254
+ x, _ = self.self_attn(
255
+ query=x,
256
+ key=y,
257
+ value=y,
258
+ key_padding_mask=self_attn_padding_mask,
259
+ incremental_state=incremental_state,
260
+ need_weights=False,
261
+ attn_mask=self_attn_mask,
262
+ )
263
+
264
+ # Optional per-head scaling and attn LayerNorm on self-attn output
265
+ x = self._maybe_apply_head_scale(x)
266
+ if self.attn_ln is not None:
267
+ x = self.attn_ln(x)
268
+
269
+ # Residual + (post-norm if applicable)
270
+ x = residual + x
271
+ if not self.normalize_before:
272
+ x = self.self_attn_layer_norm(x)
273
+
274
+ # ---- (2) Encoder-Decoder Attention block --------------------------
275
+ attn_out: Optional[Tensor] = None
276
+ assert encoder_out is not None
277
+ residual = x
278
+ assert self.encoder_attn_layer_norm is not None
279
+ if self.normalize_before:
280
+ x = self.encoder_attn_layer_norm(x)
281
+
282
+ # Load provided cached cross-attn state
283
+ if prev_attn_state is not None:
284
+ prev_key, prev_value = prev_attn_state[:2]
285
+ saved_state = {"prev_key": prev_key, "prev_value": prev_value}
286
+ if len(prev_attn_state) >= 3:
287
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
288
+ assert incremental_state is not None
289
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
290
+
291
+ # Cross-attn (static_kv=True to reuse encoder K/V across steps)
292
+ assert self.encoder_attn is not None
293
+ x, attn_out = self.encoder_attn(
294
+ query=x,
295
+ key=encoder_out,
296
+ value=encoder_out,
297
+ key_padding_mask=encoder_padding_mask,
298
+ incremental_state=incremental_state,
299
+ static_kv=True,
300
+ need_weights=need_attn or self.need_attn,
301
+ need_head_weights=need_head_weights,
302
+ )
303
+
304
+ x = residual + x
305
+ if not self.normalize_before:
306
+ x = self.encoder_attn_layer_norm(x)
307
+
308
+ # ---- (3) Feed-Forward block --------------------------------------
309
+ residual = x
310
+ if self.normalize_before:
311
+ x = self.final_layer_norm(x)
312
+
313
+ # FFN: fc1 -> activation -> (optional LN) -> fc2
314
+ x = self.fc1(x)
315
+ x = self.activation_fn(x) # type: ignore[operator]
316
+ x = self._fq(x, self.obs_activation_fn)
317
+ if self.ffn_layernorm is not None:
318
+ x = self.ffn_layernorm(x)
319
+ x = self.fc2(x)
320
+
321
+ # Optional residual scaling (scale_resids)
322
+ if self.w_resid is not None:
323
+ residual = torch.mul(self.w_resid, residual)
324
+
325
+ x = residual + x
326
+ if not self.normalize_before:
327
+ x = self.final_layer_norm(x)
328
+
329
+ # Return attn from encoder-attn branch when requested; self-attn weights are not returned.
330
+ return x, attn_out, None
331
+
332
+ def forward_external(
333
+ self,
334
+ x: Tensor, # [1, B, C] (embedded current-step token)
335
+ *,
336
+ encoder_out: Optional[Tensor], # [S, B, Ce]
337
+ encoder_padding_mask: Optional[
338
+ Tensor
339
+ ] = None, # [B,S] bool or additive-float or [B,1,S] additive-float
340
+ prev_self_k: Optional[Tensor] = None, # [B, H, Tprev, Dh]
341
+ prev_self_v: Optional[Tensor] = None, # [B, H, Tprev, Dh]
342
+ self_attn_mask: Optional[
343
+ Tensor
344
+ ] = None, # [1, 1, S_hist+1] or [B,1,S_hist+1] additive-float
345
+ need_attn: bool = False,
346
+ need_head_weights: bool = False,
347
+ ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
348
+ """
349
+ Export-only single-step:
350
+ Returns (x_out[1,B,C], attn_from_cross, new_self_k[B,H,1,Dh], new_self_v[B,H,1,Dh]).
351
+ """
352
+ if need_head_weights:
353
+ need_attn = True
354
+
355
+ assert x.dim() == 3 and x.size(0) == 1, "x must be [1,B,C]"
356
+ B = x.size(1)
357
+
358
+ # ---- Self-Attention (uses MHA return_new_kv) ----------------------
359
+ x_tbc = x
360
+ if self.normalize_before:
361
+ x_tbc = self.self_attn_layer_norm(x_tbc)
362
+
363
+ # Provide prev KV via incremental_state so wrapper appends internally
364
+ incr: Dict[str, Dict[str, Optional[Tensor]]] = {}
365
+ if prev_self_k is not None and prev_self_v is not None:
366
+ # Attach observers to incoming caches
367
+ prev_self_k = self._fq(prev_self_k, self.obs_prev_self_k_in)
368
+ prev_self_v = self._fq(prev_self_v, self.obs_prev_self_v_in)
369
+ assert isinstance(prev_self_k, Tensor) and isinstance(prev_self_v, Tensor)
370
+ saved = {
371
+ "prev_key": prev_self_k.detach(),
372
+ "prev_value": prev_self_v.detach(),
373
+ }
374
+ self.self_attn._set_input_buffer(incr, saved) # type: ignore[arg-type]
375
+
376
+ # Normalize self-attn additive mask to shapes wrapper accepts: [T,S] or [B,T,S]
377
+ attn_mask_for_wrapper = None
378
+ if self_attn_mask is not None:
379
+ if (
380
+ self_attn_mask.dim() == 3
381
+ and self_attn_mask.size(0) == B
382
+ and self_attn_mask.size(1) == 1
383
+ ):
384
+ attn_mask_for_wrapper = self_attn_mask # [B,1,S]
385
+ elif (
386
+ self_attn_mask.dim() == 3
387
+ and self_attn_mask.size(0) == 1
388
+ and self_attn_mask.size(1) == 1
389
+ ):
390
+ attn_mask_for_wrapper = self_attn_mask[0] # -> [1,S]
391
+ elif self_attn_mask.dim() == 2 and self_attn_mask.size(0) == 1:
392
+ attn_mask_for_wrapper = self_attn_mask # [1,S]
393
+ else:
394
+ raise RuntimeError(
395
+ "self_attn_mask must be [1,S] or [B,1,S] additive-float."
396
+ )
397
+ attn_mask_for_wrapper = attn_mask_for_wrapper.to(
398
+ dtype=x_tbc.dtype, device=x_tbc.device
399
+ )
400
+
401
+ x_sa, _, new_k_bh, new_v_bh = self.self_attn(
402
+ query=x_tbc,
403
+ key=x_tbc,
404
+ value=x_tbc,
405
+ key_padding_mask=None,
406
+ incremental_state=incr,
407
+ need_weights=False,
408
+ attn_mask=attn_mask_for_wrapper,
409
+ return_new_kv=True, # <<< NEW: ask wrapper to return this step's K/V
410
+ ) # x_sa: [1,B,C]; new_k_bh/new_v_bh: [B*H, Tnew, Dh]
411
+
412
+ x_sa = self._maybe_apply_head_scale(x_sa)
413
+ if self.attn_ln is not None:
414
+ x_sa = self.attn_ln(x_sa)
415
+
416
+ x_tbc = x_tbc + x_sa
417
+ if not self.normalize_before:
418
+ x_tbc = self.self_attn_layer_norm(x_tbc)
419
+
420
+ # ---- Encoder-Decoder Attention -----------------------------------
421
+ assert encoder_out is not None, "encoder_out is required in export path"
422
+ residual = x_tbc
423
+ if self.normalize_before:
424
+ assert self.encoder_attn_layer_norm is not None
425
+ x_tbc = self.encoder_attn_layer_norm(x_tbc)
426
+
427
+ enc_kpm = encoder_padding_mask # pass-through; wrapper handles bool/additive
428
+ x_ed, attn_out = self.encoder_attn(
429
+ query=x_tbc,
430
+ key=encoder_out,
431
+ value=encoder_out,
432
+ key_padding_mask=enc_kpm,
433
+ incremental_state=None,
434
+ static_kv=True,
435
+ need_weights=need_attn,
436
+ need_head_weights=need_head_weights,
437
+ )
438
+
439
+ x_tbc = residual + x_ed
440
+ if not self.normalize_before:
441
+ assert self.encoder_attn_layer_norm is not None
442
+ x_tbc = self.encoder_attn_layer_norm(x_tbc)
443
+
444
+ # ---- Feed-Forward -------------------------------------------------
445
+ residual = x_tbc
446
+ if self.normalize_before:
447
+ x_tbc = self.final_layer_norm(x_tbc)
448
+
449
+ x_tbc = self.fc1(x_tbc)
450
+ x_tbc = self.activation_fn(x_tbc) # type: ignore[operator]
451
+ x_tbc = self._fq(x_tbc, self.obs_activation_fn)
452
+ if self.ffn_layernorm is not None:
453
+ x_tbc = self.ffn_layernorm(x_tbc)
454
+ x_tbc = self.fc2(x_tbc)
455
+
456
+ if self.w_resid is not None:
457
+ residual = torch.mul(self.w_resid, residual)
458
+
459
+ x_tbc = residual + x_tbc
460
+ if not self.normalize_before:
461
+ x_tbc = self.final_layer_norm(x_tbc)
462
+
463
+ return (
464
+ x_tbc,
465
+ attn_out,
466
+ new_k_bh,
467
+ new_v_bh,
468
+ ) # [1,B,C], attn, [B*H, Tnew, Dh], [B*H, Tnew, Dh]
469
+
470
+ def _all_observers(self) -> Iterable:
471
+ """
472
+ Expose all observers from child PTQ-wrapped modules.
473
+ This layer itself does not add extra per-tensor observers.
474
+ """
475
+ # local observers
476
+ yield from (
477
+ self.obs_activation_fn,
478
+ self.obs_prev_self_k_in,
479
+ self.obs_prev_self_v_in,
480
+ )
481
+
482
+ for m in (
483
+ self.self_attn,
484
+ self.encoder_attn,
485
+ self.fc1,
486
+ self.fc2,
487
+ self.encoder_attn_layer_norm,
488
+ self.self_attn_layer_norm,
489
+ self.final_layer_norm,
490
+ ):
491
+ if isinstance(m, QuantModuleBase) and m is not None:
492
+ yield from m._all_observers()