tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/config/v1.py +5 -0
  4. tico/passes/cast_aten_where_arg_type.py +1 -1
  5. tico/passes/cast_clamp_mixed_type_args.py +169 -0
  6. tico/passes/cast_mixed_type_args.py +4 -2
  7. tico/passes/const_prop_pass.py +1 -1
  8. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  9. tico/passes/convert_expand_to_slice_cat.py +153 -0
  10. tico/passes/convert_matmul_to_linear.py +312 -0
  11. tico/passes/convert_to_relu6.py +1 -1
  12. tico/passes/decompose_addmm.py +0 -3
  13. tico/passes/decompose_batch_norm.py +2 -2
  14. tico/passes/decompose_fake_quantize.py +0 -3
  15. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
  16. tico/passes/decompose_group_norm.py +0 -3
  17. tico/passes/legalize_predefined_layout_operators.py +2 -11
  18. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  19. tico/passes/lower_to_slice.py +1 -1
  20. tico/passes/merge_consecutive_cat.py +1 -1
  21. tico/passes/ops.py +1 -1
  22. tico/passes/remove_redundant_assert_nodes.py +3 -1
  23. tico/passes/remove_redundant_expand.py +3 -6
  24. tico/passes/remove_redundant_reshape.py +5 -5
  25. tico/passes/segment_index_select.py +1 -1
  26. tico/quantization/__init__.py +6 -0
  27. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  28. tico/quantization/algorithm/gptq/quantizer.py +292 -0
  29. tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
  30. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
  31. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  32. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  33. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  34. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  35. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
  36. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  37. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  38. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  39. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  40. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  41. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  42. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  44. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
  45. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  46. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  47. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  48. tico/quantization/config/base.py +26 -0
  49. tico/quantization/config/gptq.py +29 -0
  50. tico/quantization/config/pt2e.py +25 -0
  51. tico/quantization/config/ptq.py +119 -0
  52. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  53. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
  54. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  55. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  56. tico/quantization/evaluation/metric.py +146 -0
  57. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  58. tico/quantization/passes/__init__.py +1 -0
  59. tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
  60. tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
  61. tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
  62. tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
  63. tico/{experimental/quantization → quantization}/public_interface.py +19 -18
  64. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  65. tico/quantization/quantizer_registry.py +73 -0
  66. tico/quantization/wrapq/__init__.py +1 -0
  67. tico/quantization/wrapq/dtypes.py +70 -0
  68. tico/quantization/wrapq/examples/__init__.py +1 -0
  69. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  70. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  71. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  72. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  73. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  74. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  75. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  76. tico/quantization/wrapq/mode.py +32 -0
  77. tico/quantization/wrapq/observers/__init__.py +1 -0
  78. tico/quantization/wrapq/observers/affine_base.py +128 -0
  79. tico/quantization/wrapq/observers/base.py +98 -0
  80. tico/quantization/wrapq/observers/ema.py +62 -0
  81. tico/quantization/wrapq/observers/identity.py +74 -0
  82. tico/quantization/wrapq/observers/minmax.py +39 -0
  83. tico/quantization/wrapq/observers/mx.py +60 -0
  84. tico/quantization/wrapq/qscheme.py +40 -0
  85. tico/quantization/wrapq/quantizer.py +179 -0
  86. tico/quantization/wrapq/utils/__init__.py +1 -0
  87. tico/quantization/wrapq/utils/introspection.py +167 -0
  88. tico/quantization/wrapq/utils/metrics.py +124 -0
  89. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  90. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  91. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  92. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  93. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  94. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  95. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  96. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  97. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  98. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  99. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  100. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  101. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  102. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  103. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  104. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  105. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  106. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  107. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  108. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  109. tico/quantization/wrapq/wrappers/registry.py +125 -0
  110. tico/serialize/circle_graph.py +12 -4
  111. tico/serialize/circle_mapping.py +76 -2
  112. tico/serialize/circle_serializer.py +253 -148
  113. tico/serialize/operators/adapters/__init__.py +1 -0
  114. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  115. tico/serialize/operators/op_any.py +7 -14
  116. tico/serialize/operators/op_avg_pool2d.py +11 -4
  117. tico/serialize/operators/op_clamp.py +5 -7
  118. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  119. tico/serialize/operators/op_conv2d.py +14 -6
  120. tico/serialize/operators/op_copy.py +26 -3
  121. tico/serialize/operators/op_cumsum.py +3 -1
  122. tico/serialize/operators/op_depthwise_conv2d.py +17 -7
  123. tico/serialize/operators/op_full_like.py +0 -2
  124. tico/serialize/operators/op_index_select.py +8 -1
  125. tico/serialize/operators/op_instance_norm.py +0 -6
  126. tico/serialize/operators/op_le.py +54 -0
  127. tico/serialize/operators/op_log1p.py +3 -2
  128. tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
  129. tico/serialize/operators/op_mm.py +15 -131
  130. tico/serialize/operators/op_mul.py +2 -8
  131. tico/serialize/operators/op_pow.py +3 -1
  132. tico/serialize/operators/op_repeat.py +12 -3
  133. tico/serialize/operators/op_reshape.py +1 -1
  134. tico/serialize/operators/op_rmsnorm.py +65 -0
  135. tico/serialize/operators/op_softmax.py +7 -14
  136. tico/serialize/operators/op_split_with_sizes.py +16 -8
  137. tico/serialize/operators/op_transpose_conv.py +11 -8
  138. tico/serialize/operators/op_view.py +2 -1
  139. tico/serialize/quant_param.py +5 -5
  140. tico/utils/convert.py +30 -17
  141. tico/utils/dtype.py +42 -0
  142. tico/utils/graph.py +1 -1
  143. tico/utils/model.py +2 -1
  144. tico/utils/padding.py +2 -2
  145. tico/utils/pytree_utils.py +134 -0
  146. tico/utils/record_input.py +102 -0
  147. tico/utils/register_custom_op.py +29 -4
  148. tico/utils/serialize.py +16 -3
  149. tico/utils/signature.py +247 -0
  150. tico/utils/torch_compat.py +52 -0
  151. tico/utils/utils.py +50 -58
  152. tico/utils/validate_args_kwargs.py +38 -3
  153. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  154. tico-0.1.0.dev251102.dist-info/RECORD +271 -0
  155. tico/experimental/quantization/__init__.py +0 -1
  156. tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
  157. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  158. tico/experimental/quantization/evaluation/metric.py +0 -109
  159. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
  160. tico-0.1.0.dev250714.dist-info/RECORD +0 -209
  161. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  162. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  163. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  164. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  165. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  166. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  167. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  168. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  169. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  170. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  171. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  172. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  173. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  174. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  175. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  176. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  177. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  178. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  179. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  180. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  181. {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,111 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22
+ from tico.quantization.wrapq.wrappers.registry import register
23
+
24
+
25
+ class QuantElementwise(QuantModuleBase):
26
+ """
27
+ Generic wrapper for any 1-to-1 element-wise op `y = f(x)`.
28
+
29
+ Sub-classes only need to implement:
30
+ • `FUNC`: a Callable that maps tensor→tensor
31
+ """
32
+
33
+ # subclass must set this
34
+ FUNC: Callable[[torch.Tensor], torch.Tensor] | None = None
35
+
36
+ def __init_subclass__(cls, **kwargs):
37
+ super().__init_subclass__(**kwargs)
38
+ if cls is QuantElementwise:
39
+ return
40
+ if cls.FUNC is None:
41
+ raise NotImplementedError(
42
+ f"{cls.__name__} must define a staticmethod `FUNC(tensor) -> tensor`"
43
+ )
44
+
45
+ def __init__(
46
+ self,
47
+ fp_module: nn.Module,
48
+ *,
49
+ qcfg: Optional[PTQConfig] = None,
50
+ fp_name: Optional[str] = None,
51
+ ):
52
+ super().__init__(qcfg, fp_name=fp_name)
53
+ self.module = fp_module
54
+ self.act_in_obs = self._make_obs("act_in")
55
+ self.act_out_obs = self._make_obs("act_out")
56
+
57
+ # ------------------------------------------------------------
58
+ def forward(self, x):
59
+ x_q = self._fq(x, self.act_in_obs)
60
+ assert self.FUNC is not None
61
+ y = self.FUNC(x_q) # element-wise op
62
+ y_q = self._fq(y, self.act_out_obs)
63
+ return y_q
64
+
65
+ # ------------------------------------------------------------
66
+ def _all_observers(self):
67
+ return (self.act_in_obs, self.act_out_obs)
68
+
69
+
70
+ """
71
+ Why `FUNC` is a `staticmethod`
72
+
73
+ - Prevents automatic binding: calling `self.FUNC(x)` will not inject `self`,
74
+ so the callable keeps the expected signature `Tensor -> Tensor`
75
+ (e.g., `torch.sigmoid(x)`), avoiding TypeErrors.
76
+
77
+ - Expresses purity and statelessness: `FUNC` is a pure, element-wise transform
78
+ that must not read or mutate module state (params, buffers, config).
79
+
80
+ - Tracing/export friendly (FX / TorchScript): the call is captured as
81
+ `call_function(torch.*)` instead of a bound `call_method`, which makes graph
82
+ rewrites/pattern-matching and backends' substitutions more reliable.
83
+
84
+ - Avoids submodule pollution: we keep a functional op (`torch.relu`) rather
85
+ than an `nn.Module` instance that would appear in the module tree.
86
+
87
+ - Small perf/alloc win: no bound-method objects are created on each call.
88
+ """
89
+
90
+ # Sigmoid
91
+ @register(nn.Sigmoid)
92
+ class QuantSigmoid(QuantElementwise):
93
+ FUNC = staticmethod(torch.sigmoid)
94
+
95
+
96
+ # Tanh
97
+ @register(nn.Tanh)
98
+ class QuantTanh(QuantElementwise):
99
+ FUNC = staticmethod(torch.tanh)
100
+
101
+
102
+ # ReLU
103
+ @register(nn.ReLU)
104
+ class QuantReLU(QuantElementwise):
105
+ FUNC = staticmethod(torch.relu)
106
+
107
+
108
+ # GELU (approximate)
109
+ @register(nn.GELU)
110
+ class QuantGELU(QuantElementwise):
111
+ FUNC = staticmethod(torch.nn.functional.gelu)
@@ -0,0 +1,168 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Iterable, Optional, Tuple
17
+
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.observers.base import ObserverBase
24
+
25
+
26
+ class QuantModuleBase(nn.Module, ABC):
27
+ """
28
+ Abstract parent for EVERY wrapper.
29
+
30
+ Responsibilities
31
+ ----------------
32
+ • Own *one* Mode enum (`NO_QUANT / CALIB / QUANT`)
33
+ • Own a PTQConfig describing default / per-observer dtypes
34
+ • Expose a canonical lifecycle:
35
+ enable_calibration()
36
+ freeze_qparams()
37
+ • Provide helper `_fq(x, observer)` (“fake-quant or collect”) so
38
+ subclasses write arithmetic code without boilerplate.
39
+ """
40
+
41
+ def __init__(
42
+ self, qcfg: Optional[PTQConfig] = None, *, fp_name: Optional[str] = None
43
+ ) -> None:
44
+ super().__init__()
45
+ self.qcfg = qcfg or PTQConfig()
46
+ self._mode: Mode = Mode.NO_QUANT # default state
47
+ self.fp_name = fp_name
48
+
49
+ def _child_quant_modules(self):
50
+ """
51
+ Yield immediate QuantModuleBase *descendants*, skipping over pure containers
52
+ (e.g., ModuleList/Sequential/ModuleDict). Once a QuantModuleBase is found,
53
+ do NOT descend into it here—let recursion happen level by level.
54
+ """
55
+ seen = set()
56
+ stack = list(self.children()) # start from direct children
57
+
58
+ while stack:
59
+ m = stack.pop()
60
+ if isinstance(m, QuantModuleBase):
61
+ if id(m) not in seen:
62
+ seen.add(id(m))
63
+ yield m
64
+ # IMPORTANT: do not recurse into `m` here; its own call will handle its subtree
65
+ elif isinstance(m, (nn.ModuleList, nn.ModuleDict, nn.Sequential)):
66
+ # `m` is a container or a non-quant leaf: keep descending until we hit quant modules
67
+ stack.extend(list(m.children()))
68
+
69
+ def enable_calibration(self) -> None:
70
+ self._mode = Mode.CALIB
71
+ for obs in self._all_observers():
72
+ obs.enabled = True
73
+ obs.reset()
74
+
75
+ # propagate to children
76
+ for child in self._child_quant_modules():
77
+ child.enable_calibration()
78
+
79
+ def freeze_qparams(self) -> None:
80
+ self._mode = Mode.QUANT
81
+ for obs in self._all_observers():
82
+ obs.enabled = False
83
+ obs.compute_qparams()
84
+
85
+ # propagate to children
86
+ for child in self._child_quant_modules():
87
+ child.freeze_qparams()
88
+
89
+ def _fq(self, x, obs: ObserverBase):
90
+ """Fake-quant or collect."""
91
+ if self._mode is Mode.CALIB:
92
+ obs.collect(x.detach())
93
+ return x
94
+ if self._mode is Mode.QUANT:
95
+ return obs.fake_quant(x)
96
+ return x # NO_QUANT
97
+
98
+ @abstractmethod
99
+ def _all_observers(self) -> Iterable[ObserverBase]:
100
+ """Return every observer owned by this module."""
101
+ ...
102
+
103
+ def named_observers(self) -> Iterable[Tuple[str, ObserverBase]]:
104
+ for obs in self._all_observers():
105
+ yield obs.name, obs
106
+
107
+ def get_observer(self, name: str) -> Optional[ObserverBase]:
108
+ for obs in self._all_observers():
109
+ if obs.name == name:
110
+ return obs
111
+ return None
112
+
113
+ def _make_obs(
114
+ self,
115
+ name: str,
116
+ **default_kwargs,
117
+ ) -> ObserverBase:
118
+ """
119
+ Instantiate an observer named *name*.
120
+
121
+ Precedence (3-tier) for keys:
122
+ • observer: user > wrapper-default > PTQConfig.default_observer
123
+ • dtype: user > wrapper-default > PTQConfig.default_dtype
124
+ • qscheme: user > wrapper-default > PTQConfig.default_qscheme
125
+
126
+ Other kwargs (e.g., qscheme, channel_axis, etc.) remain:
127
+ user override > wrapper-default
128
+ """
129
+ _UNSPEC = object()
130
+
131
+ wrapper_defaults = default_kwargs.copy()
132
+ user_cfg = self.qcfg.get_kwargs(name).copy()
133
+
134
+ def pick3(user_val, wrap_val, global_val):
135
+ return (
136
+ user_val
137
+ if user_val is not _UNSPEC
138
+ else wrap_val
139
+ if wrap_val is not _UNSPEC
140
+ else global_val
141
+ )
142
+
143
+ # 1) resolve observer class
144
+ user_observer = user_cfg.pop("observer", _UNSPEC)
145
+ wrapper_observer = wrapper_defaults.pop("observer", _UNSPEC)
146
+ obs_cls = pick3(user_observer, wrapper_observer, self.qcfg.default_observer)
147
+
148
+ # 2) resolve dtype
149
+ user_dtype = user_cfg.pop("dtype", _UNSPEC)
150
+ wrapper_dtype = wrapper_defaults.pop("dtype", _UNSPEC)
151
+ final_dtype = pick3(user_dtype, wrapper_dtype, self.qcfg.default_dtype)
152
+
153
+ # 3) resolve qscheme
154
+ user_qscheme = user_cfg.pop("qscheme", _UNSPEC)
155
+ wrapper_qscheme = wrapper_defaults.pop("qscheme", _UNSPEC)
156
+ final_qscheme = pick3(user_qscheme, wrapper_qscheme, self.qcfg.default_qscheme)
157
+
158
+ # 4) merge remaining kwargs: user_cfg wins
159
+ final_kw = wrapper_defaults
160
+ final_kw.update(user_cfg)
161
+ final_kw["dtype"] = final_dtype
162
+ final_kw["qscheme"] = final_qscheme
163
+
164
+ return obs_cls(**final_kw, name=name)
165
+
166
+ # nice repr
167
+ def extra_repr(self) -> str:
168
+ return f"mode={self._mode.name.lower()}"
@@ -0,0 +1,125 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ from typing import Callable, Dict, Type
17
+
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
21
+
22
+ _WRAPPERS: Dict[Type[nn.Module], Type[QuantModuleBase]] = {}
23
+ _IMPORT_ONCE = False
24
+ _CORE_MODULES = (
25
+ "tico.quantization.wrapq.wrappers.quant_elementwise",
26
+ "tico.quantization.wrapq.wrappers.nn.quant_layernorm",
27
+ "tico.quantization.wrapq.wrappers.nn.quant_linear",
28
+ "tico.quantization.wrapq.wrappers.nn.quant_silu",
29
+ # llama
30
+ "tico.quantization.wrapq.wrappers.llama.quant_attn",
31
+ "tico.quantization.wrapq.wrappers.llama.quant_decoder_layer",
32
+ "tico.quantization.wrapq.wrappers.llama.quant_mlp",
33
+ # fairseq
34
+ "tico.quantization.wrapq.wrappers.fairseq.quant_decoder_layer",
35
+ "tico.quantization.wrapq.wrappers.fairseq.quant_encoder",
36
+ "tico.quantization.wrapq.wrappers.fairseq.quant_encoder_layer",
37
+ "tico.quantization.wrapq.wrappers.fairseq.quant_mha",
38
+ # add future core wrappers here
39
+ )
40
+
41
+
42
+ def _lazy_init():
43
+ """
44
+ Deferred one-shot import of "core wrapper modules".
45
+
46
+ Why not import everything when the program first starts?
47
+ --------------------------------------------------
48
+ * **Avoid circular-import hell**
49
+ Core wrappers often import `PTQWrapper`, which in turn calls
50
+ `registry.lookup()`. Importing those files eagerly here would create a
51
+ cycle (`registry → wrapper → registry`). Delaying the import until the
52
+ *first* `lookup()` call lets Python finish constructing the registry
53
+ module before any wrapper files are touched.
54
+
55
+ * **Cold-start speed**
56
+ Most user code never wraps layers explicitly; they only hit
57
+ `PTQWrapper` if they are doing quantization. Deferring half-a-dozen
58
+ heavyweight `import torch …` files until they are really needed
59
+ reduces library start-up latency in the common path.
60
+
61
+ * **Optional dependencies**
62
+ Core wrappers listed in `_CORE_MODULES` are chosen to be dependency-free
63
+ (pure PyTorch). Anything that needs `transformers`, `torchvision`,
64
+ etc. uses the `@try_register()` decorator inside its own module. Those
65
+ optional modules are *not* imported here, so users without the extra
66
+ packages still get a clean import.
67
+
68
+ Implementation notes
69
+ --------------------
70
+ * `_IMPORT_ONCE` guard ensures we execute the import loop only once,
71
+ even if `lookup()` is called from multiple threads.
72
+ * Each path in `_CORE_MODULES` is a "fully-qualified module string"
73
+ (e.g. "ptq.wrappers.linear_quant"). Importing the module runs all
74
+ its `@register(nn.Layer)` decorators, populating `_WRAPPERS`.
75
+ * After the first call the function becomes a cheap constant-time no-op.
76
+ """
77
+ global _IMPORT_ONCE
78
+ if _IMPORT_ONCE:
79
+ return
80
+ for mod in _CORE_MODULES:
81
+ __import__(mod) # triggers decorators
82
+ _IMPORT_ONCE = True
83
+
84
+
85
+ # ───────────────────────────── decorator for always-present classes
86
+ def register(
87
+ fp_cls: Type[nn.Module],
88
+ ) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
89
+ def _decorator(quant_cls: Type[QuantModuleBase]):
90
+ _WRAPPERS[fp_cls] = quant_cls
91
+ return quant_cls
92
+
93
+ return _decorator
94
+
95
+
96
+ # ───────────────────────────── conditional decorator
97
+ def try_register(
98
+ *paths: str,
99
+ ) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
100
+ """
101
+ @try_register("transformers.models.llama.modeling_llama.LlamaMLP")
102
+
103
+ • If import succeeds → behave like `@register`
104
+ • If module/class not found → become a NO-OP
105
+ """
106
+
107
+ def _decorator(quant_cls: Type[QuantModuleBase]):
108
+ for path in paths:
109
+ module_name, _, cls_name = path.rpartition(".")
110
+ try:
111
+ mod = importlib.import_module(module_name)
112
+ fp_cls = getattr(mod, cls_name)
113
+ _WRAPPERS[fp_cls] = quant_cls
114
+ except (ModuleNotFoundError, AttributeError):
115
+ # optional dep missing or class renamed – skip silently
116
+ pass
117
+ return quant_cls
118
+
119
+ return _decorator
120
+
121
+
122
+ # ───────────────────────────── lookup
123
+ def lookup(fp_cls: Type[nn.Module]) -> Type[QuantModuleBase] | None:
124
+ _lazy_init()
125
+ return _WRAPPERS.get(fp_cls)
@@ -24,9 +24,10 @@ from torch._subclasses.fake_tensor import FakeTensor
24
24
 
25
25
  from tico.serialize.circle_mapping import (
26
26
  extract_circle_dtype,
27
- extract_shape,
27
+ extract_circle_shape,
28
28
  str_to_circle_dtype,
29
29
  to_circle_dtype,
30
+ to_circle_shape,
30
31
  )
31
32
  from tico.serialize.pack import pack_buffer
32
33
  from tico.serialize.quant_param import QPARAM_KEY, QuantParam
@@ -151,7 +152,8 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
151
152
  self.name_to_node[tensor.name] = node
152
153
  assert node.meta.get("val") is not None
153
154
  tensor.type = extract_circle_dtype(node)
154
- tensor.shape = list(extract_shape(node))
155
+ tensor.shape, tensor.shapeSignature = extract_circle_shape(node) # type: ignore[assignment]
156
+
155
157
  if QPARAM_KEY in node.meta:
156
158
  tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
157
159
  tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
@@ -185,7 +187,7 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
185
187
  torch_t = torch.as_tensor(data=data)
186
188
  torch_t_shape = list(torch_t.size())
187
189
  tensor.type = to_circle_dtype(torch_dtype=torch_t.dtype)
188
- tensor.shape = torch_t_shape
190
+ tensor.shape, tensor.shapeSignature = to_circle_shape(torch_t_shape)
189
191
 
190
192
  buffer = circle.Buffer.BufferT()
191
193
  buffer.data = torch_t.flatten().cpu().numpy().view(np.uint8) # type: ignore[assignment]
@@ -199,6 +201,7 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
199
201
  self,
200
202
  prefix: str,
201
203
  shape: List[int],
204
+ shape_signature: Optional[List[int]],
202
205
  dtype: int,
203
206
  qparam: Optional[QuantParam] = None,
204
207
  source_node: Optional[torch.fx.Node] = None,
@@ -221,6 +224,8 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
221
224
  A name prefix used to generate a unique tensor name.
222
225
  shape : List[int]
223
226
  The shape of the tensor.
227
+ shape_signature : Optional[List[int]]
228
+ The shape signature of the tensor to express Dynamic Shape. Defaults to `None` for Static Shape.
224
229
  dtype : int
225
230
  The Circle-compatible dtype of the tensor. Use `to_circle_dtype()` to convert.
226
231
  qparam : Optional[QuantParam]
@@ -241,6 +246,9 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
241
246
  if source_node is not None:
242
247
  self.name_to_node[tensor.name] = source_node
243
248
  tensor.shape = shape
249
+ if shape_signature is not None:
250
+ tensor.shapeSignature = shape_signature
251
+
244
252
  if qparam is not None:
245
253
  tensor.quantization = to_circle_qparam(qparam)
246
254
  tensor.type = str_to_circle_dtype(qparam.dtype)
@@ -305,7 +313,7 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
305
313
  self, node: Union[torch.fx.Node, circle.Tensor.TensorT, ConstData]
306
314
  ) -> int:
