tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,429 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ import math
22
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn, Tensor
27
+
28
+ from tico.quantization.config.ptq import PTQConfig
29
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
30
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
31
+ from tico.quantization.wrapq.wrappers.registry import try_register
32
+
33
+
34
+ @try_register("fairseq.models.transformer.TransformerDecoderBase")
35
+ class QuantFairseqDecoder(QuantModuleBase):
36
+ """
37
+ Quant-aware drop-in replacement for Fairseq TransformerDecoderBase.
38
+
39
+ Design (inference-only):
40
+ - Keep embeddings, positional embeddings, LayerNorms, output_projection in FP.
41
+ - PTQ-wrap all TransformerDecoderLayerBase items via PTQWrapper (uses QuantFairseqDecoderLayer).
42
+ - Drop training-only logic (dropout, activation-dropout, quant-noise, checkpoint wrappers).
43
+ - Preserve Fairseq forward/extract_features contract, shapes, and incremental decoding behavior.
44
+
45
+ I/O:
46
+ - Forward(prev_output_tokens, encoder_out, incremental_state, ...) -> (logits, extra) like the original.
47
+ - `features_only=True` returns features without output projection.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ fp_decoder: nn.Module,
53
+ *,
54
+ qcfg: Optional[PTQConfig] = None,
55
+ fp_name: Optional[str] = None,
56
+ ):
57
+ super().__init__(qcfg, fp_name=fp_name)
58
+
59
+ # ---- carry config/meta (read-only views) --------------------------
60
+ assert hasattr(fp_decoder, "cfg")
61
+ self.cfg = fp_decoder.cfg
62
+ self.share_input_output_embed: bool = bool(
63
+ getattr(fp_decoder, "share_input_output_embed", False)
64
+ )
65
+
66
+ # Version buffer (parity with original)
67
+ version = getattr(fp_decoder, "version", None)
68
+ if isinstance(version, torch.Tensor):
69
+ self.register_buffer("version", version.clone(), persistent=False)
70
+ else:
71
+ self.register_buffer("version", torch.tensor([3.0]), persistent=False)
72
+
73
+ # Embeddings / positional encodings (FP; reuse modules)
74
+ assert hasattr(fp_decoder, "embed_tokens") and isinstance(
75
+ fp_decoder.embed_tokens, nn.Module
76
+ )
77
+ self.embed_tokens = fp_decoder.embed_tokens # (B,T)->(B,T,C)
78
+
79
+ self.padding_idx: int = int(fp_decoder.padding_idx) # type: ignore[arg-type]
80
+ self.max_target_positions: int = int(fp_decoder.max_target_positions) # type: ignore[arg-type]
81
+
82
+ self.embed_positions = getattr(fp_decoder, "embed_positions", None)
83
+ self.layernorm_embedding = getattr(fp_decoder, "layernorm_embedding", None)
84
+
85
+ # Dimensions / projections (reuse)
86
+ self.embed_dim: int = int(getattr(fp_decoder, "embed_dim"))
87
+ self.output_embed_dim: int = int(getattr(fp_decoder, "output_embed_dim"))
88
+ self.project_in_dim = getattr(fp_decoder, "project_in_dim", None)
89
+ self.project_out_dim = getattr(fp_decoder, "project_out_dim", None)
90
+
91
+ # Scale factor (sqrt(embed_dim) unless disabled)
92
+ no_scale = bool(getattr(self.cfg, "no_scale_embedding", False))
93
+ self.embed_scale: float = 1.0 if no_scale else math.sqrt(self.embed_dim)
94
+
95
+ # Final decoder LayerNorm (may be None depending on cfg)
96
+ self.layer_norm = getattr(fp_decoder, "layer_norm", None)
97
+
98
+ # Output projection / adaptive softmax (reuse FP modules)
99
+ self.adaptive_softmax = getattr(fp_decoder, "adaptive_softmax", None)
100
+ self.output_projection = getattr(fp_decoder, "output_projection", None)
101
+
102
+ # ---- wrap decoder layers ------------------------------------------
103
+ assert hasattr(fp_decoder, "layers")
104
+ fp_layers = list(fp_decoder.layers) # type: ignore[arg-type]
105
+ self.layers = nn.ModuleList()
106
+
107
+ # Safe prefix to avoid None-based name collisions in KV cache keys
108
+ def _safe_prefix(name: Optional[str]) -> str:
109
+ return (
110
+ name
111
+ if (name is not None and name != "" and name != "None")
112
+ else f"{self.__class__.__name__}_{id(self)}"
113
+ )
114
+
115
+ prefix = _safe_prefix(fp_name)
116
+
117
+ # Prepare child PTQConfig namespaces: layers/<idx>
118
+ layers_qcfg = qcfg.child("layers") if qcfg else None
119
+ for i, layer in enumerate(fp_layers):
120
+ child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
121
+ # Not every item is necessarily a TransformerDecoderLayerBase (e.g., BaseLayer).
122
+ # If there's no registered wrapper for a layer type, keep it FP.
123
+ try:
124
+ wrapped = PTQWrapper(
125
+ layer, qcfg=child_cfg, fp_name=f"{prefix}.layers.{i}"
126
+ )
127
+ except NotImplementedError:
128
+ wrapped = layer # keep as-is (FP)
129
+ self.layers.append(wrapped)
130
+ self.num_layers = len(self.layers)
131
+
132
+ # choose a generous upper-bound; you can wire this from cfg if you like
133
+ self.mask_fill_value: float = -120.0
134
+ max_tgt = int(getattr(self.cfg, "max_target_positions", 2048)) # fallback: 2048
135
+
136
+ mask = torch.full((1, 1, max_tgt, max_tgt), float(self.mask_fill_value))
137
+ mask.triu_(1) # upper triangle set to fill_value; diagonal/lower are zeros
138
+ self.register_buffer("causal_mask_template", mask, persistent=False)
139
+
140
+ def forward(
141
+ self,
142
+ prev_output_tokens: Tensor, # [B, T]
143
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
144
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
145
+ features_only: bool = False,
146
+ full_context_alignment: bool = False,
147
+ alignment_layer: Optional[int] = None,
148
+ alignment_heads: Optional[int] = None,
149
+ src_lengths: Optional[Any] = None,
150
+ return_all_hiddens: bool = False,
151
+ ):
152
+ """
153
+ Match the original API.
154
+ Returns:
155
+ (logits_or_features, extra_dict)
156
+ """
157
+ x, extra = self.extract_features_scriptable(
158
+ prev_output_tokens=prev_output_tokens,
159
+ encoder_out=encoder_out,
160
+ incremental_state=incremental_state,
161
+ full_context_alignment=full_context_alignment,
162
+ alignment_layer=alignment_layer,
163
+ alignment_heads=alignment_heads,
164
+ )
165
+ if not features_only:
166
+ x = self.output_layer(x)
167
+ return x, extra
168
+
169
+ def extract_features_scriptable(
170
+ self,
171
+ prev_output_tokens: Tensor, # [B,T]
172
+ encoder_out: Optional[Dict[str, List[Tensor]]],
173
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
174
+ full_context_alignment: bool = False,
175
+ alignment_layer: Optional[int] = None,
176
+ alignment_heads: Optional[int] = None,
177
+ ) -> Tuple[Tensor, Dict[str, List[Optional[Tensor]]]]:
178
+ """
179
+ Feature path that mirrors Fairseq's implementation (minus training-only code).
180
+
181
+ Returns:
182
+ x: [B, T, C]
183
+ extra: {"attn": [attn or None], "inner_states": [T x B x C tensors]}
184
+ """
185
+ B, T = prev_output_tokens.size()
186
+ if alignment_layer is None:
187
+ alignment_layer = self.num_layers - 1
188
+
189
+ # Unpack encoder outputs in Fairseq dict format
190
+ enc: Optional[Tensor] = None
191
+ padding_mask: Optional[Tensor] = None
192
+ if encoder_out is not None and len(encoder_out.get("encoder_out", [])) > 0:
193
+ enc = encoder_out["encoder_out"][0] # [S,B,Ce]
194
+ if (
195
+ encoder_out is not None
196
+ and len(encoder_out.get("encoder_padding_mask", [])) > 0
197
+ ):
198
+ padding_mask = encoder_out["encoder_padding_mask"][0] # [B,S] (bool)
199
+
200
+ # Positional embeddings (support incremental decoding)
201
+ positions = None
202
+ if self.embed_positions is not None:
203
+ positions = self.embed_positions(
204
+ prev_output_tokens, incremental_state=incremental_state
205
+ )
206
+
207
+ # In incremental mode, only the last step is consumed
208
+ if incremental_state is not None:
209
+ prev_output_tokens = prev_output_tokens[:, -1:]
210
+ if positions is not None:
211
+ positions = positions[:, -1:]
212
+
213
+ # Prevent view quirks (TorchScript parity in original)
214
+ prev_output_tokens = prev_output_tokens.contiguous()
215
+
216
+ # Token embeddings (+ optional proj-in), + positions, + optional LN
217
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens) # [B,T,C]
218
+ if self.project_in_dim is not None:
219
+ x = self.project_in_dim(x)
220
+ if positions is not None:
221
+ x = x + positions
222
+ if self.layernorm_embedding is not None:
223
+ x = self.layernorm_embedding(x)
224
+
225
+ # No dropout / quant_noise (inference-only)
226
+
227
+ # B x T x C -> T x B x C
228
+ x = x.transpose(0, 1)
229
+
230
+ # Build self-attn masks
231
+ self_attn_padding_mask: Optional[Tensor] = None
232
+ if (
233
+ getattr(self.cfg, "cross_self_attention", False)
234
+ or prev_output_tokens.eq(self.padding_idx).any()
235
+ ):
236
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) # [B,T]
237
+
238
+ attn: Optional[Tensor] = None
239
+ inner_states: List[Optional[Tensor]] = [x]
240
+
241
+ for idx, layer in enumerate(self.layers):
242
+ # Causal mask unless full-context alignment or incremental decoding
243
+ if incremental_state is None and not full_context_alignment:
244
+ Tq = x.size(0)
245
+ self_attn_mask = self.buffered_future_mask(
246
+ Tq, Tq, x=x
247
+ ) # [Tq,Tq] additive float
248
+ else:
249
+ self_attn_mask = None
250
+
251
+ x, layer_attn, _ = layer(
252
+ x,
253
+ enc,
254
+ padding_mask,
255
+ incremental_state,
256
+ self_attn_mask=self_attn_mask,
257
+ self_attn_padding_mask=self_attn_padding_mask,
258
+ need_attn=bool(idx == alignment_layer),
259
+ need_head_weights=bool(idx == alignment_layer),
260
+ )
261
+
262
+ inner_states.append(x)
263
+ if layer_attn is not None and idx == alignment_layer:
264
+ attn = layer_attn.float().to(x)
265
+
266
+ # Average heads if needed
267
+ if attn is not None and alignment_heads is not None:
268
+ attn = attn[:alignment_heads]
269
+ if attn is not None:
270
+ attn = attn.mean(dim=0) # [B,T,S]
271
+
272
+ # Optional final layer norm
273
+ if self.layer_norm is not None:
274
+ x = self.layer_norm(x)
275
+
276
+ # T x B x C -> B x T x C
277
+ x = x.transpose(0, 1)
278
+
279
+ # Optional proj-out
280
+ if self.project_out_dim is not None:
281
+ assert self.project_out_dim is not None
282
+ x = self.project_out_dim(x)
283
+
284
+ return x, {"attn": [attn], "inner_states": inner_states}
285
+
286
+ def output_layer(self, features: Tensor) -> Tensor:
287
+ """Project features to vocabulary size (or return features with adaptive softmax)."""
288
+ if self.adaptive_softmax is None:
289
+ assert self.output_projection is not None
290
+ return self.output_projection(features) # type: ignore[operator]
291
+ else:
292
+ return features
293
+
294
+ def buffered_future_mask(
295
+ self, Tq: int, Ts: int, *, x: torch.Tensor
296
+ ) -> torch.Tensor:
297
+ """
298
+ Return additive float mask [Tq, Ts]: zeros on allowed, large-neg on disallowed.
299
+ Uses the prebuilt template; will re-build if you exceed template size.
300
+ """
301
+ assert isinstance(self.causal_mask_template, torch.Tensor)
302
+ Mmax = self.causal_mask_template.size(-1)
303
+ assert Tq <= Mmax and Ts <= Mmax
304
+ cm = self.causal_mask_template[..., :Tq, :Ts].to(device=x.device, dtype=x.dtype)
305
+ return cm.squeeze(0).squeeze(0) # [Tq, Ts]
306
+
307
+ def max_positions(self) -> int:
308
+ """Maximum output length supported by the decoder (same policy as the original)."""
309
+ if self.embed_positions is None:
310
+ return self.max_target_positions
311
+ return min(self.max_target_positions, self.embed_positions.max_positions)
312
+
313
+ def get_normalized_probs(
314
+ self,
315
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
316
+ log_probs: bool,
317
+ sample: Optional[Dict[str, Tensor]] = None,
318
+ ):
319
+ """Get normalized probabilities (or log probs) from a net's output."""
320
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
321
+
322
+ def get_normalized_probs_scriptable(
323
+ self,
324
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
325
+ log_probs: bool,
326
+ sample: Optional[Dict[str, Tensor]] = None,
327
+ ):
328
+ """Get normalized probabilities (or log probs) from a net's output."""
329
+
330
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
331
+ if sample is not None:
332
+ assert "target" in sample
333
+ target = sample["target"]
334
+ else:
335
+ target = None
336
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
337
+ return out.exp_() if not log_probs else out
338
+
339
+ logits = net_output[0]
340
+ if log_probs:
341
+ return F.log_softmax(logits, dim=-1, dtype=torch.float32)
342
+ else:
343
+ return F.softmax(logits, dim=-1, dtype=torch.float32)
344
+
345
+ def reorder_incremental_state_scripting(
346
+ self,
347
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
348
+ new_order: Tensor,
349
+ ):
350
+ """Main entry point for reordering the incremental state.
351
+
352
+ Due to limitations in TorchScript, we call this function in
353
+ :class:`fairseq.sequence_generator.SequenceGenerator` instead of
354
+ calling :func:`reorder_incremental_state` directly.
355
+ """
356
+ for module in self.modules():
357
+ if hasattr(module, "reorder_incremental_state"):
358
+ result = module.reorder_incremental_state(incremental_state, new_order) # type: ignore[operator]
359
+ if result is not None:
360
+ incremental_state = result
361
+
362
+ def forward_external_step(
363
+ self,
364
+ prev_output_x: Tensor, # [1, B, C]
365
+ *,
366
+ encoder_out_x: Tensor, # [S, B, Ce]
367
+ encoder_padding_mask: Optional[
368
+ Tensor
369
+ ] = None, # [B,S] or [B,1,S] additive-float
370
+ self_attn_mask: Optional[
371
+ Tensor
372
+ ] = None, # [1,S_hist+1] or [B,1,S_hist+1] additive-float
373
+ prev_self_k_list: Optional[
374
+ List[Tensor]
375
+ ] = None, # length=L; each [B,H,Tprev,Dh]
376
+ prev_self_v_list: Optional[
377
+ List[Tensor]
378
+ ] = None, # length=L; each [B,H,Tprev,Dh]
379
+ need_attn: bool = False,
380
+ need_head_weights: bool = False,
381
+ ) -> Tuple[Tensor, List[Tensor], List[Tensor]]:
382
+ """
383
+ Export-only single-step decoder.
384
+ Returns:
385
+ - x_out: [1, B, C]
386
+ - new_self_k_list/new_self_v_list: lists of length L; each [B*H, Tnew, Dh]
387
+ """
388
+ assert (
389
+ prev_output_x.dim() == 3 and prev_output_x.size(0) == 1
390
+ ), "prev_output_x must be [1,B,C]"
391
+ L = self.num_layers
392
+ if prev_self_k_list is None:
393
+ prev_self_k_list = [None] * L # type: ignore[list-item]
394
+ if prev_self_v_list is None:
395
+ prev_self_v_list = [None] * L # type: ignore[list-item]
396
+ assert len(prev_self_k_list) == L and len(prev_self_v_list) == L
397
+
398
+ assert encoder_out_x.dim() == 3, "encoder_out_x must be [S,B,C]"
399
+ x = prev_output_x # [1,B,C]
400
+ enc = encoder_out_x
401
+
402
+ new_k_list: List[Tensor] = []
403
+ new_v_list: List[Tensor] = []
404
+
405
+ for li, layer in enumerate(self.layers):
406
+ assert isinstance(layer, PTQWrapper)
407
+ x, _, k_new, v_new = layer.wrapped.forward_external( # type: ignore[attr-defined, operator]
408
+ x,
409
+ encoder_out=enc,
410
+ encoder_padding_mask=encoder_padding_mask,
411
+ prev_self_k=prev_self_k_list[li],
412
+ prev_self_v=prev_self_v_list[li],
413
+ self_attn_mask=self_attn_mask,
414
+ need_attn=need_attn and (li == L - 1),
415
+ need_head_weights=need_head_weights and (li == L - 1),
416
+ )
417
+ new_k_list.append(k_new) # [B*H, Tnew, Dh]
418
+ new_v_list.append(v_new) # [B*H, Tnew, Dh]
419
+
420
+ if self.layer_norm is not None:
421
+ x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1)
422
+
423
+ return x, new_k_list, new_v_list # [1,B,C], lists of [B*H, Tnew, Dh]
424
+
425
+ def _all_observers(self) -> Iterable:
426
+ """Yield all observers from wrapped decoder layers (if any)."""
427
+ for m in self.layers:
428
+ if isinstance(m, QuantModuleBase):
429
+ yield from m._all_observers()