tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tico might be problematic. Click here for more details.

Files changed (133) hide show
  1. tico/__init__.py +1 -1
  2. tico/config/v1.py +5 -0
  3. tico/passes/cast_mixed_type_args.py +2 -0
  4. tico/passes/convert_expand_to_slice_cat.py +153 -0
  5. tico/passes/convert_matmul_to_linear.py +312 -0
  6. tico/passes/convert_to_relu6.py +1 -1
  7. tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
  8. tico/passes/ops.py +0 -1
  9. tico/passes/remove_redundant_expand.py +3 -1
  10. tico/quantization/__init__.py +6 -0
  11. tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
  12. tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
  13. tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
  14. tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
  15. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
  16. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
  17. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
  18. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
  19. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
  20. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
  21. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
  22. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
  23. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
  24. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
  25. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
  26. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
  27. tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
  28. tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
  29. tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
  30. tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
  31. tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
  32. tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
  33. tico/quantization/config/base.py +26 -0
  34. tico/quantization/config/fpi_gptq.py +29 -0
  35. tico/quantization/config/gptq.py +29 -0
  36. tico/quantization/config/pt2e.py +25 -0
  37. tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
  38. tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
  39. tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
  40. tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
  41. tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
  42. tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
  43. tico/{experimental/quantization → quantization}/public_interface.py +11 -18
  44. tico/{experimental/quantization → quantization}/quantizer.py +1 -1
  45. tico/quantization/quantizer_registry.py +73 -0
  46. tico/quantization/wrapq/examples/compare_ppl.py +230 -0
  47. tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
  48. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
  49. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
  50. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
  51. tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
  52. tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
  53. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
  54. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
  55. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
  56. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
  57. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
  58. tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
  59. tico/quantization/wrapq/quantizer.py +179 -0
  60. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
  61. tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
  62. tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
  63. tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
  64. tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
  65. tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
  66. tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
  67. tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
  68. tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
  69. tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
  70. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
  71. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
  72. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
  73. tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
  74. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
  75. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
  76. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
  77. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
  78. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
  79. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
  80. tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
  81. tico/serialize/circle_serializer.py +11 -4
  82. tico/serialize/operators/op_constant_pad_nd.py +41 -11
  83. tico/serialize/operators/op_le.py +54 -0
  84. tico/serialize/operators/op_mm.py +15 -132
  85. tico/utils/convert.py +20 -15
  86. tico/utils/register_custom_op.py +6 -4
  87. tico/utils/signature.py +7 -8
  88. tico/utils/validate_args_kwargs.py +12 -0
  89. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
  90. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
  91. tico/experimental/quantization/__init__.py +0 -6
  92. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
  93. tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
  94. tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
  95. tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
  96. /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
  97. /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
  98. /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
  99. /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
  100. /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
  101. /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
  102. /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
  103. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
  104. /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
  105. /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
  106. /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
  107. /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
  108. /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
  109. /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
  110. /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
  111. /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
  112. /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
  113. /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
  114. /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
  115. /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
  116. /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
  117. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
  118. /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
  119. /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
  120. /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
  121. /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
  122. /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
  123. /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
  124. /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
  125. /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
  126. /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
  127. /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
  128. /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
  129. /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
  130. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
  131. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
  132. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
  133. {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
@@ -17,18 +17,17 @@ from typing import Optional
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
22
- QuantModuleBase,
23
- )
24
- from tico.experimental.quantization.ptq.wrappers.registry import register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22
+ from tico.quantization.wrapq.wrappers.registry import try_register
25
23
 
26
24
 
27
- @register(nn.SiLU)
25
+ @try_register("torch.nn.SiLU", "transformers.activations.SiLUActivation")
28
26
  class QuantSiLU(QuantModuleBase):
29
27
  """
30
- QuantSiLU — drop-in replacement for nn.SiLU that quantizes
31
- both intermediate tensors:
28
+ QuantSiLU — drop-in quantized implementation of the SiLU operation.
29
+
30
+ This module quantizes both intermediate tensors:
32
31
  • s = sigmoid(x) (logistic)
33
32
  • y = x * s (mul)
34
33
  """
@@ -37,7 +36,7 @@ class QuantSiLU(QuantModuleBase):
37
36
  self,
