tico 0.1.0.dev250904__py3-none-any.whl → 0.1.0.dev251109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tico might be problematic. Click here for more details.
- tico/__init__.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_mixed_type_args.py +2 -0
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_fake_quantize_tensor_qparams.py +4 -3
- tico/passes/ops.py +0 -1
- tico/passes/remove_redundant_expand.py +3 -1
- tico/quantization/__init__.py +6 -0
- tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +161 -0
- tico/quantization/algorithm/fpi_gptq/quantizer.py +179 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +24 -3
- tico/{experimental/quantization → quantization}/algorithm/gptq/quantizer.py +14 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +6 -8
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/fpi_gptq.py +29 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/{experimental/quantization/ptq/quant_config.py → quantization/config/ptq.py} +18 -10
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -37
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +6 -12
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +11 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_linear.py +11 -10
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_attn.py +10 -12
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_decoder_layer.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/examples/quantize_llama_mlp.py +13 -13
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/affine_base.py +3 -3
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/base.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/ema.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/identity.py +1 -1
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/minmax.py +2 -2
- tico/{experimental/quantization/ptq → quantization/wrapq}/observers/mx.py +1 -1
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/introspection.py +3 -5
- tico/{experimental/quantization/ptq → quantization/wrapq}/utils/metrics.py +3 -2
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_attn.py +58 -21
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_decoder_layer.py +21 -13
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/llama/quant_mlp.py +5 -7
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_layernorm.py +6 -7
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_linear.py +7 -8
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/nn/quant_silu.py +8 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/ptq_wrapper.py +4 -6
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_elementwise.py +55 -17
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/quant_module_base.py +10 -9
- tico/{experimental/quantization/ptq → quantization/wrapq}/wrappers/registry.py +17 -10
- tico/serialize/circle_serializer.py +11 -4
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_mm.py +15 -132
- tico/utils/convert.py +20 -15
- tico/utils/register_custom_op.py +6 -4
- tico/utils/signature.py +7 -8
- tico/utils/validate_args_kwargs.py +12 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/METADATA +48 -2
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/RECORD +128 -108
- tico/experimental/quantization/__init__.py +0 -6
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/ptq/examples/compare_ppl.py +0 -121
- tico/experimental/quantization/ptq/examples/debug_quant_outputs.py +0 -129
- tico/experimental/quantization/ptq/examples/quantize_with_gptq.py +0 -165
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/gptq → quantization/algorithm/fpi_gptq}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e → quantization/algorithm/gptq}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/annotation → quantization/algorithm/pt2e}/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/pt2e/transformation → quantization/algorithm/pt2e/annotation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization/algorithm/smoothquant → quantization/algorithm/pt2e/transformation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/algorithm/smoothquant}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/metric.py +0 -0
- /tico/{experimental/quantization/ptq/examples → quantization/passes}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +0 -0
- /tico/{experimental/quantization/ptq/observers → quantization/wrapq}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/dtypes.py +0 -0
- /tico/{experimental/quantization/ptq/utils → quantization/wrapq/examples}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/mode.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers → quantization/wrapq/observers}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/qscheme.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/llama → quantization/wrapq/utils}/__init__.py +0 -0
- /tico/{experimental/quantization/ptq → quantization/wrapq}/utils/reduce_utils.py +0 -0
- /tico/{experimental/quantization/ptq/wrappers/nn → quantization/wrapq/wrappers}/__init__.py +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250904.dist-info → tico-0.1.0.dev251109.dist-info}/top_level.txt +0 -0
|
@@ -17,18 +17,17 @@ from typing import Optional
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
|
|
23
|
-
)
|
|
24
|
-
from tico.experimental.quantization.ptq.wrappers.registry import register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
22
|
+
from tico.quantization.wrapq.wrappers.registry import try_register
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
@
|
|
25
|
+
@try_register("torch.nn.SiLU", "transformers.activations.SiLUActivation")
|
|
28
26
|
class QuantSiLU(QuantModuleBase):
|
|
29
27
|
"""
|
|
30
|
-
QuantSiLU — drop-in
|
|
31
|
-
|
|
28
|
+
QuantSiLU — drop-in quantized implementation of the SiLU operation.
|
|
29
|
+
|
|
30
|
+
This module quantizes both intermediate tensors:
|
|
32
31
|
• s = sigmoid(x) (logistic)
|
|
33
32
|
• y = x * s (mul)
|
|
34
33
|
"""
|
|
@@ -37,7 +36,7 @@ class QuantSiLU(QuantModuleBase):
|
|
|
37
36
|
self,
|
|
38
37
|
fp: nn.SiLU,
|
|
39
38
|
*,
|
|
40
|
-
qcfg: Optional[
|
|
39
|
+
qcfg: Optional[PTQConfig] = None,
|
|
41
40
|
fp_name: Optional[str] = None
|
|
42
41
|
):
|
|
43
42
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -16,11 +16,9 @@ from typing import Optional
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from tico.
|
|
20
|
-
from tico.
|
|
21
|
-
|
|
22
|
-
)
|
|
23
|
-
from tico.experimental.quantization.ptq.wrappers.registry import lookup
|
|
19
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
20
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
21
|
+
from tico.quantization.wrapq.wrappers.registry import lookup
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
class PTQWrapper(QuantModuleBase):
|
|
@@ -34,7 +32,7 @@ class PTQWrapper(QuantModuleBase):
|
|
|
34
32
|
def __init__(
|
|
35
33
|
self,
|
|
36
34
|
module: torch.nn.Module,
|
|
37
|
-
qcfg: Optional[
|
|
35
|
+
qcfg: Optional[PTQConfig] = None,
|
|
38
36
|
*,
|
|
39
37
|
fp_name: Optional[str] = None,
|
|
40
38
|
):
|
|
@@ -12,16 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
from tico.
|
|
22
|
-
|
|
23
|
-
)
|
|
24
|
-
from tico.experimental.quantization.ptq.wrappers.registry import register
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
22
|
+
from tico.quantization.wrapq.wrappers.registry import register
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
class QuantElementwise(QuantModuleBase):
|
|
@@ -33,7 +31,7 @@ class QuantElementwise(QuantModuleBase):
|
|
|
33
31
|
"""
|
|
34
32
|
|
|
35
33
|
# subclass must set this
|
|
36
|
-
FUNC:
|
|
34
|
+
FUNC: Any = None
|
|
37
35
|
|
|
38
36
|
def __init_subclass__(cls, **kwargs):
|
|
39
37
|
super().__init_subclass__(**kwargs)
|
|
@@ -48,7 +46,7 @@ class QuantElementwise(QuantModuleBase):
|
|
|
48
46
|
self,
|
|
49
47
|
fp_module: nn.Module,
|
|
50
48
|
*,
|
|
51
|
-
qcfg: Optional[
|
|
49
|
+
qcfg: Optional[PTQConfig] = None,
|
|
52
50
|
fp_name: Optional[str] = None,
|
|
53
51
|
):
|
|
54
52
|
super().__init__(qcfg, fp_name=fp_name)
|
|
@@ -70,7 +68,7 @@ class QuantElementwise(QuantModuleBase):
|
|
|
70
68
|
|
|
71
69
|
|
|
72
70
|
"""
|
|
73
|
-
Why `FUNC` is a `staticmethod`
|
|
71
|
+
Q1) Why `FUNC` is a `staticmethod`
|
|
74
72
|
|
|
75
73
|
- Prevents automatic binding: calling `self.FUNC(x)` will not inject `self`,
|
|
76
74
|
so the callable keeps the expected signature `Tensor -> Tensor`
|
|
@@ -87,27 +85,67 @@ Why `FUNC` is a `staticmethod`
|
|
|
87
85
|
than an `nn.Module` instance that would appear in the module tree.
|
|
88
86
|
|
|
89
87
|
- Small perf/alloc win: no bound-method objects are created on each call.
|
|
88
|
+
|
|
89
|
+
Q2) Why we define small Python wrappers (_relu, _tanh, etc.)
|
|
90
|
+
|
|
91
|
+
- torch.relu / torch.tanh / torch.sigmoid are CPython built-ins.
|
|
92
|
+
Their type is `builtin_function_or_method`, not a Python `FunctionType`.
|
|
93
|
+
This causes `torch.export` (and FX tracing) to fail with:
|
|
94
|
+
"expected FunctionType, found builtin_function_or_method".
|
|
95
|
+
|
|
96
|
+
- By defining a thin Python wrapper (e.g., `def _tanh(x): return torch.tanh(x)`),
|
|
97
|
+
we convert it into a normal Python function object (`FunctionType`),
|
|
98
|
+
which satisfies export/tracing requirements.
|
|
99
|
+
|
|
100
|
+
- Functionally, this adds zero overhead and preserves semantics,
|
|
101
|
+
but makes the callable introspectable (has __code__, __name__, etc.)
|
|
102
|
+
and compatible with TorchDynamo / FX graph capture.
|
|
103
|
+
|
|
104
|
+
- It also keeps FUNC pure and stateless, ensuring the elementwise op
|
|
105
|
+
is represented as `call_function(_tanh)` in the traced graph
|
|
106
|
+
rather than a bound `call_method` or module attribute access.
|
|
90
107
|
"""
|
|
91
108
|
|
|
92
|
-
|
|
109
|
+
|
|
110
|
+
def _relu(x: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
return torch.relu(x)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _tanh(x: torch.Tensor) -> torch.Tensor:
|
|
115
|
+
return torch.tanh(x)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _sigmoid(x: torch.Tensor) -> torch.Tensor:
|
|
119
|
+
return torch.sigmoid(x)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _gelu(x: torch.Tensor) -> torch.Tensor:
|
|
123
|
+
return torch.nn.functional.gelu(x)
|
|
124
|
+
|
|
125
|
+
|
|
93
126
|
@register(nn.Sigmoid)
|
|
94
127
|
class QuantSigmoid(QuantElementwise):
|
|
95
|
-
|
|
128
|
+
@staticmethod
|
|
129
|
+
def FUNC(x: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
return _sigmoid(x)
|
|
96
131
|
|
|
97
132
|
|
|
98
|
-
# Tanh
|
|
99
133
|
@register(nn.Tanh)
|
|
100
134
|
class QuantTanh(QuantElementwise):
|
|
101
|
-
|
|
135
|
+
@staticmethod
|
|
136
|
+
def FUNC(x: torch.Tensor) -> torch.Tensor:
|
|
137
|
+
return _tanh(x)
|
|
102
138
|
|
|
103
139
|
|
|
104
|
-
# ReLU
|
|
105
140
|
@register(nn.ReLU)
|
|
106
141
|
class QuantReLU(QuantElementwise):
|
|
107
|
-
|
|
142
|
+
@staticmethod
|
|
143
|
+
def FUNC(x: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
return _relu(x)
|
|
108
145
|
|
|
109
146
|
|
|
110
|
-
# GELU (approximate)
|
|
111
147
|
@register(nn.GELU)
|
|
112
148
|
class QuantGELU(QuantElementwise):
|
|
113
|
-
|
|
149
|
+
@staticmethod
|
|
150
|
+
def FUNC(x: torch.Tensor) -> torch.Tensor:
|
|
151
|
+
return _gelu(x)
|
|
@@ -17,9 +17,10 @@ from typing import Iterable, Optional, Tuple
|
|
|
17
17
|
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
|
|
22
|
-
from tico.
|
|
20
|
+
from tico.quantization.config.ptq import PTQConfig
|
|
21
|
+
|
|
22
|
+
from tico.quantization.wrapq.mode import Mode
|
|
23
|
+
from tico.quantization.wrapq.observers.base import ObserverBase
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class QuantModuleBase(nn.Module, ABC):
|
|
@@ -29,7 +30,7 @@ class QuantModuleBase(nn.Module, ABC):
|
|
|
29
30
|
Responsibilities
|
|
30
31
|
----------------
|
|
31
32
|
• Own *one* Mode enum (`NO_QUANT / CALIB / QUANT`)
|
|
32
|
-
• Own a
|
|
33
|
+
• Own a PTQConfig describing default / per-observer dtypes
|
|
33
34
|
• Expose a canonical lifecycle:
|
|
34
35
|
enable_calibration()
|
|
35
36
|
freeze_qparams()
|
|
@@ -38,10 +39,10 @@ class QuantModuleBase(nn.Module, ABC):
|
|
|
38
39
|
"""
|
|
39
40
|
|
|
40
41
|
def __init__(
|
|
41
|
-
self, qcfg: Optional[
|
|
42
|
+
self, qcfg: Optional[PTQConfig] = None, *, fp_name: Optional[str] = None
|
|
42
43
|
) -> None:
|
|
43
44
|
super().__init__()
|
|
44
|
-
self.qcfg = qcfg or
|
|
45
|
+
self.qcfg = qcfg or PTQConfig()
|
|
45
46
|
self._mode: Mode = Mode.NO_QUANT # default state
|
|
46
47
|
self.fp_name = fp_name
|
|
47
48
|
|
|
@@ -118,9 +119,9 @@ class QuantModuleBase(nn.Module, ABC):
|
|
|
118
119
|
Instantiate an observer named *name*.
|
|
119
120
|
|
|
120
121
|
Precedence (3-tier) for keys:
|
|
121
|
-
• observer: user > wrapper-default >
|
|
122
|
-
• dtype: user > wrapper-default >
|
|
123
|
-
• qscheme: user > wrapper-default >
|
|
122
|
+
• observer: user > wrapper-default > PTQConfig.default_observer
|
|
123
|
+
• dtype: user > wrapper-default > PTQConfig.default_dtype
|
|
124
|
+
• qscheme: user > wrapper-default > PTQConfig.default_qscheme
|
|
124
125
|
|
|
125
126
|
Other kwargs (e.g., qscheme, channel_axis, etc.) remain:
|
|
126
127
|
user override > wrapper-default
|
|
@@ -17,20 +17,27 @@ from typing import Callable, Dict, Type
|
|
|
17
17
|
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
|
|
20
|
-
from tico.
|
|
21
|
-
QuantModuleBase,
|
|
22
|
-
)
|
|
20
|
+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
|
|
23
21
|
|
|
24
22
|
_WRAPPERS: Dict[Type[nn.Module], Type[QuantModuleBase]] = {}
|
|
25
23
|
_IMPORT_ONCE = False
|
|
26
24
|
_CORE_MODULES = (
|
|
27
|
-
"tico.
|
|
28
|
-
|
|
29
|
-
"tico.
|
|
30
|
-
"tico.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
"tico.
|
|
25
|
+
"tico.quantization.wrapq.wrappers.quant_elementwise",
|
|
26
|
+
## nn ##
|
|
27
|
+
"tico.quantization.wrapq.wrappers.nn.quant_layernorm",
|
|
28
|
+
"tico.quantization.wrapq.wrappers.nn.quant_linear",
|
|
29
|
+
# This includes not only `nn.SiLU` but also `SiLUActivation` from transformers
|
|
30
|
+
# as they are same operation.
|
|
31
|
+
"tico.quantization.wrapq.wrappers.nn.quant_silu",
|
|
32
|
+
## llama ##
|
|
33
|
+
"tico.quantization.wrapq.wrappers.llama.quant_attn",
|
|
34
|
+
"tico.quantization.wrapq.wrappers.llama.quant_decoder_layer",
|
|
35
|
+
"tico.quantization.wrapq.wrappers.llama.quant_mlp",
|
|
36
|
+
## fairseq ##
|
|
37
|
+
"tico.quantization.wrapq.wrappers.fairseq.quant_decoder_layer",
|
|
38
|
+
"tico.quantization.wrapq.wrappers.fairseq.quant_encoder",
|
|
39
|
+
"tico.quantization.wrapq.wrappers.fairseq.quant_encoder_layer",
|
|
40
|
+
"tico.quantization.wrapq.wrappers.fairseq.quant_mha",
|
|
34
41
|
# add future core wrappers here
|
|
35
42
|
)
|
|
36
43
|
|
|
@@ -20,6 +20,7 @@ import torch
|
|
|
20
20
|
from circle_schema import circle
|
|
21
21
|
from torch.export.exported_program import ConstantArgument, ExportedProgram, InputKind
|
|
22
22
|
|
|
23
|
+
from tico.config import CompileConfigBase, get_default_config
|
|
23
24
|
from tico.serialize.circle_mapping import to_circle_dtype, to_circle_shape
|
|
24
25
|
from tico.serialize.operators import *
|
|
25
26
|
from tico.serialize.circle_graph import CircleModel, CircleSubgraph
|
|
@@ -47,7 +48,9 @@ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
|
|
|
47
48
|
return model, graph
|
|
48
49
|
|
|
49
50
|
|
|
50
|
-
def build_circle(
|
|
51
|
+
def build_circle(
|
|
52
|
+
ep: ExportedProgram, config: CompileConfigBase = get_default_config()
|
|
53
|
+
) -> bytes:
|
|
51
54
|
"""Convert ExportedProgram to Circle format.
|
|
52
55
|
|
|
53
56
|
Args:
|
|
@@ -68,9 +71,13 @@ def build_circle(ep: ExportedProgram) -> bytes:
|
|
|
68
71
|
for in_spec in ep.graph_signature.input_specs:
|
|
69
72
|
if in_spec.kind != InputKind.USER_INPUT:
|
|
70
73
|
continue
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
+
if isinstance(in_spec.arg, ConstantArgument):
|
|
75
|
+
# ConstantArgument is ignored when option is given
|
|
76
|
+
if config.get("remove_constant_input"):
|
|
77
|
+
continue
|
|
78
|
+
# NoneType ConstantArgument is ignored.
|
|
79
|
+
if in_spec.arg.value == None:
|
|
80
|
+
continue
|
|
74
81
|
arg_name = in_spec.arg.name
|
|
75
82
|
graph.add_input(arg_name)
|
|
76
83
|
logger.debug(f"Registered input: {arg_name}")
|
|
@@ -28,6 +28,42 @@ from tico.utils.errors import InvalidArgumentError
|
|
|
28
28
|
from tico.utils.validate_args_kwargs import ConstantPadNdArgs
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
def convert_to_circle_padding(pad, input_shape_len):
|
|
32
|
+
MAX_RANK = 4
|
|
33
|
+
|
|
34
|
+
if not (1 <= input_shape_len <= MAX_RANK):
|
|
35
|
+
raise InvalidArgumentError(
|
|
36
|
+
f"Input rank must be between 1 and {MAX_RANK}, got {input_shape_len}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if len(pad) % 2 != 0 or len(pad) < 2 or len(pad) > 8:
|
|
40
|
+
raise InvalidArgumentError(
|
|
41
|
+
f"Pad length must be an even number between 2 and 8, got {len(pad)}"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if len(pad) == 2:
|
|
45
|
+
padding = [[pad[0], pad[1]]]
|
|
46
|
+
elif len(pad) == 4:
|
|
47
|
+
padding = [[pad[2], pad[3]], [pad[0], pad[1]]]
|
|
48
|
+
elif len(pad) == 6:
|
|
49
|
+
padding = [[pad[4], pad[5]], [pad[2], pad[3]], [pad[0], pad[1]]]
|
|
50
|
+
elif len(pad) == 8:
|
|
51
|
+
padding = [
|
|
52
|
+
[pad[6], pad[7]],
|
|
53
|
+
[pad[4], pad[5]],
|
|
54
|
+
[pad[2], pad[3]],
|
|
55
|
+
[pad[0], pad[1]],
|
|
56
|
+
]
|
|
57
|
+
else:
|
|
58
|
+
assert False, "Cannot reach here"
|
|
59
|
+
|
|
60
|
+
# Fill [0, 0] padding for the rest of dimension
|
|
61
|
+
while len(padding) < input_shape_len:
|
|
62
|
+
padding.insert(0, [0, 0])
|
|
63
|
+
|
|
64
|
+
return padding
|
|
65
|
+
|
|
66
|
+
|
|
31
67
|
@register_node_visitor
|
|
32
68
|
class ConstantPadNdVisitor(NodeVisitor):
|
|
33
69
|
target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
|
|
@@ -45,19 +81,13 @@ class ConstantPadNdVisitor(NodeVisitor):
|
|
|
45
81
|
val = args.value
|
|
46
82
|
|
|
47
83
|
if val != 0:
|
|
48
|
-
raise InvalidArgumentError("Only support 0 value padding.")
|
|
84
|
+
raise InvalidArgumentError(f"Only support 0 value padding. pad:{pad}")
|
|
49
85
|
|
|
50
86
|
input_shape_len = len(extract_shape(input_))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
padding_size = [[0, 0], [0, 0]] + padding_size
|
|
56
|
-
else:
|
|
57
|
-
raise InvalidArgumentError("Only support 3D/4D inputs.")
|
|
58
|
-
|
|
59
|
-
paddings = torch.tensor(padding_size, dtype=torch.int32)
|
|
60
|
-
inputs = [input_, paddings]
|
|
87
|
+
|
|
88
|
+
padding = convert_to_circle_padding(pad, input_shape_len)
|
|
89
|
+
|
|
90
|
+
inputs = [input_, torch.tensor(padding, dtype=torch.int32)]
|
|
61
91
|
outputs = [node]
|
|
62
92
|
|
|
63
93
|
op_index = get_op_index(
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch._ops
|
|
19
|
+
import torch.fx
|
|
20
|
+
import torch
|
|
21
|
+
from circle_schema import circle
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
27
|
+
from tico.utils.validate_args_kwargs import LeArgs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_node_visitor
|
|
31
|
+
class LeVisitor(NodeVisitor):
|
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
|
33
|
+
torch.ops.aten.le.Scalar,
|
|
34
|
+
torch.ops.aten.le.Tensor,
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
38
|
+
super().__init__(op_codes, graph)
|
|
39
|
+
|
|
40
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
|
41
|
+
args = LeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
42
|
+
input = args.input
|
|
43
|
+
other = args.other
|
|
44
|
+
|
|
45
|
+
op_index = get_op_index(
|
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.LESS_EQUAL, self._op_codes
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
inputs = [input, other]
|
|
50
|
+
outputs = [node]
|
|
51
|
+
|
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
53
|
+
|
|
54
|
+
return operator
|
|
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|
|
20
20
|
import torch
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
|
-
from tico.serialize.circle_graph import CircleSubgraph
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
24
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
25
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
26
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
@@ -28,9 +28,9 @@ from tico.utils.validate_args_kwargs import MatmulArgs
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@register_node_visitor
|
|
31
|
-
class
|
|
31
|
+
class MatmulVisitor(NodeVisitor):
|
|
32
32
|
"""
|
|
33
|
-
Convert matmul to
|
|
33
|
+
Convert matmul to Circle BatchMatMul
|
|
34
34
|
"""
|
|
35
35
|
|
|
36
36
|
target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
|
|
@@ -38,131 +38,7 @@ class MatmulDefaultVisitor(NodeVisitor):
|
|
|
38
38
|
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
39
39
|
super().__init__(op_codes, graph)
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
43
|
-
def set_bmm_option(operator):
|
|
44
|
-
operator.builtinOptionsType = (
|
|
45
|
-
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
|
46
|
-
)
|
|
47
|
-
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
|
48
|
-
option.adjointLhs, option.adjointRhs = False, False
|
|
49
|
-
option.asymmetricQuantizeInputs = False
|
|
50
|
-
operator.builtinOptions = option
|
|
51
|
-
|
|
52
|
-
op_index = get_op_index(
|
|
53
|
-
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
|
54
|
-
)
|
|
55
|
-
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
56
|
-
set_bmm_option(operator)
|
|
57
|
-
|
|
58
|
-
return operator
|
|
59
|
-
|
|
60
|
-
def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
61
|
-
def set_transpose_option(operator):
|
|
62
|
-
operator.builtinOptionsType = (
|
|
63
|
-
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
|
64
|
-
)
|
|
65
|
-
option = circle.TransposeOptions.TransposeOptionsT()
|
|
66
|
-
operator.builtinOptions = option
|
|
67
|
-
|
|
68
|
-
transpose_op_index = get_op_index(
|
|
69
|
-
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
|
70
|
-
)
|
|
71
|
-
operator = create_builtin_operator(
|
|
72
|
-
self.graph, transpose_op_index, inputs, outputs
|
|
73
|
-
)
|
|
74
|
-
set_transpose_option(operator)
|
|
75
|
-
return operator
|
|
76
|
-
|
|
77
|
-
def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
|
78
|
-
def set_fc_option(operator):
|
|
79
|
-
operator.builtinOptionsType = (
|
|
80
|
-
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
|
|
81
|
-
)
|
|
82
|
-
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
|
|
83
|
-
|
|
84
|
-
option.fusedActivationFunction = (
|
|
85
|
-
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
|
86
|
-
)
|
|
87
|
-
option.weightsFormat = (
|
|
88
|
-
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
|
|
89
|
-
)
|
|
90
|
-
option.keepNumDims = False
|
|
91
|
-
option.asymmetricQuantizeInputs = False
|
|
92
|
-
option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
|
|
93
|
-
|
|
94
|
-
operator.builtinOptions = option
|
|
95
|
-
|
|
96
|
-
fc_op_index = get_op_index(
|
|
97
|
-
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
|
|
98
|
-
)
|
|
99
|
-
operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
|
|
100
|
-
set_fc_option(operator)
|
|
101
|
-
return operator
|
|
102
|
-
|
|
103
|
-
"""
|
|
104
|
-
Define FullyConnnected with Tranpose operator.
|
|
105
|
-
Note that those sets of operators are equivalent.
|
|
106
|
-
(1) Matmul
|
|
107
|
-
matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
|
|
108
|
-
|
|
109
|
-
(2) Transpose + FullyConneccted
|
|
110
|
-
transpose( rhs[K, W'] ) -> trs_output[W', K]
|
|
111
|
-
fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
|
|
112
|
-
"""
|
|
113
|
-
|
|
114
|
-
def define_fc_with_transpose(
|
|
115
|
-
self, node, inputs, outputs
|
|
116
|
-
) -> circle.Operator.OperatorT:
|
|
117
|
-
lhs, rhs = inputs
|
|
118
|
-
|
|
119
|
-
# get transpose shape
|
|
120
|
-
rhs_tid: int = self.graph.get_tid_registered(rhs)
|
|
121
|
-
rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
|
|
122
|
-
rhs_name: str = rhs.name
|
|
123
|
-
rhs_type: int = rhs_tensor.type
|
|
124
|
-
rhs_shape: List[int] = rhs_tensor.shape
|
|
125
|
-
assert len(rhs_shape) == 2, len(rhs_shape)
|
|
126
|
-
rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
|
|
127
|
-
|
|
128
|
-
# create transpose output tensor
|
|
129
|
-
trs_output = self.graph.add_tensor_from_scratch(
|
|
130
|
-
prefix=f"{rhs_name}_transposed_output",
|
|
131
|
-
shape=rhs_shape_transpose,
|
|
132
|
-
shape_signature=None,
|
|
133
|
-
dtype=rhs_type,
|
|
134
|
-
source_node=node,
|
|
135
|
-
)
|
|
136
|
-
trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
|
|
137
|
-
trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
|
|
138
|
-
self.graph.add_operator(trs_operator)
|
|
139
|
-
|
|
140
|
-
# define fc node
|
|
141
|
-
fc_input = lhs
|
|
142
|
-
fc_weight = trs_output
|
|
143
|
-
fc_shape = [fc_weight.shape[0]]
|
|
144
|
-
fc_bias = self.graph.add_const_tensor(
|
|
145
|
-
data=[0.0] * fc_shape[0], source_node=node
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
|
|
149
|
-
|
|
150
|
-
return operator
|
|
151
|
-
|
|
152
|
-
def define_node(
|
|
153
|
-
self, node: torch.fx.Node, prior_latency=True
|
|
154
|
-
) -> circle.Operator.OperatorT:
|
|
155
|
-
"""
|
|
156
|
-
NOTE: Possibility of accuracy-latency trade-off
|
|
157
|
-
From ONE compiler's perspective:
|
|
158
|
-
- BMM uses per-tensor quantization for both rhs and lhs.
|
|
159
|
-
- FC uses per-channel quantization for weight and per-tensor for input.
|
|
160
|
-
Thus, FC is better in terms of accuracy.
|
|
161
|
-
FC necessarily involves an additional transpose operation to be identical with mm.
|
|
162
|
-
If transposed operand is const, it can be optimized by constant folding.
|
|
163
|
-
Thus, convert FC only if tranpose can be folded.
|
|
164
|
-
TODO set prior_latency outside
|
|
165
|
-
"""
|
|
41
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
|
166
42
|
args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
167
43
|
input = args.input
|
|
168
44
|
other = args.other
|
|
@@ -170,9 +46,16 @@ class MatmulDefaultVisitor(NodeVisitor):
|
|
|
170
46
|
inputs = [input, other]
|
|
171
47
|
outputs = [node]
|
|
172
48
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
49
|
+
op_index = get_op_index(
|
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
|
51
|
+
)
|
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
53
|
+
operator.builtinOptionsType = (
|
|
54
|
+
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
|
55
|
+
)
|
|
56
|
+
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
|
57
|
+
option.adjointLhs, option.adjointRhs = False, False
|
|
58
|
+
option.asymmetricQuantizeInputs = False
|
|
59
|
+
operator.builtinOptions = option
|
|
177
60
|
|
|
178
61
|
return operator
|
tico/utils/convert.py
CHANGED
|
@@ -20,26 +20,14 @@ import torch
|
|
|
20
20
|
from torch.export import export, ExportedProgram
|
|
21
21
|
|
|
22
22
|
from tico.config import CompileConfigBase, get_default_config
|
|
23
|
-
from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
|
|
24
|
-
from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
|
25
|
-
InsertQuantizeOnDtypeMismatch,
|
|
26
|
-
)
|
|
27
|
-
from tico.experimental.quantization.passes.propagate_qparam_backward import (
|
|
28
|
-
PropagateQParamBackward,
|
|
29
|
-
)
|
|
30
|
-
from tico.experimental.quantization.passes.propagate_qparam_forward import (
|
|
31
|
-
PropagateQParamForward,
|
|
32
|
-
)
|
|
33
|
-
from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
|
|
34
|
-
from tico.experimental.quantization.passes.remove_weight_dequant_op import (
|
|
35
|
-
RemoveWeightDequantOp,
|
|
36
|
-
)
|
|
37
23
|
from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
|
|
38
24
|
from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
|
|
39
25
|
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
|
40
26
|
from tico.passes.const_prop_pass import ConstPropPass
|
|
41
27
|
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
|
28
|
+
from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
|
|
42
29
|
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
|
|
30
|
+
from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
|
|
43
31
|
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
|
|
44
32
|
from tico.passes.convert_to_relu6 import ConvertToReLU6
|
|
45
33
|
from tico.passes.decompose_addmm import DecomposeAddmm
|
|
@@ -72,6 +60,14 @@ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
|
|
|
72
60
|
from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
|
|
73
61
|
from tico.passes.restore_linear import RestoreLinear
|
|
74
62
|
from tico.passes.segment_index_select import SegmentIndexSelectConst
|
|
63
|
+
from tico.quantization.passes.fold_quant_ops import FoldQuantOps
|
|
64
|
+
from tico.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
|
65
|
+
InsertQuantizeOnDtypeMismatch,
|
|
66
|
+
)
|
|
67
|
+
from tico.quantization.passes.propagate_qparam_backward import PropagateQParamBackward
|
|
68
|
+
from tico.quantization.passes.propagate_qparam_forward import PropagateQParamForward
|
|
69
|
+
from tico.quantization.passes.quantize_bias import QuantizeBias
|
|
70
|
+
from tico.quantization.passes.remove_weight_dequant_op import RemoveWeightDequantOp
|
|
75
71
|
from tico.serialize.circle_serializer import build_circle
|
|
76
72
|
from tico.serialize.operators.node_visitor import get_support_targets
|
|
77
73
|
from tico.utils import logging
|
|
@@ -141,6 +137,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
141
137
|
or torch.__version__.startswith("2.7")
|
|
142
138
|
or torch.__version__.startswith("2.8")
|
|
143
139
|
or torch.__version__.startswith("2.9")
|
|
140
|
+
or torch.__version__.startswith("2.10")
|
|
144
141
|
):
|
|
145
142
|
return run_decompositions(exported_program)
|
|
146
143
|
else:
|
|
@@ -249,6 +246,14 @@ def convert_exported_module_to_circle(
|
|
|
249
246
|
ConstPropPass(),
|
|
250
247
|
SegmentIndexSelectConst(),
|
|
251
248
|
LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
|
|
249
|
+
ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
|
|
250
|
+
ConvertMatmulToLinear(
|
|
251
|
+
enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
|
|
252
|
+
enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
|
|
253
|
+
enable_single_batch_lhs_const_bmm=config.get(
|
|
254
|
+
"convert_single_batch_lhs_const_bmm_to_fc"
|
|
255
|
+
),
|
|
256
|
+
),
|
|
252
257
|
LowerToResizeNearestNeighbor(),
|
|
253
258
|
LegalizePreDefinedLayoutOperators(),
|
|
254
259
|
LowerPow2ToMul(),
|
|
@@ -287,7 +292,7 @@ def convert_exported_module_to_circle(
|
|
|
287
292
|
|
|
288
293
|
check_unsupported_target(exported_program)
|
|
289
294
|
check_training_ops(exported_program)
|
|
290
|
-
circle_program = build_circle(exported_program)
|
|
295
|
+
circle_program = build_circle(exported_program, config)
|
|
291
296
|
|
|
292
297
|
return circle_program
|
|
293
298
|
|