tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import PReLUArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class PReLUVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.prelu.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.PRELU, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = PReLUArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
weight = args.weight
|
48
|
+
|
49
|
+
inputs = [input, weight]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
return operator
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import QuantizePerTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class QuantizePerTensorDefaultVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = QuantizePerTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
tensor = args.tensor
|
45
|
+
scale = args.scale
|
46
|
+
zero_p = args.zero_p
|
47
|
+
quant_min = args.quant_min
|
48
|
+
quant_max = args.quant_max
|
49
|
+
|
50
|
+
output_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node)
|
51
|
+
assert output_tensor.quantization is not None
|
52
|
+
|
53
|
+
# Tensor should have qparam when it's exported
|
54
|
+
# The qparam must match with the arguments of this Op
|
55
|
+
assert output_tensor.quantization.scale[0] == scale
|
56
|
+
assert output_tensor.quantization.zeroPoint[0] == zero_p
|
57
|
+
|
58
|
+
if output_tensor.type == circle.TensorType.TensorType.UINT8:
|
59
|
+
assert quant_min == 0 and quant_max == 255
|
60
|
+
elif output_tensor.type == circle.TensorType.TensorType.INT16:
|
61
|
+
# Some frameworks use -32767 as quant_min of int16
|
62
|
+
assert quant_min in (-32768, -32767) and quant_max == 32767
|
63
|
+
|
64
|
+
inputs = [tensor]
|
65
|
+
outputs = [node]
|
66
|
+
|
67
|
+
op_index = get_op_index(
|
68
|
+
circle.BuiltinOperator.BuiltinOperator.QUANTIZE, self._op_codes
|
69
|
+
)
|
70
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
71
|
+
|
72
|
+
# Op-specific option
|
73
|
+
operator.builtinOptionsType = (
|
74
|
+
circle.BuiltinOptions.BuiltinOptions.QuantizeOptions
|
75
|
+
)
|
76
|
+
option = circle.QuantizeOptions.QuantizeOptionsT()
|
77
|
+
operator.builtinOptions = option
|
78
|
+
|
79
|
+
return operator
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import ReciprocalArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class ReciprocalVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.reciprocal.default]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
op_index = get_op_index(
|
43
|
+
circle.BuiltinOperator.BuiltinOperator.DIV, self._op_codes
|
44
|
+
)
|
45
|
+
|
46
|
+
args = ReciprocalArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
47
|
+
input = args.input
|
48
|
+
|
49
|
+
input_tensor = torch.tensor(1, dtype=extract_torch_dtype(input))
|
50
|
+
x = self.graph.add_const_tensor(input_tensor)
|
51
|
+
inputs = [x, input]
|
52
|
+
outputs = [node]
|
53
|
+
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
55
|
+
|
56
|
+
# Op-specific option
|
57
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.DivOptions
|
58
|
+
option = circle.DivOptions.DivOptionsT()
|
59
|
+
option.fusedActivationFunction = (
|
60
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
61
|
+
)
|
62
|
+
operator.builtinOptions = option
|
63
|
+
|
64
|
+
return operator
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import ReluArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class ReluVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.relu.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.RELU, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = ReluArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import Relu6Args
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class Relu6Visitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.relu6.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
40
|
+
args = Relu6Args(*node.args, **node.kwargs) # type: ignore[arg-type]
|
41
|
+
input = args.input
|
42
|
+
|
43
|
+
op_index = get_op_index(
|
44
|
+
circle.BuiltinOperator.BuiltinOperator.RELU6, self._op_codes
|
45
|
+
)
|
46
|
+
|
47
|
+
inputs = [input]
|
48
|
+
outputs = [node]
|
49
|
+
|
50
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
51
|
+
|
52
|
+
return operator
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.errors import InvalidArgumentError, NotYetSupportedError
|
29
|
+
from tico.utils.validate_args_kwargs import RepeatArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class RepeatVisitor(NodeVisitor):
|
34
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.repeat.default]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = RepeatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
repeats = args.repeats
|
46
|
+
|
47
|
+
for r in repeats:
|
48
|
+
if r == 0:
|
49
|
+
# TODO: Support r == 0 case
|
50
|
+
raise NotYetSupportedError("Only support positive repeat value")
|
51
|
+
elif r < 0:
|
52
|
+
raise InvalidArgumentError("Only support positive repeat value")
|
53
|
+
|
54
|
+
tensor_shape = extract_shape(input)
|
55
|
+
assert len(tensor_shape) <= len(repeats)
|
56
|
+
if len(tensor_shape) != len(repeats):
|
57
|
+
# TODO Support len(tensor_shape) < len(repeats)
|
58
|
+
raise NotYetSupportedError(
|
59
|
+
"Length of both input tensor and repeats vector should be same."
|
60
|
+
)
|
61
|
+
repeat_dim_cnt = len(repeats) - repeats.count(1)
|
62
|
+
tensor_dtype = extract_circle_dtype(input)
|
63
|
+
op_index = get_op_index(
|
64
|
+
circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes
|
65
|
+
)
|
66
|
+
concat_input: torch.fx.Node | circle.Tensor.TensorT = input
|
67
|
+
concat_output: torch.fx.node.Node | circle.Tensor.TensorT = node
|
68
|
+
for idx, r in enumerate(repeats):
|
69
|
+
# concat along idx dimension
|
70
|
+
if r > 1:
|
71
|
+
# Except last created concat, a tensor should be created.
|
72
|
+
if repeat_dim_cnt > 1:
|
73
|
+
repeated_shape = list(tensor_shape)
|
74
|
+
repeated_shape[idx] = repeated_shape[idx] * r
|
75
|
+
concat_output = self.graph.add_tensor_from_scratch(
|
76
|
+
prefix=f"{node.name}_concat_{idx}",
|
77
|
+
shape=repeated_shape,
|
78
|
+
dtype=tensor_dtype,
|
79
|
+
)
|
80
|
+
inputs = [concat_input] * r
|
81
|
+
if repeat_dim_cnt == 1:
|
82
|
+
outputs: List[torch.fx.node.Node | circle.Tensor.TensorT] = [node]
|
83
|
+
else:
|
84
|
+
outputs = [concat_output]
|
85
|
+
operator = create_builtin_operator(
|
86
|
+
self.graph, op_index, inputs, outputs
|
87
|
+
)
|
88
|
+
operator.builtinOptionsType = (
|
89
|
+
circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions
|
90
|
+
)
|
91
|
+
option = circle.ConcatenationOptions.ConcatenationOptionsT()
|
92
|
+
option.axis = idx
|
93
|
+
operator.builtinOptions = option
|
94
|
+
if repeat_dim_cnt > 1:
|
95
|
+
self.graph.add_operator(operator)
|
96
|
+
concat_input = concat_output
|
97
|
+
repeat_dim_cnt -= 1
|
98
|
+
|
99
|
+
return operator
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph, is_const
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.errors import NotYetSupportedError
|
29
|
+
from tico.utils.validate_args_kwargs import ReshapeArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class ReshapeVisitor(NodeVisitor):
|
34
|
+
target: List[torch._ops.OpOverload] = [
|
35
|
+
torch.ops.aten.reshape.default,
|
36
|
+
]
|
37
|
+
|
38
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
39
|
+
super().__init__(op_codes, graph)
|
40
|
+
|
41
|
+
def define_node(
|
42
|
+
self,
|
43
|
+
node: torch.fx.Node,
|
44
|
+
) -> circle.Operator.OperatorT:
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.RESHAPE,
|
47
|
+
self._op_codes,
|
48
|
+
)
|
49
|
+
args = ReshapeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
50
|
+
input = args.input
|
51
|
+
size = args.size
|
52
|
+
|
53
|
+
if isinstance(size, int):
|
54
|
+
raise NotYetSupportedError("scalar size conversion is not supported yet.")
|
55
|
+
|
56
|
+
assert is_const(size), type(size)
|
57
|
+
|
58
|
+
size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
|
59
|
+
inputs = [input, size_i32]
|
60
|
+
outputs = [node]
|
61
|
+
|
62
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
63
|
+
|
64
|
+
# Op-specific option
|
65
|
+
operator.builtinOptionsType = (
|
66
|
+
circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
|
67
|
+
)
|
68
|
+
option = circle.ReshapeOptions.ReshapeOptionsT()
|
69
|
+
option.newShape = size_i32
|
70
|
+
|
71
|
+
operator.builtinOptions = option
|
72
|
+
|
73
|
+
return operator
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import ResizeNearestNeighborArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class ResizeNearestNeighborVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.circle_custom.resize_nearest_neighbor
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
# TODO Support generic algorithm
|
44
|
+
args = ResizeNearestNeighborArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
size = args.size
|
47
|
+
|
48
|
+
size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
|
49
|
+
inputs = [input, size_i32]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
op_index = get_op_index(
|
53
|
+
circle.BuiltinOperator.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR,
|
54
|
+
self._op_codes,
|
55
|
+
)
|
56
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
57
|
+
|
58
|
+
# Op-specific option
|
59
|
+
operator.builtinOptionsType = (
|
60
|
+
circle.BuiltinOptions.BuiltinOptions.ResizeNearestNeighborOptions
|
61
|
+
)
|
62
|
+
option = circle.ResizeNearestNeighborOptions.ResizeNearestNeighborOptionsT()
|
63
|
+
# TODO Consider these options
|
64
|
+
# If True, the centers of the 4 corner pixels of the input and output tensors are aligned, preserving the values at the corner pixels.
|
65
|
+
option.alignCorners = False
|
66
|
+
# If True, the pixel centers are assumed to be at (0.5, 0.5). If this parameter is True, then align_corners parameter must be False.
|
67
|
+
option.halfPixelCenters = False
|
68
|
+
operator.builtinOptions = option
|
69
|
+
|
70
|
+
return operator
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import RsqrtArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class RsqrtVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.rsqrt.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.RSQRT, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = RsqrtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.utils.validate_args_kwargs import ScalarTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class ScalarTensorVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.scalar_tensor.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
# assert False, "This pass must not be in use."
|
42
|
+
|
43
|
+
args = ScalarTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
scalar = args.scalar
|
45
|
+
|
46
|
+
# Set dtype as node dtype because `scalar_tensor` results in float even the input is int.
|
47
|
+
output_data = torch.scalar_tensor(scalar, dtype=extract_torch_dtype(node))
|
48
|
+
|
49
|
+
self.graph.update_tensor_buffer(output_data, node.name)
|
50
|
+
|
51
|
+
return None # type: ignore[return-value]
|