tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,240 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import operator
|
16
|
+
from typing import Dict
|
17
|
+
|
18
|
+
import flatbuffers
|
19
|
+
import torch
|
20
|
+
from circle_schema import circle
|
21
|
+
from torch.export.exported_program import (
|
22
|
+
ConstantArgument,
|
23
|
+
ExportedProgram,
|
24
|
+
InputKind,
|
25
|
+
TensorArgument,
|
26
|
+
)
|
27
|
+
|
28
|
+
from tico.serialize.circle_mapping import to_circle_dtype
|
29
|
+
from tico.serialize.operators import *
|
30
|
+
from tico.serialize.circle_graph import CircleModel, CircleSubgraph
|
31
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
32
|
+
from tico.serialize.operators.node_visitor import get_node_visitors
|
33
|
+
from tico.utils import logging
|
34
|
+
from tico.utils.serialize import finalise_tensor_names
|
35
|
+
|
36
|
+
|
37
|
+
multiple_output_ops = [
|
38
|
+
torch.ops.aten.split_with_sizes.default,
|
39
|
+
torch.ops.aten.max.dim,
|
40
|
+
]
|
41
|
+
|
42
|
+
# Build circle model from ExportedProgram
|
43
|
+
# Return raw bytes of circle model
|
44
|
+
def build_circle(edge_program: ExportedProgram) -> bytes:
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
builder = flatbuffers.Builder()
|
48
|
+
|
49
|
+
# Init Model
|
50
|
+
model = CircleModel()
|
51
|
+
|
52
|
+
# Add empty buffer at the front (convention)
|
53
|
+
model.add_buffer(circle.Buffer.BufferT())
|
54
|
+
|
55
|
+
# Create an empty subgraph (assume a single subgraph)
|
56
|
+
graph = CircleSubgraph(model)
|
57
|
+
|
58
|
+
# Export tensors
|
59
|
+
logger.debug("---------------Export tensors--------------")
|
60
|
+
buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
|
61
|
+
for node in edge_program.graph.nodes:
|
62
|
+
if node.op == "call_function":
|
63
|
+
if node.target in multiple_output_ops:
|
64
|
+
continue
|
65
|
+
node_val = node.meta["val"]
|
66
|
+
if node_val.layout != torch.strided:
|
67
|
+
raise RuntimeError(
|
68
|
+
f"Only support dense tensors (node layout: {node_val.layout})"
|
69
|
+
)
|
70
|
+
graph.add_tensor_from_node(node)
|
71
|
+
logger.debug(f"call_function: {node.name} tensor exported.")
|
72
|
+
|
73
|
+
# placeholder: function input (including parameters, buffers, constant tensors)
|
74
|
+
elif node.op == "placeholder":
|
75
|
+
# placeholder invariants
|
76
|
+
assert node.args is None or len(node.args) == 0 # Not support default param
|
77
|
+
|
78
|
+
# parameters
|
79
|
+
if node.name in edge_program.graph_signature.inputs_to_parameters:
|
80
|
+
param_name = edge_program.graph_signature.inputs_to_parameters[
|
81
|
+
node.name
|
82
|
+
]
|
83
|
+
param_data = edge_program.state_dict[param_name]
|
84
|
+
|
85
|
+
assert isinstance(
|
86
|
+
param_data, torch.Tensor
|
87
|
+
), "Expect parameters to be a tensor"
|
88
|
+
param_value = param_data.cpu().detach().numpy()
|
89
|
+
|
90
|
+
graph.add_tensor_from_node(node, param_value)
|
91
|
+
logger.debug(f"placeholder(param): {node.name} tensor exported.")
|
92
|
+
elif node.name in edge_program.graph_signature.inputs_to_buffers:
|
93
|
+
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
|
94
|
+
assert buffer_name in buf_name_to_data
|
95
|
+
buffer_data = buf_name_to_data[buffer_name]
|
96
|
+
assert isinstance(
|
97
|
+
buffer_data, torch.Tensor
|
98
|
+
), "Expect buffers to be a tensor"
|
99
|
+
buffer_value = buffer_data.cpu().detach().numpy()
|
100
|
+
|
101
|
+
graph.add_tensor_from_node(node, buffer_value)
|
102
|
+
logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
|
103
|
+
elif (
|
104
|
+
node.name
|
105
|
+
in edge_program.graph_signature.inputs_to_lifted_tensor_constants
|
106
|
+
):
|
107
|
+
ctensor_name = (
|
108
|
+
edge_program.graph_signature.inputs_to_lifted_tensor_constants[
|
109
|
+
node.name
|
110
|
+
]
|
111
|
+
)
|
112
|
+
ctensor_data = edge_program.constants[ctensor_name]
|
113
|
+
|
114
|
+
assert isinstance(
|
115
|
+
ctensor_data, torch.Tensor
|
116
|
+
), "Expect constant tensor to be a tensor"
|
117
|
+
ctensor_value = ctensor_data.cpu().detach().numpy()
|
118
|
+
|
119
|
+
graph.add_tensor_from_node(node, ctensor_value)
|
120
|
+
logger.debug(
|
121
|
+
f"placeholder(constant tensor): {node.name} tensor exported."
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
user_inputs = [
|
125
|
+
specs
|
126
|
+
for specs in edge_program.graph_signature.input_specs
|
127
|
+
if specs.kind == InputKind.USER_INPUT
|
128
|
+
]
|
129
|
+
constant_inputs = [
|
130
|
+
specs
|
131
|
+
for specs in user_inputs
|
132
|
+
if isinstance(specs.arg, ConstantArgument)
|
133
|
+
]
|
134
|
+
name_to_value = {
|
135
|
+
specs.arg.name: specs.arg.value for specs in constant_inputs
|
136
|
+
}
|
137
|
+
# NoneType ConstantArgument is ignored.
|
138
|
+
if node.name in name_to_value and name_to_value[node.name] == None:
|
139
|
+
continue
|
140
|
+
graph.add_tensor_from_node(node)
|
141
|
+
logger.debug(f"placeholder: {node.name} tensor exported.")
|
142
|
+
|
143
|
+
# get_attr: retrieve parameter
|
144
|
+
elif node.op == "get_attr":
|
145
|
+
# node.name: Place where fetched attribute is saved
|
146
|
+
# node.target: Attribute in the module
|
147
|
+
attr_tensor = getattr(node.graph.owning_module, node.target)
|
148
|
+
assert isinstance(attr_tensor, torch.Tensor)
|
149
|
+
|
150
|
+
graph.add_tensor_from_scratch(
|
151
|
+
prefix=node.name,
|
152
|
+
shape=list(attr_tensor.shape),
|
153
|
+
dtype=to_circle_dtype(attr_tensor.dtype),
|
154
|
+
source_node=node,
|
155
|
+
)
|
156
|
+
|
157
|
+
logger.debug(f"get_attr: {node.name} tensor exported.")
|
158
|
+
|
159
|
+
# output: function output
|
160
|
+
elif node.op == "output":
|
161
|
+
# output node itself does not need a buffer
|
162
|
+
# argument of output node is assumed to be exported beforehand
|
163
|
+
for output in node.args[0]:
|
164
|
+
if isinstance(output, torch.fx.Node):
|
165
|
+
assert graph.has_tensor(output.name)
|
166
|
+
continue
|
167
|
+
|
168
|
+
# call_method: call method
|
169
|
+
elif node.op == "call_method":
|
170
|
+
raise AssertionError("Not yet implemented")
|
171
|
+
|
172
|
+
# call_module: call 'forward' of module
|
173
|
+
elif node.op == "call_module":
|
174
|
+
raise AssertionError("Not yet implemented")
|
175
|
+
|
176
|
+
else:
|
177
|
+
# Add more if fx.Node is extended
|
178
|
+
raise AssertionError(f"Unknown fx.Node op {node.op}")
|
179
|
+
|
180
|
+
# Register inputs
|
181
|
+
logger.debug("---------------Register inputs--------------")
|
182
|
+
for in_spec in edge_program.graph_signature.input_specs:
|
183
|
+
if in_spec.kind != InputKind.USER_INPUT:
|
184
|
+
continue
|
185
|
+
# NoneType ConstantArgument is ignored.
|
186
|
+
if isinstance(in_spec.arg, ConstantArgument) and in_spec.arg.value == None:
|
187
|
+
continue
|
188
|
+
arg_name = in_spec.arg.name
|
189
|
+
graph.add_input(arg_name)
|
190
|
+
logger.debug(f"Registered input: {arg_name}")
|
191
|
+
|
192
|
+
# Register outputs
|
193
|
+
logger.debug("---------------Register outputs--------------")
|
194
|
+
for user_output in edge_program.graph_signature.user_outputs:
|
195
|
+
if user_output == None:
|
196
|
+
logger.debug(f"Ignore 'None' output")
|
197
|
+
continue
|
198
|
+
|
199
|
+
graph.add_output(user_output)
|
200
|
+
logger.debug(f"Registered output: {user_output}")
|
201
|
+
|
202
|
+
# Export operators
|
203
|
+
logger.debug("---------------Export operators--------------")
|
204
|
+
op_codes: Dict[OpCode, int] = {}
|
205
|
+
visitors = get_node_visitors(op_codes, graph)
|
206
|
+
for node in edge_program.graph.nodes:
|
207
|
+
if node.op != "call_function":
|
208
|
+
continue
|
209
|
+
|
210
|
+
opcode = node.target
|
211
|
+
if opcode == operator.getitem:
|
212
|
+
continue
|
213
|
+
if opcode not in visitors:
|
214
|
+
raise RuntimeError(f"{opcode} is not yet supported")
|
215
|
+
circle_op = visitors[opcode].define_node(node)
|
216
|
+
|
217
|
+
if circle_op:
|
218
|
+
graph.add_operator(circle_op)
|
219
|
+
logger.debug(f"call_function: {node.name} ({opcode}) Op exported.")
|
220
|
+
|
221
|
+
# Register subgraph
|
222
|
+
finalise_tensor_names(graph)
|
223
|
+
model.subgraphs.append(graph)
|
224
|
+
|
225
|
+
# Encode operator codes
|
226
|
+
model.operatorCodes = [
|
227
|
+
code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
|
228
|
+
]
|
229
|
+
|
230
|
+
# Description
|
231
|
+
model.description = "circle"
|
232
|
+
|
233
|
+
# Set version
|
234
|
+
model.version = 0
|
235
|
+
|
236
|
+
# Finish model
|
237
|
+
builder.Finish(model.Pack(builder), "CIR0".encode("utf8"))
|
238
|
+
buf = builder.Output()
|
239
|
+
|
240
|
+
return bytes(buf)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import glob
|
16
|
+
from os.path import basename, dirname, isfile, join
|
17
|
+
|
18
|
+
from tico.utils.register_custom_op import RegisterOps
|
19
|
+
|
20
|
+
|
21
|
+
# Register custom ops to torch namespace
|
22
|
+
RegisterOps()
|
23
|
+
|
24
|
+
# Load all modules in the current directory
|
25
|
+
modules = glob.glob(join(dirname(__file__), "*.py"))
|
26
|
+
__all__ = [
|
27
|
+
basename(f)[:-3] for f in modules if isfile(f) and not f.endswith("__init__.py")
|
28
|
+
]
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from circle_schema import circle
|
16
|
+
|
17
|
+
|
18
|
+
class OpCode(circle.OperatorCode.OperatorCodeT):
|
19
|
+
"""
|
20
|
+
Wrapper class for operator code in circle schema
|
21
|
+
This implements __eq__ and __hash__ for use with dict()
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
super().__init__()
|
26
|
+
|
27
|
+
def __eq__(self, other):
|
28
|
+
if self.version != other.version:
|
29
|
+
return False
|
30
|
+
|
31
|
+
if self.builtinCode == circle.BuiltinOperator.BuiltinOperator.CUSTOM:
|
32
|
+
return self.customCode == other.customCode
|
33
|
+
|
34
|
+
return self.builtinCode == other.builtinCode
|
35
|
+
|
36
|
+
def __hash__(self):
|
37
|
+
val = (
|
38
|
+
self.deprecatedBuiltinCode,
|
39
|
+
self.customCode,
|
40
|
+
self.version,
|
41
|
+
self.builtinCode,
|
42
|
+
)
|
43
|
+
return hash(val)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, Type, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from circle_schema import circle
|
21
|
+
|
22
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
|
25
|
+
|
26
|
+
class NodeVisitor:
|
27
|
+
"""
|
28
|
+
Node visitor for lowering edge IR to circle
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
32
|
+
# For setting opcode index in circle model
|
33
|
+
# This is updated during serialization
|
34
|
+
self._op_codes = op_codes
|
35
|
+
self.graph = graph
|
36
|
+
|
37
|
+
# Define circle model operator
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.node.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
raise NotImplementedError("NodeVisitor must be extended.")
|
43
|
+
|
44
|
+
|
45
|
+
# container for all node visitors
|
46
|
+
_node_visitor_dict: Dict[torch._ops.OpOverload, Type[NodeVisitor]] = {}
|
47
|
+
|
48
|
+
|
49
|
+
# Decorator for each visitor
|
50
|
+
def register_node_visitor(visitor):
|
51
|
+
for target in visitor.target:
|
52
|
+
_node_visitor_dict[target] = visitor
|
53
|
+
return visitor
|
54
|
+
|
55
|
+
|
56
|
+
def get_node_visitor(target: torch._ops.OpOverload) -> Type[NodeVisitor]:
|
57
|
+
"""
|
58
|
+
Get a single node visitor (for unittest purpose)
|
59
|
+
"""
|
60
|
+
_visitor = _node_visitor_dict.get(target, None)
|
61
|
+
|
62
|
+
if not _visitor:
|
63
|
+
raise LookupError(f"NodeVisitor for {target} is not registered")
|
64
|
+
|
65
|
+
return _visitor
|
66
|
+
|
67
|
+
|
68
|
+
# Get all node visitors
|
69
|
+
def get_node_visitors(
|
70
|
+
op_codes: Dict[OpCode, int], graph: CircleSubgraph
|
71
|
+
) -> Dict[torch._ops.OpOverload, NodeVisitor]:
|
72
|
+
node_visitors = {}
|
73
|
+
for target, visitor in _node_visitor_dict.items():
|
74
|
+
node_visitors[target] = visitor(op_codes, graph)
|
75
|
+
|
76
|
+
return node_visitors
|
77
|
+
|
78
|
+
|
79
|
+
def get_support_targets():
|
80
|
+
return _node_visitor_dict.keys()
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import AbsArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class AbsVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.abs.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.ABS, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = AbsArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import AddTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class AddVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.add.Tensor,
|
34
|
+
torch.ops.aten.add.Scalar,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = AddTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
other = args.other
|
47
|
+
|
48
|
+
inputs = [input, other]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
op_index = get_op_index(
|
52
|
+
circle.BuiltinOperator.BuiltinOperator.ADD, self._op_codes
|
53
|
+
)
|
54
|
+
|
55
|
+
inputs = [input, other]
|
56
|
+
outputs = [node]
|
57
|
+
|
58
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
59
|
+
|
60
|
+
# Op-specific option
|
61
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.AddOptions
|
62
|
+
option = circle.AddOptions.AddOptionsT()
|
63
|
+
option.fusedActivationFunction = (
|
64
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
65
|
+
)
|
66
|
+
option.potScaleInt16 = False
|
67
|
+
operator.builtinOptions = option
|
68
|
+
|
69
|
+
return operator
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.validate_args_kwargs import AliasCopyArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class AliasCopyVisitor(NodeVisitor):
|
31
|
+
target: List[torch._ops.OpOverload] = [
|
32
|
+
torch.ops.aten.alias.default,
|
33
|
+
torch.ops.aten.alias_copy.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = AliasCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
|
46
|
+
op_index = get_op_index(
|
47
|
+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
48
|
+
)
|
49
|
+
|
50
|
+
permute = torch.IntTensor(list(range(len(input.meta["val"].shape))))
|
51
|
+
|
52
|
+
inputs = [input, permute]
|
53
|
+
outputs = [node]
|
54
|
+
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
|
57
|
+
# Op-specific option
|
58
|
+
operator.builtinOptionsType = (
|
59
|
+
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
60
|
+
)
|
61
|
+
option = circle.TransposeOptions.TransposeOptionsT()
|
62
|
+
operator.builtinOptions = option
|
63
|
+
|
64
|
+
return operator
|