tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, List, Tuple
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from circle_schema import circle
|
21
|
+
from numpy.typing import DTypeLike
|
22
|
+
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.model import CircleModel
|
25
|
+
|
26
|
+
|
27
|
+
def quantize(
|
28
|
+
data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
|
29
|
+
) -> np.ndarray:
|
30
|
+
"""
|
31
|
+
Quantize the given data using the specified scale, zero point, and data type.
|
32
|
+
This function takes input data and applies quantization using the formula:
|
33
|
+
round(data / scale) + zero_point
|
34
|
+
The result is clamped to the range of the specified data type.
|
35
|
+
"""
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
dtype = np.dtype(dtype)
|
38
|
+
assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
|
39
|
+
if dtype == np.int16:
|
40
|
+
assert zero_point == 0
|
41
|
+
|
42
|
+
# Convert input to Numpy array if necessary
|
43
|
+
if not isinstance(data, np.ndarray):
|
44
|
+
data = np.array(data)
|
45
|
+
# Perfrom quantization
|
46
|
+
if not scale:
|
47
|
+
logger.warn("WARNING: scale value is 0. 1e-7 will be used instead.")
|
48
|
+
scale = 1e-7
|
49
|
+
rescaled = np.round(data / scale) + zero_point
|
50
|
+
# Clamp the values
|
51
|
+
clipped = np.clip(rescaled, np.iinfo(dtype).min, np.iinfo(dtype).max)
|
52
|
+
# Convert to the specified dtype
|
53
|
+
return clipped.astype(dtype)
|
54
|
+
|
55
|
+
|
56
|
+
def dequantize(
|
57
|
+
data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
|
58
|
+
) -> np.ndarray:
|
59
|
+
"""
|
60
|
+
Dequantize the given quantized data using the specified scale and zero point.
|
61
|
+
This function reverses the quantization process by applying the formula:
|
62
|
+
(quantized_value - zero_point) * scale
|
63
|
+
"""
|
64
|
+
dtype = np.dtype(dtype)
|
65
|
+
assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
|
66
|
+
if dtype == np.int16:
|
67
|
+
assert zero_point == 0
|
68
|
+
|
69
|
+
# Convert input to Numpy array if necessary
|
70
|
+
if not isinstance(data, np.ndarray):
|
71
|
+
data = np.array(data)
|
72
|
+
# Perform dequantization
|
73
|
+
ret = (data.astype(np.float32) - zero_point) * scale
|
74
|
+
# np.float32 * np.int64 = np.float64
|
75
|
+
return ret.astype(np.float32)
|
76
|
+
|
77
|
+
|
78
|
+
def get_graph_input_output(
|
79
|
+
circle_model: CircleModel,
|
80
|
+
) -> Tuple[List[circle.Tensor.Tensor], List[circle.Tensor.Tensor]]:
|
81
|
+
"""
|
82
|
+
Retrieve the inputs and the outputs from the circle model, and return them
|
83
|
+
as two lists.
|
84
|
+
"""
|
85
|
+
circle_buf: bytes = circle_model.circle_binary
|
86
|
+
circle_fb: circle.Model.Model = circle.Model.Model.GetRootAs(circle_buf, 0)
|
87
|
+
assert circle_fb.SubgraphsLength() == 1, "Only support single graph."
|
88
|
+
circle_graph = circle_fb.Subgraphs(0)
|
89
|
+
circle_inputs: List[circle.Tensor.Tensor] = [
|
90
|
+
circle_graph.Tensors(circle_graph.Inputs(i))
|
91
|
+
for i in range(circle_graph.InputsLength())
|
92
|
+
]
|
93
|
+
circle_outputs: List[circle.Tensor.Tensor] = [
|
94
|
+
circle_graph.Tensors(circle_graph.Outputs(o))
|
95
|
+
for o in range(circle_graph.OutputsLength())
|
96
|
+
]
|
97
|
+
|
98
|
+
return circle_inputs, circle_outputs
|
99
|
+
|
100
|
+
|
101
|
+
def find_invalid_types(
|
102
|
+
input: List[torch.Tensor] | List[np.ndarray], allowed_types: List
|
103
|
+
) -> List:
|
104
|
+
"""
|
105
|
+
Indentifies the types of items in a list that are not allowed and removes duplicates.
|
106
|
+
|
107
|
+
Parameters
|
108
|
+
-----------
|
109
|
+
input
|
110
|
+
List of itmes to check.
|
111
|
+
allowed_types
|
112
|
+
List of allowed types (e.g. [int, str])
|
113
|
+
Returns
|
114
|
+
--------
|
115
|
+
A list of unique types that are not allowed in the input list.
|
116
|
+
"""
|
117
|
+
# Use set comprehension for uniqueness
|
118
|
+
invalid_types = {
|
119
|
+
type(item) for item in input if not isinstance(item, tuple(allowed_types))
|
120
|
+
}
|
121
|
+
return list(invalid_types)
|
122
|
+
|
123
|
+
|
124
|
+
def plot_two_outputs(x_values: torch.Tensor, y_values: torch.Tensor):
|
125
|
+
"""
|
126
|
+
Plot two values on a 2D graph using plotext.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
--------
|
130
|
+
A figure built from plotext.
|
131
|
+
|
132
|
+
Example
|
133
|
+
--------
|
134
|
+
>>> x_values = torch.tensor([1, 2, 3, 4, 5])
|
135
|
+
>>> y_values = torch.tensor([10, 20, 30, 40, 50])
|
136
|
+
>>> fig = plot_two_outputs(x_values, y_values)
|
137
|
+
>>> print(fig)
|
138
|
+
"""
|
139
|
+
x_np = x_values.numpy().reshape(-1)
|
140
|
+
y_np = y_values.numpy().reshape(-1)
|
141
|
+
min_value = min([x_np.min(), y_np.min()])
|
142
|
+
max_value = max([x_np.max(), y_np.max()])
|
143
|
+
|
144
|
+
interval = max_value - min_value
|
145
|
+
interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
|
146
|
+
|
147
|
+
# Enlarge axis
|
148
|
+
axis_min = min_value - interval * 0.05
|
149
|
+
axis_max = max_value + interval * 0.05
|
150
|
+
|
151
|
+
import plotext as plt
|
152
|
+
|
153
|
+
plt.clear_data()
|
154
|
+
plt.xlim(axis_min, axis_max)
|
155
|
+
plt.ylim(axis_min, axis_max)
|
156
|
+
plt.plotsize(width=50, height=25)
|
157
|
+
plt.scatter(x_np, y_np, marker="dot")
|
158
|
+
plt.theme("clear")
|
159
|
+
|
160
|
+
return plt.build()
|
161
|
+
|
162
|
+
|
163
|
+
def ensure_list(inputs: Any | Tuple[Any] | List[Any]) -> List[Any]:
|
164
|
+
"""
|
165
|
+
Ensures that the given inputs is converted into a list.
|
166
|
+
|
167
|
+
- If the input is a single element, it wraps it into a list.
|
168
|
+
- If the input is a tuple, it converts the tuple to a list.
|
169
|
+
- If the input is already a list, it returns the input unchanged.
|
170
|
+
|
171
|
+
Example
|
172
|
+
--------
|
173
|
+
>>> ensure_list(42)
|
174
|
+
>>> [42]
|
175
|
+
>>> ensure_list((1, 2, 3))
|
176
|
+
>>> [1, 2, 3]
|
177
|
+
>>> ensure_list([4, 5, 6])
|
178
|
+
>>> [4, 5, 6]
|
179
|
+
"""
|
180
|
+
if isinstance(inputs, list):
|
181
|
+
return inputs
|
182
|
+
elif isinstance(inputs, tuple):
|
183
|
+
return list(inputs)
|
184
|
+
else:
|
185
|
+
return [inputs]
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,154 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.utils import get_quant_dtype
|
29
|
+
from tico.utils.validate_args_kwargs import (
|
30
|
+
DequantizePerTensorArgs,
|
31
|
+
QuantizePerTensorArgs,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
@trace_graph_diff_on_pass
|
36
|
+
class FoldQuantOps(PassBase):
|
37
|
+
"""
|
38
|
+
This pass folds (Q - DQ) pattern to previous op. After quantization from torch, activation ops
|
39
|
+
have (op - Q - DQ) pattern.
|
40
|
+
|
41
|
+
To export quantized circle, this pass removes (Q - DQ) nodes and saves those quantization info
|
42
|
+
to previous op's metadata.
|
43
|
+
|
44
|
+
────────────────────────────────────────────────────────────────
|
45
|
+
BEFORE AFTER
|
46
|
+
────────────────────────────────────────────────────────────────
|
47
|
+
op(float) ─ Q ─ DQ ─ … op(float, meta[QPARAM])
|
48
|
+
|
49
|
+
op ─ Q1 ─ DQ1 ─ Q2 ─ DQ2 op(meta[QPARAM]) ─ Q2
|
50
|
+
▲ ▲
|
51
|
+
│ (Q1, DQ1 folded) │ (re-quantization kept)
|
52
|
+
|
53
|
+
op ─ Q ─┬─ DQ0 op(meta[QPARAM])
|
54
|
+
├─ DQ1 (each DQ* folded, Q dropped when orphaned)
|
55
|
+
└─ DQ2
|
56
|
+
────────────────────────────────────────────────────────────────
|
57
|
+
|
58
|
+
Algorithm
|
59
|
+
---------
|
60
|
+
1. Iterate over *all* Dequantize nodes.
|
61
|
+
2. For each DQ, verify it is driven by a Quantize node `q` and that
|
62
|
+
`q` and `dq` share identical (scale, zero-point, dtype).
|
63
|
+
3. a) If the producer op has **no** QPARAM, attach one, then replace
|
64
|
+
*this* DQ's usages with the producer op.
|
65
|
+
b) If the producer is already quantized with a different dtype,
|
66
|
+
this is a *re-quantization*: attach QPARAM to `q` and keep it,
|
67
|
+
but still remove the DQ.
|
68
|
+
4. After all replacements, run `graph.eliminate_dead_code()`.
|
69
|
+
Any Quantize that became orphaned because *all* its DQs were folded
|
70
|
+
is deleted automatically.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(self):
|
74
|
+
super().__init__()
|
75
|
+
|
76
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
77
|
+
logger = logging.getLogger(__name__)
|
78
|
+
|
79
|
+
graph_module = exported_program.graph_module
|
80
|
+
graph: torch.fx.Graph = graph_module.graph
|
81
|
+
for dq in graph.nodes:
|
82
|
+
if dq.op != "call_function":
|
83
|
+
continue
|
84
|
+
if (
|
85
|
+
dq.target
|
86
|
+
!= torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
87
|
+
):
|
88
|
+
continue
|
89
|
+
dq_args = DequantizePerTensorArgs(*dq.args, **dq.kwargs)
|
90
|
+
|
91
|
+
q = dq_args.input
|
92
|
+
if q.target != torch.ops.quantized_decomposed.quantize_per_tensor.default:
|
93
|
+
continue
|
94
|
+
q_args = QuantizePerTensorArgs(*q.args, **q.kwargs) # type: ignore[arg-type]
|
95
|
+
op = q_args.tensor
|
96
|
+
|
97
|
+
# Check if Q and DQ have same quant param
|
98
|
+
if q_args.scale != dq_args.scale:
|
99
|
+
continue
|
100
|
+
if q_args.zero_p != dq_args.zero_point:
|
101
|
+
continue
|
102
|
+
if q_args.dtype != dq_args.dtype:
|
103
|
+
continue
|
104
|
+
|
105
|
+
# ───────────────────────────────────────────
|
106
|
+
# Case 1: op not yet quantized
|
107
|
+
# ───────────────────────────────────────────
|
108
|
+
if QPARAM_KEY not in op.meta:
|
109
|
+
qparam = QuantParam()
|
110
|
+
qparam.scale = [q_args.scale]
|
111
|
+
qparam.zero_point = [q_args.zero_p]
|
112
|
+
qparam.dtype = get_quant_dtype(q_args.quant_min, q_args.quant_max)
|
113
|
+
op.meta[QPARAM_KEY] = qparam
|
114
|
+
|
115
|
+
dq.replace_all_uses_with(op, propagate_meta=False)
|
116
|
+
|
117
|
+
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
|
118
|
+
# ───────────────────────────────────────────
|
119
|
+
# Case 2: op already quantized
|
120
|
+
# 2.1 same dtype → nothing to do
|
121
|
+
# 2.2 diff dtype → leave Q in place
|
122
|
+
# ───────────────────────────────────────────
|
123
|
+
else:
|
124
|
+
op_qparam: QuantParam = op.meta[QPARAM_KEY]
|
125
|
+
qdq_dtype = get_quant_dtype(q_args.quant_min, q_args.quant_max)
|
126
|
+
|
127
|
+
if op_qparam.dtype != qdq_dtype:
|
128
|
+
# Attach QPARAM to Q once
|
129
|
+
if QPARAM_KEY not in q.meta:
|
130
|
+
qparam = QuantParam()
|
131
|
+
qparam.scale = [q_args.scale]
|
132
|
+
qparam.zero_point = [q_args.zero_p]
|
133
|
+
qparam.dtype = qdq_dtype
|
134
|
+
q.meta[QPARAM_KEY] = qparam
|
135
|
+
assert len(q.users) == 1, "Fix me unless"
|
136
|
+
|
137
|
+
dq.replace_all_uses_with(q, propagate_meta=False)
|
138
|
+
logger.debug(f"{dq.name} is folded ({q.name} is left).")
|
139
|
+
else:
|
140
|
+
# Same dtype → the Quantize–Dequantize pair is redundant.
|
141
|
+
assert op_qparam.scale and op_qparam.scale[0] == q_args.scale
|
142
|
+
assert (
|
143
|
+
op_qparam.zero_point
|
144
|
+
and op_qparam.zero_point[0] == q_args.zero_p
|
145
|
+
)
|
146
|
+
dq.replace_all_uses_with(op, propagate_meta=False)
|
147
|
+
logger.debug(f"Removed redundant {dq.name}")
|
148
|
+
|
149
|
+
graph.eliminate_dead_code()
|
150
|
+
graph.lint()
|
151
|
+
graph_module.recompile()
|
152
|
+
|
153
|
+
# Run only once.
|
154
|
+
return PassResult(False)
|
@@ -0,0 +1,345 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.errors import NotYetSupportedError
|
27
|
+
from tico.utils.graph import create_node
|
28
|
+
from tico.utils.passes import PassBase, PassResult
|
29
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
30
|
+
from tico.utils.utils import quant_min_max, set_new_meta_val
|
31
|
+
from tico.utils.validate_args_kwargs import (
|
32
|
+
BmmArgs,
|
33
|
+
LinearArgs,
|
34
|
+
MulTensorArgs,
|
35
|
+
PermuteArgs,
|
36
|
+
ReshapeArgs,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
def qparam_dtype(node: torch.fx.Node) -> str:
|
41
|
+
assert QPARAM_KEY in node.meta
|
42
|
+
return node.meta[QPARAM_KEY].dtype
|
43
|
+
|
44
|
+
|
45
|
+
# Convert i16 qparam to u8 qparam
|
46
|
+
# scale and zero_point are inferred from i16 qparam
|
47
|
+
def _i16_to_u8(qparam: QuantParam) -> QuantParam:
|
48
|
+
# Assume per-tensor quantization
|
49
|
+
assert qparam.scale is not None and len(qparam.scale) == 1
|
50
|
+
assert qparam.dtype == "int16"
|
51
|
+
|
52
|
+
s16_scale = qparam.scale[0]
|
53
|
+
max_ = s16_scale * 32767 # numeric_limits<int16>
|
54
|
+
min_ = -max_
|
55
|
+
|
56
|
+
u8_scale = (max_ - min_) / 255
|
57
|
+
u8_zerop = round(-min_ / u8_scale)
|
58
|
+
|
59
|
+
new_qparam = QuantParam()
|
60
|
+
new_qparam.scale = [u8_scale]
|
61
|
+
new_qparam.zero_point = [u8_zerop]
|
62
|
+
new_qparam.dtype = "uint8"
|
63
|
+
|
64
|
+
return new_qparam
|
65
|
+
|
66
|
+
|
67
|
+
# Convert u8 qparam to i16 qparam
|
68
|
+
# scale is inferred from u8 qparam
|
69
|
+
def _u8_to_i16(qparam: QuantParam) -> QuantParam:
|
70
|
+
# Assume per-tensor quantization
|
71
|
+
assert qparam.scale is not None and len(qparam.scale) == 1
|
72
|
+
assert qparam.zero_point is not None and len(qparam.zero_point) == 1
|
73
|
+
assert qparam.dtype == "uint8"
|
74
|
+
|
75
|
+
u8_scale = qparam.scale[0]
|
76
|
+
u8_zerop = qparam.zero_point[0]
|
77
|
+
max_ = u8_scale * (255 - u8_zerop)
|
78
|
+
min_ = u8_scale * (-u8_zerop)
|
79
|
+
|
80
|
+
abs_max = max([max_, min_], key=abs)
|
81
|
+
s16_scale = abs_max / 32767
|
82
|
+
s16_zerop = 0
|
83
|
+
|
84
|
+
new_qparam = QuantParam()
|
85
|
+
new_qparam.scale = [s16_scale]
|
86
|
+
new_qparam.zero_point = [s16_zerop]
|
87
|
+
new_qparam.dtype = "int16"
|
88
|
+
|
89
|
+
return new_qparam
|
90
|
+
|
91
|
+
|
92
|
+
@trace_graph_diff_on_pass
|
93
|
+
class InsertQuantizeOnDtypeMismatch(PassBase):
|
94
|
+
"""
|
95
|
+
Insert quantize Op in the operators where circle's type inference is violated.
|
96
|
+
Example. FullyConnected
|
97
|
+
[BEFORE]
|
98
|
+
Op (uint8) - aten.linear.default (int16)
|
99
|
+
[AFTER]
|
100
|
+
Op (uint8) - aten.linear.default (uint8) - quantized_decomposed.quantize_per_tensor.default (int16)
|
101
|
+
Why is this pass necessary?
|
102
|
+
- For some operators, circle's type inference pass overwrites the input's dtype to
|
103
|
+
the output's dtype. For the above example, fully-connected layer (aten.linear.default)'s
|
104
|
+
output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
|
105
|
+
This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
|
106
|
+
- To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
|
107
|
+
- NOTE For some cases, Quantize Op is inserted before the operators.
|
108
|
+
|
109
|
+
Let's assume Reshape Op's input is int16 and output is uint8. There are two possible places to insert
|
110
|
+
Quantize Op.
|
111
|
+
|
112
|
+
1. Insert Quantize before Reshape.
|
113
|
+
|
114
|
+
```
|
115
|
+
Predecessor (int16)-> Quantize (uint8) -> Reshape (uint8) -> ...
|
116
|
+
```
|
117
|
+
|
118
|
+
2. Insert Quantize after Reshape.
|
119
|
+
|
120
|
+
```
|
121
|
+
Predecessor (int16)-> Reshape (int16) -> Quantize (uint8) -> ...
|
122
|
+
```
|
123
|
+
|
124
|
+
Comparing 1) and 2), the difference is that Reshape operation is conducted in uint8 or int16.
|
125
|
+
We go with 1), which does Reshape in uint8, for faster execution. Note that Reshape Op does not
|
126
|
+
change the value, so its dytpe does not affect accuracy.
|
127
|
+
"""
|
128
|
+
|
129
|
+
def __init__(self):
|
130
|
+
super().__init__()
|
131
|
+
|
132
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
133
|
+
logger = logging.getLogger(__name__)
|
134
|
+
|
135
|
+
graph_module = exported_program.graph_module
|
136
|
+
graph: torch.fx.Graph = graph_module.graph
|
137
|
+
|
138
|
+
def _insert_quantize_op_before(node, inp):
|
139
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
140
|
+
assert qparam.scale is not None
|
141
|
+
assert qparam.zero_point is not None
|
142
|
+
scale = qparam.scale[0]
|
143
|
+
zerop = qparam.zero_point[0]
|
144
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
145
|
+
dtype = getattr(torch, qparam.dtype)
|
146
|
+
|
147
|
+
with graph.inserting_before(node):
|
148
|
+
q_args = (inp, scale, zerop, min_, max_, dtype)
|
149
|
+
quantize = create_node(
|
150
|
+
graph,
|
151
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
152
|
+
args=q_args,
|
153
|
+
origin=node,
|
154
|
+
)
|
155
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
156
|
+
set_new_meta_val(quantize)
|
157
|
+
|
158
|
+
node.replace_input_with(inp, quantize)
|
159
|
+
|
160
|
+
return quantize
|
161
|
+
|
162
|
+
def _insert_quantize_op_after(node):
|
163
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
164
|
+
assert qparam.scale is not None
|
165
|
+
assert qparam.zero_point is not None
|
166
|
+
scale = qparam.scale[0]
|
167
|
+
zerop = qparam.zero_point[0]
|
168
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
169
|
+
dtype = getattr(torch, qparam.dtype)
|
170
|
+
with graph.inserting_after(node):
|
171
|
+
q_args = (node, scale, zerop, min_, max_, dtype)
|
172
|
+
quantize = create_node(
|
173
|
+
graph,
|
174
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
175
|
+
args=q_args,
|
176
|
+
)
|
177
|
+
|
178
|
+
node.replace_all_uses_with(quantize, propagate_meta=True)
|
179
|
+
quantize.replace_input_with(quantize, node)
|
180
|
+
|
181
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
182
|
+
|
183
|
+
return quantize
|
184
|
+
|
185
|
+
for node in graph.nodes:
|
186
|
+
if node.op != "call_function":
|
187
|
+
continue
|
188
|
+
if node.target == torch.ops.aten.linear.default:
|
189
|
+
lin_args = LinearArgs(*node.args, **node.kwargs)
|
190
|
+
inp = lin_args.input
|
191
|
+
|
192
|
+
if QPARAM_KEY not in inp.meta:
|
193
|
+
continue
|
194
|
+
|
195
|
+
if QPARAM_KEY not in node.meta:
|
196
|
+
continue
|
197
|
+
|
198
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
199
|
+
continue
|
200
|
+
|
201
|
+
if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
202
|
+
quantize = _insert_quantize_op_after(node)
|
203
|
+
|
204
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
205
|
+
|
206
|
+
# Update node's qparam from i16 to u8
|
207
|
+
# NOTE This would severely degrade accuracy. It is
|
208
|
+
# important to mitigate this accuracy drop in backend.
|
209
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
210
|
+
logger.debug(
|
211
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
raise NotYetSupportedError("Unsupported dtype")
|
215
|
+
|
216
|
+
elif node.target == torch.ops.aten.mul.Tensor:
|
217
|
+
mul_args = MulTensorArgs(*node.args, **node.kwargs)
|
218
|
+
x = mul_args.input
|
219
|
+
y = mul_args.other
|
220
|
+
|
221
|
+
if not isinstance(x, torch.fx.Node):
|
222
|
+
continue
|
223
|
+
if not isinstance(y, torch.fx.Node):
|
224
|
+
continue
|
225
|
+
|
226
|
+
if QPARAM_KEY not in x.meta:
|
227
|
+
continue
|
228
|
+
if QPARAM_KEY not in y.meta:
|
229
|
+
continue
|
230
|
+
if QPARAM_KEY not in node.meta:
|
231
|
+
continue
|
232
|
+
|
233
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
234
|
+
continue
|
235
|
+
|
236
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
237
|
+
quantize = _insert_quantize_op_after(node)
|
238
|
+
|
239
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
240
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
241
|
+
logger.debug(
|
242
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
243
|
+
)
|
244
|
+
else:
|
245
|
+
raise NotYetSupportedError("Unsupported dtype")
|
246
|
+
|
247
|
+
elif node.target == torch.ops.aten.bmm.default:
|
248
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
249
|
+
x = bmm_args.input
|
250
|
+
y = bmm_args.mat2
|
251
|
+
|
252
|
+
if QPARAM_KEY not in x.meta:
|
253
|
+
continue
|
254
|
+
if QPARAM_KEY not in y.meta:
|
255
|
+
continue
|
256
|
+
if QPARAM_KEY not in node.meta:
|
257
|
+
continue
|
258
|
+
|
259
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
260
|
+
continue
|
261
|
+
|
262
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
263
|
+
quantize = _insert_quantize_op_after(node)
|
264
|
+
|
265
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
266
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
267
|
+
logger.debug(
|
268
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
269
|
+
)
|
270
|
+
else:
|
271
|
+
raise NotYetSupportedError("Unsupported dtype")
|
272
|
+
|
273
|
+
elif node.target == torch.ops.aten.permute.default:
|
274
|
+
per_args = PermuteArgs(*node.args, **node.kwargs)
|
275
|
+
inp = per_args.input
|
276
|
+
|
277
|
+
if QPARAM_KEY not in inp.meta:
|
278
|
+
continue
|
279
|
+
|
280
|
+
if QPARAM_KEY not in node.meta:
|
281
|
+
continue
|
282
|
+
|
283
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
284
|
+
continue
|
285
|
+
|
286
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
287
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
288
|
+
# permute Op to reduce tensor size ealier
|
289
|
+
quantize = _insert_quantize_op_before(node, inp)
|
290
|
+
|
291
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
292
|
+
logger.debug(
|
293
|
+
f"quantize_per_tensor.default is inserted before {node.name}."
|
294
|
+
)
|
295
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
296
|
+
quantize = _insert_quantize_op_after(node)
|
297
|
+
|
298
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
299
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
300
|
+
logger.debug(
|
301
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
302
|
+
)
|
303
|
+
else:
|
304
|
+
raise NotYetSupportedError("Unsupported dtype")
|
305
|
+
elif node.target == torch.ops.aten.reshape.default:
|
306
|
+
reshape_args = ReshapeArgs(*node.args, **node.kwargs)
|
307
|
+
inp = reshape_args.input
|
308
|
+
|
309
|
+
if QPARAM_KEY not in inp.meta:
|
310
|
+
continue
|
311
|
+
|
312
|
+
if QPARAM_KEY not in node.meta:
|
313
|
+
continue
|
314
|
+
|
315
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
316
|
+
continue
|
317
|
+
|
318
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
319
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
320
|
+
# reshape Op to reduce tensor size ealier
|
321
|
+
quantize = _insert_quantize_op_before(node, inp)
|
322
|
+
|
323
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
324
|
+
logger.debug(
|
325
|
+
f"quantize_per_tensor.default is inserted before {node.name}."
|
326
|
+
)
|
327
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
328
|
+
quantize = _insert_quantize_op_after(node)
|
329
|
+
|
330
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
331
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
332
|
+
logger.debug(
|
333
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
raise NotYetSupportedError("Unsupported dtype")
|
337
|
+
|
338
|
+
# TODO Support more ops.
|
339
|
+
|
340
|
+
graph.eliminate_dead_code()
|
341
|
+
graph.lint()
|
342
|
+
graph_module.recompile()
|
343
|
+
|
344
|
+
# Run only once.
|
345
|
+
return PassResult(False)
|