38
37
  fp: nn.SiLU,
39
38
  *,
40
- qcfg: Optional[QuantConfig] = None,
39
+ qcfg: Optional[PTQConfig] = None,
41
40
  fp_name: Optional[str] = None
42
41
  ):
43
42
  super().__init__(qcfg, fp_name=fp_name)
@@ -16,11 +16,9 @@ from typing import Optional
16
16
 
17
17
  import torch
18
18
 
19
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
20
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
21
- QuantModuleBase,
22
- )
23
- from tico.experimental.quantization.ptq.wrappers.registry import lookup
19
+ from tico.quantization.config.ptq import PTQConfig
20
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
21
+ from tico.quantization.wrapq.wrappers.registry import lookup
24
22
 
25
23
 
26
24
  class PTQWrapper(QuantModuleBase):
@@ -34,7 +32,7 @@ class PTQWrapper(QuantModuleBase):
34
32
  def __init__(
35
33
  self,
36
34
  module: torch.nn.Module,
37
- qcfg: Optional[QuantConfig] = None,
35
+ qcfg: Optional[PTQConfig] = None,
38
36
  *,
39
37
  fp_name: Optional[str] = None,
40
38
  ):
@@ -12,16 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Callable, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
21
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
22
- QuantModuleBase,
23
- )
24
- from tico.experimental.quantization.ptq.wrappers.registry import register
20
+ from tico.quantization.config.ptq import PTQConfig
21
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22
+ from tico.quantization.wrapq.wrappers.registry import register
25
23
 
26
24
 
27
25
  class QuantElementwise(QuantModuleBase):
@@ -33,7 +31,7 @@ class QuantElementwise(QuantModuleBase):
33
31
  """
34
32
 
35
33
  # subclass must set this
36
- FUNC: Callable[[torch.Tensor], torch.Tensor] | None = None
34
+ FUNC: Any = None
37
35
 
38
36
  def __init_subclass__(cls, **kwargs):
39
37
  super().__init_subclass__(**kwargs)
@@ -48,7 +46,7 @@ class QuantElementwise(QuantModuleBase):
48
46
  self,
49
47
  fp_module: nn.Module,
50
48
  *,
51
- qcfg: Optional[QuantConfig] = None,
49
+ qcfg: Optional[PTQConfig] = None,
52
50
  fp_name: Optional[str] = None,
53
51
  ):
54
52
  super().__init__(qcfg, fp_name=fp_name)
@@ -70,7 +68,7 @@ class QuantElementwise(QuantModuleBase):
70
68
 
71
69
 
72
70
  """
73
- Why `FUNC` is a `staticmethod`
71
+ Q1) Why `FUNC` is a `staticmethod`
74
72
 
75
73
  - Prevents automatic binding: calling `self.FUNC(x)` will not inject `self`,
76
74
  so the callable keeps the expected signature `Tensor -> Tensor`
@@ -87,27 +85,67 @@ Why `FUNC` is a `staticmethod`
87
85
  than an `nn.Module` instance that would appear in the module tree.
88
86
 
89
87
  - Small perf/alloc win: no bound-method objects are created on each call.
