tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_assert_nodes.py +3 -1
  10. tico/passes/remove_redundant_expand.py +3 -1
  11. tico/quantization/__init__.py +6 -0
  12. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  29. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  31. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  32. tico/quantization/config/base.py +26 -0
  33. tico/quantization/config/gptq.py +29 -0
  34. tico/quantization/config/pt2e.py +25 -0
  35. tico/quantization/config/ptq.py +119 -0
  36. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  37. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
  38. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  39. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  40. tico/quantization/evaluation/metric.py +146 -0
  41. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  42. tico/quantization/passes/__init__.py +1 -0
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/__init__.py +1 -0
  47. tico/quantization/wrapq/dtypes.py +70 -0
  48. tico/quantization/wrapq/examples/__init__.py +1 -0
  49. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  50. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  51. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  52. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  53. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  54. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  55. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  56. tico/quantization/wrapq/mode.py +32 -0
  57. tico/quantization/wrapq/observers/__init__.py +1 -0
  58. tico/quantization/wrapq/observers/affine_base.py +128 -0
  59. tico/quantization/wrapq/observers/base.py +98 -0
  60. tico/quantization/wrapq/observers/ema.py +62 -0
  61. tico/quantization/wrapq/observers/identity.py +74 -0
  62. tico/quantization/wrapq/observers/minmax.py +39 -0
  63. tico/quantization/wrapq/observers/mx.py +60 -0
  64. tico/quantization/wrapq/qscheme.py +40 -0
  65. tico/quantization/wrapq/quantizer.py +179 -0
  66. tico/quantization/wrapq/utils/__init__.py +1 -0
  67. tico/quantization/wrapq/utils/introspection.py +167 -0
  68. tico/quantization/wrapq/utils/metrics.py +124 -0
  69. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  70. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  71. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  72. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  73. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  74. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  75. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  76. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  77. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  78. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  79. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  80. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  81. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  82. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  83. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  84. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  85. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  86. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  87. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  88. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  89. tico/quantization/wrapq/wrappers/registry.py +125 -0
  90. tico/serialize/circle_serializer.py +11 -4
  91. tico/serialize/operators/adapters/__init__.py +1 -0
  92. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  93. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  94. tico/serialize/operators/op_le.py +54 -0
  95. tico/serialize/operators/op_mm.py +15 -132
  96. tico/serialize/operators/op_rmsnorm.py +65 -0
  97. tico/utils/convert.py +20 -15
  98. tico/utils/dtype.py +22 -0
  99. tico/utils/register_custom_op.py +29 -4
  100. tico/utils/signature.py +247 -0
  101. tico/utils/utils.py +50 -53
  102. tico/utils/validate_args_kwargs.py +37 -0
  103. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  104. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
  105. tico/experimental/quantization/__init__.py +0 -6
  106. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  107. tico/experimental/quantization/evaluation/metric.py +0 -109
  108. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  109. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  111. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  112. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  113. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  114. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  115. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  116. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  117. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  118. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  119. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  120. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  121. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  122. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  123. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  124. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  125. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  126. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  127. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  128. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  129. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  130. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ import math
