tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.utils import logging
|
23
|
+
from tico.utils.graph import create_node
|
24
|
+
from tico.utils.passes import PassBase, PassResult
|
25
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
26
|
+
from tico.utils.validate_args_kwargs import ClampArgs, HardTanhArgs
|
27
|
+
|
28
|
+
|
29
|
+
class Converter: # type: ignore[empty-body]
|
30
|
+
def __init__(self):
|
31
|
+
super().__init__()
|
32
|
+
|
33
|
+
def match(self, node) -> bool: # type: ignore[empty-body]
|
34
|
+
return False
|
35
|
+
|
36
|
+
def convert(self, exported_program, node) -> None: # type: ignore[empty-body]
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class ConvertHardTanhToReLU6(Converter):
|
41
|
+
def __init__(self):
|
42
|
+
super().__init__()
|
43
|
+
|
44
|
+
def match(self, node) -> bool:
|
45
|
+
if node.target == torch.ops.aten.hardtanh.default:
|
46
|
+
args = HardTanhArgs(*node.args, **node.kwargs)
|
47
|
+
min_val = args.min_val
|
48
|
+
max_val = args.max_val
|
49
|
+
|
50
|
+
# NOTE: int and float are both covered by pytorch implicit type conversion
|
51
|
+
return min_val == 0.0 and max_val == 6.0
|
52
|
+
|
53
|
+
return False
|
54
|
+
|
55
|
+
def convert(self, exported_program, node):
|
56
|
+
graph_module = exported_program.graph_module
|
57
|
+
graph = graph_module.graph
|
58
|
+
args = HardTanhArgs(*node.args, **node.kwargs)
|
59
|
+
input = args.input
|
60
|
+
|
61
|
+
with graph.inserting_after(node):
|
62
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
63
|
+
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
64
|
+
|
65
|
+
|
66
|
+
class ConvertClampToReLU6(Converter):
|
67
|
+
def __init__(self):
|
68
|
+
super().__init__()
|
69
|
+
|
70
|
+
def match(self, node) -> bool:
|
71
|
+
if node.target == torch.ops.aten.clamp.default:
|
72
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
73
|
+
min_val = args.min
|
74
|
+
max_val = args.max
|
75
|
+
|
76
|
+
# NOTE: int and float are both covered by pytorch implicit type conversion
|
77
|
+
return min_val == 0 and max_val == 6
|
78
|
+
|
79
|
+
return False
|
80
|
+
|
81
|
+
def convert(self, exported_program, node):
|
82
|
+
graph_module = exported_program.graph_module
|
83
|
+
graph = graph_module.graph
|
84
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
85
|
+
input = args.input
|
86
|
+
|
87
|
+
with graph.inserting_after(node):
|
88
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
89
|
+
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
90
|
+
|
91
|
+
|
92
|
+
class ConvertDoubleClampsToReLU6(Converter):
|
93
|
+
def __init__(self):
|
94
|
+
super().__init__()
|
95
|
+
|
96
|
+
def match(self, node) -> bool:
|
97
|
+
"""
|
98
|
+
This pass matches the pattern of two clamps where it equals to clamp which has a min value of 0 and a max value of 6.
|
99
|
+
|
100
|
+
(equivalent)
|
101
|
+
input input
|
102
|
+
| |
|
103
|
+
node_prev (min, max) node (0, 6)
|
104
|
+
| |
|
105
|
+
node (min', max') |
|
106
|
+
| |
|
107
|
+
output output
|
108
|
+
|
109
|
+
*where max(min, min') == 0 and min(max, max') == 6 so that it equivalents to clamp(input, 0, 6)
|
110
|
+
|
111
|
+
TODO Make this step more generic. For now we only support the case above.
|
112
|
+
"""
|
113
|
+
if not node.target == torch.ops.aten.clamp.default:
|
114
|
+
return False
|
115
|
+
|
116
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
117
|
+
node_prev = args.input
|
118
|
+
min_val = args.min if args.min is not None else float("-inf")
|
119
|
+
max_val = args.max if args.max is not None else float("inf")
|
120
|
+
|
121
|
+
if not node_prev.target == torch.ops.aten.clamp.default:
|
122
|
+
return False
|
123
|
+
|
124
|
+
prev_args = ClampArgs(*node_prev.args, **node_prev.kwargs) # type: ignore[arg-type]
|
125
|
+
min_val_prev = prev_args.min if prev_args.min is not None else float("-inf")
|
126
|
+
max_val_prev = prev_args.max if prev_args.max is not None else float("inf")
|
127
|
+
|
128
|
+
# NOTE: int and float are both covered by pytorch implicit type conversion
|
129
|
+
if max(min_val, min_val_prev) == 0 and min(max_val, max_val_prev) == 6:
|
130
|
+
return True
|
131
|
+
|
132
|
+
return False
|
133
|
+
|
134
|
+
def convert(self, exported_program, node):
|
135
|
+
graph_module = exported_program.graph_module
|
136
|
+
graph = graph_module.graph
|
137
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
138
|
+
|
139
|
+
prev_node = args.input
|
140
|
+
prev_args = ClampArgs(*prev_node.args, **prev_node.kwargs) # type: ignore[arg-type]
|
141
|
+
input = prev_args.input
|
142
|
+
|
143
|
+
with graph.inserting_after(node):
|
144
|
+
relu_node = create_node(graph, torch.ops.aten.relu6.default, args=(input,))
|
145
|
+
node.replace_all_uses_with(relu_node, propagate_meta=True)
|
146
|
+
|
147
|
+
|
148
|
+
@trace_graph_diff_on_pass
|
149
|
+
class ConvertToReLU6(PassBase):
|
150
|
+
def __init__(self):
|
151
|
+
super().__init__()
|
152
|
+
self.converters: List[Converter] = [
|
153
|
+
ConvertHardTanhToReLU6(),
|
154
|
+
ConvertClampToReLU6(),
|
155
|
+
ConvertDoubleClampsToReLU6(),
|
156
|
+
]
|
157
|
+
|
158
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
159
|
+
logger = logging.getLogger(__name__)
|
160
|
+
|
161
|
+
graph_module = exported_program.graph_module
|
162
|
+
graph = graph_module.graph
|
163
|
+
modified = False
|
164
|
+
for node in graph.nodes:
|
165
|
+
if not node.op == "call_function":
|
166
|
+
continue
|
167
|
+
|
168
|
+
for converter in self.converters:
|
169
|
+
if not converter.match(node):
|
170
|
+
continue
|
171
|
+
|
172
|
+
converter.convert(exported_program, node)
|
173
|
+
modified = True
|
174
|
+
logger.debug(f"{node.name} is replaced with ReLU6 operator")
|
175
|
+
break
|
176
|
+
|
177
|
+
graph.eliminate_dead_code()
|
178
|
+
graph.lint()
|
179
|
+
graph_module.recompile()
|
180
|
+
|
181
|
+
return PassResult(modified)
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.circle_mapping import extract_shape
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.graph import add_placeholder, create_node
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.utils import is_target_node, set_new_meta_val
|
28
|
+
from tico.utils.validate_args_kwargs import AddmmArgs
|
29
|
+
|
30
|
+
|
31
|
+
@trace_graph_diff_on_pass
|
32
|
+
class DecomposeAddmm(PassBase):
|
33
|
+
"""
|
34
|
+
Let's decompose addmm to add + mul + matmul.
|
35
|
+
|
36
|
+
[BEFORE]
|
37
|
+
|
38
|
+
input mat1 mat2 beta alpha
|
39
|
+
| | | | |
|
40
|
+
--------------addmm--------------
|
41
|
+
|
|
42
|
+
out
|
43
|
+
|
44
|
+
[AFTER]
|
45
|
+
|
46
|
+
input beta mat1 mat2 alpha
|
47
|
+
| | | | |
|
48
|
+
---mul--- ---mm---- |
|
49
|
+
| | |
|
50
|
+
| -----mul-----
|
51
|
+
| |
|
52
|
+
---------add----------
|
53
|
+
|
|
54
|
+
out
|
55
|
+
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self):
|
59
|
+
super().__init__()
|
60
|
+
|
61
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
62
|
+
logger = logging.getLogger(__name__)
|
63
|
+
|
64
|
+
gm = exported_program.graph_module
|
65
|
+
graph: torch.fx.Graph = gm.graph
|
66
|
+
modified = False
|
67
|
+
|
68
|
+
for node in graph.nodes:
|
69
|
+
if not is_target_node(node, torch.ops.aten.addmm.default):
|
70
|
+
continue
|
71
|
+
|
72
|
+
args = AddmmArgs(*node.args, **node.kwargs)
|
73
|
+
input = args.input
|
74
|
+
mat1 = args.mat1
|
75
|
+
mat2 = args.mat2
|
76
|
+
beta = args.beta
|
77
|
+
alpha = args.alpha
|
78
|
+
|
79
|
+
with graph.inserting_before(node):
|
80
|
+
# out = beta * input + alpha * (mat1 @ mat2)
|
81
|
+
matmul = create_node(
|
82
|
+
graph, torch.ops.aten.mm.default, (mat1, mat2), origin=node
|
83
|
+
)
|
84
|
+
set_new_meta_val(matmul)
|
85
|
+
|
86
|
+
if beta == 1:
|
87
|
+
bias: torch.fx.Node | torch.Tensor = input
|
88
|
+
elif beta == 0:
|
89
|
+
bias = add_placeholder(
|
90
|
+
exported_program,
|
91
|
+
torch.zeros(extract_shape(input)),
|
92
|
+
f"{node.name}_beta_zeros",
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
bias = create_node(
|
96
|
+
graph, torch.ops.aten.mul.Tensor, (input, beta), origin=node
|
97
|
+
)
|
98
|
+
|
99
|
+
if alpha == 1:
|
100
|
+
scaled_matmul: torch.fx.Node | torch.Tensor = matmul
|
101
|
+
elif alpha == 0:
|
102
|
+
scaled_matmul = add_placeholder(
|
103
|
+
exported_program,
|
104
|
+
torch.zeros(extract_shape(matmul)),
|
105
|
+
f"{node.name}_alpha_zeros",
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
scaled_matmul = create_node(
|
109
|
+
graph, torch.ops.aten.mul.Tensor, (matmul, alpha), origin=node
|
110
|
+
)
|
111
|
+
|
112
|
+
result = create_node(
|
113
|
+
graph, torch.ops.aten.add.Tensor, (bias, scaled_matmul)
|
114
|
+
)
|
115
|
+
|
116
|
+
node.replace_all_uses_with(result, propagate_meta=True)
|
117
|
+
|
118
|
+
modified = True
|
119
|
+
|
120
|
+
gm.graph.eliminate_dead_code()
|
121
|
+
gm.graph.lint()
|
122
|
+
gm.recompile()
|
123
|
+
|
124
|
+
return PassResult(modified)
|
@@ -0,0 +1,192 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.circle_mapping import extract_shape
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.graph import (
|
26
|
+
add_placeholder,
|
27
|
+
create_node,
|
28
|
+
get_first_user_input,
|
29
|
+
get_torch_buffer_value,
|
30
|
+
get_torch_param_value,
|
31
|
+
is_torch_buffer,
|
32
|
+
is_torch_param,
|
33
|
+
)
|
34
|
+
from tico.utils.passes import PassBase, PassResult
|
35
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
36
|
+
from tico.utils.utils import is_target_node
|
37
|
+
from tico.utils.validate_args_kwargs import NativeBatchNormLegitNoTrainingArgs
|
38
|
+
|
39
|
+
|
40
|
+
@trace_graph_diff_on_pass
|
41
|
+
class DecomposeBatchNorm(PassBase):
|
42
|
+
"""
|
43
|
+
[BatchNorm]
|
44
|
+
|
45
|
+
The op can be decomposed to a single aten.mul and a single aten.add because mean and
|
46
|
+
var are fixed during evaluation.
|
47
|
+
|
48
|
+
W = (weight / sqrt(var + eps))
|
49
|
+
B = bias - (mean * weight) / sqrt(var + eps)
|
50
|
+
Y = X * W + B
|
51
|
+
|
52
|
+
[before]
|
53
|
+
|
54
|
+
input (tensor, weight, bias, running_mean, running_var, momentum, eps)
|
55
|
+
|
|
56
|
+
BatchNorm
|
57
|
+
|
|
58
|
+
output
|
59
|
+
|
60
|
+
[after]
|
61
|
+
|
62
|
+
input
|
63
|
+
(tensor)
|
64
|
+
| W
|
65
|
+
| /
|
66
|
+
mul
|
67
|
+
| B
|
68
|
+
| /
|
69
|
+
add
|
70
|
+
|
|
71
|
+
output
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self):
|
75
|
+
super().__init__()
|
76
|
+
|
77
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
78
|
+
logger = logging.getLogger(__name__)
|
79
|
+
|
80
|
+
gm = exported_program.graph_module
|
81
|
+
graph: torch.fx.Graph = gm.graph
|
82
|
+
modified = False
|
83
|
+
|
84
|
+
for node in graph.nodes:
|
85
|
+
if not is_target_node(
|
86
|
+
node, torch.ops.aten._native_batch_norm_legit_no_training.default
|
87
|
+
):
|
88
|
+
continue
|
89
|
+
|
90
|
+
args = NativeBatchNormLegitNoTrainingArgs(*node.args)
|
91
|
+
input_ = args.input
|
92
|
+
weight = args.weight
|
93
|
+
bias = args.bias
|
94
|
+
running_mean = args.running_mean
|
95
|
+
running_var = args.running_var
|
96
|
+
eps = args.eps
|
97
|
+
|
98
|
+
if not running_mean:
|
99
|
+
raise NotYetSupportedError(f"running_mean=None is not supported yet")
|
100
|
+
if not running_var:
|
101
|
+
raise NotYetSupportedError(f"running_var=None is not supported yet")
|
102
|
+
|
103
|
+
"""
|
104
|
+
Only support the cases generated from torch.nn.BatchNorm2d module,
|
105
|
+
for which, let's checks if weight and bias are parameters and
|
106
|
+
running_mean and running_var are buffers.
|
107
|
+
"""
|
108
|
+
if weight and not is_torch_param(weight, exported_program):
|
109
|
+
continue
|
110
|
+
if bias and not is_torch_param(bias, exported_program):
|
111
|
+
continue
|
112
|
+
if not is_torch_buffer(running_mean, exported_program):
|
113
|
+
continue
|
114
|
+
if not is_torch_buffer(running_var, exported_program):
|
115
|
+
continue
|
116
|
+
|
117
|
+
input_shape = extract_shape(input_)
|
118
|
+
assert len(input_shape) == 4
|
119
|
+
C = input_shape[1]
|
120
|
+
|
121
|
+
weight_value = (
|
122
|
+
get_torch_param_value(weight, exported_program)
|
123
|
+
if weight
|
124
|
+
else torch.tensor([1] * C)
|
125
|
+
)
|
126
|
+
bias_value = (
|
127
|
+
get_torch_param_value(bias, exported_program)
|
128
|
+
if bias
|
129
|
+
else torch.tensor([0] * C)
|
130
|
+
)
|
131
|
+
mean_value = get_torch_buffer_value(running_mean, exported_program)
|
132
|
+
var_value = get_torch_buffer_value(running_var, exported_program)
|
133
|
+
|
134
|
+
assert isinstance(weight_value, torch.Tensor)
|
135
|
+
assert isinstance(bias_value, torch.Tensor)
|
136
|
+
assert isinstance(mean_value, torch.Tensor)
|
137
|
+
assert isinstance(var_value, torch.Tensor)
|
138
|
+
|
139
|
+
assert (
|
140
|
+
weight_value.shape
|
141
|
+
== bias_value.shape
|
142
|
+
== mean_value.shape
|
143
|
+
== var_value.shape
|
144
|
+
)
|
145
|
+
# Calculate constants for mul and add
|
146
|
+
mul_const = weight_value / torch.sqrt(var_value + eps)
|
147
|
+
add_const = bias_value - (mul_const * mean_value)
|
148
|
+
# N, C, H, W
|
149
|
+
assert len(mul_const) == len(add_const) == C
|
150
|
+
# reshape along with channel dimension
|
151
|
+
mul_const = mul_const.view(1, mul_const.shape[0], 1, 1)
|
152
|
+
add_const = add_const.view(1, add_const.shape[0], 1, 1)
|
153
|
+
|
154
|
+
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
155
|
+
# Therefore, insert the newly created placeholders at the start of the node list.
|
156
|
+
with exported_program.graph.inserting_before(
|
157
|
+
get_first_user_input(exported_program)
|
158
|
+
):
|
159
|
+
mul_const_node = add_placeholder(
|
160
|
+
exported_program,
|
161
|
+
mul_const,
|
162
|
+
prefix=f"{node.name}_mul_const",
|
163
|
+
)
|
164
|
+
add_const_node = add_placeholder(
|
165
|
+
exported_program,
|
166
|
+
add_const,
|
167
|
+
prefix=f"{node.name}_add_const",
|
168
|
+
)
|
169
|
+
|
170
|
+
with gm.graph.inserting_before(node):
|
171
|
+
mul = create_node(
|
172
|
+
graph,
|
173
|
+
torch.ops.aten.mul.Tensor,
|
174
|
+
args=(input_, mul_const_node),
|
175
|
+
origin=node,
|
176
|
+
)
|
177
|
+
add = create_node(
|
178
|
+
graph,
|
179
|
+
torch.ops.aten.add.Tensor,
|
180
|
+
args=(mul, add_const_node),
|
181
|
+
)
|
182
|
+
get_item, *_ = node.users.keys()
|
183
|
+
get_item.replace_all_uses_with(add, propagate_meta=True)
|
184
|
+
|
185
|
+
logger.debug(f"{node.name} is decomposed to {mul.name} and {add.name}")
|
186
|
+
modified = True
|
187
|
+
|
188
|
+
gm.graph.eliminate_dead_code()
|
189
|
+
gm.graph.lint()
|
190
|
+
gm.recompile()
|
191
|
+
|
192
|
+
return PassResult(modified)
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
|
21
|
+
# To import torch.ops.quantized_decomposed related operator
|
22
|
+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
23
|
+
from torch.export import ExportedProgram
|
24
|
+
|
25
|
+
from tico.utils import logging
|
26
|
+
from tico.utils.graph import create_node
|
27
|
+
from tico.utils.passes import PassBase, PassResult
|
28
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
29
|
+
from tico.utils.validate_args_kwargs import FakeQuantizePerChannelArgs
|
30
|
+
|
31
|
+
|
32
|
+
def get_quant_type(min: int, max: int) -> torch.dtype:
|
33
|
+
if min == 0 and max == 15:
|
34
|
+
# torch can't represent "uint4".
|
35
|
+
# Let's set torch.uint8 and infer dtype with quant_min/quant_max instead.
|
36
|
+
return torch.uint8
|
37
|
+
if min == 0 and max == 255:
|
38
|
+
return torch.uint8
|
39
|
+
if min == -32768 and max == 32767:
|
40
|
+
return torch.int16
|
41
|
+
if min == -32767 and max == 32767:
|
42
|
+
return torch.int16
|
43
|
+
|
44
|
+
raise RuntimeError(f"Not supported min/max values: {min}/{max}")
|
45
|
+
|
46
|
+
|
47
|
+
@trace_graph_diff_on_pass
|
48
|
+
class DecomposeFakeQuantize(PassBase):
|
49
|
+
"""
|
50
|
+
Decompose fake quantize operator to quant/dequant operators.
|
51
|
+
Otherwise, it can't be converted to the edge IR because fake quantize operator is not Aten Canonical.
|
52
|
+
|
53
|
+
[Before]
|
54
|
+
def forward(self, x):
|
55
|
+
fake_quantize_per_tensor_affine = torch.ops.aten.fake_quantize_per_tensor_affine.default(tensor, scale, zero_p, quant_min, quant_max); x = None
|
56
|
+
return (fake_quantize_per_tensor_affine,)
|
57
|
+
|
58
|
+
[After]
|
59
|
+
def forward(self, x):
|
60
|
+
quantize_per_tensor_default = torch.ops.quantized_decomposed.quantize_per_tensor.default(tensor, scale, zero_p, quant_min, quant_max, dtype = ${torch.dtype}); x = None
|
61
|
+
dequantize_per_tensor_default = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor_default, scale, zero_p, quant_min, quant_max, dtype = ${torch.dtype}); quantize_per_tensor_default = None
|
62
|
+
return (dequantize_per_tensor_default,)
|
63
|
+
"""
|
64
|
+
|
65
|
+
def __init__(self):
|
66
|
+
super().__init__()
|
67
|
+
|
68
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
69
|
+
logger = logging.getLogger(__name__)
|
70
|
+
modified = False
|
71
|
+
|
72
|
+
gm = exported_program.graph_module
|
73
|
+
g = gm.graph
|
74
|
+
qd = torch.ops.quantized_decomposed # type: ignore[return]
|
75
|
+
for node in gm.graph.nodes:
|
76
|
+
if node.op != "call_function":
|
77
|
+
continue
|
78
|
+
if node.target in [torch.ops.aten.fake_quantize_per_tensor_affine.default]:
|
79
|
+
# tensor, scale, zero_p, quant_min, quant_max
|
80
|
+
assert len(node.args) == 5
|
81
|
+
_, _, _, quant_min, quant_max = node.args
|
82
|
+
|
83
|
+
quant_kwargs = {
|
84
|
+
**node.kwargs,
|
85
|
+
**{"dtype": get_quant_type(quant_min, quant_max)},
|
86
|
+
}
|
87
|
+
with gm.graph.inserting_before(node):
|
88
|
+
quant = create_node(
|
89
|
+
g,
|
90
|
+
qd.quantize_per_tensor.default,
|
91
|
+
args=node.args,
|
92
|
+
kwargs=quant_kwargs,
|
93
|
+
origin=node,
|
94
|
+
)
|
95
|
+
dequnt = create_node(
|
96
|
+
g,
|
97
|
+
qd.dequantize_per_tensor.default,
|
98
|
+
args=(quant, *quant.args[1:]),
|
99
|
+
kwargs=quant.kwargs,
|
100
|
+
)
|
101
|
+
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
102
|
+
modified = True
|
103
|
+
|
104
|
+
if node.target in [torch.ops.aten.fake_quantize_per_channel_affine.default]:
|
105
|
+
fq_args = FakeQuantizePerChannelArgs(*node.args, **node.kwargs)
|
106
|
+
quant_min = fq_args.quant_min
|
107
|
+
quant_max = fq_args.quant_max
|
108
|
+
|
109
|
+
quant_kwargs = {
|
110
|
+
**node.kwargs,
|
111
|
+
**{"dtype": get_quant_type(quant_min, quant_max)},
|
112
|
+
}
|
113
|
+
with gm.graph.inserting_before(node):
|
114
|
+
quant = create_node(
|
115
|
+
g,
|
116
|
+
qd.quantize_per_channel.default,
|
117
|
+
args=node.args,
|
118
|
+
kwargs=quant_kwargs,
|
119
|
+
origin=node,
|
120
|
+
)
|
121
|
+
dequnt = create_node(
|
122
|
+
g,
|
123
|
+
qd.dequantize_per_channel.default,
|
124
|
+
args=(quant, *quant.args[1:]),
|
125
|
+
kwargs=quant.kwargs,
|
126
|
+
)
|
127
|
+
node.replace_all_uses_with(dequnt, propagate_meta=True)
|
128
|
+
modified = True
|
129
|
+
|
130
|
+
gm.graph.eliminate_dead_code()
|
131
|
+
gm.graph.lint()
|
132
|
+
gm.recompile()
|
133
|
+
|
134
|
+
return PassResult(modified)
|