88
+
89
+ Q2) Why we define small Python wrappers (_relu, _tanh, etc.)
90
+
91
+ - torch.relu / torch.tanh / torch.sigmoid are CPython built-ins.
92
+ Their type is `builtin_function_or_method`, not a Python `FunctionType`.
93
+ This causes `torch.export` (and FX tracing) to fail with:
94
+ "expected FunctionType, found builtin_function_or_method".
95
+
96
+ - By defining a thin Python wrapper (e.g., `def _tanh(x): return torch.tanh(x)`),
97
+ we convert it into a normal Python function object (`FunctionType`),
98
+ which satisfies export/tracing requirements.
99
+
100
+ - Functionally, this adds zero overhead and preserves semantics,
101
+ but makes the callable introspectable (has __code__, __name__, etc.)
102
+ and compatible with TorchDynamo / FX graph capture.
103
+
104
+ - It also keeps FUNC pure and stateless, ensuring the elementwise op
105
+ is represented as `call_function(_tanh)` in the traced graph
106
+ rather than a bound `call_method` or module attribute access.
90
107
  """
91
108
 
92
- # Sigmoid
109
+
110
+ def _relu(x: torch.Tensor) -> torch.Tensor:
111
+ return torch.relu(x)
112
+
113
+
114
+ def _tanh(x: torch.Tensor) -> torch.Tensor:
115
+ return torch.tanh(x)
116
+
117
+
118
+ def _sigmoid(x: torch.Tensor) -> torch.Tensor:
119
+ return torch.sigmoid(x)
120
+
121
+
122
+ def _gelu(x: torch.Tensor) -> torch.Tensor:
123
+ return torch.nn.functional.gelu(x)
124
+
125
+
93
126
  @register(nn.Sigmoid)
94
127
  class QuantSigmoid(QuantElementwise):
95
- FUNC = staticmethod(torch.sigmoid)
128
+ @staticmethod
129
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
130
+ return _sigmoid(x)
96
131
 
97
132
 
98
- # Tanh
99
133
  @register(nn.Tanh)
100
134
  class QuantTanh(QuantElementwise):
101
- FUNC = staticmethod(torch.tanh)
135
+ @staticmethod
136
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
137
+ return _tanh(x)
102
138
 
103
139
 
104
- # ReLU
105
140
  @register(nn.ReLU)
106
141
  class QuantReLU(QuantElementwise):
107
- FUNC = staticmethod(torch.relu)
142
+ @staticmethod
143
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
144
+ return _relu(x)
108
145
 
109
146
 
110
- # GELU (approximate)
111
147
  @register(nn.GELU)
112
148
  class QuantGELU(QuantElementwise):
113
- FUNC = staticmethod(torch.nn.functional.gelu)
149
+ @staticmethod
150
+ def FUNC(x: torch.Tensor) -> torch.Tensor:
151
+ return _gelu(x)
@@ -17,9 +17,10 @@ from typing import Iterable, Optional, Tuple
17
17
 
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.mode import Mode
21
- from tico.experimental.quantization.ptq.observers.base import ObserverBase
22
- from tico.experimental.quantization.ptq.quant_config import QuantConfig
20
+ from tico.quantization.config.ptq import PTQConfig
21
+
22
+ from tico.quantization.wrapq.mode import Mode
23
+ from tico.quantization.wrapq.observers.base import ObserverBase
23
24
 
24
25
 
25
26
  class QuantModuleBase(nn.Module, ABC):
@@ -29,7 +30,7 @@ class QuantModuleBase(nn.Module, ABC):
29
30
  Responsibilities
30
31
  ----------------
31
32
  • Own *one* Mode enum (`NO_QUANT / CALIB / QUANT`)
32
- • Own a QuantConfig describing default / per-observer dtypes
33
+ • Own a PTQConfig describing default / per-observer dtypes
33
34
  • Expose a canonical lifecycle:
34
35
  enable_calibration()
35
36
  freeze_qparams()
@@ -38,10 +39,10 @@ class QuantModuleBase(nn.Module, ABC):
38
39
  """
39
40
 
40
41
  def __init__(
41
- self, qcfg: Optional[QuantConfig] = None, *, fp_name: Optional[str] = None
42
+ self, qcfg: Optional[PTQConfig] = None, *, fp_name: Optional[str] = None
42
43
  ) -> None:
43
44
  super().__init__()
44
- self.qcfg = qcfg or QuantConfig()
45
+ self.qcfg = qcfg or PTQConfig()
45
46
  self._mode: Mode = Mode.NO_QUANT # default state
46
47
  self.fp_name = fp_name
47
48
 
@@ -118,9 +119,9 @@ class QuantModuleBase(nn.Module, ABC):
118
119
  Instantiate an observer named *name*.
119
120
 
120
121
  Precedence (3-tier) for keys:
121
- • observer: user > wrapper-default > QuantConfig.default_observer
122
- • dtype: user > wrapper-default > QuantConfig.default_dtype
123
- • qscheme: user > wrapper-default > QuantConfig.default_qscheme
122
+ • observer: user > wrapper-default > PTQConfig.default_observer
123
+ • dtype: user > wrapper-default > PTQConfig.default_dtype
124
+ • qscheme: user > wrapper-default > PTQConfig.default_qscheme
124
125
 
