tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,185 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, List, Tuple
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from circle_schema import circle
|
21
|
+
from numpy.typing import DTypeLike
|
22
|
+
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.model import CircleModel
|
25
|
+
|
26
|
+
|
27
|
+
def quantize(
|
28
|
+
data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
|
29
|
+
) -> np.ndarray:
|
30
|
+
"""
|
31
|
+
Quantize the given data using the specified scale, zero point, and data type.
|
32
|
+
This function takes input data and applies quantization using the formula:
|
33
|
+
round(data / scale) + zero_point
|
34
|
+
The result is clamped to the range of the specified data type.
|
35
|
+
"""
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
dtype = np.dtype(dtype)
|
38
|
+
assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
|
39
|
+
if dtype == np.int16:
|
40
|
+
assert zero_point == 0
|
41
|
+
|
42
|
+
# Convert input to Numpy array if necessary
|
43
|
+
if not isinstance(data, np.ndarray):
|
44
|
+
data = np.array(data)
|
45
|
+
# Perfrom quantization
|
46
|
+
if not scale:
|
47
|
+
logger.warn("WARNING: scale value is 0. 1e-7 will be used instead.")
|
48
|
+
scale = 1e-7
|
49
|
+
rescaled = np.round(data / scale) + zero_point
|
50
|
+
# Clamp the values
|
51
|
+
clipped = np.clip(rescaled, np.iinfo(dtype).min, np.iinfo(dtype).max)
|
52
|
+
# Convert to the specified dtype
|
53
|
+
return clipped.astype(dtype)
|
54
|
+
|
55
|
+
|
56
|
+
def dequantize(
|
57
|
+
data: np.ndarray, scale: float, zero_point: int, dtype: DTypeLike
|
58
|
+
) -> np.ndarray:
|
59
|
+
"""
|
60
|
+
Dequantize the given quantized data using the specified scale and zero point.
|
61
|
+
This function reverses the quantization process by applying the formula:
|
62
|
+
(quantized_value - zero_point) * scale
|
63
|
+
"""
|
64
|
+
dtype = np.dtype(dtype)
|
65
|
+
assert dtype == np.uint8 or dtype == np.int16, f"Invalid dtype: {dtype}"
|
66
|
+
if dtype == np.int16:
|
67
|
+
assert zero_point == 0
|
68
|
+
|
69
|
+
# Convert input to Numpy array if necessary
|
70
|
+
if not isinstance(data, np.ndarray):
|
71
|
+
data = np.array(data)
|
72
|
+
# Perform dequantization
|
73
|
+
ret = (data.astype(np.float32) - zero_point) * scale
|
74
|
+
# np.float32 * np.int64 = np.float64
|
75
|
+
return ret.astype(np.float32)
|
76
|
+
|
77
|
+
|
78
|
+
def get_graph_input_output(
|
79
|
+
circle_model: CircleModel,
|
80
|
+
) -> Tuple[List[circle.Tensor.Tensor], List[circle.Tensor.Tensor]]:
|
81
|
+
"""
|
82
|
+
Retrieve the inputs and the outputs from the circle model, and return them
|
83
|
+
as two lists.
|
84
|
+
"""
|
85
|
+
circle_buf: bytes = circle_model.circle_binary
|
86
|
+
circle_fb: circle.Model.Model = circle.Model.Model.GetRootAs(circle_buf, 0)
|
87
|
+
assert circle_fb.SubgraphsLength() == 1, "Only support single graph."
|
88
|
+
circle_graph = circle_fb.Subgraphs(0)
|
89
|
+
circle_inputs: List[circle.Tensor.Tensor] = [
|
90
|
+
circle_graph.Tensors(circle_graph.Inputs(i))
|
91
|
+
for i in range(circle_graph.InputsLength())
|
92
|
+
]
|
93
|
+
circle_outputs: List[circle.Tensor.Tensor] = [
|
94
|
+
circle_graph.Tensors(circle_graph.Outputs(o))
|
95
|
+
for o in range(circle_graph.OutputsLength())
|
96
|
+
]
|
97
|
+
|
98
|
+
return circle_inputs, circle_outputs
|
99
|
+
|
100
|
+
|
101
|
+
def find_invalid_types(
|
102
|
+
input: List[torch.Tensor] | List[np.ndarray], allowed_types: List
|
103
|
+
) -> List:
|
104
|
+
"""
|
105
|
+
Indentifies the types of items in a list that are not allowed and removes duplicates.
|
106
|
+
|
107
|
+
Parameters
|
108
|
+
-----------
|
109
|
+
input
|
110
|
+
List of itmes to check.
|
111
|
+
allowed_types
|
112
|
+
List of allowed types (e.g. [int, str])
|
113
|
+
Returns
|
114
|
+
--------
|
115
|
+
A list of unique types that are not allowed in the input list.
|
116
|
+
"""
|
117
|
+
# Use set comprehension for uniqueness
|
118
|
+
invalid_types = {
|
119
|
+
type(item) for item in input if not isinstance(item, tuple(allowed_types))
|
120
|
+
}
|
121
|
+
return list(invalid_types)
|
122
|
+
|
123
|
+
|
124
|
+
def plot_two_outputs(x_values: torch.Tensor, y_values: torch.Tensor):
|
125
|
+
"""
|
126
|
+
Plot two values on a 2D graph using plotext.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
--------
|
130
|
+
A figure built from plotext.
|
131
|
+
|
132
|
+
Example
|
133
|
+
--------
|
134
|
+
>>> x_values = torch.tensor([1, 2, 3, 4, 5])
|
135
|
+
>>> y_values = torch.tensor([10, 20, 30, 40, 50])
|
136
|
+
>>> fig = plot_two_outputs(x_values, y_values)
|
137
|
+
>>> print(fig)
|
138
|
+
"""
|
139
|
+
x_np = x_values.numpy().reshape(-1)
|
140
|
+
y_np = y_values.numpy().reshape(-1)
|
141
|
+
min_value = min([x_np.min(), y_np.min()])
|
142
|
+
max_value = max([x_np.max(), y_np.max()])
|
143
|
+
|
144
|
+
interval = max_value - min_value
|
145
|
+
interval = 1.0 if interval == 0.0 else interval # Avoid zero interval
|
146
|
+
|
147
|
+
# Enlarge axis
|
148
|
+
axis_min = min_value - interval * 0.05
|
149
|
+
axis_max = max_value + interval * 0.05
|
150
|
+
|
151
|
+
import plotext as plt
|
152
|
+
|
153
|
+
plt.clear_data()
|
154
|
+
plt.xlim(axis_min, axis_max)
|
155
|
+
plt.ylim(axis_min, axis_max)
|
156
|
+
plt.plotsize(width=50, height=25)
|
157
|
+
plt.scatter(x_np, y_np, marker="dot")
|
158
|
+
plt.theme("clear")
|
159
|
+
|
160
|
+
return plt.build()
|
161
|
+
|
162
|
+
|
163
|
+
def ensure_list(inputs: Any | Tuple[Any] | List[Any]) -> List[Any]:
|
164
|
+
"""
|
165
|
+
Ensures that the given inputs is converted into a list.
|
166
|
+
|
167
|
+
- If the input is a single element, it wraps it into a list.
|
168
|
+
- If the input is a tuple, it converts the tuple to a list.
|
169
|
+
- If the input is already a list, it returns the input unchanged.
|
170
|
+
|
171
|
+
Example
|
172
|
+
--------
|
173
|
+
>>> ensure_list(42)
|
174
|
+
>>> [42]
|
175
|
+
>>> ensure_list((1, 2, 3))
|
176
|
+
>>> [1, 2, 3]
|
177
|
+
>>> ensure_list([4, 5, 6])
|
178
|
+
>>> [4, 5, 6]
|
179
|
+
"""
|
180
|
+
if isinstance(inputs, list):
|
181
|
+
return inputs
|
182
|
+
elif isinstance(inputs, tuple):
|
183
|
+
return list(inputs)
|
184
|
+
else:
|
185
|
+
return [inputs]
|
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam, to_qparam_dtype
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
25
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
from tico.utils.validate_args_kwargs import (
|
27
|
+
DequantizePerTensorArgs,
|
28
|
+
QuantizePerTensorArgs,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
@trace_graph_diff_on_pass
|
33
|
+
class FoldQuantOps(PassBase):
|
34
|
+
"""
|
35
|
+
This pass folds (Q - DQ) pattern to previous op. After quantization from torch, activation ops
|
36
|
+
have (op - Q - DQ) pattern.
|
37
|
+
|
38
|
+
To export quantized circle, this pass removes (Q - DQ) nodes and saves those quantization info
|
39
|
+
to previous op's metadata.
|
40
|
+
|
41
|
+
[BEFORE]
|
42
|
+
op (float) - Quantize - Dequantize - (float)
|
43
|
+
|
44
|
+
[AFTER]
|
45
|
+
op (float with meta[QPARAM_KEY])
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(self):
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
52
|
+
logger = logging.getLogger(__name__)
|
53
|
+
|
54
|
+
graph_module = exported_program.graph_module
|
55
|
+
graph: torch.fx.Graph = graph_module.graph
|
56
|
+
for dq in graph.nodes:
|
57
|
+
if dq.op != "call_function":
|
58
|
+
continue
|
59
|
+
if (
|
60
|
+
dq.target
|
61
|
+
!= torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
62
|
+
):
|
63
|
+
continue
|
64
|
+
dq_args = DequantizePerTensorArgs(*dq.args, **dq.kwargs)
|
65
|
+
|
66
|
+
q = dq_args.input
|
67
|
+
if q.target != torch.ops.quantized_decomposed.quantize_per_tensor.default:
|
68
|
+
continue
|
69
|
+
q_args = QuantizePerTensorArgs(*q.args, **q.kwargs) # type: ignore[arg-type]
|
70
|
+
op = q_args.tensor
|
71
|
+
|
72
|
+
# Check if Q and DQ have same quant param
|
73
|
+
if q_args.scale != dq_args.scale:
|
74
|
+
continue
|
75
|
+
if q_args.zero_p != dq_args.zero_point:
|
76
|
+
continue
|
77
|
+
if q_args.dtype != dq_args.dtype:
|
78
|
+
continue
|
79
|
+
|
80
|
+
if QPARAM_KEY not in op.meta:
|
81
|
+
qparam = QuantParam()
|
82
|
+
qparam.scale = [q_args.scale]
|
83
|
+
qparam.zero_point = [q_args.zero_p]
|
84
|
+
assert "val" in q.meta and hasattr(q.meta["val"], "dtype")
|
85
|
+
qparam.dtype = to_qparam_dtype(q.meta["val"].dtype)
|
86
|
+
op.meta[QPARAM_KEY] = qparam
|
87
|
+
|
88
|
+
dq.replace_all_uses_with(op, propagate_meta=False)
|
89
|
+
|
90
|
+
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
|
91
|
+
|
92
|
+
graph.eliminate_dead_code()
|
93
|
+
graph.lint()
|
94
|
+
graph_module.recompile()
|
95
|
+
|
96
|
+
# Run only once.
|
97
|
+
return PassResult(False)
|
@@ -0,0 +1,289 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.errors import NotYetSupportedError
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.utils import quant_min_max, set_new_meta_val
|
30
|
+
from tico.utils.validate_args_kwargs import (
|
31
|
+
BmmArgs,
|
32
|
+
LinearArgs,
|
33
|
+
MulTensorArgs,
|
34
|
+
PermuteArgs,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def qparam_dtype(node: torch.fx.Node) -> str:
|
39
|
+
assert QPARAM_KEY in node.meta
|
40
|
+
return node.meta[QPARAM_KEY].dtype
|
41
|
+
|
42
|
+
|
43
|
+
# Convert i16 qparam to u8 qparam
|
44
|
+
# scale and zero_point are inferred from i16 qparam
|
45
|
+
def _i16_to_u8(qparam: QuantParam) -> QuantParam:
|
46
|
+
# Assume per-tensor quantization
|
47
|
+
assert qparam.scale is not None and len(qparam.scale) == 1
|
48
|
+
assert qparam.dtype == "int16"
|
49
|
+
|
50
|
+
s16_scale = qparam.scale[0]
|
51
|
+
max_ = s16_scale * 32767 # numeric_limits<int16>
|
52
|
+
min_ = -max_
|
53
|
+
|
54
|
+
u8_scale = (max_ - min_) / 255
|
55
|
+
u8_zerop = round(-min_ / u8_scale)
|
56
|
+
|
57
|
+
new_qparam = QuantParam()
|
58
|
+
new_qparam.scale = [u8_scale]
|
59
|
+
new_qparam.zero_point = [u8_zerop]
|
60
|
+
new_qparam.dtype = "uint8"
|
61
|
+
|
62
|
+
return new_qparam
|
63
|
+
|
64
|
+
|
65
|
+
# Convert u8 qparam to i16 qparam
|
66
|
+
# scale is inferred from u8 qparam
|
67
|
+
def _u8_to_i16(qparam: QuantParam) -> QuantParam:
|
68
|
+
# Assume per-tensor quantization
|
69
|
+
assert qparam.scale is not None and len(qparam.scale) == 1
|
70
|
+
assert qparam.zero_point is not None and len(qparam.zero_point) == 1
|
71
|
+
assert qparam.dtype == "uint8"
|
72
|
+
|
73
|
+
u8_scale = qparam.scale[0]
|
74
|
+
u8_zerop = qparam.zero_point[0]
|
75
|
+
max_ = u8_scale * (255 - u8_zerop)
|
76
|
+
min_ = u8_scale * (-u8_zerop)
|
77
|
+
|
78
|
+
abs_max = max([max_, min_], key=abs)
|
79
|
+
s16_scale = abs_max / 32767
|
80
|
+
s16_zerop = 0
|
81
|
+
|
82
|
+
new_qparam = QuantParam()
|
83
|
+
new_qparam.scale = [s16_scale]
|
84
|
+
new_qparam.zero_point = [s16_zerop]
|
85
|
+
new_qparam.dtype = "int16"
|
86
|
+
|
87
|
+
return new_qparam
|
88
|
+
|
89
|
+
|
90
|
+
@trace_graph_diff_on_pass
|
91
|
+
class InsertQuantizeOnDtypeMismatch(PassBase):
|
92
|
+
"""
|
93
|
+
Insert quantize Op in the operators where circle's type inference is violated.
|
94
|
+
Example. FullyConnected
|
95
|
+
[BEFORE]
|
96
|
+
Op (uint8) - aten.linear.default (int16)
|
97
|
+
[AFTER]
|
98
|
+
Op (uint8) - aten.linear.default (uint8) - quantized_decomposed.quantize_per_tensor.default (int16)
|
99
|
+
Why is this pass necessary?
|
100
|
+
- For some operators, circle's type inference pass overwrites the input's dtype to
|
101
|
+
the output's dtype. For the above example, fully-connected layer (aten.linear.default)'s
|
102
|
+
output dtype (int16) is updated to the input dtype (uint8), which breaks the semantics.
|
103
|
+
This problem can occur in the tools (ex: circle2circle) that automatically apply type inference.
|
104
|
+
- To resolve the issue, we insert quantize operators not to violate circle's type inference logic.
|
105
|
+
"""
|
106
|
+
|
107
|
+
def __init__(self):
|
108
|
+
super().__init__()
|
109
|
+
|
110
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
111
|
+
logger = logging.getLogger(__name__)
|
112
|
+
|
113
|
+
graph_module = exported_program.graph_module
|
114
|
+
graph: torch.fx.Graph = graph_module.graph
|
115
|
+
|
116
|
+
def _insert_quantize_op_before(node, inp):
|
117
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
118
|
+
assert qparam.scale is not None
|
119
|
+
assert qparam.zero_point is not None
|
120
|
+
scale = qparam.scale[0]
|
121
|
+
zerop = qparam.zero_point[0]
|
122
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
123
|
+
dtype = getattr(torch, qparam.dtype)
|
124
|
+
|
125
|
+
with graph.inserting_before(node):
|
126
|
+
q_args = (inp, scale, zerop, min_, max_, dtype)
|
127
|
+
quantize = graph.call_function(
|
128
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
129
|
+
args=q_args,
|
130
|
+
)
|
131
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
132
|
+
set_new_meta_val(quantize)
|
133
|
+
|
134
|
+
node.replace_input_with(inp, quantize)
|
135
|
+
|
136
|
+
return quantize
|
137
|
+
|
138
|
+
def _insert_quantize_op_after(node):
|
139
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
140
|
+
assert qparam.scale is not None
|
141
|
+
assert qparam.zero_point is not None
|
142
|
+
scale = qparam.scale[0]
|
143
|
+
zerop = qparam.zero_point[0]
|
144
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
145
|
+
dtype = getattr(torch, qparam.dtype)
|
146
|
+
with graph.inserting_after(node):
|
147
|
+
q_args = (node, scale, zerop, min_, max_, dtype)
|
148
|
+
quantize = graph.call_function(
|
149
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
150
|
+
args=q_args,
|
151
|
+
)
|
152
|
+
|
153
|
+
node.replace_all_uses_with(quantize, propagate_meta=True)
|
154
|
+
quantize.replace_input_with(quantize, node)
|
155
|
+
|
156
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
157
|
+
|
158
|
+
return quantize
|
159
|
+
|
160
|
+
for node in graph.nodes:
|
161
|
+
if node.op != "call_function":
|
162
|
+
continue
|
163
|
+
if node.target == torch.ops.aten.linear.default:
|
164
|
+
lin_args = LinearArgs(*node.args, **node.kwargs)
|
165
|
+
inp = lin_args.input
|
166
|
+
|
167
|
+
if QPARAM_KEY not in inp.meta:
|
168
|
+
continue
|
169
|
+
|
170
|
+
if QPARAM_KEY not in node.meta:
|
171
|
+
continue
|
172
|
+
|
173
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
174
|
+
continue
|
175
|
+
|
176
|
+
if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
177
|
+
quantize = _insert_quantize_op_after(node)
|
178
|
+
|
179
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
180
|
+
|
181
|
+
# Update node's qparam from i16 to u8
|
182
|
+
# NOTE This would severely degrade accuracy. It is
|
183
|
+
# important to mitigate this accuracy drop in backend.
|
184
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
185
|
+
logger.debug(
|
186
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
raise NotYetSupportedError("Unsupported dtype")
|
190
|
+
|
191
|
+
elif node.target == torch.ops.aten.mul.Tensor:
|
192
|
+
mul_args = MulTensorArgs(*node.args, **node.kwargs)
|
193
|
+
x = mul_args.input
|
194
|
+
y = mul_args.other
|
195
|
+
|
196
|
+
if not isinstance(x, torch.fx.Node):
|
197
|
+
continue
|
198
|
+
if not isinstance(y, torch.fx.Node):
|
199
|
+
continue
|
200
|
+
|
201
|
+
if QPARAM_KEY not in x.meta:
|
202
|
+
continue
|
203
|
+
if QPARAM_KEY not in y.meta:
|
204
|
+
continue
|
205
|
+
if QPARAM_KEY not in node.meta:
|
206
|
+
continue
|
207
|
+
|
208
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
209
|
+
continue
|
210
|
+
|
211
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
212
|
+
quantize = _insert_quantize_op_after(node)
|
213
|
+
|
214
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
215
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
216
|
+
logger.debug(
|
217
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
raise NotYetSupportedError("Unsupported dtype")
|
221
|
+
|
222
|
+
elif node.target == torch.ops.aten.bmm.default:
|
223
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
224
|
+
x = bmm_args.input
|
225
|
+
y = bmm_args.mat2
|
226
|
+
|
227
|
+
if QPARAM_KEY not in x.meta:
|
228
|
+
continue
|
229
|
+
if QPARAM_KEY not in y.meta:
|
230
|
+
continue
|
231
|
+
if QPARAM_KEY not in node.meta:
|
232
|
+
continue
|
233
|
+
|
234
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
235
|
+
continue
|
236
|
+
|
237
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
238
|
+
quantize = _insert_quantize_op_after(node)
|
239
|
+
|
240
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
241
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
242
|
+
logger.debug(
|
243
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
244
|
+
)
|
245
|
+
else:
|
246
|
+
raise NotYetSupportedError("Unsupported dtype")
|
247
|
+
|
248
|
+
elif node.target == torch.ops.aten.permute.default:
|
249
|
+
per_args = PermuteArgs(*node.args, **node.kwargs)
|
250
|
+
inp = per_args.input
|
251
|
+
|
252
|
+
if QPARAM_KEY not in inp.meta:
|
253
|
+
continue
|
254
|
+
|
255
|
+
if QPARAM_KEY not in node.meta:
|
256
|
+
continue
|
257
|
+
|
258
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
259
|
+
continue
|
260
|
+
|
261
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
262
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
263
|
+
# permute Op to reduce tensor size ealier
|
264
|
+
quantize = _insert_quantize_op_before(node, inp)
|
265
|
+
|
266
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
267
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
268
|
+
logger.debug(
|
269
|
+
f"quantize_per_tensor.default is inserted before {node.name}."
|
270
|
+
)
|
271
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
272
|
+
quantize = _insert_quantize_op_after(node)
|
273
|
+
|
274
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
275
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
276
|
+
logger.debug(
|
277
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
278
|
+
)
|
279
|
+
else:
|
280
|
+
raise NotYetSupportedError("Unsupported dtype")
|
281
|
+
|
282
|
+
# TODO Support more ops.
|
283
|
+
|
284
|
+
graph.eliminate_dead_code()
|
285
|
+
graph.lint()
|
286
|
+
graph_module.recompile()
|
287
|
+
|
288
|
+
# Run only once.
|
289
|
+
return PassResult(False)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import copy
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch.export import ExportedProgram
|
23
|
+
|
24
|
+
from tico.serialize.quant_param import QPARAM_KEY
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import CatArgs, PermuteArgs, ReshapeArgs
|
29
|
+
|
30
|
+
|
31
|
+
@trace_graph_diff_on_pass
|
32
|
+
class PropagateQParamBackward(PassBase):
|
33
|
+
"""
|
34
|
+
This pass propagates quantization parameters backward.
|
35
|
+
|
36
|
+
BEFORE)
|
37
|
+
|
38
|
+
node -> reshape (with meta[QPARAM_KEY])
|
39
|
+
|
40
|
+
AFTER)
|
41
|
+
|
42
|
+
node (with meta[QPARAM_KEY]) -> reshape (with meta[QPARAM_KEY])
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
49
|
+
logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
graph_module = exported_program.graph_module
|
52
|
+
graph: torch.fx.Graph = graph_module.graph
|
53
|
+
|
54
|
+
def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):
|
55
|
+
if QPARAM_KEY not in src.meta:
|
56
|
+
return
|
57
|
+
|
58
|
+
if (
|
59
|
+
QPARAM_KEY in dst.meta
|
60
|
+
and src.meta[QPARAM_KEY].dtype != dst.meta[QPARAM_KEY].dtype
|
61
|
+
):
|
62
|
+
return
|
63
|
+
|
64
|
+
dst.meta[QPARAM_KEY] = copy.deepcopy(src.meta[QPARAM_KEY])
|
65
|
+
|
66
|
+
logger.debug(f"{src.name}'s quantparam is propagated to {dst.name}.")
|
67
|
+
|
68
|
+
# Do reverse-order traversal for backward propagation
|
69
|
+
for node in reversed(graph.nodes):
|
70
|
+
if node.op != "call_function":
|
71
|
+
continue
|
72
|
+
if node.target == torch.ops.aten.cat.default:
|
73
|
+
concat_args = CatArgs(*node.args, **node.kwargs)
|
74
|
+
concat_inputs = concat_args.tensors
|
75
|
+
|
76
|
+
for concat_input in concat_inputs:
|
77
|
+
_propagate_qparam_if_possible(node, concat_input)
|
78
|
+
elif node.target == torch.ops.aten.reshape.default:
|
79
|
+
args = ReshapeArgs(*node.args, **node.kwargs)
|
80
|
+
_propagate_qparam_if_possible(node, args.input)
|
81
|
+
elif node.target == torch.ops.aten.permute.default:
|
82
|
+
permute_args = PermuteArgs(*node.args, **node.kwargs)
|
83
|
+
_propagate_qparam_if_possible(node, permute_args.input)
|
84
|
+
# TODO Support more ops.
|
85
|
+
|
86
|
+
graph.eliminate_dead_code()
|
87
|
+
graph.lint()
|
88
|
+
graph_module.recompile()
|
89
|
+
|
90
|
+
# Run only once.
|
91
|
+
return PassResult(False)
|