tico 0.1.0.dev250917__py3-none-any.whl → 0.1.0.dev250921__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (25) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +3 -0
  3. tico/experimental/quantization/algorithm/gptq/quantizer.py +2 -2
  4. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +1 -1
  5. tico/experimental/quantization/config/__init__.py +1 -0
  6. tico/experimental/quantization/config/base.py +26 -0
  7. tico/experimental/quantization/config/gptq.py +29 -0
  8. tico/experimental/quantization/config/pt2e.py +25 -0
  9. tico/experimental/quantization/{config.py → config/smoothquant.py} +1 -35
  10. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +191 -70
  11. tico/experimental/quantization/ptq/wrappers/fairseq/quant_decoder_layer.py +494 -0
  12. tico/experimental/quantization/ptq/wrappers/registry.py +1 -0
  13. tico/experimental/quantization/public_interface.py +1 -1
  14. tico/experimental/quantization/quantizer.py +1 -1
  15. tico/passes/convert_matmul_to_linear.py +200 -0
  16. tico/passes/convert_to_relu6.py +1 -1
  17. tico/serialize/circle_serializer.py +11 -4
  18. tico/serialize/operators/op_mm.py +15 -132
  19. tico/utils/convert.py +6 -1
  20. {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/METADATA +1 -1
  21. {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/RECORD +25 -19
  22. {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/LICENSE +0 -0
  23. {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/WHEEL +0 -0
  24. {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/entry_points.txt +0 -0
  25. {tico-0.1.0.dev250917.dist-info → tico-0.1.0.dev250921.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,494 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ from typing import Dict, Iterable, List, Optional, Tuple
22
+
23
+ import torch
24
+ from torch import nn, Tensor
25
+
26
+ from tico.experimental.quantization.ptq.quant_config import QuantConfig
27
+ from tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha import (
28
+ QuantFairseqMultiheadAttention,
29
+ )
30
+ from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
31
+ from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
32
+ QuantModuleBase,
33
+ )
34
+ from tico.experimental.quantization.ptq.wrappers.registry import try_register
35
+
36
+
37
+ @try_register("fairseq.modules.transformer_layer.TransformerDecoderLayerBase")
38
+ class QuantFairseqDecoderLayer(QuantModuleBase):
39
+ """
40
+ Quant-aware drop-in replacement for Fairseq TransformerDecoderLayerBase.
41
+
42
+ Design (inference-only):
43
+ - Keep LayerNorms and scalar head/residual scalers in FP.
44
+ - PTQ-wrap: self_attn, (optional) encoder_attn, fc1, fc2.
45
+ - Preserve Fairseq tensor contracts and incremental state handling.
46
+ - Remove training-time behaviors: dropout, activation-dropout, quant-noise, onnx_trace.
47
+
48
+ I/O:
49
+ - Input/Output use Fairseq shapes: [T, B, C].
50
+ - Forward returns: (x, attn, None) to match the original call sites in decoder.
51
+ * `attn` is from encoder-attention when requested (alignment).
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ fp_layer: nn.Module,
57
+ *,
58
+ qcfg: Optional[QuantConfig] = None,
59
+ fp_name: Optional[str] = None,
60
+ ):
61
+ super().__init__(qcfg, fp_name=fp_name)
62
+
63
+ # --- read-only metadata copied from FP layer -----------------------
64
+ assert hasattr(fp_layer, "embed_dim")
65
+ assert hasattr(fp_layer, "normalize_before")
66
+ self.embed_dim: int = int(fp_layer.embed_dim) # type: ignore[arg-type]
67
+ self.normalize_before: bool = bool(fp_layer.normalize_before)
68
+
69
+ # Cross-self attention flag (when True, key/value can include encoder_out)
70
+ self.cross_self_attention: bool = bool(
71
+ getattr(fp_layer, "cross_self_attention", False)
72
+ )
73
+
74
+ # Generate prefix
75
+ def _safe_prefix(name: Optional[str]) -> str:
76
+ # Avoid "None.*" strings causing collisions
77
+ return (
78
+ name
79
+ if (name is not None and name != "None" and name != "")
80
+ else f"{self.__class__.__name__}_{id(self)}"
81
+ )
82
+
83
+ prefix = _safe_prefix(fp_name)
84
+ # Self-attn (PTQ) ---------------------------------------------------
85
+ # Use our MHA wrapper with identical API to the FP module.
86
+ attn_cfg = qcfg.child("self_attn") if qcfg else None
87
+ assert hasattr(fp_layer, "self_attn") and isinstance(
88
+ fp_layer.self_attn, nn.Module
89
+ )
90
+ self.self_attn = QuantFairseqMultiheadAttention(
91
+ fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{prefix}.self_attn"
92
+ )
93
+
94
+ # Optional attention LayerNorm applied to self-attn output (scale_attn)
95
+ # Kept in FP; reuse original instance for weight parity.
96
+ self.attn_ln = getattr(fp_layer, "attn_ln", None)
97
+
98
+ # Optional per-head scaling after self-attn output (scale_heads)
99
+ # Keep exact Parameter reference if present (shape: [num_heads])
100
+ self.c_attn = getattr(fp_layer, "c_attn", None)
101
+
102
+ # Cache head meta for c_attn path
103
+ self.nh = int(getattr(self.self_attn, "num_heads"))
104
+ self.head_dim = int(getattr(self.self_attn, "head_dim"))
105
+
106
+ # Encoder-attn (PTQ) ------------------------------------------------
107
+ # Only present if the original layer was constructed with encoder_attn.
108
+ enc_attn_mod = getattr(fp_layer, "encoder_attn", None)
109
+ assert enc_attn_mod is not None
110
+ enc_cfg = qcfg.child("encoder_attn") if qcfg else None
111
+ self.encoder_attn = QuantFairseqMultiheadAttention(
112
+ enc_attn_mod, qcfg=enc_cfg, fp_name=f"{prefix}.encoder_attn"
113
+ )
114
+
115
+ # Feed-forward (PTQ) ------------------------------------------------
116
+ fc1_cfg = qcfg.child("fc1") if qcfg else None
117
+ fc2_cfg = qcfg.child("fc2") if qcfg else None
118
+ assert hasattr(fp_layer, "fc1") and isinstance(fp_layer.fc1, nn.Module)
119
+ assert hasattr(fp_layer, "fc2") and isinstance(fp_layer.fc2, nn.Module)
120
+ self.fc1 = PTQWrapper(fp_layer.fc1, qcfg=fc1_cfg, fp_name=f"{fp_name}.fc1")
121
+ self.fc2 = PTQWrapper(fp_layer.fc2, qcfg=fc2_cfg, fp_name=f"{fp_name}.fc2")
122
+
123
+ # LayerNorms
124
+ enc_attn_ln_cfg = qcfg.child("encoder_attn_layer_norm") if qcfg else None
125
+ attn_ln_cfg = qcfg.child("self_attn_layer_norm") if qcfg else None
126
+ final_ln_cfg = qcfg.child("final_layer_norm") if qcfg else None
127
+ assert hasattr(fp_layer, "encoder_attn_layer_norm") and isinstance(
128
+ fp_layer.encoder_attn_layer_norm, nn.Module
129
+ )
130
+ assert hasattr(fp_layer, "self_attn_layer_norm") and isinstance(
131
+ fp_layer.self_attn_layer_norm, nn.Module
132
+ )
133
+ assert hasattr(fp_layer, "final_layer_norm") and isinstance(
134
+ fp_layer.final_layer_norm, nn.Module
135
+ )
136
+ self.encoder_attn_layer_norm = PTQWrapper(
137
+ fp_layer.encoder_attn_layer_norm,
138
+ qcfg=enc_attn_ln_cfg,
139
+ fp_name=f"{fp_name}.encoder_attn_layer_norm",
140
+ )
141
+ self.self_attn_layer_norm = PTQWrapper(
142
+ fp_layer.self_attn_layer_norm,
143
+ qcfg=attn_ln_cfg,
144
+ fp_name=f"{fp_name}.self_attn_layer_norm",
145
+ )
146
+ self.final_layer_norm = PTQWrapper(
147
+ fp_layer.final_layer_norm,
148
+ qcfg=final_ln_cfg,
149
+ fp_name=f"{fp_name}.final_layer_norm",
150
+ )
151
+
152
+ # Optional FFN intermediate LayerNorm (scale_fc), FP
153
+ self.ffn_layernorm = getattr(fp_layer, "ffn_layernorm", None)
154
+
155
+ # Optional residual scaling (scale_resids), keep Parameter reference
156
+ self.w_resid = getattr(fp_layer, "w_resid", None)
157
+
158
+ # Activation function
159
+ self.activation_fn = fp_layer.activation_fn # type: ignore[operator]
160
+ self.obs_activation_fn = self._make_obs("activation_fn")
161
+
162
+ # Alignment flag used by Fairseq (kept for API parity)
163
+ self.need_attn: bool = bool(getattr(fp_layer, "need_attn", True))
164
+
165
+ # No dropout / activation-dropout in inference wrapper
166
+ # (intentionally omitted)
167
+
168
+ # --- observers for external/self-attn KV cache inputs --------------
169
+ self.obs_prev_self_k_in = self._make_obs("prev_self_k_in")
170
+ self.obs_prev_self_v_in = self._make_obs("prev_self_v_in")
171
+
172
+ # ----------------------------------------------------------------------
173
+ def _maybe_apply_head_scale(self, x: Tensor) -> Tensor:
174
+ """
175
+ Optional per-head scaling (scale_heads) after self-attention.
176
+ x: [T, B, C]
177
+ """
178
+ if self.c_attn is None:
179
+ return x
180
+ T, B, _ = x.shape
181
+ x = x.view(T, B, self.nh, self.head_dim) # [T,B,H,Dh]
182
+ # einsum over head dim: scales each head independently
183
+ x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) # [T,B,H,Dh]
184
+ return x.reshape(T, B, self.nh * self.head_dim) # [T,B,C]
185
+
186
+ # ----------------------------------------------------------------------
187
+ def forward(
188
+ self,
189
+ x: Tensor, # [T,B,C]
190
+ encoder_out: Optional[Tensor] = None, # [S,B,Ce] or None
191
+ encoder_padding_mask: Optional[Tensor] = None, # [B,S] bool or additive float
192
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
193
+ prev_self_attn_state: Optional[List[Tensor]] = None,
194
+ prev_attn_state: Optional[List[Tensor]] = None,
195
+ self_attn_mask: Optional[Tensor] = None, # [T,T] or [B,T,T] or None
196
+ self_attn_padding_mask: Optional[Tensor] = None, # [B,T] or [B,T,T] or None
197
+ need_attn: bool = False,
198
+ need_head_weights: bool = False,
199
+ ) -> Tuple[Tensor, Optional[Tensor], None]:
200
+ """
201
+ Mirrors the original forward, minus training-only logic.
202
+ Returns:
203
+ x': [T,B,C], attn (from encoder-attn when requested), None
204
+ """
205
+ if need_head_weights:
206
+ need_attn = True
207
+
208
+ # ---- (1) Self-Attention block ------------------------------------
209
+ residual = x
210
+ if self.normalize_before:
211
+ x = self.self_attn_layer_norm(x)
212
+
213
+ # Load provided cached self-attn state (for incremental decoding)
214
+ if prev_self_attn_state is not None:
215
+ prev_key, prev_value = prev_self_attn_state[:2]
216
+ saved_state: Dict[str, Optional[Tensor]] = {
217
+ "prev_key": prev_key,
218
+ "prev_value": prev_value,
219
+ }
220
+ if len(prev_self_attn_state) >= 3:
221
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
222
+ assert incremental_state is not None
223
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
224
+
225
+ # Cross-self-attention: prepend encoder_out to K/V at the first step
226
+ y = x
227
+ if self.cross_self_attention:
228
+ _buf = self.self_attn._get_input_buffer(incremental_state)
229
+ no_cache_yet = not (
230
+ incremental_state is not None
231
+ and _buf is not None
232
+ and "prev_key" in _buf
233
+ )
234
+ if no_cache_yet:
235
+ if self_attn_mask is not None:
236
+ assert encoder_out is not None
237
+ # Grow attn mask to cover encoder timesteps (no autoregressive penalty for them)
238
+ self_attn_mask = torch.cat(
239
+ (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask),
240
+ dim=1,
241
+ )
242
+ if self_attn_padding_mask is not None:
243
+ if encoder_padding_mask is None:
244
+ assert encoder_out is not None
245
+ encoder_padding_mask = self_attn_padding_mask.new_zeros(
246
+ encoder_out.size(1), encoder_out.size(0)
247
+ )
248
+ # Concatenate encoder pad-mask in front of target pad-mask
249
+ self_attn_padding_mask = torch.cat(
250
+ (encoder_padding_mask, self_attn_padding_mask), dim=1
251
+ )
252
+ assert encoder_out is not None
253
+ y = torch.cat((encoder_out, x), dim=0) # [S+T, B, C]
254
+
255
+ # Self-attn; Fairseq never consumes self-attn weights for alignment here
256
+ x, _ = self.self_attn(
257
+ query=x,
258
+ key=y,
259
+ value=y,
260
+ key_padding_mask=self_attn_padding_mask,
261
+ incremental_state=incremental_state,
262
+ need_weights=False,
263
+ attn_mask=self_attn_mask,
264
+ )
265
+
266
+ # Optional per-head scaling and attn LayerNorm on self-attn output
267
+ x = self._maybe_apply_head_scale(x)
268
+ if self.attn_ln is not None:
269
+ x = self.attn_ln(x)
270
+
271
+ # Residual + (post-norm if applicable)
272
+ x = residual + x
273
+ if not self.normalize_before:
274
+ x = self.self_attn_layer_norm(x)
275
+
276
+ # ---- (2) Encoder-Decoder Attention block --------------------------
277
+ attn_out: Optional[Tensor] = None
278
+ assert encoder_out is not None
279
+ residual = x
280
+ assert self.encoder_attn_layer_norm is not None
281
+ if self.normalize_before:
282
+ x = self.encoder_attn_layer_norm(x)
283
+
284
+ # Load provided cached cross-attn state
285
+ if prev_attn_state is not None:
286
+ prev_key, prev_value = prev_attn_state[:2]
287
+ saved_state = {"prev_key": prev_key, "prev_value": prev_value}
288
+ if len(prev_attn_state) >= 3:
289
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
290
+ assert incremental_state is not None
291
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
292
+
293
+ # Cross-attn (static_kv=True to reuse encoder K/V across steps)
294
+ assert self.encoder_attn is not None
295
+ x, attn_out = self.encoder_attn(
296
+ query=x,
297
+ key=encoder_out,
298
+ value=encoder_out,
299
+ key_padding_mask=encoder_padding_mask,
300
+ incremental_state=incremental_state,
301
+ static_kv=True,
302
+ need_weights=need_attn or self.need_attn,
303
+ need_head_weights=need_head_weights,
304
+ )
305
+
306
+ x = residual + x
307
+ if not self.normalize_before:
308
+ x = self.encoder_attn_layer_norm(x)
309
+
310
+ # ---- (3) Feed-Forward block --------------------------------------
311
+ residual = x
312
+ if self.normalize_before:
313
+ x = self.final_layer_norm(x)
314
+
315
+ # FFN: fc1 -> activation -> (optional LN) -> fc2
316
+ x = self.fc1(x)
317
+ x = self.activation_fn(x) # type: ignore[operator]
318
+ x = self._fq(x, self.obs_activation_fn)
319
+ if self.ffn_layernorm is not None:
320
+ x = self.ffn_layernorm(x)
321
+ x = self.fc2(x)
322
+
323
+ # Optional residual scaling (scale_resids)
324
+ if self.w_resid is not None:
325
+ residual = torch.mul(self.w_resid, residual)
326
+
327
+ x = residual + x
328
+ if not self.normalize_before:
329
+ x = self.final_layer_norm(x)
330
+
331
+ # Return attn from encoder-attn branch when requested; self-attn weights are not returned.
332
+ return x, attn_out, None
333
+
334
+ def forward_external(
335
+ self,
336
+ x: Tensor, # [1, B, C] (embedded current-step token)
337
+ *,
338
+ encoder_out: Optional[Tensor], # [S, B, Ce]
339
+ encoder_padding_mask: Optional[
340
+ Tensor
341
+ ] = None, # [B,S] bool or additive-float or [B,1,S] additive-float
342
+ prev_self_k: Optional[Tensor] = None, # [B, H, Tprev, Dh]
343
+ prev_self_v: Optional[Tensor] = None, # [B, H, Tprev, Dh]
344
+ self_attn_mask: Optional[
345
+ Tensor
346
+ ] = None, # [1, 1, S_hist+1] or [B,1,S_hist+1] additive-float
347
+ need_attn: bool = False,
348
+ need_head_weights: bool = False,
349
+ ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor]:
350
+ """
351
+ Export-only single-step:
352
+ Returns (x_out[1,B,C], attn_from_cross, new_self_k[B,H,1,Dh], new_self_v[B,H,1,Dh]).
353
+ """
354
+ if need_head_weights:
355
+ need_attn = True
356
+
357
+ assert x.dim() == 3 and x.size(0) == 1, "x must be [1,B,C]"
358
+ B = x.size(1)
359
+
360
+ # ---- Self-Attention (uses MHA return_new_kv) ----------------------
361
+ x_tbc = x
362
+ if self.normalize_before:
363
+ x_tbc = self.self_attn_layer_norm(x_tbc)
364
+
365
+ # Provide prev KV via incremental_state so wrapper appends internally
366
+ incr: Dict[str, Dict[str, Optional[Tensor]]] = {}
367
+ if prev_self_k is not None and prev_self_v is not None:
368
+ # Attach observers to incoming caches
369
+ prev_self_k = self._fq(prev_self_k, self.obs_prev_self_k_in)
370
+ prev_self_v = self._fq(prev_self_v, self.obs_prev_self_v_in)
371
+ assert isinstance(prev_self_k, Tensor) and isinstance(prev_self_v, Tensor)
372
+ saved = {
373
+ "prev_key": prev_self_k.detach(),
374
+ "prev_value": prev_self_v.detach(),
375
+ }
376
+ self.self_attn._set_input_buffer(incr, saved) # type: ignore[arg-type]
377
+
378
+ # Normalize self-attn additive mask to shapes wrapper accepts: [T,S] or [B,T,S]
379
+ attn_mask_for_wrapper = None
380
+ if self_attn_mask is not None:
381
+ if (
382
+ self_attn_mask.dim() == 3
383
+ and self_attn_mask.size(0) == B
384
+ and self_attn_mask.size(1) == 1
385
+ ):
386
+ attn_mask_for_wrapper = self_attn_mask # [B,1,S]
387
+ elif (
388
+ self_attn_mask.dim() == 3
389
+ and self_attn_mask.size(0) == 1
390
+ and self_attn_mask.size(1) == 1
391
+ ):
392
+ attn_mask_for_wrapper = self_attn_mask[0] # -> [1,S]
393
+ elif self_attn_mask.dim() == 2 and self_attn_mask.size(0) == 1:
394
+ attn_mask_for_wrapper = self_attn_mask # [1,S]
395
+ else:
396
+ raise RuntimeError(
397
+ "self_attn_mask must be [1,S] or [B,1,S] additive-float."
398
+ )
399
+ attn_mask_for_wrapper = attn_mask_for_wrapper.to(
400
+ dtype=x_tbc.dtype, device=x_tbc.device
401
+ )
402
+
403
+ x_sa, _, new_k_bh, new_v_bh = self.self_attn(
404
+ query=x_tbc,
405
+ key=x_tbc,
406
+ value=x_tbc,
407
+ key_padding_mask=None,
408
+ incremental_state=incr,
409
+ need_weights=False,
410
+ attn_mask=attn_mask_for_wrapper,
411
+ return_new_kv=True, # <<< NEW: ask wrapper to return this step's K/V
412
+ ) # x_sa: [1,B,C]; new_k_bh/new_v_bh: [B*H, Tnew, Dh]
413
+
414
+ x_sa = self._maybe_apply_head_scale(x_sa)
415
+ if self.attn_ln is not None:
416
+ x_sa = self.attn_ln(x_sa)
417
+
418
+ x_tbc = x_tbc + x_sa
419
+ if not self.normalize_before:
420
+ x_tbc = self.self_attn_layer_norm(x_tbc)
421
+
422
+ # ---- Encoder-Decoder Attention -----------------------------------
423
+ assert encoder_out is not None, "encoder_out is required in export path"
424
+ residual = x_tbc
425
+ if self.normalize_before:
426
+ assert self.encoder_attn_layer_norm is not None
427
+ x_tbc = self.encoder_attn_layer_norm(x_tbc)
428
+
429
+ enc_kpm = encoder_padding_mask # pass-through; wrapper handles bool/additive
430
+ x_ed, attn_out = self.encoder_attn(
431
+ query=x_tbc,
432
+ key=encoder_out,
433
+ value=encoder_out,
434
+ key_padding_mask=enc_kpm,
435
+ incremental_state=None,
436
+ static_kv=True,
437
+ need_weights=need_attn,
438
+ need_head_weights=need_head_weights,
439
+ )
440
+
441
+ x_tbc = residual + x_ed
442
+ if not self.normalize_before:
443
+ assert self.encoder_attn_layer_norm is not None
444
+ x_tbc = self.encoder_attn_layer_norm(x_tbc)
445
+
446
+ # ---- Feed-Forward -------------------------------------------------
447
+ residual = x_tbc
448
+ if self.normalize_before:
449
+ x_tbc = self.final_layer_norm(x_tbc)
450
+
451
+ x_tbc = self.fc1(x_tbc)
452
+ x_tbc = self.activation_fn(x_tbc) # type: ignore[operator]
453
+ x_tbc = self._fq(x_tbc, self.obs_activation_fn)
454
+ if self.ffn_layernorm is not None:
455
+ x_tbc = self.ffn_layernorm(x_tbc)
456
+ x_tbc = self.fc2(x_tbc)
457
+
458
+ if self.w_resid is not None:
459
+ residual = torch.mul(self.w_resid, residual)
460
+
461
+ x_tbc = residual + x_tbc
462
+ if not self.normalize_before:
463
+ x_tbc = self.final_layer_norm(x_tbc)
464
+
465
+ return (
466
+ x_tbc,
467
+ attn_out,
468
+ new_k_bh,
469
+ new_v_bh,
470
+ ) # [1,B,C], attn, [B*H, Tnew, Dh], [B*H, Tnew, Dh]
471
+
472
+ def _all_observers(self) -> Iterable:
473
+ """
474
+ Expose all observers from child PTQ-wrapped modules.
475
+ This layer itself does not add extra per-tensor observers.
476
+ """
477
+ # local observers
478
+ yield from (
479
+ self.obs_activation_fn,
480
+ self.obs_prev_self_k_in,
481
+ self.obs_prev_self_v_in,
482
+ )
483
+
484
+ for m in (
485
+ self.self_attn,
486
+ self.encoder_attn,
487
+ self.fc1,
488
+ self.fc2,
489
+ self.encoder_attn_layer_norm,
490
+ self.self_attn_layer_norm,
491
+ self.final_layer_norm,
492
+ ):
493
+ if isinstance(m, QuantModuleBase) and m is not None:
494
+ yield from m._all_observers()
@@ -33,6 +33,7 @@ _CORE_MODULES = (
33
33
  "tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer",
34
34
  "tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
35
35
  # fairseq
36
+ "tico.experimental.quantization.ptq.wrappers.fairseq.quant_decoder_layer",
36
37
  "tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder",
37
38
  "tico.experimental.quantization.ptq.wrappers.fairseq.quant_encoder_layer",
38
39
  "tico.experimental.quantization.ptq.wrappers.fairseq.quant_mha",
@@ -22,7 +22,7 @@ from tico.experimental.quantization.algorithm.pt2e.quantizer import PT2EQuantize
22
22
  from tico.experimental.quantization.algorithm.smoothquant.quantizer import (
23
23
  SmoothQuantQuantizer,
24
24
  )
25
- from tico.experimental.quantization.config import BaseConfig
25
+ from tico.experimental.quantization.config.base import BaseConfig
26
26
  from tico.experimental.quantization.quantizer import BaseQuantizer
27
27
 
28
28
 
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from tico.experimental.quantization.config import BaseConfig
20
+ from tico.experimental.quantization.config.base import BaseConfig
21
21
 
22
22
 
23
23
  class BaseQuantizer(ABC):