125
126
  Other kwargs (e.g., qscheme, channel_axis, etc.) remain:
126
127
  user override > wrapper-default
@@ -17,20 +17,27 @@ from typing import Callable, Dict, Type
17
17
 
18
18
  import torch.nn as nn
19
19
 
20
- from tico.experimental.quantization.ptq.wrappers.quant_module_base import (
21
- QuantModuleBase,
22
- )
20
+ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
23
21
 
24
22
  _WRAPPERS: Dict[Type[nn.Module], Type[QuantModuleBase]] = {}
25
23
  _IMPORT_ONCE = False
26
24
  _CORE_MODULES = (
27
- "tico.experimental.quantization.ptq.wrappers.quant_elementwise",
28
- "tico.experimental.quantization.ptq.wrappers.nn.quant_layernorm",
29
- "tico.experimental.quantization.ptq.wrappers.nn.quant_linear",
30
- "tico.experimental.quantization.ptq.wrappers.nn.quant_silu",
31
- "tico.experimental.quantization.ptq.wrappers.llama.quant_attn",
32
- "tico.experimental.quantization.ptq.wrappers.llama.quant_decoder_layer",
33
- "tico.experimental.quantization.ptq.wrappers.llama.quant_mlp",
25
+ "tico.quantization.wrapq.wrappers.quant_elementwise",
26
+ ## nn ##
27
+ "tico.quantization.wrapq.wrappers.nn.quant_layernorm",
28
+ "tico.quantization.wrapq.wrappers.nn.quant_linear",
29
+ # This includes not only `nn.SiLU` but also `SiLUActivation` from transformers
30
+ # as they are same operation.
31
+ "tico.quantization.wrapq.wrappers.nn.quant_silu",
32
+ ## llama ##
33
+ "tico.quantization.wrapq.wrappers.llama.quant_attn",
34
+ "tico.quantization.wrapq.wrappers.llama.quant_decoder_layer",
35
+ "tico.quantization.wrapq.wrappers.llama.quant_mlp",
36
+ ## fairseq ##
37
+ "tico.quantization.wrapq.wrappers.fairseq.quant_decoder_layer",
38
+ "tico.quantization.wrapq.wrappers.fairseq.quant_encoder",
39
+ "tico.quantization.wrapq.wrappers.fairseq.quant_encoder_layer",
40
+ "tico.quantization.wrapq.wrappers.fairseq.quant_mha",
34
41
  # add future core wrappers here
35
42
  )
36
43
 
@@ -20,6 +20,7 @@ import torch
20
20
  from circle_schema import circle
21
21
  from torch.export.exported_program import ConstantArgument, ExportedProgram, InputKind
22
22
 
23
+ from tico.config import CompileConfigBase, get_default_config
23
24
  from tico.serialize.circle_mapping import to_circle_dtype, to_circle_shape
24
25
  from tico.serialize.operators import *
25
26
  from tico.serialize.circle_graph import CircleModel, CircleSubgraph
@@ -47,7 +48,9 @@ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
47
48
  return model, graph
48
49
 
49
50
 
