tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,381 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ from typing import Dict, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from tico.quantization.config.ptq import PTQConfig
28
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
29
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
30
+ from tico.quantization.wrapq.wrappers.registry import try_register
31
+
32
+
33
+ @try_register("fairseq.modules.multihead_attention.MultiheadAttention")
34
+ class QuantFairseqMultiheadAttention(QuantModuleBase):
35
+ """
36
+ Quant-aware drop-in for Fairseq MultiheadAttention.
37
+
38
+ - No xFormers / no torch F.multi_head_attention_forward fast-path.
39
+ - Self/cross attention + minimal incremental KV cache.
40
+ - Causal mask is pre-built statically; `key_padding_mask` is additive float.
41
+ - I/O shape: [T, B, C]
42
+
43
+ Runtime optimization flags
44
+ --------------------------
45
+ use_static_causal : bool
46
+ If True, reuse a precomputed upper-triangular causal mask template
47
+ instead of rebuilding it each forward step. Reduces per-step mask
48
+ construction overhead during incremental decoding.
49
+
50
+ assume_additive_key_padding : bool
51
+ If True, assume the `key_padding_mask` is already an additive float
52
+ tensor (large negative values at padded positions). Skips conversion
53
+ from boolean masks, reducing runtime overhead.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ fp_attn: nn.Module,
59
+ *,
60
+ qcfg: Optional[PTQConfig] = None,
61
+ fp_name: Optional[str] = None,
62
+ max_seq: int = 4096,
63
+ use_static_causal: bool = False,
64
+ mask_fill_value: float = -120.0,
65
+ assume_additive_key_padding: bool = False,
66
+ ):
67
+ super().__init__(qcfg, fp_name=fp_name)
68
+
69
+ self.use_static_causal = use_static_causal
70
+ self.mask_fill_value = mask_fill_value
71
+ self.assume_additive_key_padding = assume_additive_key_padding
72
+ self.embed_dim: int = int(fp_attn.embed_dim) # type: ignore[arg-type]
73
+ self.num_heads: int = int(fp_attn.num_heads) # type: ignore[arg-type]
74
+ self.head_dim: int = self.embed_dim // self.num_heads
75
+ assert self.head_dim * self.num_heads == self.embed_dim
76
+
77
+ self.self_attention: bool = bool(getattr(fp_attn, "self_attention", False))
78
+ self.encoder_decoder_attention: bool = bool(
79
+ getattr(fp_attn, "encoder_decoder_attention", False)
80
+ )
81
+ assert self.self_attention != self.encoder_decoder_attention
82
+
83
+ # PTQ-wrapped projections
84
+ qc = qcfg.child("q_proj") if qcfg else None
85
+ kc = qcfg.child("k_proj") if qcfg else None
86
+ vc = qcfg.child("v_proj") if qcfg else None
87
+ oc = qcfg.child("out_proj") if qcfg else None
88
+ assert hasattr(fp_attn, "q_proj") and hasattr(fp_attn, "k_proj")
89
+ assert hasattr(fp_attn, "v_proj") and hasattr(fp_attn, "out_proj")
90
+ assert isinstance(fp_attn.q_proj, nn.Module) and isinstance(
91
+ fp_attn.k_proj, nn.Module
92
+ )
93
+ assert isinstance(fp_attn.v_proj, nn.Module) and isinstance(
94
+ fp_attn.out_proj, nn.Module
95
+ )
96
+ self.q_proj = PTQWrapper(fp_attn.q_proj, qcfg=qc, fp_name=f"{fp_name}.q_proj")
97
+ self.k_proj = PTQWrapper(fp_attn.k_proj, qcfg=kc, fp_name=f"{fp_name}.k_proj")
98
+ self.v_proj = PTQWrapper(fp_attn.v_proj, qcfg=vc, fp_name=f"{fp_name}.v_proj")
99
+ self.out_proj = PTQWrapper(
100
+ fp_attn.out_proj, qcfg=oc, fp_name=f"{fp_name}.out_proj"
101
+ )
102
+
103
+ # scale & static causal mask
104
+ self.register_buffer(
105
+ "scale_const", torch.tensor(self.head_dim**-0.5), persistent=False
106
+ )
107
+ mask = torch.full((1, 1, max_seq, max_seq), float(self.mask_fill_value))
108
+ mask.triu_(1)
109
+ self.register_buffer("causal_mask_template", mask, persistent=False)
110
+
111
+ # observers (no *_proj_out here; PTQWrapper handles module outputs)
112
+ mk = self._make_obs
113
+ self.obs_query_in = mk("query_in")
114
+ self.obs_key_in = mk("key_in")
115
+ self.obs_value_in = mk("value_in")
116
+ self.obs_kpm_in = mk("kpm_in")
117
+ self.obs_causal_mask = mk("causal_mask")
118
+ self.obs_q_fold = mk("q_fold")
119
+ self.obs_k_fold = mk("k_fold")
120
+ self.obs_v_fold = mk("v_fold")
121
+ self.obs_scale = mk("scale")
122
+ self.obs_logits_raw = mk("logits_raw")
123
+ self.obs_logits = mk("logits_scaled")
124
+ self.obs_attn_mask_add = mk("obs_attn_mask_add")
125
+ self.obs_kp_mask_add = mk("obs_kp_mask_add")
126
+ self.obs_softmax = mk("softmax")
127
+ self.obs_attn_out = mk("attn_out")
128
+
129
+ safe_name = (
130
+ fp_name if (fp_name not in (None, "", "None")) else f"QuantFsMHA_{id(self)}"
131
+ )
132
+ assert safe_name is not None
133
+ self._state_key = safe_name + ".attn_state"
134
+
135
+ def _get_input_buffer(
136
+ self,
137
+ incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
138
+ ) -> Optional[Dict[str, Optional[torch.Tensor]]]:
139
+ """Return saved KV/mask dict or None."""
140
+ if incremental_state is None:
141
+ return None
142
+ return incremental_state.get(self._state_key, None)
143
+
144
+ def _set_input_buffer(
145
+ self,
146
+ incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
147
+ buffer: Dict[str, Optional[torch.Tensor]],
148
+ ):
149
+ """Store KV/mask dict in incremental_state."""
150
+ if incremental_state is not None:
151
+ incremental_state[self._state_key] = buffer
152
+ return incremental_state
153
+
154
+ # ---- utils ----
155
+ def _fold_heads(self, x: torch.Tensor, B: int) -> torch.Tensor:
156
+ # [T,B,E] -> [B*H, T, Dh]
157
+ T = x.size(0)
158
+ x = x.view(T, B, self.num_heads, self.head_dim).permute(1, 2, 0, 3).contiguous()
159
+ return x.view(B * self.num_heads, T, self.head_dim)
160
+
161
+ def _unfold_heads(self, x: torch.Tensor, B: int, T: int) -> torch.Tensor:
162
+ # [B*H, T, Dh] -> [T,B,E]
163
+ x = x.view(B, self.num_heads, T, self.head_dim).permute(2, 0, 1, 3).contiguous()
164
+ return x.view(T, B, self.embed_dim)
165
+
166
+ def forward(
167
+ self,
168
+ query: torch.Tensor, # [Tq,B,C]
169
+ key: Optional[torch.Tensor],
170
+ value: Optional[torch.Tensor],
171
+ key_padding_mask: Optional[
172
+ torch.Tensor
173
+ ] = None, # additive float (e.g. -120 at pads)
174
+ incremental_state: Optional[
175
+ Dict[str, Dict[str, Optional[torch.Tensor]]]
176
+ ] = None,
177
+ need_weights: bool = False,
178
+ static_kv: bool = False,
179
+ attn_mask: Optional[torch.Tensor] = None, # if None -> internal causal
180
+ before_softmax: bool = False,
181
+ need_head_weights: bool = False,
182
+ return_new_kv: bool = False,
183
+ ) -> Union[
184
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
185
+ Tuple[
186
+ torch.Tensor,
187
+ Optional[torch.Tensor],
188
+ Optional[torch.Tensor],
189
+ Optional[torch.Tensor],
190
+ ],
191
+ ]:
192
+
193
+ if need_head_weights:
194
+ need_weights = True
195
+
196
+ Tq, B, _ = query.shape
197
+ if self.self_attention:
198
+ key = query if key is None else key
199
+ value = query if value is None else value
200
+ else:
201
+ assert key is not None and value is not None
202
+
203
+ Tk, Bk, _ = key.shape
204
+ Tv, Bv, _ = value.shape
205
+ assert B == Bk == Bv
206
+
207
+ q = self.q_proj(self._fq(query, self.obs_query_in))
208
+ k = self.k_proj(self._fq(key, self.obs_key_in))
209
+ v = self.v_proj(self._fq(value, self.obs_value_in))
210
+
211
+ state = self._get_input_buffer(incremental_state)
212
+ if incremental_state is not None and state is None:
213
+ state = {}
214
+
215
+ # Capture "new" K/V for this call BEFORE concatenating with cache
216
+ new_k_bh: Optional[torch.Tensor] = None
217
+ new_v_bh: Optional[torch.Tensor] = None
218
+
219
+ # Fold heads
220
+ q = self._fq(self._fold_heads(q, B), self.obs_q_fold)
221
+ if state is not None and "prev_key" in state and static_kv:
222
+ # Cross-attention static_kv path: reuse cached KV; there is no new KV this call.
223
+ k = None
224
+ v = None
225
+ if k is not None:
226
+ k = self._fq(self._fold_heads(k, B), self.obs_k_fold) # [B*H, Tnew, Dh]
227
+ if return_new_kv:
228
+ new_k_bh = k.contiguous()
229
+ if v is not None:
230
+ v = self._fq(self._fold_heads(v, B), self.obs_v_fold) # [B*H, Tnew, Dh]
231
+ if return_new_kv:
232
+ new_v_bh = v.contiguous()
233
+
234
+ # Append/reuse cache
235
+ if state is not None:
236
+ pk = state.get("prev_key")
237
+ pv = state.get("prev_value")
238
+ if pk is not None:
239
+ pk = pk.view(B * self.num_heads, -1, self.head_dim)
240
+ k = pk if static_kv else torch.cat([pk, k], dim=1)
241
+ if pv is not None:
242
+ pv = pv.view(B * self.num_heads, -1, self.head_dim)
243
+ v = pv if static_kv else torch.cat([pv, v], dim=1)
244
+
245
+ assert k is not None and v is not None
246
+ Ts = k.size(1)
247
+
248
+ # Scaled dot-product
249
+ scale = self._fq(self.scale_const, self.obs_scale).to(q.dtype)
250
+ logits_raw = self._fq(
251
+ torch.bmm(q, k.transpose(1, 2)), self.obs_logits_raw
252
+ ) # [B*H,Tq,Ts]
253
+ logits = self._fq(logits_raw * scale, self.obs_logits)
254
+
255
+ assert isinstance(self.causal_mask_template, torch.Tensor)
256
+ # Masks
257
+ device = logits.device
258
+ if attn_mask is None and self.use_static_causal:
259
+ # Incremental decoding aware slicing:
260
+ # align the causal row(s) to the current time indices
261
+ start_q = max(Ts - Tq, 0)
262
+ cm = self.causal_mask_template[..., start_q : start_q + Tq, :Ts].to(
263
+ device=device, dtype=logits.dtype
264
+ )
265
+ attn_mask = cm.squeeze(0).squeeze(0) # [Tq,Ts]
266
+
267
+ if attn_mask is not None:
268
+ # Bool/byte mask -> additive float with large negatives
269
+ if not torch.is_floating_point(attn_mask):
270
+ fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
271
+ attn_mask = torch.where(
272
+ attn_mask.to(torch.bool), fill, fill.new_zeros(())
273
+ )
274
+ attn_mask = self._fq(attn_mask, self.obs_causal_mask)
275
+ assert isinstance(attn_mask, torch.Tensor)
276
+
277
+ if not self.assume_additive_key_padding:
278
+ # attn_mask -> [B*H,Tq,Ts]
279
+ if attn_mask.dim() == 2:
280
+ add_mask = attn_mask.unsqueeze(0).expand(logits.size(0), -1, -1)
281
+ elif attn_mask.dim() == 3:
282
+ add_mask = (
283
+ attn_mask.unsqueeze(1)
284
+ .expand(B, self.num_heads, Tq, Ts)
285
+ .contiguous()
286
+ )
287
+ add_mask = add_mask.view(B * self.num_heads, Tq, Ts)
288
+ else:
289
+ raise RuntimeError("attn_mask must be [T,S] or [B,T,S]")
290
+ else:
291
+ add_mask = attn_mask
292
+ logits = self._fq(logits + add_mask, self.obs_attn_mask_add)
293
+
294
+ if key_padding_mask is not None:
295
+ if not torch.is_floating_point(key_padding_mask):
296
+ fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
297
+ kpm = torch.where(
298
+ key_padding_mask.to(torch.bool), fill, fill.new_zeros(())
299
+ )
300
+ else:
301
+ kpm = key_padding_mask
302
+ kpm = self._fq(kpm, self.obs_kpm_in)
303
+
304
+ if not self.assume_additive_key_padding:
305
+ # key_padding_mask: additive float already
306
+ kpm = kpm.to(dtype=logits.dtype, device=device)
307
+ if kpm.dim() == 2: # [B,S]
308
+ kpm = (
309
+ kpm.view(B, 1, 1, Ts)
310
+ .expand(B, self.num_heads, Tq, Ts)
311
+ .contiguous()
312
+ )
313
+ kpm = kpm.view(B * self.num_heads, Tq, Ts)
314
+ elif kpm.dim() == 3: # [B,T,S]
315
+ kpm = (
316
+ kpm.unsqueeze(1).expand(B, self.num_heads, Tq, Ts).contiguous()
317
+ )
318
+ kpm = kpm.view(B * self.num_heads, Tq, Ts)
319
+ else:
320
+ raise RuntimeError(
321
+ "key_padding_mask must be [B,S] or [B,T,S] (additive)"
322
+ )
323
+ logits = self._fq(logits + kpm, self.obs_kp_mask_add)
324
+
325
+ if before_softmax:
326
+ if return_new_kv:
327
+ return logits, v, new_k_bh, new_v_bh
328
+ return logits, v
329
+
330
+ # Softmax (float32) -> back to q.dtype
331
+ attn_probs = torch.softmax(logits, dim=-1, dtype=torch.float32).to(q.dtype)
332
+ attn_probs = self._fq(attn_probs, self.obs_softmax)
333
+
334
+ # Context + output proj
335
+ ctx = self._fq(torch.bmm(attn_probs, v), self.obs_attn_out) # [B*H,Tq,Dh]
336
+ ctx = self._unfold_heads(ctx, B, Tq) # [Tq,B,E]
337
+ out = self.out_proj(ctx)
338
+
339
+ # Weights (optional)
340
+ attn_weights_out: Optional[torch.Tensor] = None
341
+ if need_weights:
342
+ aw = (
343
+ torch.softmax(logits, dim=-1, dtype=torch.float32)
344
+ .view(B, self.num_heads, Tq, Ts)
345
+ .transpose(1, 0)
346
+ )
347
+ if not need_head_weights:
348
+ aw = aw.mean(dim=1) # [B,Tq,Ts]
349
+ attn_weights_out = aw
350
+
351
+ # Cache write
352
+ if state is not None:
353
+ state["prev_key"] = k.view(B, self.num_heads, -1, self.head_dim).detach()
354
+ state["prev_value"] = v.view(B, self.num_heads, -1, self.head_dim).detach()
355
+ self._set_input_buffer(incremental_state, state)
356
+
357
+ if return_new_kv:
358
+ return out, attn_weights_out, new_k_bh, new_v_bh
359
+ return out, attn_weights_out
360
+
361
+ def _all_observers(self):
362
+ yield from (
363
+ self.obs_query_in,
364
+ self.obs_key_in,
365
+ self.obs_value_in,
366
+ self.obs_kpm_in,
367
+ self.obs_causal_mask,
368
+ self.obs_q_fold,
369
+ self.obs_k_fold,
370
+ self.obs_v_fold,
371
+ self.obs_scale,
372
+ self.obs_logits_raw,
373
+ self.obs_logits,
374
+ self.obs_attn_mask_add,
375
+ self.obs_kp_mask_add,
376
+ self.obs_softmax,
377
+ self.obs_attn_out,
378
+ )
379
+ for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
380
+ if isinstance(m, QuantModuleBase):
381
+ yield from m._all_observers()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -12,17 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Optional, Tuple
16
16
 
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
22
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
- QuantModuleBase,
24
- )
25
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
+ from tico.quantization.wrapq.wrappers.registry import try_register
26
24
 
27
25
 
28
26
  @try_register(
@@ -34,7 +32,7 @@ class QuantLlamaAttention(QuantModuleBase):
34
32
  self,
35
33
  fp_attn: nn.Module,
36
34
  *,
37
- qcfg: Optional[QuantConfig] = None,
35
+ qcfg: Optional[PTQConfig] = None,
38
36
  fp_name: Optional[str] = None,
39
37
  ):
40
38
  super().__init__(qcfg, fp_name=fp_name)
@@ -131,28 +129,38 @@ class QuantLlamaAttention(QuantModuleBase):
131
129
  x2n = self._fq(-x2, o_neg)
132
130
  return self._fq(torch.cat((x2n, x1), -1), o_cat)
133
131
 
132
+ @staticmethod
133
+ def _concat_kv(
134
+ past: Optional[Tuple[torch.Tensor, torch.Tensor]],
135
+ k_new: torch.Tensor,
136
+ v_new: torch.Tensor,
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Concat along sequence dim (dim=2): (B, n_kv, S, H)."""
139
+ if past is None:
140
+ return k_new, v_new
141
+ past_k, past_v = past
142
+ k = torch.cat([past_k, k_new], dim=2)
143
+ v = torch.cat([past_v, v_new], dim=2)
144
+ return k, v
145
+
134
146
  def forward(
135
147
  self,
136
148
  hidden_states: torch.Tensor,
137
149
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
138
150
  attention_mask: Optional[torch.Tensor] = None,
139
- past_key_value=None, # not supported yet
151
+ past_key_value=None, # tuple(k, v) or HF Cache-like object
152
+ use_cache: Optional[bool] = False,
140
153
  cache_position: Optional[torch.LongTensor] = None,
141
154
  **kwargs,
142
155
  ):
