tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
22
+ from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
23
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
24
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
25
+ from tico.quantization.wrapq.wrappers.registry import try_register
26
+
27
+
28
+ @try_register("transformers.models.llama.modeling_llama.LlamaDecoderLayer")
29
+ class QuantLlamaDecoderLayer(QuantModuleBase):
30
+ """
31
+ Quant-aware drop-in replacement for HF `LlamaDecoderLayer`.
32
+ Signature and return-value are identical to the original.
33
+
34
+ ▸ Attention & MLP blocks are replaced by their quantized counterparts
35
+ ▸ LayerNorms remain FP32 (no fake-quant)
36
+ ▸ A "static" causal mask is pre-built in `__init__` to avoid
37
+ dynamic boolean-to-float casts inside `forward`.
38
+
39
+ Notes on the causal mask
40
+ ------------------------
41
+ Building a boolean mask "inside" `forward` would introduce
42
+ non-deterministic dynamic ops that an integer-only accelerator cannot
43
+ fuse easily. Therefore we:
44
+
45
+ 1. Pre-compute a full upper-triangular mask of size
46
+ `[1, 1, max_seq, max_seq]` in `__init__`.
47
+ 2. In `forward`, if the caller passes `attention_mask=None`, we
48
+ slice the pre-computed template to the current sequence length.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ fp_layer: nn.Module,
54
+ *,
55
+ qcfg: Optional[PTQConfig] = None,
56
+ fp_name: Optional[str] = None,
57
+ return_type: Optional[str] = None,
58
+ ):
59
+ """
60
+ Q) Why do we need `return_type`?
61
+ A) Different versions of `transformers` wrap the decoder output in
62
+ different containers: a plain Tensor or a tuple.
63
+ """
64
+ self.return_type = return_type
65
+ if self.return_type is None:
66
+ import transformers
67
+
68
+ v = tuple(map(int, transformers.__version__.split(".")[:2]))
69
+ self.return_type = "tensor" if v >= (4, 54) else "tuple"
70
+ assert self.return_type is not None
71
+ super().__init__(qcfg, fp_name=fp_name)
72
+
73
+ # Child QuantConfigs -------------------------------------------------
74
+ attn_cfg = qcfg.child("self_attn") if qcfg else None
75
+ mlp_cfg = qcfg.child("mlp") if qcfg else None
76
+
77
+ # Quantized sub-modules ---------------------------------------------
78
+ assert hasattr(fp_layer, "self_attn") and isinstance(
79
+ fp_layer.self_attn, torch.nn.Module
80
+ )
81
+ assert hasattr(fp_layer, "mlp") and isinstance(fp_layer.mlp, torch.nn.Module)
82
+ self.self_attn = PTQWrapper(
83
+ fp_layer.self_attn, qcfg=attn_cfg, fp_name=f"{fp_name}.self_attn"
84
+ )
85
+ self.mlp = PTQWrapper(fp_layer.mlp, qcfg=mlp_cfg, fp_name=f"{fp_name}.mlp")
86
+
87
+ # LayerNorms remain FP (copied from fp_layer to keep weights)
88
+ assert hasattr(fp_layer, "input_layernorm") and isinstance(
89
+ fp_layer.input_layernorm, torch.nn.Module
90
+ )
91
+ assert hasattr(fp_layer, "post_attention_layernorm") and isinstance(
92
+ fp_layer.post_attention_layernorm, torch.nn.Module
93
+ )
94
+ self.input_layernorm = fp_layer.input_layernorm
95
+ self.post_attention_layernorm = fp_layer.post_attention_layernorm
96
+
97
+ # Static causal mask template ---------------------------------------
98
+ assert hasattr(fp_layer.self_attn, "config") and hasattr(
99
+ fp_layer.self_attn.config, "max_position_embeddings"
100
+ )
101
+ assert isinstance(fp_layer.self_attn.config.max_position_embeddings, int)
102
+ max_seq = fp_layer.self_attn.config.max_position_embeddings
103
+ mask = torch.full((1, 1, max_seq, max_seq), float("-120"))
104
+ mask.triu_(1)
105
+ self.register_buffer("causal_mask_template", mask, persistent=False)
106
+
107
+ def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor:
108
+ """Return `[1,1,L,L]` causal mask slice on *device*."""
109
+ assert isinstance(self.causal_mask_template, torch.Tensor)
110
+ return self.causal_mask_template[..., :seq_len, :seq_len].to(device)
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ position_ids: Optional[torch.LongTensor] = None,
117
+ past_key_value: Optional["Cache"] = None, # type: ignore[name-defined]
118
+ output_attentions: Optional[bool] = False,
119
+ use_cache: Optional[bool] = False,
120
+ cache_position: Optional[torch.LongTensor] = None,
121
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor] | torch.Tensor:
124
+ if output_attentions:
125
+ raise NotImplementedError(
126
+ "QuantLlamaDecoderLayer does not support output attention yet."
127
+ )
128
+ residual = hidden_states
129
+ hidden_states = self.input_layernorm(hidden_states)
130
+
131
+ if attention_mask is None or attention_mask.dtype == torch.bool:
132
+ L = hidden_states.size(1)
133
+ attention_mask = self._slice_causal(L, hidden_states.device)
134
+
135
+ attn_out = self.self_attn(
136
+ hidden_states=hidden_states,
137
+ attention_mask=attention_mask,
138
+ position_ids=position_ids,
139
+ past_key_value=past_key_value,
140
+ output_attentions=output_attentions,
141
+ use_cache=use_cache,
142
+ cache_position=cache_position,
143
+ position_embeddings=position_embeddings,
144
+ **kwargs,
145
+ )
146
+ if use_cache:
147
+ hidden_states_attn, _attn_weights, present_key_value = attn_out
148
+ else:
149
+ hidden_states_attn, _attn_weights = attn_out
150
+ present_key_value = None
151
+
152
+ hidden_states = residual + hidden_states_attn
153
+
154
+ # ─── MLP block ─────────────────────────────────────────────────
155
+ residual = hidden_states
156
+ hidden_states = self.post_attention_layernorm(hidden_states)
157
+ hidden_states = self.mlp(hidden_states)
158
+ hidden_states = residual + hidden_states
159
+
160
+ # Return type policy:
161
+ # - If use_cache: always return (hidden_states, present_key_value)
162
+ # - Else: return as configured (tuple/tensor) for HF compatibility
163
+ if use_cache:
164
+ return hidden_states, present_key_value # type: ignore[return-value]
165
+
166
+ if self.return_type == "tuple":
167
+ return (hidden_states,)
168
+ elif self.return_type == "tensor":
169
+ return hidden_states
170
+ else:
171
+ raise RuntimeError("Invalid return type.")
172
+
173
+ # No local observers; just recurse into children
174
+ def _all_observers(self):
175
+ yield from self.self_attn._all_observers()
176
+ yield from self.mlp._all_observers()
@@ -0,0 +1,96 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
22
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
+ from tico.quantization.wrapq.wrappers.registry import try_register
24
+
25
+
26
+ @try_register("transformers.models.llama.modeling_llama.LlamaMLP")
27
+ class QuantLlamaMLP(QuantModuleBase):
28
+ def __init__(
29
+ self,
30
+ mlp_fp: nn.Module,
31
+ *,
32
+ qcfg: Optional[PTQConfig] = None,
33
+ fp_name: Optional[str] = None,
34
+ ):
35
+ super().__init__(qcfg, fp_name=fp_name)
36
+
37
+ # ----- child configs (hierarchical override) -------------------
38
+ gate_cfg = qcfg.child("gate_proj") if qcfg else None
39
+ up_cfg = qcfg.child("up_proj") if qcfg else None
40
+ down_cfg = qcfg.child("down_proj") if qcfg else None
41
+ act_cfg = qcfg.child("act_fn") if qcfg else None
42
+
43
+ # ----- wrap three Linear layers -------------------------------
44
+ assert hasattr(mlp_fp, "gate_proj") and isinstance(
45
+ mlp_fp.gate_proj, torch.nn.Module
46
+ )
47
+ assert hasattr(mlp_fp, "up_proj") and isinstance(
48
+ mlp_fp.up_proj, torch.nn.Module
49
+ )
50
+ assert hasattr(mlp_fp, "down_proj") and isinstance(
51
+ mlp_fp.down_proj, torch.nn.Module
52
+ )
53
+ self.gate_proj = PTQWrapper(
54
+ mlp_fp.gate_proj, qcfg=gate_cfg, fp_name=f"{fp_name}.gate_proj"
55
+ )
56
+ self.up_proj = PTQWrapper(
57
+ mlp_fp.up_proj, qcfg=up_cfg, fp_name=f"{fp_name}.up_proj"
58
+ )
59
+ self.down_proj = PTQWrapper(
60
+ mlp_fp.down_proj, qcfg=down_cfg, fp_name=f"{fp_name}.down_proj"
61
+ )
62
+
63
+ # ----- activation ---------------------------------------------
64
+ assert hasattr(mlp_fp, "act_fn") and isinstance(mlp_fp.act_fn, torch.nn.Module)
65
+ self.act_fn = PTQWrapper(
66
+ mlp_fp.act_fn, qcfg=act_cfg, fp_name=f"{fp_name}.act_fn"
67
+ )
68
+
69
+ # ----- local observers ----------------------------------------
70
+ self.act_in_obs = self._make_obs("act_in")
71
+ self.mul_obs = self._make_obs("mul")
72
+
73
+ def forward(self, x: torch.Tensor):
74
+ # 1) quantize input once
75
+ x_q = self._fq(x, self.act_in_obs)
76
+
77
+ # 2) parallel projections
78
+ g = self.gate_proj(x_q)
79
+ u = self.up_proj(x_q)
80
+
81
+ # 3) activation on gate
82
+ a = self.act_fn(g)
83
+
84
+ # 4) element-wise product
85
+ h = self._fq(a * u, self.mul_obs)
86
+
87
+ # 5) final projection
88
+ return self.down_proj(h)
89
+
90
+ def _all_observers(self):
91
+ # local first
92
+ yield self.act_in_obs
93
+ yield self.mul_obs
94
+ # recurse into children that are QuantModuleBase
95
+ for m in (self.gate_proj, self.up_proj, self.down_proj, self.act_fn):
96
+ yield from m._all_observers()
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,183 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Iterable, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
24
+ from tico.quantization.wrapq.wrappers.registry import register
25
+
26
+
27
+ @register(nn.LayerNorm)
28
+ class QuantLayerNorm(QuantModuleBase):
29
+ """
30
+ QuantLayerNorm — drop-in replacement for nn.LayerNorm that quantizes
31
+ the elementary steps:
32
+ 1) μ = mean(x, dims) (mean)
33
+ 2) c = x - μ (sub)
34
+ 3) s = c * c (square)
35
+ 4) v = mean(s, dims) (variance)
36
+ 5) e = v + eps (add-eps)
37
+ 6) r = rsqrt(e) (rsqrt)
38
+ 7) n = c * r (normalize)
39
+ 8) y = (n * γ) + β (affine), with:
40
+ • affine_mul : n * γ
41
+ • affine_add : (n * γ) + β
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ fp: nn.LayerNorm,
47
+ *,
48
+ qcfg: Optional[PTQConfig] = None,
49
+ fp_name: Optional[str] = None
50
+ ):
51
+ super().__init__(qcfg, fp_name=fp_name)
52
+ self.module = fp
53
+ self.eps = torch.tensor(self.module.eps)
54
+ # Number of trailing dims participating in normalization
55
+ # (PyTorch stores normalized_shape as a tuple even if an int was passed)
56
+ self._norm_ndim: int = len(fp.normalized_shape) # safe for int→tuple
57
+
58
+ # Activation / intermediate observers
59
+ self.act_in_obs = self._make_obs("act_in")
60
+ self.mean_obs = self._make_obs("mean")
61
+ self.centered_obs = self._make_obs("centered")
62
+ self.square_obs = self._make_obs("square")
63
+ self.var_obs = self._make_obs("var")
64
+ self.eps_obs = self._make_obs("eps")
65
+ self.add_eps_obs = self._make_obs("add_eps")
66
+ self.inv_std_obs = self._make_obs("inv_std")
67
+ self.norm_obs = self._make_obs("norm")
68
+ self.act_out_obs = self._make_obs("act_out")
69
+
70
+ # Optional affine parameter observers (γ, β)
71
+ self.weight_obs = None
72
+ self.bias_obs = None
73
+ self.affine_mul_obs = None
74
+ self.affine_add_obs = None
75
+ if self.module.elementwise_affine:
76
+ if self.module.weight is not None:
77
+ self.weight_obs = self._make_obs("weight")
78
+ if self.module.bias is not None:
79
+ self.bias_obs = self._make_obs("bias")
80
+ # Per-op observers for (n * w) and (+ b)
81
+ self.affine_mul_obs = self._make_obs("affine_mul")
82
+ self.affine_add_obs = self._make_obs("affine_add")
83
+
84
+ def enable_calibration(self) -> None:
85
+ """
86
+ Switch to CALIB mode and collect *fixed* ranges for affine params
87
+ immediately, since they do not change across inputs.
88
+ """
89
+ super().enable_calibration()
90
+ if self.module.elementwise_affine:
91
+ if self.weight_obs is not None and self.module.weight is not None:
92
+ self.weight_obs.collect(self.module.weight)
93
+ if self.bias_obs is not None and self.module.bias is not None:
94
+ self.bias_obs.collect(self.module.bias)
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ # Determine reduction dims (last self._norm_ndim axes)
98
+ # Example: if x.ndim=4 and norm_ndim=2 → dims=(2,3)
99
+ dims = tuple(range(x.dim() - self._norm_ndim, x.dim()))
100
+
101
+ # 0) input
102
+ x_q = self._fq(x, self.act_in_obs)
103
+
104
+ # 1) mean
105
+ mu = x_q.mean(dim=dims, keepdim=True)
106
+ mu_q = self._fq(mu, self.mean_obs)
107
+
108
+ # 2) center
109
+ c = x_q - mu_q
110
+ c_q = self._fq(c, self.centered_obs)
111
+
112
+ # 3) square (elementwise mul)
113
+ s = c_q * c_q
114
+ s_q = self._fq(s, self.square_obs)
115
+
116
+ # 4) variance (via squared mean)
117
+ v = s_q.mean(dim=dims, keepdim=True)
118
+ v_q = self._fq(v, self.var_obs)
119
+
120
+ # 5) add eps
121
+ eps_q = self._fq(self.eps, self.eps_obs)
122
+ e = v_q + eps_q
123
+ e_q = self._fq(e, self.add_eps_obs)
124
+
125
+ # 6) inverse std
126
+ r = torch.rsqrt(e_q)
127
+ r_q = self._fq(r, self.inv_std_obs)
128
+
129
+ # 7) normalize
130
+ n = c_q * r_q
131
+ n_q = self._fq(n, self.norm_obs)
132
+
133
+ # 8) optional affine
134
+ if self.module.elementwise_affine:
135
+ w = self.module.weight
136
+ b = self.module.bias
137
+ if self._mode is Mode.QUANT:
138
+ if self.weight_obs is not None and w is not None:
139
+ w = self.weight_obs.fake_quant(w) # type: ignore[assignment]
140
+ if self.bias_obs is not None and b is not None:
141
+ b = self.bias_obs.fake_quant(b) # type: ignore[assignment]
142
+ y = n_q
143
+ # 8a) n * w (fake-quant the result of the mul)
144
+ if w is not None:
145
+ y = y * w
146
+ if self.affine_mul_obs is not None:
147
+ y = self._fq(y, self.affine_mul_obs)
148
+
149
+ # 8b) (+ b) (fake-quant the result of the add)
150
+ if b is not None:
151
+ y = y + b
152
+ if self.affine_add_obs is not None:
153
+ y = self._fq(y, self.affine_add_obs)
154
+ else:
155
+ y = n_q
156
+
157
+ # 9) output activation
158
+ return self._fq(y, self.act_out_obs)
159
+
160
+ def _all_observers(self) -> Iterable:
161
+ obs: Tuple = (
162
+ self.act_in_obs,
163
+ self.mean_obs,
164
+ self.centered_obs,
165
+ self.square_obs,
166
+ self.var_obs,
167
+ self.eps_obs,
168
+ self.add_eps_obs,
169
+ self.inv_std_obs,
170
+ self.norm_obs,
171
+ self.act_out_obs,
172
+ )
173
+ # Insert affine param observers if present
174
+ if self.module.elementwise_affine:
175
+ if self.weight_obs is not None:
176
+ obs = (self.weight_obs,) + obs
177
+ if self.bias_obs is not None:
178
+ obs = obs + (self.bias_obs,)
179
+ if self.affine_mul_obs is not None:
180
+ obs = obs + (self.affine_mul_obs,)
181
+ if self.affine_add_obs is not None:
182
+ obs = obs + (self.affine_add_obs,)
183
+ return obs
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.qscheme import QScheme
24
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
25
+ from tico.quantization.wrapq.wrappers.registry import register
26
+
27
+
28
+ @register(nn.Linear)
29
+ class QuantLinear(QuantModuleBase):
30
+ """Per-channel weight fake-quant, eager-output activation fake-quant."""
31
+
32
+ def __init__(
33
+ self,
34
+ fp: nn.Linear,
35
+ *,
36
+ qcfg: Optional[PTQConfig] = None,
37
+ fp_name: Optional[str] = None
38
+ ):
39
+ super().__init__(qcfg, fp_name=fp_name)
40
+ self.weight_obs = self._make_obs(
41
+ "weight", qscheme=QScheme.PER_CHANNEL_ASYMM, channel_axis=0
42
+ )
43
+ self.act_in_obs = self._make_obs("act_in")
44
+ self.act_out_obs = self._make_obs("act_out")
45
+ self.module = fp
46
+
47
+ def enable_calibration(self) -> None:
48
+ super().enable_calibration()
49
+ # immediately capture the fixed weight range
50
+ self.weight_obs.collect(self.module.weight)
51
+
52
+ def forward(self, x):
53
+ x_q = self._fq(x, self.act_in_obs)
54
+
55
+ w = self.module.weight
56
+ if self._mode is Mode.QUANT:
57
+ w = self.weight_obs.fake_quant(w)
58
+ b = self.module.bias
59
+
60
+ out = F.linear(x_q, w, b)
61
+
62
+ return self._fq(out, self.act_out_obs)
63
+
64
+ def _all_observers(self):
65
+ return (self.weight_obs, self.act_in_obs, self.act_out_obs)
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22
+ from tico.quantization.wrapq.wrappers.registry import register
23
+
24
+
25
+ @register(nn.SiLU)
26
+ class QuantSiLU(QuantModuleBase):
27
+ """
28
+ QuantSiLU — drop-in replacement for nn.SiLU that quantizes
29
+ both intermediate tensors:
30
+ • s = sigmoid(x) (logistic)
31
+ • y = x * s (mul)
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ fp: nn.SiLU,
37
+ *,
38
+ qcfg: Optional[PTQConfig] = None,
39
+ fp_name: Optional[str] = None
40
+ ):
41
+ super().__init__(qcfg, fp_name=fp_name)
42
+ self.act_in_obs = self._make_obs("act_in")
43
+ self.sig_obs = self._make_obs("sigmoid")
44
+ self.mul_obs = self._make_obs("mul")
45
+ self.module = fp
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ x_q = self._fq(x, self.act_in_obs)
49
+
50
+ s = torch.sigmoid(x_q)
51
+ s = self._fq(s, self.sig_obs)
52
+
53
+ y = x * s
54
+ y = self._fq(y, self.mul_obs)
55
+
56
+ return y
57
+
58
+ def _all_observers(self):
59
+ return (self.act_in_obs, self.sig_obs, self.mul_obs)
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+
19
+ from tico.quantization.config.ptq import PTQConfig
20
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
21
+ from tico.quantization.wrapq.wrappers.registry import lookup
22
+
23
+
24
+ class PTQWrapper(QuantModuleBase):
25
+ """
26
+ Adapter that turns a fp module into its quantized counterpart.
27
+
28
+ It is itself a QuantModuleBase so composite wrappers can treat
29
+ it exactly like any other quant module.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ module: torch.nn.Module,
35
+ qcfg: Optional[PTQConfig] = None,
36
+ *,
37
+ fp_name: Optional[str] = None,
38
+ ):
39
+ super().__init__(qcfg)
40
+ wrapped_cls = lookup(type(module))
41
+ if wrapped_cls is None:
42
+ raise NotImplementedError(f"No quant wrapper for {type(module).__name__}")
43
+ self.wrapped: QuantModuleBase = wrapped_cls(module, qcfg=qcfg, fp_name=fp_name) # type: ignore[arg-type, misc]
44
+
45
+ def forward(self, *args, **kwargs):
46
+ return self.wrapped(*args, **kwargs)
47
+
48
+ def _all_observers(self):
49
+ """
50
+ PTQWrapper itself owns NO observers (transparent node).
51
+ Returning an empty iterator prevents double-processing when parents
52
+ traverse the tree and then recurse into `self.wrapped`.
53
+ """
54
+ return () # no local observers
55
+
56
+ def named_observers(self):
57
+ """
58
+ Proxy to the wrapped module so debugging tools can still enumerate observers.
59
+ """
60
+ yield from self.wrapped.named_observers()
61
+
62
+ def get_observer(self, name: str):
63
+ """
64
+ Proxy to the wrapped module for direct lookup by name.
65
+ """
66
+ return self.wrapped.get_observer(name)
67
+
68
+ def extra_repr(self) -> str:
69
+ return self.wrapped.extra_repr()