tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from circle_schema import circle
|
16
|
+
|
17
|
+
|
18
|
+
class OpCode(circle.OperatorCode.OperatorCodeT):
|
19
|
+
"""
|
20
|
+
Wrapper class for operator code in circle schema
|
21
|
+
This implements __eq__ and __hash__ for use with dict()
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
super().__init__()
|
26
|
+
|
27
|
+
def __eq__(self, other):
|
28
|
+
if self.version != other.version:
|
29
|
+
return False
|
30
|
+
|
31
|
+
if self.builtinCode == circle.BuiltinOperator.BuiltinOperator.CUSTOM:
|
32
|
+
return self.customCode == other.customCode
|
33
|
+
|
34
|
+
return self.builtinCode == other.builtinCode
|
35
|
+
|
36
|
+
def __hash__(self):
|
37
|
+
val = (
|
38
|
+
self.deprecatedBuiltinCode,
|
39
|
+
self.customCode,
|
40
|
+
self.version,
|
41
|
+
self.builtinCode,
|
42
|
+
)
|
43
|
+
return hash(val)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, Type, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from circle_schema import circle
|
21
|
+
|
22
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
|
25
|
+
|
26
|
+
class NodeVisitor:
|
27
|
+
"""
|
28
|
+
Node visitor for lowering edge IR to circle
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
32
|
+
# For setting opcode index in circle model
|
33
|
+
# This is updated during serialization
|
34
|
+
self._op_codes = op_codes
|
35
|
+
self.graph = graph
|
36
|
+
|
37
|
+
# Define circle model operator
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.node.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
raise NotImplementedError("NodeVisitor must be extended.")
|
43
|
+
|
44
|
+
|
45
|
+
# container for all node visitors
|
46
|
+
_node_visitor_dict: Dict[torch._ops.OpOverload, Type[NodeVisitor]] = {}
|
47
|
+
|
48
|
+
|
49
|
+
# Decorator for each visitor
|
50
|
+
def register_node_visitor(visitor):
|
51
|
+
for target in visitor.target:
|
52
|
+
_node_visitor_dict[target] = visitor
|
53
|
+
return visitor
|
54
|
+
|
55
|
+
|
56
|
+
def get_node_visitor(target: torch._ops.OpOverload) -> Type[NodeVisitor]:
|
57
|
+
"""
|
58
|
+
Get a single node visitor (for unittest purpose)
|
59
|
+
"""
|
60
|
+
_visitor = _node_visitor_dict.get(target, None)
|
61
|
+
|
62
|
+
if not _visitor:
|
63
|
+
raise LookupError(f"NodeVisitor for {target} is not registered")
|
64
|
+
|
65
|
+
return _visitor
|
66
|
+
|
67
|
+
|
68
|
+
# Get all node visitors
|
69
|
+
def get_node_visitors(
|
70
|
+
op_codes: Dict[OpCode, int], graph: CircleSubgraph
|
71
|
+
) -> Dict[torch._ops.OpOverload, NodeVisitor]:
|
72
|
+
node_visitors = {}
|
73
|
+
for target, visitor in _node_visitor_dict.items():
|
74
|
+
node_visitors[target] = visitor(op_codes, graph)
|
75
|
+
|
76
|
+
return node_visitors
|
77
|
+
|
78
|
+
|
79
|
+
def get_support_targets():
|
80
|
+
return _node_visitor_dict.keys()
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import AddTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class AddVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.add.Tensor,
|
34
|
+
torch.ops.aten.add.Scalar,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = AddTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
other = args.other
|
47
|
+
|
48
|
+
inputs = [input, other]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
op_index = get_op_index(
|
52
|
+
circle.BuiltinOperator.BuiltinOperator.ADD, self._op_codes
|
53
|
+
)
|
54
|
+
|
55
|
+
inputs = [input, other]
|
56
|
+
outputs = [node]
|
57
|
+
|
58
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
59
|
+
|
60
|
+
# Op-specific option
|
61
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.AddOptions
|
62
|
+
option = circle.AddOptions.AddOptionsT()
|
63
|
+
option.fusedActivationFunction = (
|
64
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
65
|
+
)
|
66
|
+
option.potScaleInt16 = False
|
67
|
+
operator.builtinOptions = option
|
68
|
+
|
69
|
+
return operator
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.validate_args_kwargs import AliasCopyArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class AliasCopyVisitor(NodeVisitor):
|
31
|
+
target: List[torch._ops.OpOverload] = [
|
32
|
+
torch.ops.aten.alias.default,
|
33
|
+
torch.ops.aten.alias_copy.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = AliasCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
|
46
|
+
op_index = get_op_index(
|
47
|
+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
48
|
+
)
|
49
|
+
|
50
|
+
permute = torch.IntTensor(list(range(len(input.meta["val"].shape))))
|
51
|
+
|
52
|
+
inputs = [input, permute]
|
53
|
+
outputs = [node]
|
54
|
+
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
|
57
|
+
# Op-specific option
|
58
|
+
operator.builtinOptionsType = (
|
59
|
+
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
60
|
+
)
|
61
|
+
option = circle.TransposeOptions.TransposeOptionsT()
|
62
|
+
operator.builtinOptions = option
|
63
|
+
|
64
|
+
return operator
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from circle_schema import circle
|
21
|
+
|
22
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
23
|
+
from tico.serialize.circle_mapping import (
|
24
|
+
circle_legalize_dtype_to,
|
25
|
+
extract_circle_dtype,
|
26
|
+
extract_shape,
|
27
|
+
extract_torch_dtype,
|
28
|
+
)
|
29
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
30
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
31
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
32
|
+
from tico.utils.validate_args_kwargs import AnyArgs
|
33
|
+
|
34
|
+
|
35
|
+
@register_node_visitor
|
36
|
+
class AnyVisitor(NodeVisitor):
|
37
|
+
"""
|
38
|
+
Let's take NotEqual0 -> ReduceMax workaround for float, int
|
39
|
+
[RESTRICTION]
|
40
|
+
1. ReduceAny is not supported (luci-interpreter)
|
41
|
+
[CASE: BOOL]
|
42
|
+
(Bool tensors don't need 'Not Equal 0' at the first step.)
|
43
|
+
bool[d0..dN] --- Reduce Max ---> bool[]
|
44
|
+
[CASE: FLOAT, INT]
|
45
|
+
int/float[d0..dN] --- Not Equal 0 ---> bool[d0,...dN]
|
46
|
+
--- Reduce Max ---> bool[]
|
47
|
+
* [d0..dN] means a tensor with any shape
|
48
|
+
* [] means Scalar
|
49
|
+
"""
|
50
|
+
|
51
|
+
target: List[torch._ops.OpOverload] = [
|
52
|
+
torch.ops.aten.any.default,
|
53
|
+
torch.ops.aten.any.dim,
|
54
|
+
torch.ops.aten.any.dims,
|
55
|
+
]
|
56
|
+
|
57
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
58
|
+
super().__init__(op_codes, graph)
|
59
|
+
|
60
|
+
def define_max_node(
|
61
|
+
self, inputs: List, outputs: List, keepdims: bool
|
62
|
+
) -> circle.Operator.OperatorT:
|
63
|
+
op_index = get_op_index(
|
64
|
+
circle.BuiltinOperator.BuiltinOperator.REDUCE_MAX, self._op_codes
|
65
|
+
)
|
66
|
+
|
67
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
68
|
+
|
69
|
+
operator.builtinOptionsType = (
|
70
|
+
circle.BuiltinOptions.BuiltinOptions.ReducerOptions
|
71
|
+
)
|
72
|
+
option = circle.ReducerOptions.ReducerOptionsT()
|
73
|
+
option.keepDims = keepdims
|
74
|
+
|
75
|
+
operator.builtinOptions = option
|
76
|
+
|
77
|
+
return operator
|
78
|
+
|
79
|
+
def define_ne_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
|
80
|
+
op_index = get_op_index(
|
81
|
+
circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
|
82
|
+
)
|
83
|
+
|
84
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
85
|
+
|
86
|
+
operator.builtinOptionsType = (
|
87
|
+
circle.BuiltinOptions.BuiltinOptions.NotEqualOptions
|
88
|
+
)
|
89
|
+
option = circle.NotEqualOptions.NotEqualOptionsT()
|
90
|
+
operator.builtinOptions = option
|
91
|
+
return operator
|
92
|
+
|
93
|
+
def define_node(
|
94
|
+
self,
|
95
|
+
node: torch.fx.Node,
|
96
|
+
) -> circle.Operator.OperatorT:
|
97
|
+
args = AnyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
98
|
+
input = args.input
|
99
|
+
dim = args.dim
|
100
|
+
keepdim = args.keepdim
|
101
|
+
|
102
|
+
input_shape = list(extract_shape(input))
|
103
|
+
output_shape = list(extract_shape(node))
|
104
|
+
|
105
|
+
if dim is None:
|
106
|
+
dims = tuple(i for i in range(0, len(input_shape)))
|
107
|
+
dim_i32 = tuple(
|
108
|
+
circle_legalize_dtype_to(dim, dtype=torch.int32) for dim in dims
|
109
|
+
)
|
110
|
+
if isinstance(dim, int):
|
111
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
112
|
+
if isinstance(dim, tuple):
|
113
|
+
dim_i32 = tuple(circle_legalize_dtype_to(d, dtype=torch.int32) for d in dim)
|
114
|
+
|
115
|
+
inputs = [input, dim_i32]
|
116
|
+
outputs = [node]
|
117
|
+
|
118
|
+
dtype_torch = extract_torch_dtype(input)
|
119
|
+
input_tensor: torch.fx.node.Node | circle.Tensor.TensorT = input
|
120
|
+
|
121
|
+
if dtype_torch in [torch.int32, torch.int64, torch.float32, torch.float64]:
|
122
|
+
dst_dtype_circle = circle.TensorType.TensorType.BOOL
|
123
|
+
dst_dtype_torch = torch.bool
|
124
|
+
ne_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
|
125
|
+
prefix=f"{input.name}_ne", shape=input_shape, dtype=dst_dtype_circle
|
126
|
+
)
|
127
|
+
ne_node = self.define_ne_node(
|
128
|
+
[input_tensor, torch.Tensor([0]).to(dtype_torch)], [ne_tensor]
|
129
|
+
)
|
130
|
+
self.graph.add_operator(ne_node)
|
131
|
+
|
132
|
+
dtype_torch = dst_dtype_torch
|
133
|
+
input_tensor = ne_tensor
|
134
|
+
inputs = [ne_tensor, dim_i32]
|
135
|
+
|
136
|
+
inputs = [input_tensor, dim_i32]
|
137
|
+
|
138
|
+
reduce_node: circle.Operator.OperatorT = self.define_max_node(
|
139
|
+
inputs, outputs, keepdim
|
140
|
+
)
|
141
|
+
|
142
|
+
return reduce_node
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.utils.validate_args_kwargs import ArangeStartStepArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class ArangeStartStepVisitor(NodeVisitor):
|
31
|
+
"""
|
32
|
+
Fuse arange_start_step to const_tensor
|
33
|
+
"""
|
34
|
+
|
35
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.arange.start_step]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = ArangeStartStepArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
start = args.start
|
46
|
+
end = args.end
|
47
|
+
step = args.step
|
48
|
+
delta = 1
|
49
|
+
|
50
|
+
if step is not None:
|
51
|
+
delta = step[0] # type: ignore[index]
|
52
|
+
# assert False, "This pass must not be in use."
|
53
|
+
|
54
|
+
arange_dtype: torch.dtype = torch.float32
|
55
|
+
if isinstance(start, int) and isinstance(end, int):
|
56
|
+
arange_dtype = torch.int64
|
57
|
+
|
58
|
+
output_data = torch.arange(start=start, end=end, step=delta, dtype=arange_dtype)
|
59
|
+
self.graph.update_tensor_buffer(output_data, node.name)
|
60
|
+
|
61
|
+
return None # type: ignore[return-value]
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import ArgMaxArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class ArgMaxVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.argmax.default]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
args = ArgMaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
43
|
+
tensor = args.tensor
|
44
|
+
dim = args.dim
|
45
|
+
|
46
|
+
op_index = get_op_index(
|
47
|
+
circle.BuiltinOperator.BuiltinOperator.ARG_MAX, self._op_codes
|
48
|
+
)
|
49
|
+
|
50
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
|
+
inputs = [tensor, dim_i32]
|
52
|
+
outputs = [node]
|
53
|
+
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
55
|
+
|
56
|
+
# Op-specific option
|
57
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ArgMaxOptions
|
58
|
+
option = circle.ArgMaxOptions.ArgMaxOptionsT()
|
59
|
+
option.outputType = circle.TensorType.TensorType.INT64
|
60
|
+
operator.builtinOptions = option
|
61
|
+
|
62
|
+
return operator
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.define import define_pad_node
|
29
|
+
from tico.utils.validate_args_kwargs import AvgPool2dArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class AvgPool2DVisitor(NodeVisitor):
|
34
|
+
target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.avgpool2d]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
kernel_size = args.kernel_size
|
46
|
+
stride = args.stride
|
47
|
+
padding = args.padding
|
48
|
+
|
49
|
+
avgpool_input: torch.fx.Node | circle.Tensor.TensorT = input
|
50
|
+
|
51
|
+
def define_padding_node():
|
52
|
+
assert isinstance(padding, list), type(padding)
|
53
|
+
padding_vec = torch.tensor(
|
54
|
+
[
|
55
|
+
[0, 0],
|
56
|
+
[padding[0], padding[0]],
|
57
|
+
[padding[1], padding[1]],
|
58
|
+
[0, 0],
|
59
|
+
],
|
60
|
+
dtype=torch.int32,
|
61
|
+
)
|
62
|
+
input_shape = list(extract_shape(input))
|
63
|
+
input_dtype: int = extract_circle_dtype(input)
|
64
|
+
padded_input_shape = [
|
65
|
+
input_shape[0],
|
66
|
+
input_shape[1],
|
67
|
+
input_shape[2],
|
68
|
+
input_shape[3],
|
69
|
+
]
|
70
|
+
padded_input_shape[1] += padding[0] * 2
|
71
|
+
padded_input_shape[2] += padding[1] * 2
|
72
|
+
# create padded input tensor
|
73
|
+
padded_input_tensor = self.graph.add_tensor_from_scratch(
|
74
|
+
prefix=f"{input.name}_pad_output",
|
75
|
+
shape=padded_input_shape,
|
76
|
+
dtype=input_dtype,
|
77
|
+
)
|
78
|
+
pad_operator = define_pad_node(
|
79
|
+
self.graph, self._op_codes, [input, padding_vec], [padded_input_tensor]
|
80
|
+
)
|
81
|
+
self.graph.add_operator(pad_operator)
|
82
|
+
return padded_input_tensor
|
83
|
+
|
84
|
+
if padding is not None:
|
85
|
+
avgpool_input = define_padding_node()
|
86
|
+
|
87
|
+
inputs = [avgpool_input]
|
88
|
+
outputs = [node]
|
89
|
+
|
90
|
+
op_index = get_op_index(
|
91
|
+
circle.BuiltinOperator.BuiltinOperator.AVERAGE_POOL_2D,
|
92
|
+
self._op_codes,
|
93
|
+
)
|
94
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
95
|
+
|
96
|
+
# Op-specific option
|
97
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
|
98
|
+
option = circle.Pool2DOptions.Pool2DOptionsT()
|
99
|
+
|
100
|
+
SAME, VALID = 0, 1
|
101
|
+
option.padding = VALID
|
102
|
+
option.strideH = stride[0]
|
103
|
+
option.strideW = stride[1]
|
104
|
+
option.filterHeight = kernel_size[0]
|
105
|
+
option.filterWidth = kernel_size[1]
|
106
|
+
option.fusedActivationFunction = (
|
107
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
108
|
+
)
|
109
|
+
|
110
|
+
operator.builtinOptions = option
|
111
|
+
|
112
|
+
return operator
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import BmmArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class BatchMatmulVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.bmm.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
mat2 = args.mat2
|
44
|
+
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input, mat2]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
# Op-specific option
|
55
|
+
operator.builtinOptionsType = (
|
56
|
+
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
57
|
+
)
|
58
|
+
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
59
|
+
option.adjointLhs, option.adjointRhs = False, False
|
60
|
+
operator.builtinOptions = option
|
61
|
+
|
62
|
+
return operator
|