tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
tico/utils/graph.py
ADDED
@@ -0,0 +1,200 @@
|
|
1
|
+
# Portions of this file are adapted from code originally authored by
|
2
|
+
# Meta Platforms, Inc. and affiliates, licensed under the BSD-style
|
3
|
+
# license found in the LICENSE file in the root directory of their source tree.
|
4
|
+
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
|
19
|
+
from typing import Optional, TYPE_CHECKING
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import torch.fx
|
23
|
+
import torch
|
24
|
+
from torch.export import ExportedProgram
|
25
|
+
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
26
|
+
|
27
|
+
from tico.utils.utils import get_fake_mode
|
28
|
+
|
29
|
+
|
30
|
+
def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
|
31
|
+
assert node.op == "placeholder"
|
32
|
+
|
33
|
+
return node.name in ep.graph_signature.inputs_to_parameters
|
34
|
+
|
35
|
+
|
36
|
+
def is_torch_buffer(node: torch.fx.Node, ep: ExportedProgram):
|
37
|
+
assert node.op == "placeholder"
|
38
|
+
|
39
|
+
return node.name in ep.graph_signature.inputs_to_buffers
|
40
|
+
|
41
|
+
|
42
|
+
def get_torch_param_value(node: torch.fx.Node, ep: ExportedProgram):
|
43
|
+
assert isinstance(node, torch.fx.Node)
|
44
|
+
assert node.op == "placeholder"
|
45
|
+
assert (
|
46
|
+
node.name in ep.graph_signature.inputs_to_parameters
|
47
|
+
), "Node {node.name} is not in the parameters" # FIX CALLER UNLESS
|
48
|
+
|
49
|
+
param_name = ep.graph_signature.inputs_to_parameters[node.name]
|
50
|
+
named_params = dict(ep.named_parameters())
|
51
|
+
assert param_name in named_params
|
52
|
+
|
53
|
+
return named_params[param_name].data
|
54
|
+
|
55
|
+
|
56
|
+
def get_torch_buffer_value(node: torch.fx.Node, ep: ExportedProgram):
|
57
|
+
assert isinstance(node, torch.fx.Node)
|
58
|
+
assert node.op == "placeholder"
|
59
|
+
assert (
|
60
|
+
node.name in ep.graph_signature.inputs_to_buffers
|
61
|
+
), "Node {node.name} is not in the buffers" # FIX CALLER UNLESS
|
62
|
+
|
63
|
+
buf_name = ep.graph_signature.inputs_to_buffers[node.name]
|
64
|
+
named_buf = dict(ep.named_buffers())
|
65
|
+
assert buf_name in named_buf
|
66
|
+
|
67
|
+
return named_buf[buf_name]
|
68
|
+
|
69
|
+
|
70
|
+
def get_first_user_input(exported_program: ExportedProgram) -> Optional[torch.fx.Node]:
|
71
|
+
"""Returns the first user input node in the graph."""
|
72
|
+
first_user_input: Optional[torch.fx.Node] = None
|
73
|
+
graph_module = exported_program.graph_module
|
74
|
+
graph: torch.fx.Graph = graph_module.graph
|
75
|
+
for node in graph.nodes:
|
76
|
+
if (
|
77
|
+
node.op == "placeholder"
|
78
|
+
and node.name in exported_program.graph_signature.user_inputs
|
79
|
+
):
|
80
|
+
first_user_input = node
|
81
|
+
break
|
82
|
+
|
83
|
+
return first_user_input
|
84
|
+
|
85
|
+
|
86
|
+
def generate_fqn(prefix: str, exported_program: ExportedProgram):
|
87
|
+
"""
|
88
|
+
Generate fully-qualized name for constants.
|
89
|
+
|
90
|
+
This function prevents `exported_program.constants` from having duplicate keys.
|
91
|
+
"""
|
92
|
+
cnt = len(exported_program.constants)
|
93
|
+
while True:
|
94
|
+
if f"{prefix}{cnt}" in exported_program.constants:
|
95
|
+
cnt += 1
|
96
|
+
continue
|
97
|
+
break
|
98
|
+
return f"{prefix}{cnt}"
|
99
|
+
|
100
|
+
|
101
|
+
def create_input_spec(node, input_kind: InputKind):
|
102
|
+
"""
|
103
|
+
@ref https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
|
104
|
+
"""
|
105
|
+
if input_kind == InputKind.CONSTANT_TENSOR:
|
106
|
+
return InputSpec(
|
107
|
+
kind=InputKind.CONSTANT_TENSOR,
|
108
|
+
arg=TensorArgument(name=node.name),
|
109
|
+
target=node.target, # type: ignore[arg-type]
|
110
|
+
persistent=True,
|
111
|
+
)
|
112
|
+
else:
|
113
|
+
raise NotImplementedError("NYI")
|
114
|
+
|
115
|
+
|
116
|
+
def validate_input_specs(exported_program):
|
117
|
+
name_to_spec_dict = {
|
118
|
+
s.arg.name: s for s in exported_program.graph_signature.input_specs
|
119
|
+
}
|
120
|
+
|
121
|
+
for node in exported_program.graph.nodes:
|
122
|
+
if node.op != "placeholder":
|
123
|
+
continue
|
124
|
+
|
125
|
+
if node.name not in name_to_spec_dict:
|
126
|
+
raise RuntimeError(
|
127
|
+
"Placeholder node {node.name} does not have corresponding input spec!"
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
def add_placeholder(
|
132
|
+
exported_program: ExportedProgram,
|
133
|
+
tensor: torch.Tensor,
|
134
|
+
prefix: str,
|
135
|
+
) -> torch.fx.Node:
|
136
|
+
"""
|
137
|
+
Add a placeholder to the graph and update the exported program.
|
138
|
+
"""
|
139
|
+
fqn_name = generate_fqn(prefix, exported_program)
|
140
|
+
|
141
|
+
# Get fake mode before adding placeholder
|
142
|
+
fake_mode = get_fake_mode(exported_program)
|
143
|
+
|
144
|
+
first_user_input = get_first_user_input(exported_program)
|
145
|
+
if not first_user_input:
|
146
|
+
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
147
|
+
# Therefore, insert the newly created placeholders at the start of the node list.
|
148
|
+
assert exported_program.graph.nodes
|
149
|
+
first_node = list(exported_program.graph.nodes)[0]
|
150
|
+
first_user_input = first_node
|
151
|
+
|
152
|
+
# Add a placeholder to the graph.
|
153
|
+
with exported_program.graph.inserting_before(first_user_input):
|
154
|
+
const_node = exported_program.graph.placeholder(fqn_name)
|
155
|
+
|
156
|
+
const_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True)
|
157
|
+
const_node.meta["val"].constant = tensor
|
158
|
+
|
159
|
+
# Add a new constant to the exported program.
|
160
|
+
exported_program.constants[const_node.name] = tensor
|
161
|
+
|
162
|
+
# Use update (instead of append) if this assert is violated
|
163
|
+
assert const_node.name not in [
|
164
|
+
s.arg.name for s in exported_program.graph_signature.input_specs
|
165
|
+
]
|
166
|
+
|
167
|
+
# Append the new input spec.
|
168
|
+
exported_program.graph_signature.input_specs.append(
|
169
|
+
create_input_spec(const_node, InputKind.CONSTANT_TENSOR)
|
170
|
+
)
|
171
|
+
|
172
|
+
# Get old input specs
|
173
|
+
name_to_spec_dict = {
|
174
|
+
s.arg.name: s for s in exported_program.graph_signature.input_specs
|
175
|
+
}
|
176
|
+
|
177
|
+
# Add the new constants to input specs dict.
|
178
|
+
name_to_spec_dict.update(
|
179
|
+
{const_node.name: create_input_spec(const_node, InputKind.CONSTANT_TENSOR)}
|
180
|
+
)
|
181
|
+
|
182
|
+
# Generate new input spec *in the same order of nodes*
|
183
|
+
# IMPORTANT Input specs and their placeholder nodes must have the same order.
|
184
|
+
new_input_specs = []
|
185
|
+
for node in exported_program.graph.nodes:
|
186
|
+
if node.op != "placeholder":
|
187
|
+
continue
|
188
|
+
new_input_specs.append(name_to_spec_dict[node.name])
|
189
|
+
exported_program.graph_signature.input_specs = new_input_specs
|
190
|
+
|
191
|
+
return const_node
|
192
|
+
|
193
|
+
|
194
|
+
def is_single_value_tensor(t: torch.Tensor):
|
195
|
+
if len(t.size()) == 0:
|
196
|
+
return True
|
197
|
+
if len(t.size()) == 1 and t.size()[0] == 1:
|
198
|
+
return True
|
199
|
+
|
200
|
+
return False
|
tico/utils/logging.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
import os
|
17
|
+
|
18
|
+
|
19
|
+
def _loggerLevel():
|
20
|
+
TICO_LOG = os.environ.get("TICO_LOG")
|
21
|
+
if TICO_LOG == "1":
|
22
|
+
log_level = logging.FATAL
|
23
|
+
elif TICO_LOG == "2":
|
24
|
+
log_level = logging.WARNING
|
25
|
+
elif TICO_LOG == "3":
|
26
|
+
log_level = logging.INFO
|
27
|
+
elif TICO_LOG == "4":
|
28
|
+
log_level = logging.DEBUG
|
29
|
+
else:
|
30
|
+
log_level = logging.WARNING
|
31
|
+
return log_level
|
32
|
+
|
33
|
+
|
34
|
+
LOG_LEVEL = _loggerLevel()
|
35
|
+
|
36
|
+
|
37
|
+
def getLogger(name: str):
|
38
|
+
"""
|
39
|
+
Get logger with setting log level according to the `TICO_LOG` environment variable.
|
40
|
+
"""
|
41
|
+
logging.basicConfig()
|
42
|
+
logger = logging.getLogger(name)
|
43
|
+
logger.setLevel(LOG_LEVEL)
|
44
|
+
|
45
|
+
return logger
|
tico/utils/model.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
from tico.interpreter import infer
|
20
|
+
|
21
|
+
|
22
|
+
class CircleModel:
|
23
|
+
def __init__(self, circle_binary: bytes):
|
24
|
+
self.circle_binary = circle_binary
|
25
|
+
|
26
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
27
|
+
return infer.infer(self.circle_binary, *args, **kwargs)
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def load(circle_path: str) -> CircleModel:
|
31
|
+
with open(circle_path, "rb") as f:
|
32
|
+
buf = bytes(f.read())
|
33
|
+
return CircleModel(buf)
|
34
|
+
|
35
|
+
def save(self, circle_path: str) -> None:
|
36
|
+
with open(circle_path, "wb") as f:
|
37
|
+
f.write(self.circle_binary)
|
tico/utils/padding.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from tico.utils.errors import InvalidArgumentError
|
18
|
+
|
19
|
+
SAME = 0
|
20
|
+
VALID = 1
|
21
|
+
|
22
|
+
|
23
|
+
def is_valid_padding(padding: str | list):
|
24
|
+
if isinstance(padding, str):
|
25
|
+
return padding == "valid"
|
26
|
+
|
27
|
+
if isinstance(padding, list):
|
28
|
+
assert len(padding) == 2, "Padding should be a list of length 2."
|
29
|
+
return padding == [0, 0]
|
30
|
+
|
31
|
+
raise InvalidArgumentError("Invalid padding.")
|
32
|
+
|
33
|
+
|
34
|
+
def is_same_padding(
|
35
|
+
padding: str | list, input_shape: list | torch.Size, output_shape: list | torch.Size
|
36
|
+
):
|
37
|
+
if isinstance(padding, str):
|
38
|
+
return padding == "same"
|
39
|
+
|
40
|
+
if isinstance(padding, list):
|
41
|
+
assert len(padding) == 2, "Padding should be a list of length 2."
|
42
|
+
|
43
|
+
input_HW = input_shape[1:2] # N H W C
|
44
|
+
output_HW = output_shape[1:2] # N H W C
|
45
|
+
return input_HW == output_HW
|
46
|
+
|
47
|
+
raise InvalidArgumentError("Invalid padding.")
|
tico/utils/passes.py
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from enum import Enum
|
18
|
+
from typing import List
|
19
|
+
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class PassResult:
|
25
|
+
modified: bool
|
26
|
+
|
27
|
+
|
28
|
+
class PassBase(ABC):
|
29
|
+
"""
|
30
|
+
Base interface for passes.
|
31
|
+
"""
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class PassStrategy(Enum):
|
39
|
+
# Run passes until there are no changes.
|
40
|
+
UNTIL_NO_CHANGE = (1,)
|
41
|
+
# Same as `UNTIL_NO_CHANGE` but it starts agian from the beginning.
|
42
|
+
RESTART = (2,)
|
43
|
+
|
44
|
+
|
45
|
+
class PassManager:
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
passes: List[PassBase],
|
49
|
+
strategy: PassStrategy = PassStrategy.RESTART,
|
50
|
+
):
|
51
|
+
self.passes: List[PassBase] = passes
|
52
|
+
self.strategy: PassStrategy = strategy
|
53
|
+
|
54
|
+
def run(self, exported_program: ExportedProgram):
|
55
|
+
MAXIMUM_STEP_COUNT = 1000
|
56
|
+
step = 0
|
57
|
+
while True:
|
58
|
+
modified = False
|
59
|
+
for _pass in self.passes:
|
60
|
+
# Automatically update the signatures of the input and output.
|
61
|
+
# https://github.com/pytorch/executorch/issues/4013#issuecomment-2187161844
|
62
|
+
with exported_program.graph_module._set_replace_hook(
|
63
|
+
exported_program.graph_signature.get_replace_hook()
|
64
|
+
):
|
65
|
+
result = _pass.call(exported_program)
|
66
|
+
modified = modified or result.modified
|
67
|
+
if modified and self.strategy == PassStrategy.RESTART:
|
68
|
+
break
|
69
|
+
|
70
|
+
if not modified:
|
71
|
+
break
|
72
|
+
step += 1
|
73
|
+
|
74
|
+
assert (
|
75
|
+
step < MAXIMUM_STEP_COUNT
|
76
|
+
), f"Loop iterated for {MAXIMUM_STEP_COUNT} times. Circular loop is suspected in {self.passes}"
|