tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.validate_args_kwargs import CatArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class CatVisitor(NodeVisitor):
|
31
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.cat.default]
|
32
|
+
|
33
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
34
|
+
super().__init__(op_codes, graph)
|
35
|
+
|
36
|
+
def define_node(
|
37
|
+
self,
|
38
|
+
node: torch.fx.Node,
|
39
|
+
) -> circle.Operator.OperatorT:
|
40
|
+
args = CatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
41
|
+
tensors = args.tensors
|
42
|
+
dim = args.dim
|
43
|
+
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes
|
46
|
+
)
|
47
|
+
|
48
|
+
inputs = tensors
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
# Op-specific option
|
54
|
+
operator.builtinOptionsType = (
|
55
|
+
circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions
|
56
|
+
)
|
57
|
+
option = circle.ConcatenationOptions.ConcatenationOptionsT()
|
58
|
+
|
59
|
+
option.axis = dim
|
60
|
+
|
61
|
+
option.fusedActivationFunction = (
|
62
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
63
|
+
)
|
64
|
+
operator.builtinOptions = option
|
65
|
+
|
66
|
+
return operator
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.passes import ops
|
24
|
+
|
25
|
+
from tico.serialize.circle_graph import (
|
26
|
+
CircleSubgraph,
|
27
|
+
extract_circle_dtype,
|
28
|
+
extract_shape,
|
29
|
+
)
|
30
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
31
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
32
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
33
|
+
from tico.utils.validate_args_kwargs import ClampArgs
|
34
|
+
|
35
|
+
|
36
|
+
@register_node_visitor
|
37
|
+
class ClampVisitor(NodeVisitor):
|
38
|
+
target: List[torch._ops.OpOverload] = ops.aten.clamp
|
39
|
+
|
40
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
41
|
+
super().__init__(op_codes, graph)
|
42
|
+
|
43
|
+
def define_minimum_node(
|
44
|
+
self,
|
45
|
+
inputs: List[torch.fx.Node | circle.Tensor.TensorT | int | float],
|
46
|
+
outputs: List[torch.fx.Node | circle.Tensor.TensorT],
|
47
|
+
) -> circle.Operator.OperatorT:
|
48
|
+
|
49
|
+
op_index = get_op_index(
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.MINIMUM, self._op_codes
|
51
|
+
)
|
52
|
+
|
53
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
54
|
+
|
55
|
+
# Op-specific option
|
56
|
+
operator.builtinOptionsType = (
|
57
|
+
circle.BuiltinOptions.BuiltinOptions.MaximumMinimumOptions
|
58
|
+
)
|
59
|
+
option = circle.MaximumMinimumOptions.MaximumMinimumOptionsT()
|
60
|
+
|
61
|
+
operator.builtinOptions = option
|
62
|
+
return operator
|
63
|
+
|
64
|
+
def define_maximum_node(
|
65
|
+
self,
|
66
|
+
inputs: List[torch.fx.Node | circle.Tensor.TensorT | int | float],
|
67
|
+
outputs: List[torch.fx.Node | circle.Tensor.TensorT],
|
68
|
+
) -> circle.Operator.OperatorT:
|
69
|
+
|
70
|
+
op_index = get_op_index(
|
71
|
+
circle.BuiltinOperator.BuiltinOperator.MAXIMUM, self._op_codes
|
72
|
+
)
|
73
|
+
|
74
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
75
|
+
|
76
|
+
# Op-specific option
|
77
|
+
operator.builtinOptionsType = (
|
78
|
+
circle.BuiltinOptions.BuiltinOptions.MaximumMinimumOptions
|
79
|
+
)
|
80
|
+
option = circle.MaximumMinimumOptions.MaximumMinimumOptionsT()
|
81
|
+
|
82
|
+
operator.builtinOptions = option
|
83
|
+
|
84
|
+
return operator
|
85
|
+
|
86
|
+
def define_node(
|
87
|
+
self,
|
88
|
+
node: torch.fx.Node,
|
89
|
+
) -> circle.Operator.OperatorT:
|
90
|
+
args = ClampArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
91
|
+
input = args.input
|
92
|
+
min_val = args.min
|
93
|
+
max_val = args.max
|
94
|
+
|
95
|
+
if min_val is None and max_val is None:
|
96
|
+
raise ValueError("Both min and max cannot be None")
|
97
|
+
|
98
|
+
elif min_val is not None and max_val is None:
|
99
|
+
# min only
|
100
|
+
return self.define_maximum_node([input, min_val], [node])
|
101
|
+
|
102
|
+
elif min_val is None and max_val is not None:
|
103
|
+
# max only
|
104
|
+
return self.define_minimum_node([input, max_val], [node])
|
105
|
+
|
106
|
+
elif min_val is not None and max_val is not None:
|
107
|
+
input_shape = extract_shape(input)
|
108
|
+
input_dtype = extract_circle_dtype(input)
|
109
|
+
minimum_tensor = self.graph.add_tensor_from_scratch(
|
110
|
+
prefix=f"{input.name}_min", dtype=input_dtype, shape=list(input_shape)
|
111
|
+
)
|
112
|
+
minimum_opertor = self.define_minimum_node(
|
113
|
+
[input, max_val], [minimum_tensor]
|
114
|
+
)
|
115
|
+
self.graph.add_operator(minimum_opertor)
|
116
|
+
|
117
|
+
maximum_operator = self.define_maximum_node(
|
118
|
+
[minimum_tensor, min_val], [node]
|
119
|
+
)
|
120
|
+
return maximum_operator
|
121
|
+
|
122
|
+
else:
|
123
|
+
raise RuntimeError("Cannot reach here")
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.validate_args_kwargs import CloneArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class CloneVisitor(NodeVisitor):
|
31
|
+
"""
|
32
|
+
Clone tensor
|
33
|
+
TODO: Support dim_order and memory_format
|
34
|
+
Tranpose may be required if 'memory_format' differs from input tensor's 'memory_format'
|
35
|
+
"""
|
36
|
+
|
37
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.clone.default]
|
38
|
+
|
39
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
40
|
+
super().__init__(op_codes, graph)
|
41
|
+
|
42
|
+
def define_node(
|
43
|
+
self,
|
44
|
+
node: torch.fx.Node,
|
45
|
+
) -> circle.Operator.OperatorT:
|
46
|
+
if "memory_format" in node.kwargs:
|
47
|
+
# TODO: Support dim_order and memory_format
|
48
|
+
pass
|
49
|
+
|
50
|
+
args = CloneArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
51
|
+
input = args.input
|
52
|
+
|
53
|
+
op_index = get_op_index(
|
54
|
+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
55
|
+
)
|
56
|
+
|
57
|
+
permute = torch.IntTensor(list(range(len(input.meta["val"].shape))))
|
58
|
+
|
59
|
+
inputs = [input, permute]
|
60
|
+
outputs = [node]
|
61
|
+
|
62
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
63
|
+
|
64
|
+
# Op-specific option
|
65
|
+
operator.builtinOptionsType = (
|
66
|
+
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
67
|
+
)
|
68
|
+
option = circle.TransposeOptions.TransposeOptionsT()
|
69
|
+
operator.builtinOptions = option
|
70
|
+
|
71
|
+
return operator
|
@@ -0,0 +1,72 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.errors import InvalidArgumentError
|
28
|
+
from tico.utils.validate_args_kwargs import ConstantPadNdArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class ConstantPadNdVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.constant_pad_nd.default]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
args = ConstantPadNdArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
43
|
+
input_ = args.input
|
44
|
+
pad = args.pad
|
45
|
+
val = args.value
|
46
|
+
|
47
|
+
if val != 0:
|
48
|
+
raise InvalidArgumentError("Only support 0 value padding.")
|
49
|
+
|
50
|
+
input_shape_len = len(extract_shape(input_))
|
51
|
+
padding_size = [[pad[2], pad[3]], [pad[0], pad[1]]]
|
52
|
+
if input_shape_len == 3:
|
53
|
+
padding_size = [[0, 0]] + padding_size
|
54
|
+
elif input_shape_len == 4:
|
55
|
+
padding_size = [[0, 0], [0, 0]] + padding_size
|
56
|
+
else:
|
57
|
+
raise InvalidArgumentError("Only support 3D/4D inputs.")
|
58
|
+
|
59
|
+
paddings = torch.tensor(padding_size, dtype=torch.int32)
|
60
|
+
inputs = [input_, paddings]
|
61
|
+
outputs = [node]
|
62
|
+
|
63
|
+
op_index = get_op_index(
|
64
|
+
circle.BuiltinOperator.BuiltinOperator.PAD, self._op_codes
|
65
|
+
)
|
66
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
67
|
+
|
68
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.PadOptions
|
69
|
+
option = circle.PadOptions.PadOptionsT()
|
70
|
+
operator.builtinOptions = option
|
71
|
+
|
72
|
+
return operator
|
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.define import define_pad_node
|
28
|
+
from tico.utils.padding import is_same_padding, is_valid_padding, SAME, VALID
|
29
|
+
from tico.utils.validate_args_kwargs import Conv2DArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class Conv2dVisitor(NodeVisitor):
|
34
|
+
"""
|
35
|
+
NOTE
|
36
|
+
- The padding of CircleConv2D has only padding type('VALID', 'SAME') in circle, but the padding of nn.Conv2d has padding type(('valid', 'same')), padding value(int)
|
37
|
+
and padding value(tuple->[pad_h, pad_w]).
|
38
|
+
ref: https://tensorflow.org/api_docs/python/tf/nn/conv2d
|
39
|
+
|
40
|
+
[1] With valid/same padding: CircleConv2D (only)
|
41
|
+
|
42
|
+
[ATEN IR]
|
43
|
+
Input[NHWC] ---- circle_cumstom.conv2d[NHWC] ---- OUTPUT[NHWC]
|
44
|
+
Weight[NHWC] ---/
|
45
|
+
Bias ----------/
|
46
|
+
|
47
|
+
[CIRCLE IR]
|
48
|
+
Input[NHWC] ---- CircleConv2D[NHWC] ---- OUTPUT[NHWC]
|
49
|
+
Weight[NHWC] ---/
|
50
|
+
Bias ----------/
|
51
|
+
|
52
|
+
[2] With additional padding: CirclePad + CircleConv2D
|
53
|
+
|
54
|
+
[ATEN IR]
|
55
|
+
Input[NHWC] ---- circle_cumstom.conv2d[NHWC] ---- OUTPUT[NHWC]
|
56
|
+
Weight[NHWC] ---/
|
57
|
+
Bias ----------/
|
58
|
+
|
59
|
+
[CIRCLE IR]
|
60
|
+
Input[NHWC] ---- CirclePad[NHWC] ---- CircleConv2D[NHWC] ---- OUTPUT[NHWC]
|
61
|
+
Weight[NHWC] ------/
|
62
|
+
Bias -------------/
|
63
|
+
"""
|
64
|
+
|
65
|
+
target: List[torch._ops.OpOverload] = [
|
66
|
+
torch.ops.circle_custom.conv2d,
|
67
|
+
torch.ops.circle_custom.conv2d.padding,
|
68
|
+
]
|
69
|
+
|
70
|
+
def define_conv2d_node(
|
71
|
+
self, padding: int, stride: List, dilation: List, inputs: List, outputs: List
|
72
|
+
) -> circle.Operator.OperatorT:
|
73
|
+
def set_conv2d_option(operator, stride, dilation):
|
74
|
+
operator.builtinOptionsType = (
|
75
|
+
circle.BuiltinOptions.BuiltinOptions.Conv2DOptions
|
76
|
+
)
|
77
|
+
option = circle.Conv2DOptions.Conv2DOptionsT()
|
78
|
+
option.padding = padding
|
79
|
+
option.strideH = stride[0]
|
80
|
+
option.strideW = stride[1]
|
81
|
+
option.dilationHFactor = dilation[0]
|
82
|
+
option.dilationWFactor = dilation[1]
|
83
|
+
option.fusedActivationFunction = (
|
84
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
85
|
+
)
|
86
|
+
operator.builtinOptions = option
|
87
|
+
|
88
|
+
conv2d_op_index = get_op_index(
|
89
|
+
circle.BuiltinOperator.BuiltinOperator.CONV_2D, self._op_codes
|
90
|
+
)
|
91
|
+
operator = create_builtin_operator(self.graph, conv2d_op_index, inputs, outputs)
|
92
|
+
set_conv2d_option(operator, stride, dilation)
|
93
|
+
return operator
|
94
|
+
|
95
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
96
|
+
super().__init__(op_codes, graph)
|
97
|
+
|
98
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
99
|
+
# conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
|
100
|
+
# conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
|
101
|
+
args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
102
|
+
|
103
|
+
input_ = args.input
|
104
|
+
weight = args.weight
|
105
|
+
bias = args.bias
|
106
|
+
stride = args.stride
|
107
|
+
padding = args.padding
|
108
|
+
dilation = args.dilation
|
109
|
+
groups = args.groups
|
110
|
+
|
111
|
+
assert groups == 1, "Only support group 1 conv2d"
|
112
|
+
|
113
|
+
input_dtype: int = extract_circle_dtype(input_)
|
114
|
+
input_shape = list(extract_shape(input_))
|
115
|
+
assert len(input_shape) == 4, len(input_shape)
|
116
|
+
output_shape = extract_shape(node)
|
117
|
+
assert len(output_shape) == 4, len(output_shape)
|
118
|
+
|
119
|
+
conv_input: torch.fx.node.Node | circle.Tensor.TensorT = input_
|
120
|
+
weight_shape = list(extract_shape(weight))
|
121
|
+
|
122
|
+
if is_valid_padding(padding):
|
123
|
+
conv2d_padding_type = VALID
|
124
|
+
elif is_same_padding(padding, input_shape, output_shape):
|
125
|
+
conv2d_padding_type = SAME
|
126
|
+
else:
|
127
|
+
assert isinstance(padding, list) and len(padding) == 2
|
128
|
+
|
129
|
+
conv2d_padding_type = VALID
|
130
|
+
|
131
|
+
# Padding is not valid or same, so we use valid padding and add padding operator before conv2d operator.
|
132
|
+
# when data_foramt is "NHWC", padding should be [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
|
133
|
+
paddings = torch.tensor(
|
134
|
+
[
|
135
|
+
[0, 0],
|
136
|
+
[padding[0], padding[0]],
|
137
|
+
[padding[1], padding[1]],
|
138
|
+
[0, 0],
|
139
|
+
],
|
140
|
+
dtype=torch.int32,
|
141
|
+
)
|
142
|
+
pad_output_shape = [
|
143
|
+
input_shape[0],
|
144
|
+
input_shape[1],
|
145
|
+
input_shape[2],
|
146
|
+
input_shape[3],
|
147
|
+
]
|
148
|
+
# Add (pad_top+pad_bottom) to pad_output_shape_h
|
149
|
+
pad_output_shape[1] += padding[0] * 2
|
150
|
+
# Add (pad_left+pad_Right) to pad_output_shape_w
|
151
|
+
pad_output_shape[2] += padding[1] * 2
|
152
|
+
# create padded output tensor
|
153
|
+
|
154
|
+
pad_output = self.graph.add_tensor_from_scratch(
|
155
|
+
prefix=f"{node.name}_input_pad_output",
|
156
|
+
shape=pad_output_shape,
|
157
|
+
dtype=input_dtype,
|
158
|
+
)
|
159
|
+
# CirclePad
|
160
|
+
pad_operator = define_pad_node(
|
161
|
+
self.graph, self._op_codes, [input_, paddings], [pad_output]
|
162
|
+
)
|
163
|
+
self.graph.add_operator(pad_operator)
|
164
|
+
conv_input = pad_output
|
165
|
+
|
166
|
+
if bias is None:
|
167
|
+
# luci-interpreter can't run no bias conv. Let's add zero vector for bias.
|
168
|
+
assert len(weight_shape) == 4
|
169
|
+
out_channel = weight_shape[0]
|
170
|
+
bias = [0.0] * out_channel # type: ignore[assignment]
|
171
|
+
|
172
|
+
# Conv2D
|
173
|
+
conv2d_operator = self.define_conv2d_node(
|
174
|
+
conv2d_padding_type, # 'SAME'(0) or 'VALID'(1)
|
175
|
+
stride,
|
176
|
+
dilation,
|
177
|
+
[conv_input, weight, bias],
|
178
|
+
[node],
|
179
|
+
)
|
180
|
+
|
181
|
+
return conv2d_operator
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.errors import NotYetSupportedError
|
28
|
+
from tico.utils.validate_args_kwargs import CopyArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class CopyVisitor(NodeVisitor):
|
33
|
+
"""
|
34
|
+
NOTE `torch.Tensor.copy_`'s behavior matches with `Reshape` of CIRCLE.
|
35
|
+
- because `torch.Tensor.copy_` is a in-place operator, so `dst` is converted to `Shape` of CIRCLE.
|
36
|
+
- after that, `dst` converted to `Shape` is connected to shape of `Reshape`.
|
37
|
+
- `src` is connected to tensor of `Reshape`.
|
38
|
+
- if `dst` is not converted to `Shape`.
|
39
|
+
[dst] [src]
|
40
|
+
|
|
41
|
+
[Reshape]
|
42
|
+
- if `dst` is converted to `Shape`.
|
43
|
+
[dst] [src]
|
44
|
+
| |
|
45
|
+
[Shape] |
|
46
|
+
\ /
|
47
|
+
[Reshape]
|
48
|
+
"""
|
49
|
+
|
50
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.copy.default]
|
51
|
+
|
52
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
53
|
+
super().__init__(op_codes, graph)
|
54
|
+
|
55
|
+
def check_to_do_broadcast(self, dst: List[int], src: List[int]) -> bool:
|
56
|
+
return dst != src
|
57
|
+
|
58
|
+
def define_broadcast_to_node(
|
59
|
+
self,
|
60
|
+
inputs: List[Union[circle.Tensor.TensorT, torch.Tensor]],
|
61
|
+
outputs: List[circle.Tensor.TensorT],
|
62
|
+
) -> circle.Operator.OperatorT:
|
63
|
+
op_index = get_op_index(
|
64
|
+
circle.BuiltinOperator.BuiltinOperator.BROADCAST_TO, self._op_codes
|
65
|
+
)
|
66
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
67
|
+
operator.builtinOptionsType = (
|
68
|
+
circle.BuiltinOptions.BuiltinOptions.BroadcastToOptions
|
69
|
+
)
|
70
|
+
|
71
|
+
option = circle.BroadcastToOptions.BroadcastToOptionsT()
|
72
|
+
operator.builtinOptions = option
|
73
|
+
return operator
|
74
|
+
|
75
|
+
def define_shape_node(
|
76
|
+
self, inputs: List[torch.fx.Node], outputs: List[circle.Tensor.TensorT]
|
77
|
+
) -> circle.Operator.OperatorT:
|
78
|
+
op_index = get_op_index(
|
79
|
+
circle.BuiltinOperator.BuiltinOperator.SHAPE, self._op_codes
|
80
|
+
)
|
81
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
82
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ShapeOptions
|
83
|
+
|
84
|
+
option = circle.ShapeOptions.ShapeOptionsT()
|
85
|
+
option.outType = circle.TensorType.TensorType.INT32
|
86
|
+
operator.builtinOptions = option
|
87
|
+
return operator
|
88
|
+
|
89
|
+
def define_node(
|
90
|
+
self,
|
91
|
+
node: torch.fx.Node,
|
92
|
+
) -> circle.Operator.OperatorT:
|
93
|
+
if len(node.args) == 3:
|
94
|
+
raise NotYetSupportedError("'non_blocking' is not supported yet.")
|
95
|
+
|
96
|
+
assert len(node.args) == 2, len(node.args)
|
97
|
+
|
98
|
+
args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
99
|
+
dst = args.dst
|
100
|
+
src = args.src
|
101
|
+
|
102
|
+
# To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op.
|
103
|
+
dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst)
|
104
|
+
dst_shape: List[int] = dst_tensor.shape
|
105
|
+
dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32)
|
106
|
+
|
107
|
+
dst_shape_shape = [len(dst_shape)]
|
108
|
+
dst_name: str = dst.name
|
109
|
+
|
110
|
+
shape_output = self.graph.add_tensor_from_scratch(
|
111
|
+
prefix=f"{dst_name}_shape_output",
|
112
|
+
shape=dst_shape_shape,
|
113
|
+
dtype=circle.TensorType.TensorType.INT32,
|
114
|
+
)
|
115
|
+
|
116
|
+
shape_operator = self.define_shape_node([dst], [shape_output])
|
117
|
+
self.graph.add_operator(shape_operator)
|
118
|
+
|
119
|
+
src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src)
|
120
|
+
src_shape: List[int] = src_tensor.shape
|
121
|
+
|
122
|
+
# The src tensor must be broadcastable with the dst tensor.
|
123
|
+
do_broadcast = self.check_to_do_broadcast(dst_shape, src_shape)
|
124
|
+
if do_broadcast:
|
125
|
+
# create braodcastTo output tensor
|
126
|
+
src_name: str = src.name
|
127
|
+
src_type: int = src_tensor.type
|
128
|
+
|
129
|
+
broadcast_to_output: circle.Tensor.TensorT = (
|
130
|
+
self.graph.add_tensor_from_scratch(
|
131
|
+
prefix=f"{src_name}_broadcast_to_output",
|
132
|
+
shape=dst_shape,
|
133
|
+
dtype=src_type,
|
134
|
+
)
|
135
|
+
)
|
136
|
+
|
137
|
+
broadcast_to_operator: circle.Operator.OperatorT = (
|
138
|
+
self.define_broadcast_to_node(
|
139
|
+
[src_tensor, dst_shape_tensor], [broadcast_to_output]
|
140
|
+
)
|
141
|
+
)
|
142
|
+
self.graph.add_operator(broadcast_to_operator)
|
143
|
+
inputs: List = [broadcast_to_output, shape_output]
|
144
|
+
else:
|
145
|
+
inputs = [src, shape_output]
|
146
|
+
|
147
|
+
outputs = [node]
|
148
|
+
op_index = get_op_index(
|
149
|
+
circle.BuiltinOperator.BuiltinOperator.RESHAPE, self._op_codes
|
150
|
+
)
|
151
|
+
|
152
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
153
|
+
|
154
|
+
# Op-specific option
|
155
|
+
operator.builtinOptionsType = (
|
156
|
+
circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
|
157
|
+
)
|
158
|
+
option = circle.ReshapeOptions.ReshapeOptionsT()
|
159
|
+
option.newShape = dst_shape
|
160
|
+
|
161
|
+
operator.builtinOptions = option
|
162
|
+
return operator
|