tico 0.1.0.dev250803__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -4
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_assert_nodes.py +3 -1
  10. tico/passes/remove_redundant_expand.py +3 -1
  11. tico/quantization/__init__.py +6 -0
  12. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +30 -8
  14. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  29. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  31. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  32. tico/quantization/config/base.py +26 -0
  33. tico/quantization/config/gptq.py +29 -0
  34. tico/quantization/config/pt2e.py +25 -0
  35. tico/quantization/config/ptq.py +119 -0
  36. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
  37. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +7 -16
  38. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  39. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  40. tico/quantization/evaluation/metric.py +146 -0
  41. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  42. tico/quantization/passes/__init__.py +1 -0
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/__init__.py +1 -0
  47. tico/quantization/wrapq/dtypes.py +70 -0
  48. tico/quantization/wrapq/examples/__init__.py +1 -0
  49. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  50. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  51. tico/quantization/wrapq/examples/quantize_linear.py +107 -0
  52. tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
  53. tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
  54. tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
  55. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  56. tico/quantization/wrapq/mode.py +32 -0
  57. tico/quantization/wrapq/observers/__init__.py +1 -0
  58. tico/quantization/wrapq/observers/affine_base.py +128 -0
  59. tico/quantization/wrapq/observers/base.py +98 -0
  60. tico/quantization/wrapq/observers/ema.py +62 -0
  61. tico/quantization/wrapq/observers/identity.py +74 -0
  62. tico/quantization/wrapq/observers/minmax.py +39 -0
  63. tico/quantization/wrapq/observers/mx.py +60 -0
  64. tico/quantization/wrapq/qscheme.py +40 -0
  65. tico/quantization/wrapq/quantizer.py +179 -0
  66. tico/quantization/wrapq/utils/__init__.py +1 -0
  67. tico/quantization/wrapq/utils/introspection.py +167 -0
  68. tico/quantization/wrapq/utils/metrics.py +124 -0
  69. tico/quantization/wrapq/utils/reduce_utils.py +25 -0
  70. tico/quantization/wrapq/wrappers/__init__.py +1 -0
  71. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  72. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  73. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  74. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  75. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  76. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  77. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  78. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  79. tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
  80. tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
  81. tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
  82. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  83. tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
  84. tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
  85. tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
  86. tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
  87. tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
  88. tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
  89. tico/quantization/wrapq/wrappers/registry.py +125 -0
  90. tico/serialize/circle_serializer.py +11 -4
  91. tico/serialize/operators/adapters/__init__.py +1 -0
  92. tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
  93. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  94. tico/serialize/operators/op_le.py +54 -0
  95. tico/serialize/operators/op_mm.py +15 -132
  96. tico/serialize/operators/op_rmsnorm.py +65 -0
  97. tico/utils/convert.py +20 -15
  98. tico/utils/dtype.py +22 -0
  99. tico/utils/register_custom_op.py +29 -4
  100. tico/utils/signature.py +247 -0
  101. tico/utils/utils.py +50 -53
  102. tico/utils/validate_args_kwargs.py +37 -0
  103. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
  104. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/RECORD +130 -73
  105. tico/experimental/quantization/__init__.py +0 -6
  106. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  107. tico/experimental/quantization/evaluation/metric.py +0 -109
  108. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  109. /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  111. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  112. /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
  113. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
  114. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  115. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  116. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
  117. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  118. /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
  119. /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
  120. /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
  121. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  122. /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
  123. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  124. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  125. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  126. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  127. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  128. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  129. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  130. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250803.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,111 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22
