tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,307 @@
|
|
1
|
+
# Portions of this file are adapted from code originally authored by
|
2
|
+
# Meta Platforms, Inc. and affiliates, licensed under the BSD-style
|
3
|
+
# license found in the LICENSE file in the root directory of their source tree.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
# https://github.com/pytorch/executorch/blob/61ddee5/exir/passes/constant_prop_pass.py
|
20
|
+
|
21
|
+
from collections import OrderedDict
|
22
|
+
from typing import List, Mapping, Optional, TYPE_CHECKING
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
import torch.fx
|
26
|
+
import torch
|
27
|
+
from torch._export.utils import (
|
28
|
+
get_buffer,
|
29
|
+
get_lifted_tensor_constant,
|
30
|
+
get_param,
|
31
|
+
is_buffer,
|
32
|
+
is_lifted_tensor_constant,
|
33
|
+
is_param,
|
34
|
+
)
|
35
|
+
from torch.export import ExportedProgram
|
36
|
+
from torch.export.exported_program import InputKind, InputSpec
|
37
|
+
from torch.utils import _pytree as pytree
|
38
|
+
|
39
|
+
from tico.serialize.circle_graph import _PRIMITIVE_TYPES
|
40
|
+
from tico.utils import logging
|
41
|
+
from tico.utils.graph import create_input_spec, generate_fqn, get_first_user_input
|
42
|
+
from tico.utils.passes import PassBase, PassResult
|
43
|
+
from tico.utils.trace_decorators import (
|
44
|
+
trace_const_diff_on_pass,
|
45
|
+
trace_graph_diff_on_pass,
|
46
|
+
)
|
47
|
+
from tico.utils.utils import get_fake_mode
|
48
|
+
|
49
|
+
|
50
|
+
def get_constant_placeholder_to_tensor_dict(
|
51
|
+
exported_program: ExportedProgram,
|
52
|
+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
|
53
|
+
"""
|
54
|
+
Returns a dictionary of constant placeholder node to constant tensor.
|
55
|
+
"""
|
56
|
+
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
|
57
|
+
graph_module = exported_program.graph_module
|
58
|
+
graph: torch.fx.Graph = graph_module.graph
|
59
|
+
for node in graph.nodes:
|
60
|
+
if node.op != "placeholder":
|
61
|
+
continue
|
62
|
+
tensor: Optional[torch.Tensor] = None
|
63
|
+
if is_param(exported_program, node):
|
64
|
+
tensor = get_param(exported_program, node)
|
65
|
+
elif is_buffer(exported_program, node):
|
66
|
+
tensor = get_buffer(exported_program, node)
|
67
|
+
elif is_lifted_tensor_constant(exported_program, node):
|
68
|
+
tensor = get_lifted_tensor_constant(exported_program, node)
|
69
|
+
|
70
|
+
if tensor is not None:
|
71
|
+
assert node not in const_node_to_tensor
|
72
|
+
const_node_to_tensor[node] = tensor
|
73
|
+
|
74
|
+
return const_node_to_tensor
|
75
|
+
|
76
|
+
|
77
|
+
def has_constant_data(arg, const_node_to_tensor=None) -> bool:
|
78
|
+
"""
|
79
|
+
Check if `arg` has constant data.
|
80
|
+
|
81
|
+
Assume that `const_node_to_tensor` is retrived from exported program.
|
82
|
+
When a node is a placeholder, only method to check if it is constant is to check the exported program.
|
83
|
+
"""
|
84
|
+
if isinstance(arg, (tuple, list)):
|
85
|
+
return all(has_constant_data(a, const_node_to_tensor) for a in arg)
|
86
|
+
elif isinstance(arg, dict):
|
87
|
+
return all(has_constant_data(a, const_node_to_tensor) for a in arg.values())
|
88
|
+
elif isinstance(
|
89
|
+
arg,
|
90
|
+
_PRIMITIVE_TYPES,
|
91
|
+
):
|
92
|
+
return True
|
93
|
+
elif not isinstance(arg, torch.fx.Node):
|
94
|
+
return False
|
95
|
+
elif const_node_to_tensor is not None and arg in const_node_to_tensor:
|
96
|
+
return True
|
97
|
+
|
98
|
+
return False
|
99
|
+
|
100
|
+
|
101
|
+
def get_data(
|
102
|
+
arg,
|
103
|
+
exported_program: ExportedProgram,
|
104
|
+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
|
105
|
+
):
|
106
|
+
if isinstance(arg, (tuple, list)):
|
107
|
+
return (get_data(x, exported_program, const_node_to_tensor) for x in arg)
|
108
|
+
elif isinstance(arg, _PRIMITIVE_TYPES):
|
109
|
+
return arg
|
110
|
+
elif arg in const_node_to_tensor:
|
111
|
+
return const_node_to_tensor[arg]
|
112
|
+
return None
|
113
|
+
|
114
|
+
|
115
|
+
def propagate_constants(
|
116
|
+
exported_program: ExportedProgram,
|
117
|
+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
|
118
|
+
"""
|
119
|
+
Propagates constants and returns a dictionary of node to constant tensors of the graph.
|
120
|
+
"""
|
121
|
+
const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program)
|
122
|
+
|
123
|
+
graph_module = exported_program.graph_module
|
124
|
+
graph: torch.fx.Graph = graph_module.graph
|
125
|
+
for node in graph.nodes:
|
126
|
+
if node.op != "call_function":
|
127
|
+
continue
|
128
|
+
if node.target in [
|
129
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
130
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
131
|
+
]:
|
132
|
+
continue
|
133
|
+
if not has_constant_data(
|
134
|
+
[node.args, node.kwargs],
|
135
|
+
const_node_to_tensor,
|
136
|
+
):
|
137
|
+
continue
|
138
|
+
|
139
|
+
args_data, kwargs_data = pytree.tree_map(
|
140
|
+
lambda x: get_data(x, exported_program, const_node_to_tensor),
|
141
|
+
(node.args, node.kwargs),
|
142
|
+
)
|
143
|
+
|
144
|
+
# propagate constant because all of its args are constant tensors.
|
145
|
+
with torch.no_grad():
|
146
|
+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
|
147
|
+
const_node_to_tensor[node] = prop_constant_tensor
|
148
|
+
|
149
|
+
return const_node_to_tensor
|
150
|
+
|
151
|
+
|
152
|
+
def erase_constant_node(
|
153
|
+
exported_program: ExportedProgram,
|
154
|
+
node: torch.fx.Node,
|
155
|
+
) -> None:
|
156
|
+
"""
|
157
|
+
Remove corresponding tensor from param/constants dict.
|
158
|
+
|
159
|
+
Q) Isn't it necessary to remove a node from `inputs_to_parameters`, `inputs_to_lifted_tensor_constants`
|
160
|
+
and `inputs_to_buffers` as well? Why do they just call `get`?
|
161
|
+
A) They internally uses `exported_program.graph_signature.input_specs` and the `input_specs` are updated
|
162
|
+
at the end of the const_prop_pass.
|
163
|
+
"""
|
164
|
+
signature = exported_program.graph_signature
|
165
|
+
if name := signature.inputs_to_parameters.get(node.name, None):
|
166
|
+
exported_program.state_dict.pop(name, None)
|
167
|
+
elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None):
|
168
|
+
exported_program.constants.pop(name, None)
|
169
|
+
elif name := signature.inputs_to_buffers.get(node.name, None):
|
170
|
+
exported_program.constants.pop(name, None)
|
171
|
+
exported_program.state_dict.pop(name, None)
|
172
|
+
|
173
|
+
# Remove from graph.
|
174
|
+
exported_program.graph.erase_node(node)
|
175
|
+
|
176
|
+
|
177
|
+
def create_constant_placeholder(
|
178
|
+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
|
179
|
+
exported_program: ExportedProgram,
|
180
|
+
) -> List[torch.fx.Node]:
|
181
|
+
"""
|
182
|
+
This function creates constant placeholder nodes according to the given constant nodes (`const_node_to_tensor`) and replace it with the original node.
|
183
|
+
"""
|
184
|
+
placeholders = []
|
185
|
+
|
186
|
+
fake_mode = get_fake_mode(exported_program)
|
187
|
+
first_user_input = get_first_user_input(exported_program)
|
188
|
+
if not first_user_input:
|
189
|
+
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
190
|
+
# Therefore, insert the newly created placeholders at the start of the node list.
|
191
|
+
assert exported_program.graph.nodes
|
192
|
+
first_node = list(exported_program.graph.nodes)[0]
|
193
|
+
first_user_input = first_node
|
194
|
+
|
195
|
+
# Iterate over nodes in reverse order to insert created placeholder before the `first_user_input`.
|
196
|
+
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
|
197
|
+
if all(x in const_node_to_tensor for x in node.users):
|
198
|
+
# All users of this constant node are also constant, so we don't need to create a new constant node.
|
199
|
+
erase_constant_node(exported_program, node)
|
200
|
+
continue
|
201
|
+
|
202
|
+
if node.op == "placeholder":
|
203
|
+
continue
|
204
|
+
|
205
|
+
# Add `prop_constant_tensor` to program.state_dict.
|
206
|
+
prop_constant_tensor_fqn = generate_fqn(
|
207
|
+
"_prop_tensor_constant", exported_program
|
208
|
+
)
|
209
|
+
|
210
|
+
# Insert a new placeholder node for the propagated constant tensor.
|
211
|
+
with exported_program.graph.inserting_before(first_user_input):
|
212
|
+
const_placeholder_node = exported_program.graph.placeholder(
|
213
|
+
prop_constant_tensor_fqn
|
214
|
+
)
|
215
|
+
|
216
|
+
# The key here should be same with "target" arg of InputSpec when creating input specs.
|
217
|
+
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
|
218
|
+
|
219
|
+
# Replace the original node with the new constant node.
|
220
|
+
node.replace_all_uses_with(const_placeholder_node, propagate_meta=True)
|
221
|
+
exported_program.graph.erase_node(node)
|
222
|
+
|
223
|
+
# Update the meta data of the new placeholder node.
|
224
|
+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
|
225
|
+
prop_constant_tensor, static_shapes=True
|
226
|
+
)
|
227
|
+
const_placeholder_node.meta["val"].constant = prop_constant_tensor
|
228
|
+
|
229
|
+
placeholders.append(const_placeholder_node)
|
230
|
+
|
231
|
+
return placeholders
|
232
|
+
|
233
|
+
|
234
|
+
def create_input_specs(
|
235
|
+
placeholders: List[torch.fx.Node],
|
236
|
+
) -> dict[str, InputSpec]:
|
237
|
+
name_to_spec: dict[str, InputSpec] = {}
|
238
|
+
|
239
|
+
# https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
|
240
|
+
# %name = placeholder[target = name](args = ())
|
241
|
+
for node in placeholders:
|
242
|
+
name_to_spec[node.name] = create_input_spec(node, InputKind.CONSTANT_TENSOR)
|
243
|
+
|
244
|
+
return name_to_spec
|
245
|
+
|
246
|
+
|
247
|
+
@trace_graph_diff_on_pass
|
248
|
+
@trace_const_diff_on_pass
|
249
|
+
class ConstPropPass(PassBase):
|
250
|
+
"""
|
251
|
+
Performs constant folding and constant propagation.
|
252
|
+
|
253
|
+
NOTE The exported program gurantees that parameters, buffers, and constant tensors are lifted out of the graph as inputs.
|
254
|
+
It means that the pass need to update input specs after folding the constant nodes.
|
255
|
+
# ref: https://pytorch.org/docs/stable/export.html#torch.export.ExportGraphSignature
|
256
|
+
|
257
|
+
[WHAT IT DOES]
|
258
|
+
[1] Propagate the constants.
|
259
|
+
[2] Get propagated data from constant nodes.
|
260
|
+
[3] Create the constant placeholder nodes according to the propagated data.
|
261
|
+
[4] Create input specs according to the created placeholders.
|
262
|
+
[5] Update the input specs.
|
263
|
+
"""
|
264
|
+
|
265
|
+
def __init__(self) -> None:
|
266
|
+
super().__init__()
|
267
|
+
|
268
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
269
|
+
logger = logging.getLogger(__name__)
|
270
|
+
|
271
|
+
graph_module = exported_program.graph_module
|
272
|
+
graph: torch.fx.Graph = graph_module.graph
|
273
|
+
|
274
|
+
# [1], [2]
|
275
|
+
const_node_to_tensor: OrderedDict[
|
276
|
+
torch.fx.Node, torch.Tensor
|
277
|
+
] = propagate_constants(exported_program)
|
278
|
+
# [3]
|
279
|
+
placeholders = create_constant_placeholder(
|
280
|
+
const_node_to_tensor, exported_program
|
281
|
+
)
|
282
|
+
# [4]
|
283
|
+
new_name_to_spec = create_input_specs(placeholders)
|
284
|
+
|
285
|
+
# [5]
|
286
|
+
# Get existing input specs.
|
287
|
+
existing_name_to_spec = {
|
288
|
+
s.arg.name: s for s in exported_program.graph_signature.input_specs
|
289
|
+
}
|
290
|
+
# Add the new constants to existing input specs dict.
|
291
|
+
existing_name_to_spec.update(new_name_to_spec)
|
292
|
+
# Generate new input spec.
|
293
|
+
new_input_specs = []
|
294
|
+
for node in exported_program.graph.nodes:
|
295
|
+
if node.op != "placeholder":
|
296
|
+
continue
|
297
|
+
assert node.name in existing_name_to_spec, node.name
|
298
|
+
new_input_specs.append(existing_name_to_spec[node.name])
|
299
|
+
exported_program.graph_signature.input_specs = new_input_specs
|
300
|
+
|
301
|
+
graph.eliminate_dead_code()
|
302
|
+
graph_module.recompile()
|
303
|
+
|
304
|
+
logger.debug(f"Constant nodes are propagated")
|
305
|
+
# Constant folding can be done with only one time run. Let's set `modified` to False.
|
306
|
+
modified = False
|
307
|
+
return PassResult(modified)
|
@@ -0,0 +1,151 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.serialize.circle_graph import extract_shape
|
23
|
+
from tico.utils import logging
|
24
|
+
from tico.utils.errors import NotYetSupportedError
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.validate_args_kwargs import Conv1DArgs
|
28
|
+
|
29
|
+
|
30
|
+
@trace_graph_diff_on_pass
|
31
|
+
class ConvertConv1dToConv2d(PassBase):
|
32
|
+
"""
|
33
|
+
This pass converts `torch.ops.aten.conv1d` to `torch.ops.aten.conv2d`
|
34
|
+
because Circle does not support `conv1d`.
|
35
|
+
|
36
|
+
[before]
|
37
|
+
|
38
|
+
input weight
|
39
|
+
(tensor,dim=3) (tensor,dim=3)
|
40
|
+
| |
|
41
|
+
conv1d<----------------+
|
42
|
+
|
|
43
|
+
output
|
44
|
+
(tensor,dim=3)
|
45
|
+
|
46
|
+
[after]
|
47
|
+
|
48
|
+
input weight
|
49
|
+
(tensor,dim=3) (tensor,dim=3)
|
50
|
+
| |
|
51
|
+
unsqueeze unsqueeze
|
52
|
+
(dim=4) (dim=4)
|
53
|
+
| |
|
54
|
+
conv2d<--------------+
|
55
|
+
|
|
56
|
+
squeeze
|
57
|
+
(dim=3)
|
58
|
+
|
|
59
|
+
output
|
60
|
+
(tensor,dim=3)
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(self):
|
64
|
+
super().__init__()
|
65
|
+
|
66
|
+
def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> bool:
|
67
|
+
logger = logging.getLogger(__name__)
|
68
|
+
modified = False
|
69
|
+
|
70
|
+
graph_module = exported_program.graph_module
|
71
|
+
graph = graph_module.graph
|
72
|
+
|
73
|
+
# conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
|
74
|
+
# conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
|
75
|
+
args = Conv1DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
76
|
+
input = args.input
|
77
|
+
weight = args.weight
|
78
|
+
bias = args.bias
|
79
|
+
stride = args.stride
|
80
|
+
padding = args.padding
|
81
|
+
dilation = args.dilation
|
82
|
+
groups = args.groups
|
83
|
+
|
84
|
+
input_shape = extract_shape(input)
|
85
|
+
if not (len(input_shape) == 3):
|
86
|
+
raise NotYetSupportedError(
|
87
|
+
f"Only support 3D input tensor: node's input shape: {input_shape}"
|
88
|
+
)
|
89
|
+
|
90
|
+
with graph.inserting_after(input):
|
91
|
+
input_unsqueeze = graph_module.graph.call_function(
|
92
|
+
torch.ops.aten.unsqueeze.default,
|
93
|
+
args=(input, 3),
|
94
|
+
)
|
95
|
+
|
96
|
+
with graph.inserting_after(weight):
|
97
|
+
weight_unsqueeze = graph_module.graph.call_function(
|
98
|
+
torch.ops.aten.unsqueeze.default,
|
99
|
+
args=(weight, 3),
|
100
|
+
)
|
101
|
+
|
102
|
+
with graph.inserting_before(node):
|
103
|
+
if isinstance(padding, list):
|
104
|
+
conv2d_op = torch.ops.aten.conv2d.default
|
105
|
+
elif isinstance(padding, str):
|
106
|
+
conv2d_op = torch.ops.aten.conv2d.padding
|
107
|
+
|
108
|
+
conv2d = graph_module.graph.call_function(
|
109
|
+
conv2d_op,
|
110
|
+
args=(
|
111
|
+
input_unsqueeze,
|
112
|
+
weight_unsqueeze,
|
113
|
+
bias,
|
114
|
+
[*stride, 1],
|
115
|
+
[*padding, 0] if isinstance(padding, list) else padding,
|
116
|
+
[*dilation, 1],
|
117
|
+
groups,
|
118
|
+
),
|
119
|
+
kwargs=node.kwargs,
|
120
|
+
)
|
121
|
+
|
122
|
+
conv_out_squeeze = graph_module.graph.call_function(
|
123
|
+
torch.ops.aten.squeeze.dims,
|
124
|
+
args=(conv2d, [3]),
|
125
|
+
)
|
126
|
+
|
127
|
+
node.replace_all_uses_with(conv_out_squeeze, propagate_meta=True)
|
128
|
+
|
129
|
+
logger.debug(f"{node.name} is replaced with {conv2d.name}")
|
130
|
+
modified = True
|
131
|
+
return modified
|
132
|
+
|
133
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
134
|
+
target_conv_op = (torch.ops.aten.conv1d.default, torch.ops.aten.conv1d.padding)
|
135
|
+
|
136
|
+
graph_module = exported_program.graph_module
|
137
|
+
graph = graph_module.graph
|
138
|
+
modified = False
|
139
|
+
for node in graph.nodes:
|
140
|
+
if not node.op == "call_function":
|
141
|
+
continue
|
142
|
+
|
143
|
+
if node.target not in target_conv_op:
|
144
|
+
continue
|
145
|
+
modified |= self.convert(exported_program, node)
|
146
|
+
|
147
|
+
graph.eliminate_dead_code()
|
148
|
+
graph.lint()
|
149
|
+
graph_module.recompile()
|
150
|
+
|
151
|
+
return PassResult(modified)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.validate_args_kwargs import SqueezeArgs, UnSqueezeArgs, ViewArgs
|
28
|
+
|
29
|
+
|
30
|
+
@trace_graph_diff_on_pass
|
31
|
+
class ConvertLayoutOpToReshape(PassBase):
|
32
|
+
"""
|
33
|
+
This pass converts layout transformation Op to reshape if possible.
|
34
|
+
This is helpful for further optimization.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self):
|
38
|
+
super().__init__()
|
39
|
+
|
40
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
graph_module = exported_program.graph_module
|
44
|
+
graph = graph_module.graph
|
45
|
+
modified = False
|
46
|
+
|
47
|
+
def convert(node, input):
|
48
|
+
out_shape = list(extract_shape(node))
|
49
|
+
|
50
|
+
with graph.inserting_after(node):
|
51
|
+
reshape_node = graph.call_function(
|
52
|
+
torch.ops.aten.reshape.default,
|
53
|
+
args=(input, out_shape),
|
54
|
+
)
|
55
|
+
|
56
|
+
node.replace_all_uses_with(reshape_node, propagate_meta=True)
|
57
|
+
|
58
|
+
logger.debug(f"{node.name} is replaced with {reshape_node.name}")
|
59
|
+
|
60
|
+
for node in graph.nodes:
|
61
|
+
if not node.op == "call_function":
|
62
|
+
continue
|
63
|
+
|
64
|
+
if node.target in ops.aten.view:
|
65
|
+
view_args = ViewArgs(*node.args, **node.kwargs)
|
66
|
+
convert(node, view_args.input)
|
67
|
+
modified = True
|
68
|
+
continue
|
69
|
+
elif node.target in ops.aten.unsqueeze:
|
70
|
+
unsqueeze_args = UnSqueezeArgs(*node.args, **node.kwargs)
|
71
|
+
convert(node, unsqueeze_args.input)
|
72
|
+
modified = True
|
73
|
+
continue
|
74
|
+
elif node.target in ops.aten.squeeze:
|
75
|
+
squeeze_args = SqueezeArgs(*node.args, **node.kwargs)
|
76
|
+
convert(node, squeeze_args.input)
|
77
|
+
modified = True
|
78
|
+
continue
|
79
|
+
|
80
|
+
graph.eliminate_dead_code()
|
81
|
+
graph.lint()
|
82
|
+
graph_module.recompile()
|
83
|
+
|
84
|
+
return PassResult(modified)
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.utils import logging
|
23
|
+
from tico.utils.passes import PassBase, PassResult
|
24
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
+
|
26
|
+
|
27
|
+
@trace_graph_diff_on_pass
|
28
|
+
class ConvertRepeatToExpandCopy(PassBase):
|
29
|
+
"""
|
30
|
+
aten.repeat.default is converted to aten.expand_copy.default.
|
31
|
+
Why? There isn't CircleNode mapped to repeat.
|
32
|
+
so, We convert it using existing aten.expand_copy.default.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
|
38
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
graph_module = exported_program.graph_module
|
42
|
+
graph = graph_module.graph
|
43
|
+
modified = False
|
44
|
+
for node in graph.nodes:
|
45
|
+
if not node.op == "call_function":
|
46
|
+
continue
|
47
|
+
|
48
|
+
if node.target != torch.ops.aten.repeat.default:
|
49
|
+
continue
|
50
|
+
|
51
|
+
assert len(node.args) == 2
|
52
|
+
|
53
|
+
tensor, repeats = node.args
|
54
|
+
assert isinstance(tensor, torch.fx.Node)
|
55
|
+
assert isinstance(repeats, list)
|
56
|
+
|
57
|
+
tensor_shape: List[int] = [int(dim) for dim in tensor.meta["val"].shape]
|
58
|
+
|
59
|
+
# Check if it is possible to convert to aten.expand_copy.default
|
60
|
+
cannot_converted = False
|
61
|
+
extending_idx = len(repeats) - len(tensor_shape)
|
62
|
+
for idx, dim in enumerate(tensor_shape):
|
63
|
+
if not (dim == 1 or repeats[extending_idx + idx] == 1):
|
64
|
+
cannot_converted = True
|
65
|
+
if cannot_converted:
|
66
|
+
continue
|
67
|
+
|
68
|
+
size = []
|
69
|
+
for idx, repeats_dim in enumerate(repeats):
|
70
|
+
if idx < extending_idx:
|
71
|
+
size.append(repeats_dim)
|
72
|
+
else:
|
73
|
+
size.append(repeats_dim * tensor_shape[idx - extending_idx])
|
74
|
+
|
75
|
+
expand_copy_args = (tensor, size)
|
76
|
+
|
77
|
+
with graph.inserting_after(node):
|
78
|
+
expand_copy_node = graph.call_function(
|
79
|
+
torch.ops.aten.expand_copy.default, args=expand_copy_args
|
80
|
+
)
|
81
|
+
node.replace_all_uses_with(expand_copy_node, propagate_meta=True)
|
82
|
+
|
83
|
+
modified = True
|
84
|
+
logger.debug(f"{node.name} is replaced with expand_copy operator")
|
85
|
+
|
86
|
+
graph.eliminate_dead_code()
|
87
|
+
graph.lint()
|
88
|
+
graph_module.recompile()
|
89
|
+
|
90
|
+
return PassResult(modified)
|