50
- def build_circle(ep: ExportedProgram) -> bytes:
51
+ def build_circle(
52
+ ep: ExportedProgram, config: CompileConfigBase = get_default_config()
53
+ ) -> bytes:
51
54
  """Convert ExportedProgram to Circle format.
52
55
 
53
56
  Args:
@@ -68,9 +71,13 @@ def build_circle(ep: ExportedProgram) -> bytes:
68
71
  for in_spec in ep.graph_signature.input_specs:
69
72
  if in_spec.kind != InputKind.USER_INPUT:
70
73
  continue
71
- # NoneType ConstantArgument is ignored.
72
- if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
73
- continue
74
+ if isinstance(in_spec.arg, ConstantArgument):
75
+ # ConstantArgument is ignored when option is given
76
+ if config.get("remove_constant_input"):
77
+ continue
78
+ # NoneType ConstantArgument is ignored.
79
+ if in_spec.arg.value == None:
80
+ continue
74
81
  arg_name = in_spec.arg.name
75
82
  graph.add_input(arg_name)
76
83
  logger.debug(f"Registered input: {arg_name}")
@@ -28,6 +28,42 @@ from tico.utils.errors import InvalidArgumentError
28
28
  from tico.utils.validate_args_kwargs import ConstantPadNdArgs
29
29
 
30
30
 
31
+ def convert_to_circle_padding(pad, input_shape_len):
32
+ MAX_RANK = 4
33
+
34
+ if not (1 <= input_shape_len <= MAX_RANK):
35
+ raise InvalidArgumentError(
36
+ f"Input rank must be between 1 and {MAX_RANK}, got {input_shape_len}"
37
+ )
38
+
39
+ if len(pad) % 2 != 0 or len(pad) < 2 or len(pad) > 8:
40
+ raise InvalidArgumentError(
41
+ f"Pad length must be an even number between 2 and 8, got {len(pad)}"
42
+ )
43
+
44
+ if len(pad) == 2:
45
+ padding = [[pad[0], pad[1]]]
46
+ elif len(pad) == 4:
47
+ padding = [[pad[2], pad[3]], [pad[0], pad[1]]]
48
+ elif len(pad) == 6:
49
+ padding = [[pad[4], pad[5]], [pad[2], pad[3]], [pad[0], pad[1]]]
50
+ elif len(pad) == 8:
51
+ padding = [
52
+ [pad[6], pad[7]],
53
+ [pad[4], pad[5]],
54
+ [pad[2], pad[3]],
55
+ [pad[0], pad[1]],
56
+ ]
57
+ else:
58
+ assert False, "Cannot reach here"
59
+
60
+ # Fill [0, 0] padding for the rest of dimension
61
+ while len(padding) < input_shape_len:
62
+ padding.insert(0, [0, 0])
63
+
64
+ return padding
65
+
66
+
31
67
  @register_node_visitor
32
68
  class ConstantPadNdVisitor(NodeVisitor):
33
69
  target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
@@ -45,19 +81,13 @@ class ConstantPadNdVisitor(NodeVisitor):
45
81
  val = args.value
46
82
 
47
83
  if val != 0:
48
- raise InvalidArgumentError("Only support 0 value padding.")
84
+ raise InvalidArgumentError(f"Only support 0 value padding. pad:{pad}")
49
85
 
50
86
  input_shape_len = len(extract_shape(input_))
51
- padding_size = [[pad[2], pad[3]], [pad[0], pad[1]]]
52
- if input_shape_len == 3:
53
- padding_size = [[0, 0]] + padding_size
54
- elif input_shape_len == 4:
55
- padding_size = [[0, 0], [0, 0]] + padding_size
56
- else:
57
- raise InvalidArgumentError("Only support 3D/4D inputs.")
58
-
59
- paddings = torch.tensor(padding_size, dtype=torch.int32)
60
- inputs = [input_, paddings]
87
+
88
+ padding = convert_to_circle_padding(pad, input_shape_len)
89
+
90
+ inputs = [input_, torch.tensor(padding, dtype=torch.int32)]
61
91
  outputs = [node]
62
92
 
63
93
  op_index = get_op_index(
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import LeArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class LeVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten.le.Scalar,
34
+ torch.ops.aten.le.Tensor,
35
+ ]
36
+
37
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
38
+ super().__init__(op_codes, graph)
39
+
40
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
41
+ args = LeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
42
+ input = args.input
43
+ other = args.other
44
+
45
+ op_index = get_op_index(
46
+ circle.BuiltinOperator.BuiltinOperator.LESS_EQUAL, self._op_codes
47
+ )
48
+
49
+ inputs = [input, other]
50
+ outputs = [node]
51
+
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+
54
+ return operator
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
20
20
  import torch
21
21
  from circle_schema import circle
22
22
 
23
- from tico.serialize.circle_graph import CircleSubgraph, is_const
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
24
  from tico.serialize.operators.hashable_opcode import OpCode
25
25
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
26
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -28,9 +28,9 @@ from tico.utils.validate_args_kwargs import MatmulArgs
28
28
 
29
29
 
30
30
  @register_node_visitor
31
- class MatmulDefaultVisitor(NodeVisitor):
31
+ class MatmulVisitor(NodeVisitor):
32
32
  """