307
315
  # return -1 if node is None. This is for generating CircleOutputExclude
308
- if node == None:
316
+ if node is None:
309
317
  return -1
310
318
 
311
319
  if hasattr(node, "name") and node.name in self.name_to_tid:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple, TYPE_CHECKING, Union
15
+ from typing import List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  import torch.fx
@@ -128,6 +128,79 @@ def extract_shape(node: torch.fx.Node) -> torch.Size:
128
128
  return val_shape
129
129
 
130
130
 
131
+ def extract_circle_shape(node: torch.fx.Node) -> Tuple[List[int], Optional[List[int]]]:
132
+ return to_circle_shape(extract_shape(node))
133
+
134
+
135
+ def to_circle_shape(
136
+ torch_shape: Union[
137
+ torch.Size, Sequence[int | torch.SymInt]
138
+ ], # Sequence[int | torch.SymInt] is added for type covariance
139
+ ) -> Tuple[List[int], Optional[List[int]]]:
140
+
141
+ if any(isinstance(s, torch.SymInt) for s in torch_shape):
142
+ # Follow dynamic shape spec
143
+ shape = []
144
+ shape_signature = []
145
+ for s in torch_shape:
146
+ if isinstance(s, torch.SymInt):
147
+ shape.append(1)
148
+ shape_signature.append(-1)
149
+ elif isinstance(s, int):
150
+ shape.append(s)
151
+ shape_signature.append(s)
152
+ else:
153
+ raise RuntimeError(f"Unsupported shape {torch_shape}")
154
+ return shape, shape_signature
155
+ else:
156
+ # Follow static shape spec
157
+ shape = []
158
+ shape_signature = None
159
+ for s in torch_shape:
160
+ if isinstance(s, int):
161
+ shape.append(s)
162
+ else:
163
+ assert False, "Cannot reach here"
164
+ return shape, shape_signature
165
+
166
+
167
+ def validate_circle_shape(shape: List[int], shape_signature: Optional[List[int]]):
168
+ """
169
+ Validate circle tensor shape and shape_signature.
170
+ @ref https://github.com/Samsung/TICO/issues/244
171
+ """
172
+ if shape_signature is not None:
173
+ if len(shape_signature) == 0:
174
+ raise ValueError(
175
+ "Invalid circle shape: shape_signature must not be an empty list. "
176
+ "For static shapes, use None instead of []."
177
+ )
178
+ if len(shape) != len(shape_signature):
179
+ raise ValueError(
180
+ f"Invalid circle shape: shape and shape_signature must have same length: {shape} {shape_signature}"
181
+ )
182
+ if not all(isinstance(s, int) for s in shape_signature):
183
+ raise ValueError(
184
+ f"circle tensor shape_signature must be all integer values. {shape_signature}"
185
+ )
186
+ for s, ss in zip(shape, shape_signature):
187
+ if ss == -1:
188
+ # dynamic shape dimension
189
+ if s != 1:
190
+ raise ValueError(
191
+ f"Invalid circle shape: {s} {ss} {shape} {shape_signature}"
192
+ )
193
+ else:
194
+ # static shape dimension
195
+ if s != ss:
196
+ raise ValueError(
197
+ f"Invalid circle shape: {s} {ss} {shape} {shape_signature}"
198
+ )
199
+
200
+ if not all(isinstance(s, int) for s in shape):
201
+ raise ValueError(f"circle tensor shape must be all integer values. {shape}")
202
+
203
+
131
204
  # Return stride of node
132
205
  def extract_stride(node: torch.fx.Node) -> Tuple[int, ...]:
133
206
  assert node.meta is not None
@@ -157,7 +230,8 @@ def check_if_i32_range(axis: Union[list, int]):
157
230
  return all(INT32_MIN <= val <= INT32_MAX for val in values)
158
231
 
159
232
 
160
- def circle_legalize_dtype_to(values, *, dtype: torch.dtype):
233
+ # TODO: Revisit this dtype legalization function as it breaks SRP
234
+ def circle_legalize_dtype_to(values, *, dtype: torch.dtype) -> torch.Tensor:
161
235
  """
162
236
  Legalize data types from `torch.int64` to `torch.int32`.
163
237