22
+ from typing import Dict, List, Literal, Optional, Tuple
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch import Tensor
27
+
28
+ from tico.quantization.config.ptq import PTQConfig
29
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
30
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
31
+ from tico.quantization.wrapq.wrappers.registry import try_register
32
+
33
+
34
+ @try_register("fairseq.models.transformer.TransformerEncoderBase")
35
+ class QuantFairseqEncoder(QuantModuleBase):
36
+ """
37
+ Quant-aware drop-in replacement for Fairseq TransformerEncoderBase.
38
+
39
+ Key design choices:
40
+ - Keep embeddings and LayerNorms in FP.
41
+ - Remove training-time logic (dropout, activation-dropout, quant_noise).
42
+ - Attention masks are handled statically inside the layer wrapper; this
43
+ encoder only does the original padding zero-out before the stack.
44
+
45
+ I/O contracts:
46
+ - Forward signature and returned dictionary are identical to the original
47
+ when `use_external_inputs=False`.
48
+ - When `use_external_inputs=True`, forward returns a single Tensor (T,B,C)
49
+ and completely skips embedding/positional/LN/mask-creation paths.
50
+ - Tensor shapes follow Fairseq convention.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ fp_encoder: nn.Module,
56
+ *,
57
+ qcfg: Optional[PTQConfig] = None,
58
+ fp_name: Optional[str] = None,
59
+ use_external_inputs: bool = False, # export-mode flag
60
+ return_type: Literal["tensor", "dict"] = "dict",
61
+ ):
62
+ super().__init__(qcfg, fp_name=fp_name)
63
+ self.use_external_inputs = use_external_inputs
64
+ self.return_type: Literal["tensor", "dict"] = return_type
65
+
66
+ # --- carry basic config / metadata (read-only copies) ---------------
67
+ assert hasattr(fp_encoder, "cfg")
68
+ self.cfg = fp_encoder.cfg
69
+ self.return_fc: bool = bool(getattr(fp_encoder, "return_fc", False))
70
+
71
+ # Embedding stack ----------------------------------------------------
72
+ assert hasattr(fp_encoder, "embed_tokens") and isinstance(
73
+ fp_encoder.embed_tokens, nn.Module
74
+ )
75
+ self.embed_tokens = fp_encoder.embed_tokens # keep FP embeddings
76
+
77
+ assert hasattr(fp_encoder, "padding_idx")
78
+ self.padding_idx: int = int(fp_encoder.padding_idx) # type: ignore[arg-type]
79
+
80
+ # scale = sqrt(embed_dim) unless disabled
81
+ embed_dim = int(self.embed_tokens.embedding_dim) # type: ignore[arg-type]
82
+ no_scale = bool(getattr(self.cfg, "no_scale_embedding", False))
83
+ self.embed_scale: float = 1.0 if no_scale else math.sqrt(embed_dim)
84
+
85
+ # Positional embeddings (keep as-is; no FQ)
86
+ self.embed_positions = getattr(fp_encoder, "embed_positions", None)
87
+ # Optional embedding LayerNorm
88
+ self.layernorm_embedding = getattr(fp_encoder, "layernorm_embedding", None)
89
+
90
+ # Final encoder LayerNorm (pre-norm stacks may set this to None)
91
+ self.layer_norm = getattr(fp_encoder, "layer_norm", None)
92
+
93
+ # Max positions (reuse for API parity)
94
+ self.max_source_positions: int = int(fp_encoder.max_source_positions) # type: ignore[arg-type]
95
+
96
+ # --- wrap encoder layers with PTQWrapper ----------------------------
97
+ assert hasattr(fp_encoder, "layers")
98
+ fp_layers = list(fp_encoder.layers) # type: ignore[arg-type]
99
+ self.layers = nn.ModuleList()
100
+
101
+ # Prepare child PTQConfig namespaces: layers/<idx>
102
+ layers_qcfg = qcfg.child("layers") if qcfg else None
103
+ for i, layer in enumerate(fp_layers):
104
+ child_cfg = layers_qcfg.child(str(i)) if layers_qcfg else None
105
+ self.layers.append(
106
+ PTQWrapper(layer, qcfg=child_cfg, fp_name=f"{fp_name}.layers.{i}")
107
+ )
108
+
109
+ # Version buffer (keep for state_dict parity)
110
+ version = getattr(fp_encoder, "version", None)
111
+ if isinstance(version, torch.Tensor):
112
+ self.register_buffer("version", version.clone(), persistent=False)
113
+ else:
114
+ self.register_buffer("version", torch.tensor([3.0]), persistent=False)
115
+
116
+ # ----------------------------------------------------------------------
117
+ def forward_embedding(
118
+ self, src_tokens: Tensor, token_embedding: Optional[Tensor] = None
119
+ ) -> Tuple[Tensor, Tensor]:
120
+ """
121
+ Embed tokens and add positional embeddings. Dropout/quant_noise are removed.
122
+ Returns:
123
+ x (B, T, C), embed (B, T, C) # embed is the token-only embedding
124
+ """
125
+ if token_embedding is None:
126
+ token_embedding = self.embed_tokens(src_tokens)
127
+ embed = token_embedding # token-only
128
+
129
+ x = self.embed_scale * token_embedding
130
+ if self.embed_positions is not None:
131
+ x = x + self.embed_positions(src_tokens)
132
+ if self.layernorm_embedding is not None:
133
+ x = self.layernorm_embedding(x)
134
+ # No dropout, no quant_noise here (inference-only)
135
+ return x, embed
136
+
137
+ # ----------------------------------------------------------------------
138
+ def forward(
139
+ self,
140
+ src_tokens: Tensor,
141
+ src_lengths: Optional[Tensor] = None,
142
+ return_all_hiddens: bool = False,
143
+ token_embeddings: Optional[Tensor] = None,
144
+ *,
145
+ # External-inputs branch (used for export)
146
+ encoder_padding_mask: Optional[Tensor] = None, # B x T (bool)
147
+ ) -> Tensor | Dict[str, List[Optional[Tensor]]]:
148
+ """
149
+ If `self.use_external_inputs` is True:
150
+ - Use only x_external and encoder_padding_mask.
151
+ - Return a single Tensor (T, B, C) for export friendliness.
152
+
153
+ Otherwise (False):
154
+ - Behave like the original Fairseq encoder forward and return dict-of-lists.
155
+ """
156
+ if self.use_external_inputs:
157
+ # ----- External-input mode: completely skip embedding/positional/LN/mask creation -----
158
+ x_external = src_tokens # T x B x C (already embedded + transposed)
159
+
160
+ encoder_states: List[Tensor] = []
161
+ if return_all_hiddens:
162
+ encoder_states.append(x_external)
163
+
164
+ for layer in self.layers:
165
+ out = layer(x_external, encoder_padding_mask=encoder_padding_mask)
166
+ x_external = (
167
+ out[0] if (isinstance(out, tuple) and len(out) == 2) else out
168
+ )
169
+ if return_all_hiddens:
170
+ encoder_states.append(x_external)
171
+
172
+ if self.layer_norm is not None:
173
+ x_external = self.layer_norm(x_external)
174
+
175
+ if self.return_type == "dict":
176
+ return {
177
+ "encoder_out": [x_external],
178
+ "encoder_padding_mask": [encoder_padding_mask],
179
+ "encoder_states": encoder_states, # type: ignore[dict-item]
180
+ }
181
+ else:
182
+ # For export, returning a single Tensor is simpler and more portable.
183
+ return x_external
184
+
185
+ # ----- Original path (training/eval compatibility) ------------------
186
+
187
+ # Compute padding mask [B, T] (bool). We keep the original "has_pads" logic.
188
+ encoder_padding_mask = src_tokens.eq(self.padding_idx)
189
+ has_pads: Tensor = (
190
+ torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
191
+ )
192
+ if torch.jit.is_scripting():
193
+ has_pads = torch.tensor(1) if has_pads else torch.tensor(0)
194
+
195
+ # Embedding path (B,T,C). No dropout/quant_noise.
196
+ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
197
+
198
+ # Zero out padded timesteps prior to the stack (same as original)
199
+ x = x * (
200
+ 1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
201
+ )
202
+
203
+ # B x T x C -> T x B x C
204
+ x = x.transpose(0, 1)
205
+
206
+ encoder_states: List[Tensor] = [] # type: ignore[no-redef]
207
+ fc_results: List[Optional[Tensor]] = []
208
+
209
+ if return_all_hiddens:
210
+ encoder_states.append(x)
211
+
212
+ # Encoder layers (each item is PTQ-wrapped and uses static additive masks internally)
213
+ for layer in self.layers:
214
+ out = layer(
215
+ x, encoder_padding_mask=encoder_padding_mask if has_pads else None
216
+ )
217
+ if isinstance(out, tuple) and len(out) == 2:
218
+ x, fc_res = out
219
+ else:
220
+ x = out
221
+ fc_res = None
222
+
223
+ if return_all_hiddens and not torch.jit.is_scripting():
224
+ encoder_states.append(x)
225
+ fc_results.append(fc_res)
226
+
227
+ if self.layer_norm is not None:
228
+ x = self.layer_norm(x)
229
+
230
+ # src_lengths (B, 1) int32, identical to original
231
+ src_lengths_out = (
232
+ src_tokens.ne(self.padding_idx)
233
+ .sum(dim=1, dtype=torch.int32)
234
+ .reshape(-1, 1)
235
+ .contiguous()
236
+ )
237
+
238
+ return {
239
+ "encoder_out": [x], # T x B x C
240
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
241
+ "encoder_embedding": [encoder_embedding], # B x T x C
242
+ "encoder_states": encoder_states, # type: ignore[dict-item] # List[T x B x C]
243
+ "fc_results": fc_results, # type: ignore[dict-item] # List[T x B x C]
244
+ "src_tokens": [],
245
+ "src_lengths": [src_lengths_out],
246
+ }
247
+
248
+ def forward_torchscript(self, net_input: Dict[str, Tensor]):
249
+ """A TorchScript-compatible version of forward.
250
+
251
+ Encoders which use additional arguments may want to override
252
+ this method for TorchScript compatibility.
253
+ """
254
+ if "encoder_padding_mask" in net_input:
255
+ return self.forward(
256
+ src_tokens=net_input["src_tokens"],
257
+ src_lengths=net_input["src_lengths"],
258
+ encoder_padding_mask=net_input["encoder_padding_mask"],
259
+ )
260
+ else:
261
+ return self.forward(
262
+ src_tokens=net_input["src_tokens"],
263
+ src_lengths=net_input["src_lengths"],
264
+ )
265
+
266
+ # ----------------------------------------------------------------------
267
+ @torch.jit.export
268
+ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
269
+ """
270
+ Match original API: reorder the batched dimension (B) according to new_order.
271
+ """
272
+ reordered = dict() # type: ignore[var-annotated]
273
+ if len(encoder_out["encoder_out"]) == 0:
274
+ new_encoder_out = []
275
+ else:
276
+ new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
277
+ reordered["encoder_out"] = new_encoder_out
278
+ keys = [
279
+ "encoder_padding_mask",
280
+ "encoder_embedding",
281
+ "src_tokens",
282
+ "src_lengths",
283
+ ]
284
+ for k in keys:
285
+ if k not in encoder_out:
286
+ continue
287
+ if len(encoder_out[k]) == 0:
288
+ reordered[k] = []
289
+ else:
290
+ reordered[k] = [encoder_out[k][0].index_select(0, new_order)]
291
+
292
+ if "encoder_states" in encoder_out:
293
+ encoder_states = encoder_out["encoder_states"]
294
+ if len(encoder_states) > 0:
295
+ for idx, state in enumerate(encoder_states):
296
+ encoder_states[idx] = state.index_select(1, new_order)
297
+ reordered["encoder_states"] = encoder_states
298
+
299
+ return reordered
300
+
301
+ @torch.jit.export
302
+ def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
303
+ """Dummy re-order for beamable enc-dec attention (API parity)."""
304
+ return encoder_out
305
+
306
+ def max_positions(self) -> int:
307
+ """Maximum input length supported by the encoder (same policy as the original)."""
308
+ if self.embed_positions is None:
309
+ return self.max_source_positions
310
+ return min(self.max_source_positions, self.embed_positions.max_positions)
311
+
312
+ def upgrade_state_dict_named(self, state_dict, name):
313
+ """
314
+ Forward-compat mapping for older checkpoints (mirror original behavior for LNs).
315
+ The actual remapping of per-layer norms is delegated to the wrapped layers.
316
+ """
317
+ for i, layer in enumerate(self.layers):
318
+ if hasattr(layer, "upgrade_state_dict_named"):
319
+ layer.upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")
320
+
321
+ version_key = f"{name}.version"
322
+ v = state_dict.get(version_key, torch.Tensor([1]))
323
+ if float(v[0].item()) < 2:
324
+ self.layer_norm = None
325
+ state_dict[version_key] = torch.Tensor([1])
326
+ return state_dict
327
+
328
+ def _all_observers(self):
329
+ for m in self.layers:
330
+ if isinstance(m, QuantModuleBase):
331
+ yield from m._all_observers()
@@ -0,0 +1,163 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ from typing import Optional
22
+
23
+ import torch.nn as nn
24
+ from torch import Tensor
25
+
26
+ from tico.quantization.config.ptq import PTQConfig
27
+ from tico.quantization.wrapq.wrappers.fairseq.quant_mha import (
28
+ QuantFairseqMultiheadAttention,
29
+ )
30
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
31
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
32
+ from tico.quantization.wrapq.wrappers.registry import try_register
33
+
34
+
35
+ @try_register("fairseq.modules.transformer_layer.TransformerEncoderLayerBase")
36
+ class QuantFairseqEncoderLayer(QuantModuleBase):
37
+ """
38
+ Quant-aware drop-in replacement for Fairseq TransformerEncoderLayerBase.
39
+
40
+ Design notes (inference-friendly):
41
+ - All training-time logic (dropout, activation-dropout) is removed.
42
+ - I/O shape follows Fairseq convention: [T, B, C].
43
+ - `return_fc` behavior is preserved (returns (x, fc_result) if enabled).
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ fp_layer: nn.Module,
49
+ *,
50
+ qcfg: Optional[PTQConfig] = None,
51
+ fp_name: Optional[str] = None,
52
+ ):
53
+ super().__init__(qcfg, fp_name=fp_name)
54
+
55
+ # --- copy meta / config flags from FP layer (read-only) -------------
56
+ assert hasattr(fp_layer, "embed_dim")
57
+ assert hasattr(fp_layer, "normalize_before")
58
+ self.embed_dim: int = int(fp_layer.embed_dim) # type: ignore[arg-type]
59
+ self.normalize_before: bool = bool(fp_layer.normalize_before)
60
+ self.return_fc: bool = bool(getattr(fp_layer, "return_fc", False))
61
+
62
+ # --- PTQ-wrapped submodules ----------------------------------------
63
+ attn_cfg = qcfg.child("self_attn") if qcfg else None
64
+ fc1_cfg = qcfg.child("fc1") if qcfg else None
65
+ fc2_cfg = qcfg.child("fc2") if qcfg else None
66
+ attn_ln_cfg = qcfg.child("self_attn_layer_norm") if qcfg else None
67
+ final_ln_cfg = qcfg.child("final_layer_norm") if qcfg else None
68
+
69
+ assert hasattr(fp_layer, "self_attn") and isinstance(
70
+ fp_layer.self_attn, nn.Module
71
+ )
72
+ assert hasattr(fp_layer, "fc1") and isinstance(fp_layer.fc1, nn.Module)
73
+ assert hasattr(fp_layer, "fc2") and isinstance(fp_layer.fc2, nn.Module)
74
+
75
+ self.self_attn = QuantFairseqMultiheadAttention(
76
+ fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{fp_name}.self_attn"
77
+ )
78
+ self.fc1 = PTQWrapper(fp_layer.fc1, qcfg=fc1_cfg, fp_name=f"{fp_name}.fc1")
79
+ self.fc2 = PTQWrapper(fp_layer.fc2, qcfg=fc2_cfg, fp_name=f"{fp_name}.fc2")
80
+
81
+ # LayerNorms
82
+ assert hasattr(fp_layer, "self_attn_layer_norm") and isinstance(
83
+ fp_layer.self_attn_layer_norm, nn.Module
84
+ )
85
+ assert hasattr(fp_layer, "final_layer_norm") and isinstance(
86
+ fp_layer.final_layer_norm, nn.Module
87
+ )
88
+ self.self_attn_layer_norm = PTQWrapper(
89
+ fp_layer.self_attn_layer_norm,
90
+ qcfg=attn_ln_cfg,
91
+ fp_name=f"{fp_name}.self_attn_layer_norm",
92
+ )
93
+ self.final_layer_norm = PTQWrapper(
94
+ fp_layer.final_layer_norm,
95
+ qcfg=final_ln_cfg,
96
+ fp_name=f"{fp_name}.final_layer_norm",
97
+ )
98
+
99
+ # Activation function
100
+ self.activation_fn = fp_layer.activation_fn # type: ignore[operator] # e.g., GELU/ReLU
101
+ self.obs_activation_fn = self._make_obs("activation_fn")
102
+
103
+ # ----------------------------------------------------------------------
104
+ def forward(
105
+ self,
106
+ x: Tensor, # [T,B,C]
107
+ encoder_padding_mask: Optional[Tensor],
108
+ attn_mask: Optional[Tensor] = None, # [T,S] boolean/byte or additive float
109
+ ):
110
+ """
111
+ Returns:
112
+ x' of shape [T, B, C] (or (x', fc_result) when return_fc=True)
113
+ """
114
+ # ---- Self-Attention block (pre-/post-norm kept as in FP layer) ----
115
+ residual = x
116
+ if self.normalize_before:
117
+ x = self.self_attn_layer_norm(x)
118
+
119
+ # Fairseq MHA expects [T,B,C]; our wrapped module keeps the same API
120
+ attn_out, _ = self.self_attn(
121
+ query=x,
122
+ key=x,
123
+ value=x,
124
+ key_padding_mask=encoder_padding_mask, # additive float [B,S] or None
125
+ need_weights=False,
126
+ attn_mask=attn_mask, # additive float [T,S] or None
127
+ )
128
+ x = residual + attn_out
129
+
130
+ if not self.normalize_before:
131
+ x = self.self_attn_layer_norm(x)
132
+
133
+ # ---- FFN block (no dropout/activation-dropout) --------------------
134
+ residual = x
135
+ if self.normalize_before:
136
+ x = self.final_layer_norm(x)
137
+
138
+ x = self.fc1(x) # Linear
139
+ x = self.activation_fn(x) # type: ignore[operator]
140
+ x = self._fq(x, self.obs_activation_fn)
141
+ x = self.fc2(x) # Linear
142
+
143
+ fc_result = x # keep before residual for optional return
144
+
145
+ x = residual + x
146
+ if not self.normalize_before:
147
+ x = self.final_layer_norm(x)
148
+
149
+ if self.return_fc:
150
+ return x, fc_result
151
+ return x
152
+
153
+ def _all_observers(self):
154
+ yield from (self.obs_activation_fn,)
155
+ for m in (
156
+ self.self_attn,
157
+ self.fc1,
158
+ self.fc2,
159
+ self.self_attn_layer_norm,
160
+ self.final_layer_norm,
161
+ ):
162
+ if isinstance(m, QuantModuleBase):
163
+ yield from m._all_observers()