tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import SelectCopyIntArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class SelectCopyIntVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.select_copy.int,
|
34
|
+
torch.ops.aten.select.int,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = SelectCopyIntArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
input = args.input
|
46
|
+
dim = args.dim
|
47
|
+
index = args.index
|
48
|
+
|
49
|
+
indices = torch.as_tensor(index, dtype=torch.int32)
|
50
|
+
inputs = [input, indices]
|
51
|
+
outputs = [node]
|
52
|
+
|
53
|
+
op_index = get_op_index(
|
54
|
+
circle.BuiltinOperator.BuiltinOperator.GATHER, self._op_codes
|
55
|
+
)
|
56
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
57
|
+
|
58
|
+
# Op-specific option
|
59
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.GatherOptions
|
60
|
+
option = circle.GatherOptions.GatherOptionsT()
|
61
|
+
option.axis = dim
|
62
|
+
# TODO option.batchDims
|
63
|
+
operator.builtinOptions = option
|
64
|
+
|
65
|
+
return operator
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import SigmoidArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class SigmoidVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.sigmoid.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
# `Sigmoid` operation is implemented as a `Logistic` operation in circle.
|
43
|
+
# https://github.com/Samsung/ONE/blob/170382a/nnpackage/schema/circle_schema.fbs#L288
|
44
|
+
circle.BuiltinOperator.BuiltinOperator.LOGISTIC,
|
45
|
+
self._op_codes,
|
46
|
+
)
|
47
|
+
|
48
|
+
args = SigmoidArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
49
|
+
input = args.input
|
50
|
+
|
51
|
+
inputs = [input]
|
52
|
+
outputs = [node]
|
53
|
+
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
55
|
+
|
56
|
+
return operator
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import SinArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class SinVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.sin.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.SIN, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = SinArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,155 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import copy
|
16
|
+
from typing import Dict, List, TYPE_CHECKING
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch._ops
|
20
|
+
import torch.fx
|
21
|
+
import torch
|
22
|
+
from circle_schema import circle
|
23
|
+
|
24
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.errors import InvalidArgumentError
|
29
|
+
from tico.utils.validate_args_kwargs import SliceArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class SliceCopyVisitor(NodeVisitor):
|
34
|
+
"""
|
35
|
+
NOTE `torch.slice_copy`'s behavior matches with `strided slice` of CIRCLE, not `slice`.
|
36
|
+
"""
|
37
|
+
|
38
|
+
target: List[torch._ops.OpOverload] = [
|
39
|
+
torch.ops.aten.slice.Tensor,
|
40
|
+
torch.ops.aten.slice_copy.Tensor,
|
41
|
+
]
|
42
|
+
|
43
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
44
|
+
super().__init__(op_codes, graph)
|
45
|
+
|
46
|
+
def define_node(
|
47
|
+
self,
|
48
|
+
node: torch.fx.Node,
|
49
|
+
) -> circle.Operator.OperatorT:
|
50
|
+
op_index = get_op_index(
|
51
|
+
circle.BuiltinOperator.BuiltinOperator.STRIDED_SLICE, self._op_codes
|
52
|
+
)
|
53
|
+
|
54
|
+
args = SliceArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
55
|
+
input = args.input
|
56
|
+
dim = args.dim
|
57
|
+
start = args.start
|
58
|
+
end = args.end
|
59
|
+
step = args.step
|
60
|
+
|
61
|
+
input_tensor: circle.Tensor.TensorT = self.graph.get_tensor(input)
|
62
|
+
input_shape: List[int] = input_tensor.shape
|
63
|
+
|
64
|
+
if start is None:
|
65
|
+
start = 0
|
66
|
+
if end is None:
|
67
|
+
end = input_shape[dim]
|
68
|
+
if step is None:
|
69
|
+
step = 1
|
70
|
+
|
71
|
+
assert dim is not None
|
72
|
+
assert (
|
73
|
+
-len(input_shape) <= dim < len(input_shape)
|
74
|
+
), "Cannot reach here (Dimension Out of Range error must be thrown by torch)"
|
75
|
+
|
76
|
+
if dim < 0:
|
77
|
+
dim = dim % len(input_shape)
|
78
|
+
|
79
|
+
assert isinstance(start, int), type(start)
|
80
|
+
assert isinstance(end, int), type(end)
|
81
|
+
assert isinstance(step, int), type(step)
|
82
|
+
|
83
|
+
if start < -input_shape[dim]: # (-inf, -M)
|
84
|
+
"""
|
85
|
+
WHY is 0?
|
86
|
+
start = -input_shape[dim] % input_shape[dim]
|
87
|
+
"""
|
88
|
+
start = 0
|
89
|
+
elif -input_shape[dim] <= start < 0: # [-M, 0)
|
90
|
+
start %= input_shape[dim]
|
91
|
+
elif 0 <= start < input_shape[dim]: # [0, M)
|
92
|
+
start = start
|
93
|
+
elif input_shape[dim] <= start: # [M, +inf)
|
94
|
+
start = input_shape[dim]
|
95
|
+
else:
|
96
|
+
assert False, "Cannot reach here"
|
97
|
+
|
98
|
+
if end < -input_shape[dim]: # (-inf, -M)
|
99
|
+
"""
|
100
|
+
WHY is 0?
|
101
|
+
end = -input_shape[dim] % input_shape[dim]
|
102
|
+
"""
|
103
|
+
end = 0
|
104
|
+
elif -input_shape[dim] <= end < 0: # [-M, 0)
|
105
|
+
end %= input_shape[dim]
|
106
|
+
elif 0 <= end < input_shape[dim]: # [0, M)
|
107
|
+
end = end
|
108
|
+
elif input_shape[dim] <= end: # [M, +inf)
|
109
|
+
end = input_shape[dim]
|
110
|
+
else:
|
111
|
+
assert False, "Cannot reach here"
|
112
|
+
|
113
|
+
assert 0 <= dim and dim < len(input_shape), dim
|
114
|
+
assert 0 <= start and start < input_shape[dim], start
|
115
|
+
assert 0 <= end and end <= input_shape[dim], end
|
116
|
+
assert 0 < step, "Restriction of torch.slice_copy"
|
117
|
+
|
118
|
+
if end <= start:
|
119
|
+
"""
|
120
|
+
CONSTRAINTS
|
121
|
+
In torch, 'end <= start' condition generates zero tensor with a peculiar shape - ex. tensor([], size=(5,0,5))
|
122
|
+
In circle, it's not accepted at all.
|
123
|
+
"""
|
124
|
+
raise InvalidArgumentError(
|
125
|
+
f"end({end}) must be greater than start ({start})"
|
126
|
+
)
|
127
|
+
|
128
|
+
# Build new arguments
|
129
|
+
rank = len(input_shape)
|
130
|
+
|
131
|
+
begin_shape = [0] * rank
|
132
|
+
begin_shape[dim] = start
|
133
|
+
begin_shape_tensor = torch.as_tensor(begin_shape, dtype=torch.int32)
|
134
|
+
|
135
|
+
end_shape = copy.deepcopy(input_shape)
|
136
|
+
end_shape[dim] = end
|
137
|
+
end_shape_tensor = torch.as_tensor(end_shape, dtype=torch.int32)
|
138
|
+
|
139
|
+
stride_shape = [1] * rank
|
140
|
+
stride_shape[dim] = step
|
141
|
+
stride_shape_tensor = torch.as_tensor(stride_shape, dtype=torch.int32)
|
142
|
+
|
143
|
+
inputs = [input, begin_shape_tensor, end_shape_tensor, stride_shape_tensor]
|
144
|
+
outputs = [node]
|
145
|
+
|
146
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
147
|
+
|
148
|
+
operator.builtinOptionsType = (
|
149
|
+
circle.BuiltinOptions.BuiltinOptions.StridedSliceOptions
|
150
|
+
)
|
151
|
+
|
152
|
+
option = circle.StridedSliceOptions.StridedSliceOptionsT()
|
153
|
+
|
154
|
+
operator.builtinOptions = option
|
155
|
+
return operator
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.errors import NotYetSupportedError
|
27
|
+
from tico.utils.utils import HAS_TORCH_OVER_25
|
28
|
+
from tico.utils.validate_args_kwargs import SafeSoftmaxArgs, SoftmaxArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class SoftMaxVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = (
|
34
|
+
[
|
35
|
+
torch.ops.aten._softmax.default,
|
36
|
+
# NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
|
37
|
+
# In order for optimization during inference, it can be replaced to softmax.
|
38
|
+
# ref: https://github.com/pytorch/pytorch/pull/133882
|
39
|
+
torch.ops.aten._safe_softmax.default,
|
40
|
+
]
|
41
|
+
if HAS_TORCH_OVER_25
|
42
|
+
else [
|
43
|
+
torch.ops.aten._softmax.default,
|
44
|
+
]
|
45
|
+
)
|
46
|
+
|
47
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
48
|
+
super().__init__(op_codes, graph)
|
49
|
+
|
50
|
+
def define_softmax_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
51
|
+
op_index = get_op_index(
|
52
|
+
circle.BuiltinOperator.BuiltinOperator.SOFTMAX, self._op_codes
|
53
|
+
)
|
54
|
+
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
operator.builtinOptionsType = (
|
57
|
+
circle.BuiltinOptions.BuiltinOptions.SoftmaxOptions
|
58
|
+
)
|
59
|
+
option = circle.SoftmaxOptions.SoftmaxOptionsT()
|
60
|
+
option.beta = 1.0
|
61
|
+
operator.builtinOptions = option
|
62
|
+
return operator
|
63
|
+
|
64
|
+
def define_node(
|
65
|
+
self,
|
66
|
+
node: torch.fx.Node,
|
67
|
+
) -> circle.Operator.OperatorT:
|
68
|
+
"""
|
69
|
+
Note that Currently, Softmax operator is supported only when `dim` is last dimension and `half_to_float` is False.
|
70
|
+
"""
|
71
|
+
if node.target == torch.ops.aten._softmax.default:
|
72
|
+
# aten._softmax
|
73
|
+
args = SoftmaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
74
|
+
half_to_float: bool = args.half_to_float
|
75
|
+
if half_to_float:
|
76
|
+
raise NotYetSupportedError(
|
77
|
+
"softmax with half to float conversion is not supported on circle."
|
78
|
+
)
|
79
|
+
elif node.target == torch.ops.aten._safe_softmax.default:
|
80
|
+
# aten._safe_softmax
|
81
|
+
args = SafeSoftmaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type, assignment]
|
82
|
+
|
83
|
+
input: torch.fx.Node = args.input
|
84
|
+
dim: int = args.dim
|
85
|
+
|
86
|
+
input_tid: int = self.graph.get_tid_registered(input)
|
87
|
+
input_tensor: circle.Tensor.TensorT = self.graph.tensors[input_tid]
|
88
|
+
input_shape: List[int] = input_tensor.shape
|
89
|
+
|
90
|
+
if dim < 0:
|
91
|
+
dim = dim % len(input_shape)
|
92
|
+
|
93
|
+
if dim == len(input_shape) - 1:
|
94
|
+
inputs = [input]
|
95
|
+
outputs = [node]
|
96
|
+
operator = self.define_softmax_node(inputs, outputs)
|
97
|
+
else:
|
98
|
+
raise NotYetSupportedError("softmax only supports last dimension for now.")
|
99
|
+
|
100
|
+
return operator
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING, Union
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
23
|
+
|
24
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
25
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to, to_circle_dtype
|
26
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
27
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
28
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
29
|
+
from tico.utils.validate_args_kwargs import SplitWithSizesArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class SplitWithSizesVisitor(NodeVisitor):
|
34
|
+
target: List[torch._ops.OpOverload] = [
|
35
|
+
torch.ops.aten.split_with_sizes.default,
|
36
|
+
]
|
37
|
+
|
38
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
39
|
+
super().__init__(op_codes, graph)
|
40
|
+
|
41
|
+
def define_node(
|
42
|
+
self,
|
43
|
+
node: torch.fx.Node,
|
44
|
+
) -> circle.Operator.OperatorT:
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.SPLIT_V, self._op_codes
|
47
|
+
)
|
48
|
+
args = SplitWithSizesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
49
|
+
input = args.input
|
50
|
+
split_sizes = args.split_sizes
|
51
|
+
axis = args.dim
|
52
|
+
|
53
|
+
split_sizes_i32 = [
|
54
|
+
circle_legalize_dtype_to(split_size, dtype=torch.int32)
|
55
|
+
for split_size in split_sizes
|
56
|
+
]
|
57
|
+
axis_i32 = circle_legalize_dtype_to(axis, dtype=torch.int32)
|
58
|
+
inputs = [input, split_sizes_i32, axis_i32]
|
59
|
+
|
60
|
+
"""
|
61
|
+
`split_with_sizes` has multiple output tensors and they are represented as `getitem`.
|
62
|
+
Therefore, unlike other ops, node itself doesn't become a circle tensor. Instead, each `getitem` will be
|
63
|
+
a circle tensor.
|
64
|
+
Further, torch module having `split_with_sizes` may somtimes return selected outputs. At that time, `getitem`
|
65
|
+
nodes are generated only for the ouptut selected. Since one-compiler assumes that `CircleSplitV` always has
|
66
|
+
all the outputs, let's add unused output tensors to compensate this restriction.
|
67
|
+
"""
|
68
|
+
outputs: List[Union[circle.Tensor.TensorT, torch.fx.node.Node]] = []
|
69
|
+
sorted_users = sorted(node.users.keys(), key=lambda x: x.args[1]) # type: ignore[arg-type, return-value]
|
70
|
+
users_indices = list(usrnode.args[1] for usrnode in sorted_users)
|
71
|
+
user_it = iter(sorted_users)
|
72
|
+
for idx, _ in enumerate(split_sizes):
|
73
|
+
if idx in users_indices:
|
74
|
+
user_node = next(user_it)
|
75
|
+
outputs.append(user_node)
|
76
|
+
else:
|
77
|
+
# Let's add unused output tensor to satisfy circle split_v operator scheme
|
78
|
+
node_val = node.meta.get("val")
|
79
|
+
assert isinstance(node_val, list)
|
80
|
+
fake_tensor = node_val[idx]
|
81
|
+
assert isinstance(fake_tensor, FakeTensor)
|
82
|
+
shape = list(fake_tensor.size())
|
83
|
+
dtype = to_circle_dtype(fake_tensor.dtype)
|
84
|
+
tensor = self.graph.add_tensor_from_scratch(
|
85
|
+
f"{node.name}_unused_{idx}", shape, dtype
|
86
|
+
)
|
87
|
+
outputs.append(tensor)
|
88
|
+
|
89
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
90
|
+
|
91
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.SplitVOptions
|
92
|
+
option = circle.SplitVOptions.SplitVOptionsT()
|
93
|
+
option.numSplits = len(split_sizes)
|
94
|
+
operator.builtinOptions = option
|
95
|
+
|
96
|
+
return operator
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import SqrtArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class SqrtVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.sqrt.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.SQRT, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = SqrtArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
# `sqrt` does not have option
|
54
|
+
|
55
|
+
return operator
|
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import SqueezeArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class SqueezeVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.squeeze.dims,
|
34
|
+
torch.ops.aten.squeeze_copy.dims,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.SQUEEZE,
|
46
|
+
self._op_codes,
|
47
|
+
)
|
48
|
+
|
49
|
+
args = SqueezeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
50
|
+
input = args.input
|
51
|
+
dims = args.dims
|
52
|
+
|
53
|
+
inputs = [input]
|
54
|
+
outputs = [node]
|
55
|
+
|
56
|
+
squeeze_dims: List = []
|
57
|
+
shape = input.meta["val"].size()
|
58
|
+
if dims:
|
59
|
+
squeeze_dims += [axis for axis in dims if shape[axis] == 1]
|
60
|
+
|
61
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
62
|
+
|
63
|
+
# Op-specific option
|
64
|
+
operator.builtinOptionsType = (
|
65
|
+
circle.BuiltinOptions.BuiltinOptions.SqueezeOptions
|
66
|
+
)
|
67
|
+
option = circle.SqueezeOptions.SqueezeOptionsT()
|
68
|
+
if squeeze_dims:
|
69
|
+
option.squeezeDims = squeeze_dims
|
70
|
+
|
71
|
+
operator.builtinOptions = option
|
72
|
+
|
73
|
+
return operator
|