143
- if past_key_value is not None:
144
- raise NotImplementedError(
145
- "QuantLlamaAttention does not support KV cache yet."
146
- )
147
-
148
156
  hidden = self._fq(hidden_states, self.obs_hidden)
149
157
  B, S, _ = hidden.shape
150
158
  H = self.hdim
151
159
 
152
160
  # projections
153
- q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2)
154
- k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2)
155
- v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2)
161
+ q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
162
+ k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
163
+ v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
156
164
 
157
165
  # rope tables
158
166
  cos, sin = position_embeddings
@@ -176,14 +184,37 @@ class QuantLlamaAttention(QuantModuleBase):
176
184
  k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
177
185
  k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
178
186
 
187
+ # --- build/update KV for attention & present_key_value -------------
188
+ present_key_value: Tuple[torch.Tensor, torch.Tensor]
189
+
190
+ # HF Cache path (if available)
191
+ if use_cache and hasattr(past_key_value, "update"):
192
+ # Many HF Cache impls use update(k, v) and return (k_total, v_total)
193
+ try:
194
+ k_total, v_total = past_key_value.update(k_rot, v)
195
+ present_key_value = (k_total, v_total)
196
+ k_for_attn, v_for_attn = k_total, v_total
197
+ except Exception:
198
+ # Fallback to tuple concat if Cache signature mismatches
199
+ k_for_attn, v_for_attn = self._concat_kv(
200
+ getattr(past_key_value, "kv", None), k_rot, v
201
+ )
202
+ present_key_value = (k_for_attn, v_for_attn)
203
+ else:
204
+ # Tuple or None path
205
+ pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None
206
+ k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v)
207
+ present_key_value = (k_for_attn, v_for_attn)
208
+
179
209
  # logits
