tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,60 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import ExpArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class ExpVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.exp.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
args = ExpArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.EXP,
|
46
|
+
self._op_codes,
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
# Op-specific option
|
55
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ExpOptions
|
56
|
+
option = circle.ExpOptions.ExpOptionsT()
|
57
|
+
|
58
|
+
operator.builtinOptions = option
|
59
|
+
|
60
|
+
return operator
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import ExpandArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class ExpandVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.expand.default,
|
34
|
+
torch.ops.aten.expand_copy.default,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_expand_copy_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.BROADCAST_TO, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
46
|
+
operator.builtinOptionsType = (
|
47
|
+
circle.BuiltinOptions.BuiltinOptions.BroadcastToOptions
|
48
|
+
)
|
49
|
+
option = circle.BroadcastToOptions.BroadcastToOptionsT()
|
50
|
+
operator.builtinOptions = option
|
51
|
+
return operator
|
52
|
+
|
53
|
+
def define_node(
|
54
|
+
self,
|
55
|
+
node: torch.fx.Node,
|
56
|
+
) -> circle.Operator.OperatorT:
|
57
|
+
args = ExpandArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
58
|
+
input = args.input
|
59
|
+
size = args.size
|
60
|
+
|
61
|
+
input_tid: int = self.graph.get_tid_registered(input)
|
62
|
+
input_tensor: circle.Tensor.TensorT = self.graph.tensors[input_tid]
|
63
|
+
input_shape: List[int] = input_tensor.shape
|
64
|
+
|
65
|
+
extending_rank = len(size) - len(input_shape)
|
66
|
+
|
67
|
+
size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
|
68
|
+
for idx, dim in enumerate(size_i32):
|
69
|
+
if idx < extending_rank:
|
70
|
+
assert (
|
71
|
+
dim >= 1
|
72
|
+
), "A dim value(less than 1) isn't allowed in the extending_rank."
|
73
|
+
|
74
|
+
"""
|
75
|
+
In pytorch, passing -1 as the size for a dimension means that the size of that dimension won't be changed.
|
76
|
+
But, circle in ONE does not support this.
|
77
|
+
So, dim value(-1) in the non-extending_rank is supported to convert to the size for the dimension.
|
78
|
+
"""
|
79
|
+
if dim == -1:
|
80
|
+
size_i32[idx] = input_shape[idx - extending_rank]
|
81
|
+
|
82
|
+
for idx, dim in enumerate(input_shape):
|
83
|
+
assert (
|
84
|
+
dim == 1 or dim == size_i32[extending_rank + idx]
|
85
|
+
), f"The size of dimension to be expanded ({dim}) must be 1 or the expanded size ({size_i32[extending_rank + idx]})."
|
86
|
+
|
87
|
+
inputs = [input, size_i32]
|
88
|
+
outputs = [node]
|
89
|
+
operator = self.define_expand_copy_node(inputs, outputs)
|
90
|
+
|
91
|
+
return operator
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.utils.validate_args_kwargs import FullArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class FullVisitor(NodeVisitor):
|
31
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.full.default]
|
32
|
+
|
33
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
34
|
+
super().__init__(op_codes, graph)
|
35
|
+
|
36
|
+
def define_node(
|
37
|
+
self,
|
38
|
+
node: torch.fx.Node,
|
39
|
+
) -> circle.Operator.OperatorT:
|
40
|
+
args = FullArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
41
|
+
size = args.size
|
42
|
+
fill_value = args.fill_value
|
43
|
+
|
44
|
+
output_data = torch.full(size, fill_value)
|
45
|
+
|
46
|
+
self.graph.update_tensor_buffer(output_data, node.name)
|
47
|
+
|
48
|
+
return None # type: ignore[return-value]
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import to_circle_dtype
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import FullLikeArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class FullLikeVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.full_like.default]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
args = FullLikeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
43
|
+
input = args.input
|
44
|
+
fill_value = args.fill_value
|
45
|
+
|
46
|
+
# can calculate full_like tensor before inference.
|
47
|
+
output_data = torch.full(
|
48
|
+
size=input.meta["val"].size(),
|
49
|
+
fill_value=fill_value,
|
50
|
+
dtype=input.meta["val"].dtype,
|
51
|
+
)
|
52
|
+
|
53
|
+
self.graph.update_tensor_buffer(output_data, node.name)
|
54
|
+
|
55
|
+
return None # type: ignore[return-value]
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import GeArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class GeVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.ge.Scalar,
|
34
|
+
torch.ops.aten.ge.Tensor,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
41
|
+
args = GeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
other = args.other
|
44
|
+
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.GREATER_EQUAL, self._op_codes
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input, other]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
return operator
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import GeluArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class GeluVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.gelu.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = GeluArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
|
46
|
+
inputs = [input]
|
47
|
+
outputs = [node]
|
48
|
+
|
49
|
+
op_index = get_op_index(
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.GELU, self._op_codes
|
51
|
+
)
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GeluOptions
|
54
|
+
option = circle.GeluOptions.GeluOptionsT()
|
55
|
+
if "approximate" in node.kwargs and node.kwargs["approximate"] == "tanh":
|
56
|
+
option.approximate = True
|
57
|
+
operator.builtinOptions = option
|
58
|
+
|
59
|
+
return operator
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import GtArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class GtVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.gt.Scalar,
|
34
|
+
torch.ops.aten.gt.Tensor,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
41
|
+
args = GtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
other = args.other
|
44
|
+
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.GREATER, self._op_codes
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input, other]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
return operator
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.errors import NotYetSupportedError
|
28
|
+
from tico.utils.validate_args_kwargs import IndexArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class IndexTensorVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.index.Tensor]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
args = IndexArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
43
|
+
tensor = args.input
|
44
|
+
indices = args.indices
|
45
|
+
|
46
|
+
# TODO Support multiple indices
|
47
|
+
if len(indices) - indices.count(None) > 1: # type: ignore[arg-type]
|
48
|
+
raise NotYetSupportedError(
|
49
|
+
"Multiple indices is not supported yet in aten.index.Tensor"
|
50
|
+
)
|
51
|
+
|
52
|
+
# find the lonely index
|
53
|
+
# ex. indices = [None, tensor, None] # index: tensor, axis: 1
|
54
|
+
# ex. indices = [1] # index: 1, axis 0
|
55
|
+
# ex. indices = [tensor] # index: tensor, axis 0
|
56
|
+
index = None
|
57
|
+
axis = None
|
58
|
+
for axis_, index_ in enumerate(indices):
|
59
|
+
if index_ is not None:
|
60
|
+
index = index_ # type: ignore[assignment]
|
61
|
+
axis = axis_ # type: ignore[assignment]
|
62
|
+
break
|
63
|
+
|
64
|
+
assert index is not None, index
|
65
|
+
assert axis is not None, axis
|
66
|
+
|
67
|
+
inputs = [tensor, index]
|
68
|
+
outputs = [node]
|
69
|
+
|
70
|
+
op_index = get_op_index(
|
71
|
+
circle.BuiltinOperator.BuiltinOperator.GATHER, self._op_codes
|
72
|
+
)
|
73
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
74
|
+
|
75
|
+
# Op-specific option
|
76
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
|
77
|
+
option = circle.GatherOptions.GatherOptionsT()
|
78
|
+
option.axis = axis # type: ignore[assignment]
|
79
|
+
# TODO option.batchDims
|
80
|
+
operator.builtinOptions = option
|
81
|
+
|
82
|
+
return operator
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import IndexSelectArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class IndexSelectVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [
|
34
|
+
torch.ops.aten.index_select.default,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
41
|
+
args = IndexSelectArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
|
43
|
+
input = args.input
|
44
|
+
dim = args.dim
|
45
|
+
index = args.index
|
46
|
+
|
47
|
+
op_index = get_op_index(
|
48
|
+
circle.BuiltinOperator.BuiltinOperator.GATHER,
|
49
|
+
self._op_codes,
|
50
|
+
)
|
51
|
+
|
52
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
53
|
+
inputs = [input, index]
|
54
|
+
outputs = [node]
|
55
|
+
|
56
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
57
|
+
|
58
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
|
59
|
+
option = circle.GatherOptions.GatherOptionsT()
|
60
|
+
option.axis = dim_i32
|
61
|
+
|
62
|
+
operator.builtinOptions = option
|
63
|
+
|
64
|
+
return operator
|
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import extract_shape
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import InstanceNormArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class InstanceNormVisitor(NodeVisitor):
|
32
|
+
"""
|
33
|
+
Input [NHWC] ---- circle_cumstom.instance_norm [NHWC] ---- OUTPUT[NHWC]
|
34
|
+
Weight -------/
|
35
|
+
Bias -------/
|
36
|
+
"""
|
37
|
+
|
38
|
+
target: List[torch._ops.OpOverload] = [
|
39
|
+
torch.ops.circle_custom.instance_norm,
|
40
|
+
]
|
41
|
+
|
42
|
+
def define_instance_norm_node(
|
43
|
+
self, eps, inputs, outputs
|
44
|
+
) -> circle.Operator.OperatorT:
|
45
|
+
def set_option(operator, eps):
|
46
|
+
operator.builtinOptionsType = (
|
47
|
+
circle.BuiltinOptions.BuiltinOptions.InstanceNormOptions
|
48
|
+
)
|
49
|
+
option = circle.InstanceNormOptions.InstanceNormOptionsT()
|
50
|
+
option.epsilon = eps
|
51
|
+
option.fusedActivationFunction = (
|
52
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
53
|
+
)
|
54
|
+
operator.builtinOptions = option
|
55
|
+
|
56
|
+
op_index = get_op_index(
|
57
|
+
circle.BuiltinOperator.BuiltinOperator.INSTANCE_NORM, self._op_codes
|
58
|
+
)
|
59
|
+
|
60
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
61
|
+
set_option(operator, eps)
|
62
|
+
return operator
|
63
|
+
|
64
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
65
|
+
super().__init__(op_codes, graph)
|
66
|
+
|
67
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
68
|
+
args = InstanceNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
69
|
+
|
70
|
+
input = args.input
|
71
|
+
weight = args.weight
|
72
|
+
bias = args.bias
|
73
|
+
eps = args.eps
|
74
|
+
|
75
|
+
# Ignore training-related args
|
76
|
+
running_mean = args.running_mean
|
77
|
+
running_var = args.running_var
|
78
|
+
use_input_stats = args.use_input_stats
|
79
|
+
momentum = args.momentum
|
80
|
+
cudnn_enabled = args.cudnn_enabled
|
81
|
+
|
82
|
+
input_shape = list(extract_shape(input))
|
83
|
+
assert len(input_shape) == 4, len(input_shape)
|
84
|
+
|
85
|
+
instance_norm_operator = self.define_instance_norm_node(
|
86
|
+
eps,
|
87
|
+
[input, weight, bias],
|
88
|
+
[node],
|
89
|
+
)
|
90
|
+
|
91
|
+
return instance_norm_operator
|