+ from tico.quantization.wrapq.wrappers.registry import register
23
+
24
+
25
+ class QuantElementwise(QuantModuleBase):
26
+ """
27
+ Generic wrapper for any 1-to-1 element-wise op `y = f(x)`.
28
+
29
+ Sub-classes only need to implement:
30
+ • `FUNC`: a Callable that maps tensor→tensor
31
+ """
32
+
33
+ # subclass must set this
34
+ FUNC: Callable[[torch.Tensor], torch.Tensor] | None = None
35
+
36
+ def __init_subclass__(cls, **kwargs):
37
+ super().__init_subclass__(**kwargs)
38
+ if cls is QuantElementwise:
39
+ return
40
+ if cls.FUNC is None:
41
+ raise NotImplementedError(
42
+ f"{cls.__name__} must define a staticmethod `FUNC(tensor) -> tensor`"
43
+ )
44
+
45
+ def __init__(
46
+ self,
47
+ fp_module: nn.Module,
48
+ *,
49
+ qcfg: Optional[PTQConfig] = None,
50
+ fp_name: Optional[str] = None,
51
+ ):
52
+ super().__init__(qcfg, fp_name=fp_name)
53
+ self.module = fp_module
54
+ self.act_in_obs = self._make_obs("act_in")
55
+ self.act_out_obs = self._make_obs("act_out")
56
+
57
+ # ------------------------------------------------------------
58
+ def forward(self, x):
59
+ x_q = self._fq(x, self.act_in_obs)
60
+ assert self.FUNC is not None
61
+ y = self.FUNC(x_q) # element-wise op
62
+ y_q = self._fq(y, self.act_out_obs)
63
+ return y_q
64
+
65
+ # ------------------------------------------------------------
66
+ def _all_observers(self):
67
+ return (self.act_in_obs, self.act_out_obs)
68
+
69
+
70
+ """
71
+ Why `FUNC` is a `staticmethod`
72
+
73
+ - Prevents automatic binding: calling `self.FUNC(x)` will not inject `self`,
74
+ so the callable keeps the expected signature `Tensor -> Tensor`
75
+ (e.g., `torch.sigmoid(x)`), avoiding TypeErrors.
76
+
77
+ - Expresses purity and statelessness: `FUNC` is a pure, element-wise transform
78
+ that must not read or mutate module state (params, buffers, config).
79
+
80
+ - Tracing/export friendly (FX / TorchScript): the call is captured as
81
+ `call_function(torch.*)` instead of a bound `call_method`, which makes graph
82
+ rewrites/pattern-matching and backends' substitutions more reliable.
83
+
84
+ - Avoids submodule pollution: we keep a functional op (`torch.relu`) rather
85
+ than an `nn.Module` instance that would appear in the module tree.
86
+
87
+ - Small perf/alloc win: no bound-method objects are created on each call.
88
+ """
89
+
90
+ # Sigmoid
91
+ @register(nn.Sigmoid)
92
+ class QuantSigmoid(QuantElementwise):
93
+ FUNC = staticmethod(torch.sigmoid)
94
+
95
+
96
+ # Tanh
97
+ @register(nn.Tanh)
98
+ class QuantTanh(QuantElementwise):
99
+ FUNC = staticmethod(torch.tanh)
100
+
101
+
102
+ # ReLU
103
+ @register(nn.ReLU)
104
+ class QuantReLU(QuantElementwise):
105
+ FUNC = staticmethod(torch.relu)
106
+
107
+
108
+ # GELU (approximate)
109
+ @register(nn.GELU)
110
+ class QuantGELU(QuantElementwise):
111
+ FUNC = staticmethod(torch.nn.functional.gelu)
@@ -0,0 +1,168 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Iterable, Optional, Tuple
17
+
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.observers.base import ObserverBase
24
+
25
+
26
+ class QuantModuleBase(nn.Module, ABC):
27
+ """
28
+ Abstract parent for EVERY wrapper.
29
+
30
+ Responsibilities
31
+ ----------------
32
+ • Own *one* Mode enum (`NO_QUANT / CALIB / QUANT`)
33
+ • Own a PTQConfig describing default / per-observer dtypes
34
+ • Expose a canonical lifecycle:
35
+ enable_calibration()
36
+ freeze_qparams()
37
+ • Provide helper `_fq(x, observer)` (“fake-quant or collect”) so
38
+ subclasses write arithmetic code without boilerplate.
39
+ """
40
+
41
+ def __init__(
42
+ self, qcfg: Optional[PTQConfig] = None, *, fp_name: Optional[str] = None
43
+ ) -> None:
44
+ super().__init__()
45
+ self.qcfg = qcfg or PTQConfig()
46
+ self._mode: Mode = Mode.NO_QUANT # default state
47
+ self.fp_name = fp_name
48
+
49
+ def _child_quant_modules(self):
50
+ """
51
+ Yield immediate QuantModuleBase *descendants*, skipping over pure containers
52
+ (e.g., ModuleList/Sequential/ModuleDict). Once a QuantModuleBase is found,
53
+ do NOT descend into it here—let recursion happen level by level.
54
+ """
55
+ seen = set()
56
+ stack = list(self.children()) # start from direct children
57
+
58
+ while stack:
59
+ m = stack.pop()
60
+ if isinstance(m, QuantModuleBase):
61
+ if id(m) not in seen:
62
+ seen.add(id(m))
63
+ yield m
64
+ # IMPORTANT: do not recurse into `m` here; its own call will handle its subtree
65
+ elif isinstance(m, (nn.ModuleList, nn.ModuleDict, nn.Sequential)):
66
+ # `m` is a container or a non-quant leaf: keep descending until we hit quant modules
67
+ stack.extend(list(m.children()))
68
+
69
+ def enable_calibration(self) -> None:
70
+ self._mode = Mode.CALIB
71
+ for obs in self._all_observers():
72
+ obs.enabled = True
73
+ obs.reset()
74
+
75
+ # propagate to children
76
+ for child in self._child_quant_modules():
77
+ child.enable_calibration()
78
+
79
+ def freeze_qparams(self) -> None:
80
+ self._mode = Mode.QUANT
81
+ for obs in self._all_observers():
82
+ obs.enabled = False
83
+ obs.compute_qparams()
84
+
85
+ # propagate to children
86
+ for child in self._child_quant_modules():
87
+ child.freeze_qparams()
88
+
89
+ def _fq(self, x, obs: ObserverBase):
90
+ """Fake-quant or collect."""
91
+ if self._mode is Mode.CALIB:
92
+ obs.collect(x.detach())
93
+ return x
94
+ if self._mode is Mode.QUANT:
95
+ return obs.fake_quant(x)
96
+ return x # NO_QUANT
97
+
98
+ @abstractmethod
99
+ def _all_observers(self) -> Iterable[ObserverBase]:
100
+ """Return every observer owned by this module."""
101
+ ...
102
+
103
+ def named_observers(self) -> Iterable[Tuple[str, ObserverBase]]:
104
+ for obs in self._all_observers():
105
+ yield obs.name, obs
106
+
107
+ def get_observer(self, name: str) -> Optional[ObserverBase]:
108
+ for obs in self._all_observers():
109
+ if obs.name == name:
110
+ return obs
111
+ return None
112
+
113
+ def _make_obs(
114
+ self,
115
+ name: str,
116
+ **default_kwargs,
117
+ ) -> ObserverBase:
118
+ """
119
+ Instantiate an observer named *name*.
120
+
121
+ Precedence (3-tier) for keys:
122
+ • observer: user > wrapper-default > PTQConfig.default_observer
123
+ • dtype: user > wrapper-default > PTQConfig.default_dtype
124
+ • qscheme: user > wrapper-default > PTQConfig.default_qscheme
125
+
126
+ Other kwargs (e.g., qscheme, channel_axis, etc.) remain:
127
+ user override > wrapper-default
128
+ """
129
+ _UNSPEC = object()
130
+
131
+ wrapper_defaults = default_kwargs.copy()
132
+ user_cfg = self.qcfg.get_kwargs(name).copy()
133
+
134
+ def pick3(user_val, wrap_val, global_val):
135
+ return (
136
+ user_val
137
+ if user_val is not _UNSPEC
138
+ else wrap_val
139
+ if wrap_val is not _UNSPEC
140
+ else global_val
141
+ )
142
+
143
+ # 1) resolve observer class
144
+ user_observer = user_cfg.pop("observer", _UNSPEC)
145
+ wrapper_observer = wrapper_defaults.pop("observer", _UNSPEC)
146
+ obs_cls = pick3(user_observer, wrapper_observer, self.qcfg.default_observer)
147
+
148
+ # 2) resolve dtype
149
+ user_dtype = user_cfg.pop("dtype", _UNSPEC)
150
+ wrapper_dtype = wrapper_defaults.pop("dtype", _UNSPEC)
151
+ final_dtype = pick3(user_dtype, wrapper_dtype, self.qcfg.default_dtype)
152
+
153
+ # 3) resolve qscheme
154
+ user_qscheme = user_cfg.pop("qscheme", _UNSPEC)
155
+ wrapper_qscheme = wrapper_defaults.pop("qscheme", _UNSPEC)
156
+ final_qscheme = pick3(user_qscheme, wrapper_qscheme, self.qcfg.default_qscheme)
157
+
158
+ # 4) merge remaining kwargs: user_cfg wins
159
+ final_kw = wrapper_defaults
160
+ final_kw.update(user_cfg)
161
+ final_kw["dtype"] = final_dtype
162
+ final_kw["qscheme"] = final_qscheme
163
+
164
+ return obs_cls(**final_kw, name=name)
165
+
166
+ # nice repr
167
+ def extra_repr(self) -> str:
168
+ return f"mode={self._mode.name.lower()}"
@@ -0,0 +1,125 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ from typing import Callable, Dict, Type
17
+
18
+ import torch.nn as nn
19
+
20
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
21
+
22
+ _WRAPPERS: Dict[Type[nn.Module], Type[QuantModuleBase]] = {}
23
+ _IMPORT_ONCE = False
24
+ _CORE_MODULES = (
25
+ "tico.quantization.wrapq.wrappers.quant_elementwise",
26
+ "tico.quantization.wrapq.wrappers.nn.quant_layernorm",
27
+ "tico.quantization.wrapq.wrappers.nn.quant_linear",
28
+ "tico.quantization.wrapq.wrappers.nn.quant_silu",
29
+ # llama
30
+ "tico.quantization.wrapq.wrappers.llama.quant_attn",
31
+ "tico.quantization.wrapq.wrappers.llama.quant_decoder_layer",
32
+ "tico.quantization.wrapq.wrappers.llama.quant_mlp",
33
+ # fairseq
34
+ "tico.quantization.wrapq.wrappers.fairseq.quant_decoder_layer",
35
+ "tico.quantization.wrapq.wrappers.fairseq.quant_encoder",
36
+ "tico.quantization.wrapq.wrappers.fairseq.quant_encoder_layer",
37
+ "tico.quantization.wrapq.wrappers.fairseq.quant_mha",
38
+ # add future core wrappers here
39
+ )
40
+
41
+
42
+ def _lazy_init():
43
+ """
44
+ Deferred one-shot import of "core wrapper modules".
45
+
46
+ Why not import everything when the program first starts?
47
+ --------------------------------------------------
48
+ * **Avoid circular-import hell**
49
+ Core wrappers often import `PTQWrapper`, which in turn calls
50
+ `registry.lookup()`. Importing those files eagerly here would create a
51
+ cycle (`registry → wrapper → registry`). Delaying the import until the
52
+ *first* `lookup()` call lets Python finish constructing the registry
53
+ module before any wrapper files are touched.
54
+
55
+ * **Cold-start speed**
56
+ Most user code never wraps layers explicitly; they only hit
57
+ `PTQWrapper` if they are doing quantization. Deferring half-a-dozen
58
+ heavyweight `import torch …` files until they are really needed
59
+ reduces library start-up latency in the common path.
60
+
61
+ * **Optional dependencies**
62
+ Core wrappers listed in `_CORE_MODULES` are chosen to be dependency-free
63
+ (pure PyTorch). Anything that needs `transformers`, `torchvision`,
64
+ etc. uses the `@try_register()` decorator inside its own module. Those
65
+ optional modules are *not* imported here, so users without the extra
66
+ packages still get a clean import.
67
+
68
+ Implementation notes
69
+ --------------------
70
+ * `_IMPORT_ONCE` guard ensures we execute the import loop only once,
71
+ even if `lookup()` is called from multiple threads.
72
+ * Each path in `_CORE_MODULES` is a "fully-qualified module string"
73
+ (e.g. "ptq.wrappers.linear_quant"). Importing the module runs all
74
+ its `@register(nn.Layer)` decorators, populating `_WRAPPERS`.
75
+ * After the first call the function becomes a cheap constant-time no-op.
76
+ """
77
+ global _IMPORT_ONCE
78
+ if _IMPORT_ONCE:
79
+ return
80
+ for mod in _CORE_MODULES:
81
+ __import__(mod) # triggers decorators
82
+ _IMPORT_ONCE = True
83
+
84
+
85
+ # ───────────────────────────── decorator for always-present classes
86
+ def register(
87
+ fp_cls: Type[nn.Module],
88
+ ) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
89
+ def _decorator(quant_cls: Type[QuantModuleBase]):
90
+ _WRAPPERS[fp_cls] = quant_cls
91
+ return quant_cls
92
+
93
+ return _decorator
94
+
95
+
96
+ # ───────────────────────────── conditional decorator
97
+ def try_register(
98
+ *paths: str,
99
+ ) -> Callable[[Type[QuantModuleBase]], Type[QuantModuleBase]]:
100
+ """
101
+ @try_register("transformers.models.llama.modeling_llama.LlamaMLP")
102
+
103
+ • If import succeeds → behave like `@register`
104
+ • If module/class not found → become a NO-OP
105
+ """
106
+
107
+ def _decorator(quant_cls: Type[QuantModuleBase]):
108
+ for path in paths:
109
+ module_name, _, cls_name = path.rpartition(".")
110
+ try:
111
+ mod = importlib.import_module(module_name)
112
+ fp_cls = getattr(mod, cls_name)
113
+ _WRAPPERS[fp_cls] = quant_cls
114
+ except (ModuleNotFoundError, AttributeError):
115
+ # optional dep missing or class renamed – skip silently
116
+ pass
117
+ return quant_cls
118
+
119
+ return _decorator
120
+
121
+
122
+ # ───────────────────────────── lookup
123
+ def lookup(fp_cls: Type[nn.Module]) -> Type[QuantModuleBase] | None:
124
+ _lazy_init()
125
+ return _WRAPPERS.get(fp_cls)
@@ -20,6 +20,7 @@ import torch
20
20
  from circle_schema import circle