180
- k_rep = k_rot.repeat_interleave(self.kv_rep, dim=1)
210
+ k_rep = k_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
181
211
  logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
182
212
  scale = self._fq(self.scale_t, self.obs_scale)
183
213
  logits = self._fq(logits_raw * scale, self.obs_logits)
184
214
 
185
215
  if attention_mask is None or attention_mask.dtype == torch.bool:
186
- _, _, q_len, k_len = logits.shape
216
+ _, _, q_len, _ = logits.shape
217
+ k_len = k_for_attn.size(2)
187
218
  assert isinstance(self.causal_mask_template, torch.Tensor)
188
219
  attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
189
220
  hidden_states.device
@@ -196,7 +227,7 @@ class QuantLlamaAttention(QuantModuleBase):
196
227
  attn_weights = self._fq(attn_weights, self.obs_softmax)
197
228
 
198
229
  # attn out
199
- v_rep = v.repeat_interleave(self.kv_rep, dim=1)
230
+ v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
200
231
  attn_out = (
201
232
  self._fq(attn_weights @ v_rep, self.obs_attn_out)
202
233
  .transpose(1, 2)
@@ -204,7 +235,13 @@ class QuantLlamaAttention(QuantModuleBase):
204
235
  )
205
236
 
206
237
  # final projection
207
- return self.o_proj(attn_out), attn_weights
238
+ out = self.o_proj(attn_out)
239
+
240
+ # return with/without cache
241
+ if use_cache:
242
+ return out, attn_weights, present_key_value
243
+ else:
244
+ return out, attn_weights
208
245
 
209
246
  def _all_observers(self):
210
247
  # local first
@@ -17,16 +17,12 @@ from typing import Optional, Tuple
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.llama.quant_attn import (
22
- QuantLlamaAttention,
23
- )
24
- from tico.experimental.quantization.ptq.wrappers.llama.quant_mlp import QuantLlamaMLP
25
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
26
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
27
- QuantModuleBase,
28
- )
29
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
22
+ from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
23
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
24
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
25
+ from tico.quantization.wrapq.wrappers.registry import try_register
30
26
 
31
27
 
32
28
  @try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
@@ -56,7 +52,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
56
52
  self,
57
53
  fp_layer: nn.Module,
58
54
  *,
59
- qcfg: Optional[QuantConfig] = None,
55
+ qcfg: Optional[PTQConfig] = None,
60
56
  fp_name: Optional[str] = None,
61
57
  return_type: Optional[str] = None,
62
58
  ):
@@ -136,7 +132,7 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
136
132
  L = hidden_states.size(1)
