tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,381 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # -----------------------------------------------------------------------------
16
+ # This file includes modifications based on fairseq
17
+ # (https://github.com/facebookresearch/fairseq), originally licensed under
18
+ # the MIT License. See the LICENSE file in the fairseq repository for details.
19
+ # -----------------------------------------------------------------------------
20
+
21
+ from typing import Dict, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from tico.quantization.config.ptq import PTQConfig
28
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
29
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
30
+ from tico.quantization.wrapq.wrappers.registry import try_register
31
+
32
+
33
+ @try_register("fairseq.modules.multihead_attention.MultiheadAttention")
34
+ class QuantFairseqMultiheadAttention(QuantModuleBase):
35
+ """
36
+ Quant-aware drop-in for Fairseq MultiheadAttention.
37
+
38
+ - No xFormers / no torch F.multi_head_attention_forward fast-path.
39
+ - Self/cross attention + minimal incremental KV cache.
40
+ - Causal mask is pre-built statically; `key_padding_mask` is additive float.
41
+ - I/O shape: [T, B, C]
42
+
43
+ Runtime optimization flags
44
+ --------------------------
45
+ use_static_causal : bool
46
+ If True, reuse a precomputed upper-triangular causal mask template
47
+ instead of rebuilding it each forward step. Reduces per-step mask
48
+ construction overhead during incremental decoding.
49
+
50
+ assume_additive_key_padding : bool
51
+ If True, assume the `key_padding_mask` is already an additive float
52
+ tensor (large negative values at padded positions). Skips conversion
53
+ from boolean masks, reducing runtime overhead.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ fp_attn: nn.Module,
59
+ *,
60
+ qcfg: Optional[PTQConfig] = None,
61
+ fp_name: Optional[str] = None,
62
+ max_seq: int = 4096,
63
+ use_static_causal: bool = False,
64
+ mask_fill_value: float = -120.0,
65
+ assume_additive_key_padding: bool = False,
66
+ ):
67
+ super().__init__(qcfg, fp_name=fp_name)
68
+
69
+ self.use_static_causal = use_static_causal
70
+ self.mask_fill_value = mask_fill_value
71
+ self.assume_additive_key_padding = assume_additive_key_padding
72
+ self.embed_dim: int = int(fp_attn.embed_dim) # type: ignore[arg-type]
73
+ self.num_heads: int = int(fp_attn.num_heads) # type: ignore[arg-type]
74
+ self.head_dim: int = self.embed_dim // self.num_heads
75
+ assert self.head_dim * self.num_heads == self.embed_dim
76
+
77
+ self.self_attention: bool = bool(getattr(fp_attn, "self_attention", False))
78
+ self.encoder_decoder_attention: bool = bool(
79
+ getattr(fp_attn, "encoder_decoder_attention", False)
80
+ )
81
+ assert self.self_attention != self.encoder_decoder_attention
82
+
83
+ # PTQ-wrapped projections
84
+ qc = qcfg.child("q_proj") if qcfg else None
85
+ kc = qcfg.child("k_proj") if qcfg else None
86
+ vc = qcfg.child("v_proj") if qcfg else None
87
+ oc = qcfg.child("out_proj") if qcfg else None
88
+ assert hasattr(fp_attn, "q_proj") and hasattr(fp_attn, "k_proj")
89
+ assert hasattr(fp_attn, "v_proj") and hasattr(fp_attn, "out_proj")
90
+ assert isinstance(fp_attn.q_proj, nn.Module) and isinstance(
91
+ fp_attn.k_proj, nn.Module
92
+ )
93
+ assert isinstance(fp_attn.v_proj, nn.Module) and isinstance(
94
+ fp_attn.out_proj, nn.Module
95
+ )
96
+ self.q_proj = PTQWrapper(fp_attn.q_proj, qcfg=qc, fp_name=f"{fp_name}.q_proj")
97
+ self.k_proj = PTQWrapper(fp_attn.k_proj, qcfg=kc, fp_name=f"{fp_name}.k_proj")
98
+ self.v_proj = PTQWrapper(fp_attn.v_proj, qcfg=vc, fp_name=f"{fp_name}.v_proj")
99
+ self.out_proj = PTQWrapper(
100
+ fp_attn.out_proj, qcfg=oc, fp_name=f"{fp_name}.out_proj"
101
+ )
102
+
103
+ # scale & static causal mask
104
+ self.register_buffer(
105
+ "scale_const", torch.tensor(self.head_dim**-0.5), persistent=False
106
+ )
107
+ mask = torch.full((1, 1, max_seq, max_seq), float(self.mask_fill_value))
108
+ mask.triu_(1)
109
+ self.register_buffer("causal_mask_template", mask, persistent=False)
110
+
111
+ # observers (no *_proj_out here; PTQWrapper handles module outputs)
112
+ mk = self._make_obs
113
+ self.obs_query_in = mk("query_in")
114
+ self.obs_key_in = mk("key_in")
115
+ self.obs_value_in = mk("value_in")
116
+ self.obs_kpm_in = mk("kpm_in")
117
+ self.obs_causal_mask = mk("causal_mask")
118
+ self.obs_q_fold = mk("q_fold")
119
+ self.obs_k_fold = mk("k_fold")
120
+ self.obs_v_fold = mk("v_fold")
121
+ self.obs_scale = mk("scale")
122
+ self.obs_logits_raw = mk("logits_raw")
123
+ self.obs_logits = mk("logits_scaled")
124
+ self.obs_attn_mask_add = mk("obs_attn_mask_add")
125
+ self.obs_kp_mask_add = mk("obs_kp_mask_add")
126
+ self.obs_softmax = mk("softmax")
127
+ self.obs_attn_out = mk("attn_out")
128
+
129
+ safe_name = (
130
+ fp_name if (fp_name not in (None, "", "None")) else f"QuantFsMHA_{id(self)}"
131
+ )
132
+ assert safe_name is not None
133
+ self._state_key = safe_name + ".attn_state"
134
+
135
+ def _get_input_buffer(
136
+ self,
137
+ incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
138
+ ) -> Optional[Dict[str, Optional[torch.Tensor]]]:
139
+ """Return saved KV/mask dict or None."""
140
+ if incremental_state is None:
141
+ return None
142
+ return incremental_state.get(self._state_key, None)
143
+
144
+ def _set_input_buffer(
145
+ self,
146
+ incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]],
147
+ buffer: Dict[str, Optional[torch.Tensor]],
148
+ ):
149
+ """Store KV/mask dict in incremental_state."""
150
+ if incremental_state is not None:
151
+ incremental_state[self._state_key] = buffer
152
+ return incremental_state
153
+
154
+ # ---- utils ----
155
+ def _fold_heads(self, x: torch.Tensor, B: int) -> torch.Tensor:
156
+ # [T,B,E] -> [B*H, T, Dh]
157
+ T = x.size(0)
158
+ x = x.view(T, B, self.num_heads, self.head_dim).permute(1, 2, 0, 3).contiguous()
159
+ return x.view(B * self.num_heads, T, self.head_dim)
160
+
161
+ def _unfold_heads(self, x: torch.Tensor, B: int, T: int) -> torch.Tensor:
162
+ # [B*H, T, Dh] -> [T,B,E]
163
+ x = x.view(B, self.num_heads, T, self.head_dim).permute(2, 0, 1, 3).contiguous()
164
+ return x.view(T, B, self.embed_dim)
165
+
166
+ def forward(
167
+ self,
168
+ query: torch.Tensor, # [Tq,B,C]
169
+ key: Optional[torch.Tensor],
170
+ value: Optional[torch.Tensor],
171
+ key_padding_mask: Optional[
172
+ torch.Tensor
173
+ ] = None, # additive float (e.g. -120 at pads)
174
+ incremental_state: Optional[
175
+ Dict[str, Dict[str, Optional[torch.Tensor]]]
176
+ ] = None,
177
+ need_weights: bool = False,
178
+ static_kv: bool = False,
179
+ attn_mask: Optional[torch.Tensor] = None, # if None -> internal causal
180
+ before_softmax: bool = False,
181
+ need_head_weights: bool = False,
182
+ return_new_kv: bool = False,
183
+ ) -> Union[
184
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
185
+ Tuple[
186
+ torch.Tensor,
187
+ Optional[torch.Tensor],
188
+ Optional[torch.Tensor],
189
+ Optional[torch.Tensor],
190
+ ],
191
+ ]:
192
+
193
+ if need_head_weights:
194
+ need_weights = True
195
+
196
+ Tq, B, _ = query.shape
197
+ if self.self_attention:
198
+ key = query if key is None else key
199
+ value = query if value is None else value
200
+ else:
201
+ assert key is not None and value is not None
202
+
203
+ Tk, Bk, _ = key.shape
204
+ Tv, Bv, _ = value.shape
205
+ assert B == Bk == Bv
206
+
207
+ q = self.q_proj(self._fq(query, self.obs_query_in))
208
+ k = self.k_proj(self._fq(key, self.obs_key_in))
209
+ v = self.v_proj(self._fq(value, self.obs_value_in))
210
+
211
+ state = self._get_input_buffer(incremental_state)
212
+ if incremental_state is not None and state is None:
213
+ state = {}
214
+
215
+ # Capture "new" K/V for this call BEFORE concatenating with cache
216
+ new_k_bh: Optional[torch.Tensor] = None
217
+ new_v_bh: Optional[torch.Tensor] = None
218
+
219
+ # Fold heads
220
+ q = self._fq(self._fold_heads(q, B), self.obs_q_fold)
221
+ if state is not None and "prev_key" in state and static_kv:
222
+ # Cross-attention static_kv path: reuse cached KV; there is no new KV this call.
223
+ k = None
224
+ v = None
225
+ if k is not None:
226
+ k = self._fq(self._fold_heads(k, B), self.obs_k_fold) # [B*H, Tnew, Dh]
227
+ if return_new_kv:
228
+ new_k_bh = k.contiguous()
229
+ if v is not None:
230
+ v = self._fq(self._fold_heads(v, B), self.obs_v_fold) # [B*H, Tnew, Dh]
231
+ if return_new_kv:
232
+ new_v_bh = v.contiguous()
233
+
234
+ # Append/reuse cache
235
+ if state is not None:
236
+ pk = state.get("prev_key")
237
+ pv = state.get("prev_value")
238
+ if pk is not None:
239
+ pk = pk.view(B * self.num_heads, -1, self.head_dim)
240
+ k = pk if static_kv else torch.cat([pk, k], dim=1)
241
+ if pv is not None:
242
+ pv = pv.view(B * self.num_heads, -1, self.head_dim)
243
+ v = pv if static_kv else torch.cat([pv, v], dim=1)
244
+
245
+ assert k is not None and v is not None
246
+ Ts = k.size(1)
247
+
248
+ # Scaled dot-product
249
+ scale = self._fq(self.scale_const, self.obs_scale).to(q.dtype)
250
+ logits_raw = self._fq(
251
+ torch.bmm(q, k.transpose(1, 2)), self.obs_logits_raw
252
+ ) # [B*H,Tq,Ts]
253
+ logits = self._fq(logits_raw * scale, self.obs_logits)
254
+
255
+ assert isinstance(self.causal_mask_template, torch.Tensor)
256
+ # Masks
257
+ device = logits.device
258
+ if attn_mask is None and self.use_static_causal:
259
+ # Incremental decoding aware slicing:
260
+ # align the causal row(s) to the current time indices
261
+ start_q = max(Ts - Tq, 0)
262
+ cm = self.causal_mask_template[..., start_q : start_q + Tq, :Ts].to(
263
+ device=device, dtype=logits.dtype
264
+ )
265
+ attn_mask = cm.squeeze(0).squeeze(0) # [Tq,Ts]
266
+
267
+ if attn_mask is not None:
268
+ # Bool/byte mask -> additive float with large negatives
269
+ if not torch.is_floating_point(attn_mask):
270
+ fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
271
+ attn_mask = torch.where(
272
+ attn_mask.to(torch.bool), fill, fill.new_zeros(())
273
+ )
274
+ attn_mask = self._fq(attn_mask, self.obs_causal_mask)
275
+ assert isinstance(attn_mask, torch.Tensor)
276
+
277
+ if not self.assume_additive_key_padding:
278
+ # attn_mask -> [B*H,Tq,Ts]
279
+ if attn_mask.dim() == 2:
280
+ add_mask = attn_mask.unsqueeze(0).expand(logits.size(0), -1, -1)
281
+ elif attn_mask.dim() == 3:
282
+ add_mask = (
283
+ attn_mask.unsqueeze(1)
284
+ .expand(B, self.num_heads, Tq, Ts)
285
+ .contiguous()
286
+ )
287
+ add_mask = add_mask.view(B * self.num_heads, Tq, Ts)
288
+ else:
289
+ raise RuntimeError("attn_mask must be [T,S] or [B,T,S]")
290
+ else:
291
+ add_mask = attn_mask
292
+ logits = self._fq(logits + add_mask, self.obs_attn_mask_add)
293
+
294
+ if key_padding_mask is not None:
295
+ if not torch.is_floating_point(key_padding_mask):
296
+ fill = self.causal_mask_template.new_tensor(self.mask_fill_value)
297
+ kpm = torch.where(
298
+ key_padding_mask.to(torch.bool), fill, fill.new_zeros(())
299
+ )
300
+ else:
301
+ kpm = key_padding_mask
302
+ kpm = self._fq(kpm, self.obs_kpm_in)
303
+
304
+ if not self.assume_additive_key_padding:
305
+ # key_padding_mask: additive float already
306
+ kpm = kpm.to(dtype=logits.dtype, device=device)
307
+ if kpm.dim() == 2: # [B,S]
308
+ kpm = (
309
+ kpm.view(B, 1, 1, Ts)
310
+ .expand(B, self.num_heads, Tq, Ts)
311
+ .contiguous()
312
+ )
313
+ kpm = kpm.view(B * self.num_heads, Tq, Ts)
314
+ elif kpm.dim() == 3: # [B,T,S]
315
+ kpm = (
316
+ kpm.unsqueeze(1).expand(B, self.num_heads, Tq, Ts).contiguous()
317
+ )
318
+ kpm = kpm.view(B * self.num_heads, Tq, Ts)
319
+ else:
320
+ raise RuntimeError(
321
+ "key_padding_mask must be [B,S] or [B,T,S] (additive)"
322
+ )
323
+ logits = self._fq(logits + kpm, self.obs_kp_mask_add)
324
+
325
+ if before_softmax:
326
+ if return_new_kv:
327
+ return logits, v, new_k_bh, new_v_bh
328
+ return logits, v
329
+
330
+ # Softmax (float32) -> back to q.dtype
331
+ attn_probs = torch.softmax(logits, dim=-1, dtype=torch.float32).to(q.dtype)
332
+ attn_probs = self._fq(attn_probs, self.obs_softmax)
333
+
334
+ # Context + output proj
335
+ ctx = self._fq(torch.bmm(attn_probs, v), self.obs_attn_out) # [B*H,Tq,Dh]
336
+ ctx = self._unfold_heads(ctx, B, Tq) # [Tq,B,E]
337
+ out = self.out_proj(ctx)
338
+
339
+ # Weights (optional)
340
+ attn_weights_out: Optional[torch.Tensor] = None
341
+ if need_weights:
342
+ aw = (
343
+ torch.softmax(logits, dim=-1, dtype=torch.float32)
344
+ .view(B, self.num_heads, Tq, Ts)
345
+ .transpose(1, 0)
346
+ )
347
+ if not need_head_weights:
348
+ aw = aw.mean(dim=1) # [B,Tq,Ts]
349
+ attn_weights_out = aw
350
+
351
+ # Cache write
352
+ if state is not None:
353
+ state["prev_key"] = k.view(B, self.num_heads, -1, self.head_dim).detach()
354
+ state["prev_value"] = v.view(B, self.num_heads, -1, self.head_dim).detach()
355
+ self._set_input_buffer(incremental_state, state)
356
+
357
+ if return_new_kv:
358
+ return out, attn_weights_out, new_k_bh, new_v_bh
359
+ return out, attn_weights_out
360
+
361
+ def _all_observers(self):
362
+ yield from (
363
+ self.obs_query_in,
364
+ self.obs_key_in,
365
+ self.obs_value_in,
366
+ self.obs_kpm_in,
367
+ self.obs_causal_mask,
368
+ self.obs_q_fold,
369
+ self.obs_k_fold,
370
+ self.obs_v_fold,
371
+ self.obs_scale,
372
+ self.obs_logits_raw,
373
+ self.obs_logits,
374
+ self.obs_attn_mask_add,
375
+ self.obs_kp_mask_add,
376
+ self.obs_softmax,
377
+ self.obs_attn_out,
378
+ )
379
+ for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
380
+ if isinstance(m, QuantModuleBase):
381
+ yield from m._all_observers()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,276 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
+ from tico.quantization.wrapq.wrappers.registry import try_register
24
+
25
+
26
+ @try_register(
27
+ "transformers.models.llama.modeling_llama.LlamaAttention",
28
+ "transformers.models.llama.modeling_llama.LlamaSdpaAttention",
29
+ )
30
+ class QuantLlamaAttention(QuantModuleBase):
31
+ def __init__(
32
+ self,
33
+ fp_attn: nn.Module,
34
+ *,
35
+ qcfg: Optional[PTQConfig] = None,
36
+ fp_name: Optional[str] = None,
37
+ ):
38
+ super().__init__(qcfg, fp_name=fp_name)
39
+
40
+ cfg = fp_attn.config
41
+ assert hasattr(cfg, "hidden_size") and hasattr(cfg, "num_attention_heads")
42
+ assert hasattr(cfg, "num_key_value_heads")
43
+ assert isinstance(cfg.hidden_size, int) and isinstance(
44
+ cfg.num_attention_heads, int
45
+ )
46
+ assert isinstance(cfg.num_key_value_heads, int)
47
+ self.hdim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
48
+ self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads
49
+
50
+ # constant scale (1/√d)
51
+ self.scale_t = torch.tensor(self.hdim**-0.5)
52
+ self.obs_scale = self._make_obs("scale")
53
+
54
+ # ---- wrap q k v o projections via PTQWrapper ---------------
55
+ q_cfg = qcfg.child("q_proj") if qcfg else None
56
+ k_cfg = qcfg.child("k_proj") if qcfg else None
57
+ v_cfg = qcfg.child("v_proj") if qcfg else None
58
+ o_cfg = qcfg.child("o_proj") if qcfg else None
59
+ assert hasattr(fp_attn, "q_proj") and isinstance(
60
+ fp_attn.q_proj, torch.nn.Module
61
+ )
62
+ assert hasattr(fp_attn, "k_proj") and isinstance(
63
+ fp_attn.k_proj, torch.nn.Module
64
+ )
65
+ assert hasattr(fp_attn, "v_proj") and isinstance(
66
+ fp_attn.v_proj, torch.nn.Module
67
+ )
68
+ assert hasattr(fp_attn, "o_proj") and isinstance(
69
+ fp_attn.o_proj, torch.nn.Module
70
+ )
71
+ self.q_proj = PTQWrapper(
72
+ fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj"
73
+ )
74
+ self.k_proj = PTQWrapper(
75
+ fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj"
76
+ )
77
+ self.v_proj = PTQWrapper(
78
+ fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj"
79
+ )
80
+ self.o_proj = PTQWrapper(
81
+ fp_attn.o_proj, qcfg=o_cfg, fp_name=f"{fp_name}.o_proj"
82
+ )
83
+
84
+ # ---- create arithmetic observers ---------------------------
85
+ mk = self._make_obs
86
+ self.obs_hidden = mk("hidden")
87
+
88
+ self.obs_cos = mk("cos")
89
+ self.obs_sin = mk("sin")
90
+
91
+ self.obs_causal_mask = mk("causal_mask")
92
+
93
+ # rotate-half sub-steps
94
+ self.obs_q_x1 = mk("q_x1")
95
+ self.obs_q_x2 = mk("q_x2")
96
+ self.obs_q_neg = mk("q_neg")
97
+ self.obs_q_cat = mk("q_cat")
98
+ self.obs_k_x1 = mk("k_x1")
99
+ self.obs_k_x2 = mk("k_x2")
100
+ self.obs_k_neg = mk("k_neg")
101
+ self.obs_k_cat = mk("k_cat")
102
+
103
+ # q / k paths
104
+ self.obs_q_cos = mk("q_cos")
105
+ self.obs_q_sin = mk("q_sin")
106
+ self.obs_q_rot = mk("q_rot")
107
+ self.obs_k_cos = mk("k_cos")
108
+ self.obs_k_sin = mk("k_sin")
109
+ self.obs_k_rot = mk("k_rot")
110
+
111
+ # logits / softmax / out
112
+ self.obs_logits_raw = mk("logits_raw")
113
+ self.obs_logits = mk("logits")
114
+ self.obs_mask_add = mk("mask_add")
115
+ self.obs_softmax = mk("softmax")
116
+ self.obs_attn_out = mk("attn_out")
117
+
118
+ # Static causal mask template
119
+ assert hasattr(cfg, "max_position_embeddings")
120
+ max_seq = cfg.max_position_embeddings
121
+ mask = torch.full((1, 1, max_seq, max_seq), float("-120")) # type: ignore[arg-type]
122
+ mask.triu_(1)
123
+ self.register_buffer("causal_mask_template", mask, persistent=False)
124
+
125
+ def _rot(self, t, o_x1, o_x2, o_neg, o_cat):
126
+ x1, x2 = torch.chunk(t, 2, dim=-1)
127
+ x1 = self._fq(x1, o_x1)
128
+ x2 = self._fq(x2, o_x2)
129
+ x2n = self._fq(-x2, o_neg)
130
+ return self._fq(torch.cat((x2n, x1), -1), o_cat)
131
+
132
+ @staticmethod
133
+ def _concat_kv(
134
+ past: Optional[Tuple[torch.Tensor, torch.Tensor]],
135
+ k_new: torch.Tensor,
136
+ v_new: torch.Tensor,
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Concat along sequence dim (dim=2): (B, n_kv, S, H)."""
139
+ if past is None:
140
+ return k_new, v_new
141
+ past_k, past_v = past
142
+ k = torch.cat([past_k, k_new], dim=2)
143
+ v = torch.cat([past_v, v_new], dim=2)
144
+ return k, v
145
+
146
+ def forward(
147
+ self,
148
+ hidden_states: torch.Tensor,
149
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ past_key_value=None, # tuple(k, v) or HF Cache-like object
152
+ use_cache: Optional[bool] = False,
153
+ cache_position: Optional[torch.LongTensor] = None,
154
+ **kwargs,
155
+ ):
156
+ hidden = self._fq(hidden_states, self.obs_hidden)
157
+ B, S, _ = hidden.shape
158
+ H = self.hdim
159
+
160
+ # projections
161
+ q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
162
+ k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
163
+ v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
164
+
165
+ # rope tables
166
+ cos, sin = position_embeddings
167
+ cos = self._fq(cos, self.obs_cos)
168
+ sin = self._fq(sin, self.obs_sin)
169
+ cos_u, sin_u = cos.unsqueeze(1), sin.unsqueeze(1)
170
+
171
+ # q_rot
172
+ q_half = self._rot(
173
+ q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
174
+ )
175
+ q_cos = self._fq(q * cos_u, self.obs_q_cos)
176
+ q_sin = self._fq(q_half * sin_u, self.obs_q_sin)
177
+ q_rot = self._fq(q_cos + q_sin, self.obs_q_rot)
178
+
179
+ # k_rot
180
+ k_half = self._rot(
181
+ k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat
182
+ )
183
+ k_cos = self._fq(k * cos_u, self.obs_k_cos)
184
+ k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
185
+ k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
186
+
187
+ # --- build/update KV for attention & present_key_value -------------
188
+ present_key_value: Tuple[torch.Tensor, torch.Tensor]
189
+
190
+ # HF Cache path (if available)
191
+ if use_cache and hasattr(past_key_value, "update"):
192
+ # Many HF Cache impls use update(k, v) and return (k_total, v_total)
193
+ try:
194
+ k_total, v_total = past_key_value.update(k_rot, v)
195
+ present_key_value = (k_total, v_total)
196
+ k_for_attn, v_for_attn = k_total, v_total
197
+ except Exception:
198
+ # Fallback to tuple concat if Cache signature mismatches
199
+ k_for_attn, v_for_attn = self._concat_kv(
200
+ getattr(past_key_value, "kv", None), k_rot, v
201
+ )
202
+ present_key_value = (k_for_attn, v_for_attn)
203
+ else:
204
+ # Tuple or None path
205
+ pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None
206
+ k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v)
207
+ present_key_value = (k_for_attn, v_for_attn)
208
+
209
+ # logits
210
+ k_rep = k_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
211
+ logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
212
+ scale = self._fq(self.scale_t, self.obs_scale)
213
+ logits = self._fq(logits_raw * scale, self.obs_logits)
214
+
215
+ if attention_mask is None or attention_mask.dtype == torch.bool:
216
+ _, _, q_len, _ = logits.shape
217
+ k_len = k_for_attn.size(2)
218
+ assert isinstance(self.causal_mask_template, torch.Tensor)
219
+ attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
220
+ hidden_states.device
221
+ )
222
+ attention_mask = self._fq(attention_mask, self.obs_causal_mask)
223
+ logits = self._fq(logits + attention_mask, self.obs_mask_add)
224
+
225
+ # softmax
226
+ attn_weights = torch.softmax(logits, -1, dtype=torch.float32).to(q.dtype)
227
+ attn_weights = self._fq(attn_weights, self.obs_softmax)
228
+
229
+ # attn out
230
+ v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)
231
+ attn_out = (
232
+ self._fq(attn_weights @ v_rep, self.obs_attn_out)
233
+ .transpose(1, 2)
234
+ .reshape(B, S, -1)
235
+ )
236
+
237
+ # final projection
238
+ out = self.o_proj(attn_out)
239
+
240
+ # return with/without cache
241
+ if use_cache:
242
+ return out, attn_weights, present_key_value
243
+ else:
244
+ return out, attn_weights
245
+
246
+ def _all_observers(self):
247
+ # local first
248
+ yield from (
249
+ self.obs_hidden,
250
+ self.obs_scale,
251
+ self.obs_cos,
252
+ self.obs_sin,
253
+ self.obs_causal_mask,
254
+ self.obs_q_x1,
255
+ self.obs_q_x2,
256
+ self.obs_q_neg,
257
+ self.obs_q_cat,
258
+ self.obs_k_x1,
259
+ self.obs_k_x2,
260
+ self.obs_k_neg,
261
+ self.obs_k_cat,
262
+ self.obs_q_cos,
263
+ self.obs_q_sin,
264
+ self.obs_q_rot,
265
+ self.obs_k_cos,
266
+ self.obs_k_sin,
267
+ self.obs_k_rot,
268
+ self.obs_logits_raw,
269
+ self.obs_logits,
270
+ self.obs_mask_add,
271
+ self.obs_softmax,
272
+ self.obs_attn_out,
273
+ )
274
+ # recurse into children that are QuantModuleBase
275
+ for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
276
+ yield from m._all_observers()