33
- Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose.
33
+ Convert matmul to Circle BatchMatMul
34
34
  """
35
35
 
36
36
  target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
@@ -38,131 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor):
38
38
  def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
39
39
  super().__init__(op_codes, graph)
40
40
 
41
- # NOTE: Matmul is equivalent to Batch MatMul (batch=1)
42
- def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
43
- def set_bmm_option(operator):
44
- operator.builtinOptionsType = (
45
- circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
46
- )
47
- option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
48
- option.adjointLhs, option.adjointRhs = False, False
49
- option.asymmetricQuantizeInputs = False
50
- operator.builtinOptions = option
51
-
52
- op_index = get_op_index(
53
- circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
54
- )
55
- operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
56
- set_bmm_option(operator)
57
-
58
- return operator
59
-
60
- def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
61
- def set_transpose_option(operator):
62
- operator.builtinOptionsType = (
63
- circle.BuiltinOptions.BuiltinOptions.TransposeOptions
64
- )
65
- option = circle.TransposeOptions.TransposeOptionsT()
66
- operator.builtinOptions = option
67
-
68
- transpose_op_index = get_op_index(
69
- circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
70
- )
71
- operator = create_builtin_operator(
72
- self.graph, transpose_op_index, inputs, outputs
73
- )
74
- set_transpose_option(operator)
75
- return operator
76
-
77
- def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
78
- def set_fc_option(operator):
79
- operator.builtinOptionsType = (
80
- circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
81
- )
82
- option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
83
-
84
- option.fusedActivationFunction = (
85
- circle.ActivationFunctionType.ActivationFunctionType.NONE
86
- )
87
- option.weightsFormat = (
88
- circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
89
- )
90
- option.keepNumDims = False
91
- option.asymmetricQuantizeInputs = False
92
- option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
93
-
94
- operator.builtinOptions = option
95
-
96
- fc_op_index = get_op_index(
97
- circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
98
- )
99
- operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
100
- set_fc_option(operator)
101
- return operator
102
-
103
- """
104
- Define FullyConnnected with Tranpose operator.
105
- Note that those sets of operators are equivalent.
106
- (1) Matmul
107
- matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
108
-
109
- (2) Transpose + FullyConneccted
110
- transpose( rhs[K, W'] ) -> trs_output[W', K]
111
- fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
112
- """
113
-
114
- def define_fc_with_transpose(
115
- self, node, inputs, outputs
116
- ) -> circle.Operator.OperatorT:
117
- lhs, rhs = inputs
118
-
119
- # get transpose shape
120
- rhs_tid: int = self.graph.get_tid_registered(rhs)
121
- rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
122
- rhs_name: str = rhs.name
123
- rhs_type: int = rhs_tensor.type
124
- rhs_shape: List[int] = rhs_tensor.shape
125
- assert len(rhs_shape) == 2, len(rhs_shape)
126
- rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
127
-
128
- # create transpose output tensor
129
- trs_output = self.graph.add_tensor_from_scratch(
130
- prefix=f"{rhs_name}_transposed_output",
131
- shape=rhs_shape_transpose,
132
- shape_signature=None,
133
- dtype=rhs_type,
134
- source_node=node,
135
- )
136
- trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
137
- trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
138
- self.graph.add_operator(trs_operator)
139
-
140
- # define fc node
141
- fc_input = lhs
142
- fc_weight = trs_output
143
- fc_shape = [fc_weight.shape[0]]
144
- fc_bias = self.graph.add_const_tensor(
145
- data=[0.0] * fc_shape[0], source_node=node
146
- )
147
-
148
- operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
149
-
150
- return operator
151
-
152
- def define_node(
153
- self, node: torch.fx.Node, prior_latency=True
154
- ) -> circle.Operator.OperatorT:
155
- """
156
- NOTE: Possibility of accuracy-latency trade-off
157
- From ONE compiler's perspective:
158
- - BMM uses per-tensor quantization for both rhs and lhs.
159
- - FC uses per-channel quantization for weight and per-tensor for input.
160
- Thus, FC is better in terms of accuracy.
161
- FC necessarily involves an additional transpose operation to be identical with mm.
162
- If transposed operand is const, it can be optimized by constant folding.
163
- Thus, convert FC only if tranpose can be folded.
164
- TODO set prior_latency outside
165
- """
41
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
166
42
  args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
167
43
  input = args.input
168
44
  other = args.other
@@ -170,9 +46,16 @@ class MatmulDefaultVisitor(NodeVisitor):
170
46
  inputs = [input, other]
171
47
  outputs = [node]
172
48
 
173
- if not is_const(other) and prior_latency:
174
- operator = self.define_bmm_node(inputs, outputs)
175
- else:
176
- operator = self.define_fc_with_transpose(node, inputs, outputs)
49
+ op_index = get_op_index(
50
+ circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
51
+ )
52
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
53
+ operator.builtinOptionsType = (
54
+ circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
55
+ )
56
+ option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
57
+ option.adjointLhs, option.adjointRhs = False, False
58
+ option.asymmetricQuantizeInputs = False
59
+ operator.builtinOptions = option
177
60
 
178
61
  return operator
tico/utils/convert.py CHANGED
@@ -20,26 +20,14 @@ import torch
20
20
  from torch.export import export, ExportedProgram
21
21
 
22
22
  from tico.config import CompileConfigBase, get_default_config
23
- from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
24
- from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
25
- InsertQuantizeOnDtypeMismatch,
26
- )
27
- from tico.experimental.quantization.passes.propagate_qparam_backward import (
28
- PropagateQParamBackward,
29
- )
30
- from tico.experimental.quantization.passes.propagate_qparam_forward import (
31
- PropagateQParamForward,
32
- )
33
- from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
34
- from tico.experimental.quantization.passes.remove_weight_dequant_op import (
35
- RemoveWeightDequantOp,
36
- )
37
23
  from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
38
24
  from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
39
25
  from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
40
26
  from tico.passes.const_prop_pass import ConstPropPass
41
27
  from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
28
+ from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
42
29
  from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
30
+ from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
43
31
  from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
44
32
  from tico.passes.convert_to_relu6 import ConvertToReLU6
45
33
  from tico.passes.decompose_addmm import DecomposeAddmm
@@ -72,6 +60,14 @@ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
72
60
  from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
73
61
  from tico.passes.restore_linear import RestoreLinear
74
62
  from tico.passes.segment_index_select import SegmentIndexSelectConst
63
+ from tico.quantization.passes.fold_quant_ops import FoldQuantOps
64
+ from tico.quantization.passes.insert_quantize_on_dtype_mismatch import (
65
+ InsertQuantizeOnDtypeMismatch,
66
+ )
67
+ from tico.quantization.passes.propagate_qparam_backward import PropagateQParamBackward
68
+ from tico.quantization.passes.propagate_qparam_forward import PropagateQParamForward
69
+ from tico.quantization.passes.quantize_bias import QuantizeBias
70
+ from tico.quantization.passes.remove_weight_dequant_op import RemoveWeightDequantOp
75
71
  from tico.serialize.circle_serializer import build_circle
76
72
  from tico.serialize.operators.node_visitor import get_support_targets
77
73
  from tico.utils import logging
@@ -141,6 +137,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
141
137
  or torch.__version__.startswith("2.7")
142
138
  or torch.__version__.startswith("2.8")
143
139
  or torch.__version__.startswith("2.9")
140
+ or torch.__version__.startswith("2.10")
144
141
  ):
145
142
  return run_decompositions(exported_program)
146
143
  else:
@@ -249,6 +246,14 @@ def convert_exported_module_to_circle(
249
246
  ConstPropPass(),
250
247
  SegmentIndexSelectConst(),
251
248
  LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
249
+ ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
250
+ ConvertMatmulToLinear(
251
+ enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
252
+ enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
253
+ enable_single_batch_lhs_const_bmm=config.get(
254
+ "convert_single_batch_lhs_const_bmm_to_fc"
255
+ ),
256
+ ),
252
257
  LowerToResizeNearestNeighbor(),
253
258
  LegalizePreDefinedLayoutOperators(),
254
259
  LowerPow2ToMul(),
@@ -287,7 +292,7 @@ def convert_exported_module_to_circle(
287
292
 
288
293
  check_unsupported_target(exported_program)
289
294
  check_training_ops(exported_program)
290
- circle_program = build_circle(exported_program)
295
+ circle_program = build_circle(exported_program, config)
291
296
 
292
297
  return circle_program
293
298