21
21
  from torch.export.exported_program import ConstantArgument, ExportedProgram, InputKind
22
22
 
23
+ from tico.config import CompileConfigBase, get_default_config
23
24
  from tico.serialize.circle_mapping import to_circle_dtype, to_circle_shape
24
25
  from tico.serialize.operators import *
25
26
  from tico.serialize.circle_graph import CircleModel, CircleSubgraph
@@ -47,7 +48,9 @@ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
47
48
  return model, graph
48
49
 
49
50
 
50
- def build_circle(ep: ExportedProgram) -> bytes:
51
+ def build_circle(
52
+ ep: ExportedProgram, config: CompileConfigBase = get_default_config()
53
+ ) -> bytes:
51
54
  """Convert ExportedProgram to Circle format.
52
55
 
53
56
  Args:
@@ -68,9 +71,13 @@ def build_circle(ep: ExportedProgram) -> bytes:
68
71
  for in_spec in ep.graph_signature.input_specs:
69
72
  if in_spec.kind != InputKind.USER_INPUT:
70
73
  continue
71
- # NoneType ConstantArgument is ignored.
72
- if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
73
- continue
74
+ if isinstance(in_spec.arg, ConstantArgument):
75
+ # ConstantArgument is ignored when option is given
76
+ if config.get("remove_constant_input"):
77
+ continue
78
+ # NoneType ConstantArgument is ignored.
79
+ if in_spec.arg.value == None:
80
+ continue
74
81
  arg_name = in_spec.arg.name
75
82
  graph.add_input(arg_name)
76
83
  logger.debug(f"Registered input: {arg_name}")
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from contextlib import contextmanager
16
+
17
+ import torch
18
+
19
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
20
+
21
+
22
+ def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor):
23
+ return torch.ops.circle_custom.rms_norm(
24
+ hidden_states, self.weight, self.variance_epsilon
25
+ )
26
+
27
+
28
+ @contextmanager
29
+ def patched_llama_rmsnorm():
30
+ orig = LlamaRMSNorm.forward
31
+ LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter
32
+ try:
33
+ yield
34
+ finally:
35
+ LlamaRMSNorm.forward = orig
@@ -28,6 +28,42 @@ from tico.utils.errors import InvalidArgumentError
28
28
  from tico.utils.validate_args_kwargs import ConstantPadNdArgs
29
29
 
30
30
 
31
+ def convert_to_circle_padding(pad, input_shape_len):
32
+ MAX_RANK = 4
33
+
34
+ if not (1 <= input_shape_len <= MAX_RANK):
35
+ raise InvalidArgumentError(
36
+ f"Input rank must be between 1 and {MAX_RANK}, got {input_shape_len}"
37
+ )
38
+
39
+ if len(pad) % 2 != 0 or len(pad) < 2 or len(pad) > 8:
40
+ raise InvalidArgumentError(
41
+ f"Pad length must be an even number between 2 and 8, got {len(pad)}"
42
+ )
43
+
44
+ if len(pad) == 2:
45
+ padding = [[pad[0], pad[1]]]
46
+ elif len(pad) == 4:
47
+ padding = [[pad[2], pad[3]], [pad[0], pad[1]]]
48
+ elif len(pad) == 6:
49
+ padding = [[pad[4], pad[5]], [pad[2], pad[3]], [pad[0], pad[1]]]
50
+ elif len(pad) == 8:
51
+ padding = [
52
+ [pad[6], pad[7]],
53
+ [pad[4], pad[5]],
54
+ [pad[2], pad[3]],
55
+ [pad[0], pad[1]],
56
+ ]
57
+ else:
58
+ assert False, "Cannot reach here"
59
+
60
+ # Fill [0, 0] padding for the rest of dimension
61
+ while len(padding) < input_shape_len:
62
+ padding.insert(0, [0, 0])
63
+
64
+ return padding
65
+
66
+
31
67
  @register_node_visitor
32
68
  class ConstantPadNdVisitor(NodeVisitor):
33
69
  target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
@@ -45,19 +81,13 @@ class ConstantPadNdVisitor(NodeVisitor):
45
81
  val = args.value
46
82
 
47
83
  if val != 0:
48
- raise InvalidArgumentError("Only support 0 value padding.")
84
+ raise InvalidArgumentError(f"Only support 0 value padding. pad:{pad}")
49
85
 
50
86
  input_shape_len = len(extract_shape(input_))
51
- padding_size = [[pad[2], pad[3]], [pad[0], pad[1]]]
52
- if input_shape_len == 3:
53
- padding_size = [[0, 0]] + padding_size
54
- elif input_shape_len == 4:
55
- padding_size = [[0, 0], [0, 0]] + padding_size
56
- else:
57
- raise InvalidArgumentError("Only support 3D/4D inputs.")
58
-
59
- paddings = torch.tensor(padding_size, dtype=torch.int32)
60
- inputs = [input_, paddings]
87
+
88
+ padding = convert_to_circle_padding(pad, input_shape_len)
89
+
90
+ inputs = [input_, torch.tensor(padding, dtype=torch.int32)]
61
91
  outputs = [node]
62
92
 
63
93
  op_index = get_op_index(
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LeArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LeVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.le.Scalar,
34
+ torch.ops.aten.le.Tensor,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
41
+ args = LeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
42
+ input = args.input
43
+ other = args.other
44
+
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.LESS_EQUAL, self._op_codes
47
+ )
48
+
49
+ inputs = [input, other]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator