tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import LinearArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class LinearVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.linear.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
|
43
|
+
)
|
44
|
+
args = LinearArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
weight = args.weight
|
47
|
+
bias = args.bias
|
48
|
+
|
49
|
+
inputs = [input, weight, bias]
|
50
|
+
|
51
|
+
outputs = [node]
|
52
|
+
|
53
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
54
|
+
|
55
|
+
# Op-specific option
|
56
|
+
operator.builtinOptionsType = (
|
57
|
+
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
|
58
|
+
)
|
59
|
+
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
|
60
|
+
option.fusedActivationFunction = (
|
61
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
62
|
+
)
|
63
|
+
option.weightsFormat = (
|
64
|
+
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
|
65
|
+
)
|
66
|
+
option.keepNumDims = True
|
67
|
+
option.asymmetricQuantizeInputs = False
|
68
|
+
operator.builtinOptions = option
|
69
|
+
|
70
|
+
return operator
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import LogArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class LogVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.log.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.LOG, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = LogArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import (
|
25
|
+
extract_circle_dtype,
|
26
|
+
extract_shape,
|
27
|
+
extract_torch_dtype,
|
28
|
+
)
|
29
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
30
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
31
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
32
|
+
from tico.utils.validate_args_kwargs import Log1pArgs
|
33
|
+
|
34
|
+
|
35
|
+
@register_node_visitor
|
36
|
+
class Log1pVisitor(NodeVisitor):
|
37
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.log1p.default]
|
38
|
+
|
39
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
40
|
+
super().__init__(op_codes, graph)
|
41
|
+
|
42
|
+
def define_add_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
|
43
|
+
op_index = get_op_index(
|
44
|
+
circle.BuiltinOperator.BuiltinOperator.ADD, self._op_codes
|
45
|
+
)
|
46
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
47
|
+
|
48
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.AddOptions
|
49
|
+
option = circle.AddOptions.AddOptionsT()
|
50
|
+
option.fusedActivationFunction = (
|
51
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
52
|
+
)
|
53
|
+
option.potScaleInt16 = False
|
54
|
+
operator.builtinOptions = option
|
55
|
+
|
56
|
+
return operator
|
57
|
+
|
58
|
+
def define_node(
|
59
|
+
self,
|
60
|
+
node: torch.fx.Node,
|
61
|
+
) -> circle.Operator.OperatorT:
|
62
|
+
args = Log1pArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
63
|
+
input = args.input
|
64
|
+
|
65
|
+
input_shape = list(extract_shape(input))
|
66
|
+
dst_dtype_circle = extract_circle_dtype(input)
|
67
|
+
add_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
|
68
|
+
prefix=f"{input.name}_add", shape=input_shape, dtype=dst_dtype_circle
|
69
|
+
)
|
70
|
+
const_one = torch.tensor([1]).to(extract_torch_dtype(input))
|
71
|
+
|
72
|
+
add_node = self.define_add_node([input, const_one], [add_tensor])
|
73
|
+
self.graph.add_operator(add_node)
|
74
|
+
|
75
|
+
inputs = [add_tensor]
|
76
|
+
outputs = [node]
|
77
|
+
|
78
|
+
op_index = get_op_index(
|
79
|
+
circle.BuiltinOperator.BuiltinOperator.LOG, self._op_codes
|
80
|
+
)
|
81
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
82
|
+
|
83
|
+
return operator
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import LogicalAndArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class LogicalAndVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.logical_and.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.LOGICAL_AND,
|
43
|
+
self._op_codes,
|
44
|
+
)
|
45
|
+
|
46
|
+
args = LogicalAndArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
47
|
+
input = args.input
|
48
|
+
other = args.other
|
49
|
+
|
50
|
+
inputs = [input, other]
|
51
|
+
outputs = [node]
|
52
|
+
|
53
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
54
|
+
|
55
|
+
# Op-specific option
|
56
|
+
operator.builtinOptionsType = (
|
57
|
+
circle.BuiltinOptions.BuiltinOptions.LogicalAndOptions
|
58
|
+
)
|
59
|
+
option = circle.LogicalAndOptions.LogicalAndOptionsT()
|
60
|
+
|
61
|
+
operator.builtinOptions = option
|
62
|
+
|
63
|
+
return operator
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import LogicalNotArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class LogicalNotVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.logical_not.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.LOGICAL_NOT,
|
43
|
+
self._op_codes,
|
44
|
+
)
|
45
|
+
|
46
|
+
args = LogicalNotArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
47
|
+
input = args.input
|
48
|
+
|
49
|
+
inputs = [input]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
# Op-specific option
|
55
|
+
operator.builtinOptionsType = (
|
56
|
+
circle.BuiltinOptions.BuiltinOptions.LogicalNotOptions
|
57
|
+
)
|
58
|
+
option = circle.LogicalNotOptions.LogicalNotOptionsT()
|
59
|
+
|
60
|
+
operator.builtinOptions = option
|
61
|
+
|
62
|
+
return operator
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import LtArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class LtVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.lt.Tensor]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.LESS,
|
43
|
+
self._op_codes,
|
44
|
+
)
|
45
|
+
|
46
|
+
args = LtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
47
|
+
input = args.input
|
48
|
+
other = args.other
|
49
|
+
|
50
|
+
inputs = [input, other]
|
51
|
+
outputs = [node]
|
52
|
+
|
53
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
54
|
+
|
55
|
+
# Op-specific option
|
56
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.LessOptions
|
57
|
+
option = circle.LessOptions.LessOptionsT()
|
58
|
+
|
59
|
+
operator.builtinOptions = option
|
60
|
+
|
61
|
+
return operator
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from enum import IntEnum
|
16
|
+
from typing import Dict, List, TYPE_CHECKING
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch._ops
|
20
|
+
import torch.fx
|
21
|
+
import torch
|
22
|
+
from circle_schema import circle
|
23
|
+
|
24
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
25
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
26
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
27
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
28
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
29
|
+
from tico.utils.validate_args_kwargs import MaxPool2dWithIndicesArgs
|
30
|
+
|
31
|
+
|
32
|
+
class PaddingType(IntEnum):
|
33
|
+
SAME = 0
|
34
|
+
VALID = 1
|
35
|
+
|
36
|
+
|
37
|
+
@register_node_visitor
|
38
|
+
class MaxPool2DWithIndicesVisitor(NodeVisitor):
|
39
|
+
target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.maxpool2d]
|
40
|
+
|
41
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
42
|
+
super().__init__(op_codes, graph)
|
43
|
+
|
44
|
+
def define_padV2_node(
|
45
|
+
self, inputs: List, outputs: List
|
46
|
+
) -> circle.Operator.OperatorT:
|
47
|
+
def set_padv2_option(operator: circle.Operator.OperatorT):
|
48
|
+
operator.builtinOptionsType = (
|
49
|
+
circle.BuiltinOptions.BuiltinOptions.PadV2Options
|
50
|
+
)
|
51
|
+
option = circle.PadV2Options.PadV2OptionsT()
|
52
|
+
operator.builtinOptions = option
|
53
|
+
|
54
|
+
pad_op_index = get_op_index(
|
55
|
+
circle.BuiltinOperator.BuiltinOperator.PADV2, self._op_codes
|
56
|
+
)
|
57
|
+
operator = create_builtin_operator(self.graph, pad_op_index, inputs, outputs)
|
58
|
+
set_padv2_option(operator)
|
59
|
+
return operator
|
60
|
+
|
61
|
+
def define_node(
|
62
|
+
self,
|
63
|
+
node: torch.fx.Node,
|
64
|
+
) -> circle.Operator.OperatorT:
|
65
|
+
# max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
66
|
+
args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
67
|
+
input = args.input
|
68
|
+
kernel_size = args.kernel_size
|
69
|
+
stride = args.stride
|
70
|
+
padding = args.padding
|
71
|
+
|
72
|
+
maxpool_input: torch.fx.Node | circle.Tensor.TensorT = input
|
73
|
+
|
74
|
+
def define_padding_node():
|
75
|
+
assert isinstance(padding, list), type(padding)
|
76
|
+
padding_vec = torch.tensor(
|
77
|
+
[
|
78
|
+
[0, 0],
|
79
|
+
[padding[0], padding[0]],
|
80
|
+
[padding[1], padding[1]],
|
81
|
+
[0, 0],
|
82
|
+
],
|
83
|
+
dtype=torch.int32,
|
84
|
+
)
|
85
|
+
padding_value = float("-inf")
|
86
|
+
input_shape = list(extract_shape(input))
|
87
|
+
input_dtype: int = extract_circle_dtype(input)
|
88
|
+
padded_input_shape = [
|
89
|
+
input_shape[0],
|
90
|
+
input_shape[1],
|
91
|
+
input_shape[2],
|
92
|
+
input_shape[3],
|
93
|
+
]
|
94
|
+
padded_input_shape[1] += padding[0] * 2
|
95
|
+
padded_input_shape[2] += padding[1] * 2
|
96
|
+
# create padded input tensor
|
97
|
+
padded_input_tensor = self.graph.add_tensor_from_scratch(
|
98
|
+
prefix=f"{input.name}_pad_output",
|
99
|
+
shape=padded_input_shape,
|
100
|
+
dtype=input_dtype,
|
101
|
+
)
|
102
|
+
pad_operator = self.define_padV2_node(
|
103
|
+
[input, padding_vec, padding_value], [padded_input_tensor]
|
104
|
+
)
|
105
|
+
self.graph.add_operator(pad_operator)
|
106
|
+
return padded_input_tensor
|
107
|
+
|
108
|
+
padding_type = PaddingType.VALID
|
109
|
+
if padding is not None:
|
110
|
+
if extract_shape(input) == extract_shape(node):
|
111
|
+
padding_type = PaddingType.SAME
|
112
|
+
else:
|
113
|
+
padding_type = PaddingType.VALID
|
114
|
+
maxpool_input = define_padding_node()
|
115
|
+
|
116
|
+
inputs = [maxpool_input]
|
117
|
+
outputs = [node]
|
118
|
+
|
119
|
+
op_index = get_op_index(
|
120
|
+
circle.BuiltinOperator.BuiltinOperator.MAX_POOL_2D,
|
121
|
+
self._op_codes,
|
122
|
+
)
|
123
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
124
|
+
|
125
|
+
# Op-specific option
|
126
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
|
127
|
+
option = circle.Pool2DOptions.Pool2DOptionsT()
|
128
|
+
|
129
|
+
option.padding = int(padding_type)
|
130
|
+
option.strideH = stride[0]
|
131
|
+
option.strideW = stride[1]
|
132
|
+
option.filterHeight = kernel_size[0]
|
133
|
+
option.filterWidth = kernel_size[1]
|
134
|
+
option.fusedActivationFunction = (
|
135
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
136
|
+
)
|
137
|
+
|
138
|
+
operator.builtinOptions = option
|
139
|
+
|
140
|
+
return operator
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import MaximumArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class MaximumVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.maximum.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
40
|
+
args = MaximumArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
41
|
+
input = args.input
|
42
|
+
other = args.other
|
43
|
+
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.MAXIMUM, self._op_codes
|
46
|
+
)
|
47
|
+
|
48
|
+
inputs = [input, other]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import MeanDimArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class MeanVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.mean.dim]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.MEAN, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
dim = args.dim
|
48
|
+
keep_dims = args.keep_dims
|
49
|
+
|
50
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
|
+
inputs = [input, dim_i32]
|
52
|
+
outputs = [node]
|
53
|
+
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
55
|
+
|
56
|
+
# Op-specific option
|
57
|
+
operator.builtinOptionsType = (
|
58
|
+
circle.BuiltinOptions.BuiltinOptions.ReducerOptions
|
59
|
+
)
|
60
|
+
option = circle.ReducerOptions.ReducerOptionsT()
|
61
|
+
if keep_dims:
|
62
|
+
option.keepDims = keep_dims
|
63
|
+
|
64
|
+
operator.builtinOptions = option
|
65
|
+
|
66
|
+
return operator
|