137
133
  attention_mask = self._slice_causal(L, hidden_states.device)
138
134
 
139
- hidden_states, _ = self.self_attn(
135
+ attn_out = self.self_attn(
140
136
  hidden_states=hidden_states,
141
137
  attention_mask=attention_mask,
142
138
  position_ids=position_ids,
@@ -147,7 +143,13 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
147
143
  position_embeddings=position_embeddings,
148
144
  **kwargs,
149
145
  )
150
- hidden_states = residual + hidden_states
146
+ if use_cache:
147
+ hidden_states_attn, _attn_weights, present_key_value = attn_out
148
+ else:
149
+ hidden_states_attn, _attn_weights = attn_out
150
+ present_key_value = None
151
+
152
+ hidden_states = residual + hidden_states_attn
151
153
 
152
154
  # ─── MLP block ─────────────────────────────────────────────────
153
155
  residual = hidden_states
@@ -155,6 +157,12 @@ class QuantLlamaDecoderLayer(QuantModuleBase):
155
157
  hidden_states = self.mlp(hidden_states)
156
158
  hidden_states = residual + hidden_states
157
159
 
160
+ # Return type policy:
161
+ # - If use_cache: always return (hidden_states, present_key_value)
162
+ # - Else: return as configured (tuple/tensor) for HF compatibility
163
+ if use_cache:
164
+ return hidden_states, present_key_value # type: ignore[return-value]
165
+
158
166
  if self.return_type == "tuple":
159
167
  return (hidden_states,)
160
168
  elif self.return_type == "tensor":
@@ -17,12 +17,10 @@ from typing import Optional
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.ptq_wrapper import PTQWrapper
22
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
- QuantModuleBase,
24
- )
25
- from tico.experimental.quantization.ptq.wrappers.registry import try_register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
+ from tico.quantization.wrapq.wrappers.registry import try_register
26
24
 
27
25
 
28
26
  @try_register("transformers.models.llama.modeling_llama.LlamaMLP")
@@ -31,7 +29,7 @@ class QuantLlamaMLP(QuantModuleBase):
31
29
  self,
32
30
  mlp_fp: nn.Module,
33
31
  *,
34
- qcfg: Optional[QuantConfig] = None,
32
+ qcfg: Optional[PTQConfig] = None,
35
33
  fp_name: Optional[str] = None,
36
34
  ):
37
35
  super().__init__(qcfg, fp_name=fp_name)
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -17,12 +17,11 @@ from typing import Iterable, Optional, Tuple
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.mode import Mode
21
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
22
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
23
- QuantModuleBase,
24
- )
25
- from tico.experimental.quantization.ptq.wrappers.registry import register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
24
+ from tico.quantization.wrapq.wrappers.registry import register
26
25
 
27
26
 
28
27
  @register(nn.LayerNorm)
@@ -46,7 +45,7 @@ class QuantLayerNorm(QuantModuleBase):
46
45
  self,
47
46
  fp: nn.LayerNorm,
48
47
  *,
49
- qcfg: Optional[QuantConfig] = None,
48
+ qcfg: Optional[PTQConfig] = None,
50
49
  fp_name: Optional[str] = None
51
50
  ):
52
51
  super().__init__(qcfg, fp_name=fp_name)
@@ -17,13 +17,12 @@ from typing import Optional
17
17
  import torch.nn as nn
18
18
  import torch.nn.functional as F
19
19
 
20
- from tico.experimental.quantization.ptq.mode import Mode
21
- from tico.experimental.quantization.ptq.qscheme import QScheme
22
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
23
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
24
- QuantModuleBase,
25
- )
26
- from tico.experimental.quantization.ptq.wrappers.registry import register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.qscheme import QScheme
24
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
25
+ from tico.quantization.wrapq.wrappers.registry import register
27
26
 
28
27
 
29
28
  @register(nn.Linear)
@@ -34,7 +33,7 @@ class QuantLinear(QuantModuleBase):
34
33
  self,
35
34
  fp: nn.Linear,
36
35
  *,
37
- qcfg: Optional[QuantConfig] = None,
36
+ qcfg: Optional[PTQConfig] = None,
38
37
  fp_name: Optional[str] = None
39
38
  ):
40
39
  super().__init__(qcfg, fp_name=fp_name)