tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from circle_schema import circle
|
21
|
+
|
22
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
23
|
+
from tico.serialize.circle_mapping import (
|
24
|
+
circle_legalize_dtype_to,
|
25
|
+
extract_circle_dtype,
|
26
|
+
extract_shape,
|
27
|
+
extract_torch_dtype,
|
28
|
+
)
|
29
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
30
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
31
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
32
|
+
from tico.utils.validate_args_kwargs import AnyArgs
|
33
|
+
|
34
|
+
|
35
|
+
@register_node_visitor
|
36
|
+
class AnyVisitor(NodeVisitor):
|
37
|
+
"""
|
38
|
+
Let's take NotEqual0 -> ReduceMax workaround for float, int
|
39
|
+
[RESTRICTION]
|
40
|
+
1. ReduceAny is not supported (luci-interpreter)
|
41
|
+
[CASE: BOOL]
|
42
|
+
(Bool tensors don't need 'Not Equal 0' at the first step.)
|
43
|
+
bool[d0..dN] --- Reduce Max ---> bool[]
|
44
|
+
[CASE: FLOAT, INT]
|
45
|
+
int/float[d0..dN] --- Not Equal 0 ---> bool[d0,...dN]
|
46
|
+
--- Reduce Max ---> bool[]
|
47
|
+
* [d0..dN] means a tensor with any shape
|
48
|
+
* [] means Scalar
|
49
|
+
"""
|
50
|
+
|
51
|
+
target: List[torch._ops.OpOverload] = [
|
52
|
+
torch.ops.aten.any.default,
|
53
|
+
torch.ops.aten.any.dim,
|
54
|
+
torch.ops.aten.any.dims,
|
55
|
+
]
|
56
|
+
|
57
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
58
|
+
super().__init__(op_codes, graph)
|
59
|
+
|
60
|
+
def define_max_node(
|
61
|
+
self, inputs: List, outputs: List, keepdims: bool
|
62
|
+
) -> circle.Operator.OperatorT:
|
63
|
+
op_index = get_op_index(
|
64
|
+
circle.BuiltinOperator.BuiltinOperator.REDUCE_MAX, self._op_codes
|
65
|
+
)
|
66
|
+
|
67
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
68
|
+
|
69
|
+
operator.builtinOptionsType = (
|
70
|
+
circle.BuiltinOptions.BuiltinOptions.ReducerOptions
|
71
|
+
)
|
72
|
+
option = circle.ReducerOptions.ReducerOptionsT()
|
73
|
+
option.keepDims = keepdims
|
74
|
+
|
75
|
+
operator.builtinOptions = option
|
76
|
+
|
77
|
+
return operator
|
78
|
+
|
79
|
+
def define_ne_node(self, inputs: List, outputs: List) -> circle.Operator.OperatorT:
|
80
|
+
op_index = get_op_index(
|
81
|
+
circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
|
82
|
+
)
|
83
|
+
|
84
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
85
|
+
|
86
|
+
operator.builtinOptionsType = (
|
87
|
+
circle.BuiltinOptions.BuiltinOptions.NotEqualOptions
|
88
|
+
)
|
89
|
+
option = circle.NotEqualOptions.NotEqualOptionsT()
|
90
|
+
operator.builtinOptions = option
|
91
|
+
return operator
|
92
|
+
|
93
|
+
def define_node(
|
94
|
+
self,
|
95
|
+
node: torch.fx.Node,
|
96
|
+
) -> circle.Operator.OperatorT:
|
97
|
+
args = AnyArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
98
|
+
input = args.input
|
99
|
+
dim = args.dim
|
100
|
+
keepdim = args.keepdim
|
101
|
+
|
102
|
+
input_shape = list(extract_shape(input))
|
103
|
+
output_shape = list(extract_shape(node))
|
104
|
+
|
105
|
+
dim_i32 = None
|
106
|
+
if dim is None:
|
107
|
+
dims = tuple(i for i in range(0, len(input_shape)))
|
108
|
+
dim_i32 = tuple(
|
109
|
+
circle_legalize_dtype_to(dim, dtype=torch.int32) for dim in dims
|
110
|
+
)
|
111
|
+
if isinstance(dim, int):
|
112
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
113
|
+
if isinstance(dim, tuple):
|
114
|
+
dim_i32 = tuple(circle_legalize_dtype_to(d, dtype=torch.int32) for d in dim)
|
115
|
+
assert dim_i32 is not None
|
116
|
+
|
117
|
+
inputs = [
|
118
|
+
input,
|
119
|
+
dim_i32,
|
120
|
+
] # type: ignore[list-item]
|
121
|
+
outputs = [node]
|
122
|
+
|
123
|
+
dtype_torch = extract_torch_dtype(input)
|
124
|
+
input_tensor: torch.fx.node.Node | circle.Tensor.TensorT = input
|
125
|
+
|
126
|
+
if dtype_torch in [torch.int32, torch.int64, torch.float32, torch.float64]:
|
127
|
+
dst_dtype_circle = circle.TensorType.TensorType.BOOL
|
128
|
+
dst_dtype_torch = torch.bool
|
129
|
+
ne_tensor: circle.Tensor.TensorT = self.graph.add_tensor_from_scratch(
|
130
|
+
prefix=f"{input.name}_ne",
|
131
|
+
shape=input_shape,
|
132
|
+
dtype=dst_dtype_circle,
|
133
|
+
source_node=input,
|
134
|
+
)
|
135
|
+
ne_node = self.define_ne_node(
|
136
|
+
[input_tensor, torch.Tensor([0]).to(dtype_torch)], [ne_tensor]
|
137
|
+
)
|
138
|
+
self.graph.add_operator(ne_node)
|
139
|
+
|
140
|
+
dtype_torch = dst_dtype_torch
|
141
|
+
input_tensor = ne_tensor
|
142
|
+
inputs = [ne_tensor, dim_i32]
|
143
|
+
|
144
|
+
inputs = [input_tensor, dim_i32]
|
145
|
+
|
146
|
+
reduce_node: circle.Operator.OperatorT = self.define_max_node(
|
147
|
+
inputs, outputs, keepdim
|
148
|
+
)
|
149
|
+
|
150
|
+
return reduce_node
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.utils.validate_args_kwargs import ArangeStartStepArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class ArangeStartStepVisitor(NodeVisitor):
|
31
|
+
"""
|
32
|
+
Fuse arange_start_step to const_tensor
|
33
|
+
"""
|
34
|
+
|
35
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.arange.start_step]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(
|
41
|
+
self,
|
42
|
+
node: torch.fx.Node,
|
43
|
+
) -> circle.Operator.OperatorT:
|
44
|
+
args = ArangeStartStepArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
45
|
+
start = args.start
|
46
|
+
end = args.end
|
47
|
+
step = args.step
|
48
|
+
delta = 1
|
49
|
+
|
50
|
+
if step is not None:
|
51
|
+
delta = step[0] # type: ignore[index]
|
52
|
+
# assert False, "This pass must not be in use."
|
53
|
+
|
54
|
+
arange_dtype: torch.dtype = torch.float32
|
55
|
+
if isinstance(start, int) and isinstance(end, int):
|
56
|
+
arange_dtype = torch.int64
|
57
|
+
|
58
|
+
output_data = torch.arange(start=start, end=end, step=delta, dtype=arange_dtype)
|
59
|
+
self.graph.update_tensor_buffer(output_data, node.name)
|
60
|
+
|
61
|
+
return None # type: ignore[return-value]
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import ArgMaxArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class ArgMaxVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.argmax.default]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
args = ArgMaxArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
43
|
+
tensor = args.tensor
|
44
|
+
dim = args.dim
|
45
|
+
|
46
|
+
op_index = get_op_index(
|
47
|
+
circle.BuiltinOperator.BuiltinOperator.ARG_MAX, self._op_codes
|
48
|
+
)
|
49
|
+
|
50
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
|
+
inputs = [tensor, dim_i32]
|
52
|
+
outputs = [node]
|
53
|
+
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
55
|
+
|
56
|
+
# Op-specific option
|
57
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.ArgMaxOptions
|
58
|
+
option = circle.ArgMaxOptions.ArgMaxOptionsT()
|
59
|
+
option.outputType = circle.TensorType.TensorType.INT64
|
60
|
+
operator.builtinOptions = option
|
61
|
+
|
62
|
+
return operator
|
@@ -0,0 +1,192 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from typing import Dict, List, TYPE_CHECKING
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch._ops
|
20
|
+
import torch.fx
|
21
|
+
import torch
|
22
|
+
from circle_schema import circle
|
23
|
+
|
24
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
25
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
26
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
27
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
28
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
29
|
+
from tico.utils.define import define_pad_node
|
30
|
+
from tico.utils.errors import NotYetSupportedError
|
31
|
+
from tico.utils.validate_args_kwargs import AvgPool2dArgs
|
32
|
+
|
33
|
+
|
34
|
+
@register_node_visitor
|
35
|
+
class AvgPool2DVisitor(NodeVisitor):
|
36
|
+
"""
|
37
|
+
This class defines how to serialize AvgPool2D operation into Circle IR.
|
38
|
+
|
39
|
+
Torch | Circle
|
40
|
+
|
41
|
+
count_include_pad: True/False | (count_include_pad): Always False
|
42
|
+
padding: number (could be valid, same, or etc) | padding: "valid"/"same"
|
43
|
+
|
44
|
+
* Circle's avgpool2d has no option for count_include_pad, so we always set it as False.
|
45
|
+
"""
|
46
|
+
|
47
|
+
target: List[torch._ops.OpOverload] = [torch.ops.circle_custom.avgpool2d]
|
48
|
+
|
49
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
50
|
+
super().__init__(op_codes, graph)
|
51
|
+
|
52
|
+
def has_padding(self, args: AvgPool2dArgs) -> bool:
|
53
|
+
padding = args.padding
|
54
|
+
if padding[0] == 0 and padding[1] == 0:
|
55
|
+
return False
|
56
|
+
else:
|
57
|
+
return True
|
58
|
+
|
59
|
+
def has_same_padding(self, args: AvgPool2dArgs) -> bool:
|
60
|
+
input_shape = list(extract_shape(args.input))
|
61
|
+
kernel_size = args.kernel_size
|
62
|
+
stride = args.stride
|
63
|
+
assert stride
|
64
|
+
padding = args.padding
|
65
|
+
# TODO Update this function when supporting ceil_mode = True
|
66
|
+
assert args.ceil_mode is False
|
67
|
+
output_height = math.floor(
|
68
|
+
(input_shape[1] + padding[0] * 2 - kernel_size[0]) / stride[0] + 1
|
69
|
+
)
|
70
|
+
output_width = math.floor(
|
71
|
+
(input_shape[2] + padding[1] * 2 - kernel_size[1]) / stride[1] + 1
|
72
|
+
)
|
73
|
+
|
74
|
+
return input_shape[1] == output_height and input_shape[2] == output_width
|
75
|
+
|
76
|
+
def define_avgpool_node(self, inputs, outputs, padding, stride, kernel_size):
|
77
|
+
op_index = get_op_index(
|
78
|
+
circle.BuiltinOperator.BuiltinOperator.AVERAGE_POOL_2D,
|
79
|
+
self._op_codes,
|
80
|
+
)
|
81
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
82
|
+
|
83
|
+
# Op-specific option
|
84
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.Pool2DOptions
|
85
|
+
option = circle.Pool2DOptions.Pool2DOptionsT()
|
86
|
+
|
87
|
+
assert padding in {"SAME": 0, "VALID": 1}
|
88
|
+
|
89
|
+
option.padding = {"SAME": 0, "VALID": 1}[padding]
|
90
|
+
option.strideH = stride[0]
|
91
|
+
option.strideW = stride[1]
|
92
|
+
option.filterHeight = kernel_size[0]
|
93
|
+
option.filterWidth = kernel_size[1]
|
94
|
+
option.fusedActivationFunction = (
|
95
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
96
|
+
)
|
97
|
+
|
98
|
+
operator.builtinOptions = option
|
99
|
+
return operator
|
100
|
+
|
101
|
+
def define_node(
|
102
|
+
self,
|
103
|
+
node: torch.fx.Node,
|
104
|
+
) -> circle.Operator.OperatorT:
|
105
|
+
"""
|
106
|
+
PSEUDO CODE
|
107
|
+
|
108
|
+
if count_include_pad == True:
|
109
|
+
(Circle cannot represent count_include_pad=True in AvgPool2D. Therefore we manually add zero padding node.)
|
110
|
+
DEFINE zero padding node
|
111
|
+
DEFINE avgpool node with no padding (valid)
|
112
|
+
if count_include_pad == False:
|
113
|
+
(Lucky! Circle can represent count_include_pad=False)
|
114
|
+
DEFINE avgpool node with same/valid padding.
|
115
|
+
|
116
|
+
(However, it cannot represent all paddings. So, if the padding is not same or valid, we throw an error.)
|
117
|
+
if the paddding is neither same nor valid:
|
118
|
+
THROW an error.
|
119
|
+
"""
|
120
|
+
args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
121
|
+
input = args.input
|
122
|
+
kernel_size = args.kernel_size
|
123
|
+
stride = args.stride
|
124
|
+
padding = args.padding
|
125
|
+
count_include_pad = args.count_include_pad
|
126
|
+
|
127
|
+
avgpool_input: torch.fx.Node | circle.Tensor.TensorT = input
|
128
|
+
|
129
|
+
def define_padding_node():
|
130
|
+
assert isinstance(padding, list), type(padding)
|
131
|
+
padding_vec = torch.tensor(
|
132
|
+
[
|
133
|
+
[0, 0],
|
134
|
+
[padding[0], padding[0]],
|
135
|
+
[padding[1], padding[1]],
|
136
|
+
[0, 0],
|
137
|
+
],
|
138
|
+
dtype=torch.int32,
|
139
|
+
)
|
140
|
+
input_shape = list(extract_shape(input))
|
141
|
+
input_dtype: int = extract_circle_dtype(input)
|
142
|
+
padded_input_shape = [
|
143
|
+
input_shape[0],
|
144
|
+
input_shape[1],
|
145
|
+
input_shape[2],
|
146
|
+
input_shape[3],
|
147
|
+
]
|
148
|
+
padded_input_shape[1] += padding[0] * 2
|
149
|
+
padded_input_shape[2] += padding[1] * 2
|
150
|
+
# create padded input tensor
|
151
|
+
padded_input_tensor = self.graph.add_tensor_from_scratch(
|
152
|
+
prefix=f"{input.name}_pad_output",
|
153
|
+
shape=padded_input_shape,
|
154
|
+
dtype=input_dtype,
|
155
|
+
source_node=node,
|
156
|
+
)
|
157
|
+
pad_operator = define_pad_node(
|
158
|
+
self.graph, self._op_codes, [input, padding_vec], [padded_input_tensor]
|
159
|
+
)
|
160
|
+
self.graph.add_operator(pad_operator)
|
161
|
+
return padded_input_tensor
|
162
|
+
|
163
|
+
if count_include_pad is True:
|
164
|
+
# Add padding before avgpool2d
|
165
|
+
# Circle's avgpool2d does not support count_include_pad=True, so we need to add padding manually
|
166
|
+
if self.has_padding(args):
|
167
|
+
avgpool_input = define_padding_node()
|
168
|
+
|
169
|
+
result = self.define_avgpool_node(
|
170
|
+
[avgpool_input], [node], "VALID", stride, kernel_size
|
171
|
+
)
|
172
|
+
elif count_include_pad is False:
|
173
|
+
if not self.has_padding(args): # valid padding
|
174
|
+
result = self.define_avgpool_node(
|
175
|
+
[avgpool_input], [node], "VALID", stride, kernel_size
|
176
|
+
)
|
177
|
+
elif self.has_same_padding(args):
|
178
|
+
result = self.define_avgpool_node(
|
179
|
+
[avgpool_input], [node], "SAME", stride, kernel_size
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
# CASE: count_include_pad is False and not VALID/SAME padding
|
183
|
+
#
|
184
|
+
# Implement this when it's needed.
|
185
|
+
# If needed, may it help: the idea of ratio masking in https://github.com/Samsung/TICO/pull/119
|
186
|
+
raise NotYetSupportedError(
|
187
|
+
f"Padding({padding}) with count_include_pad({count_include_pad}) is not supported yet."
|
188
|
+
)
|
189
|
+
else:
|
190
|
+
raise RuntimeError("Cannot reach here")
|
191
|
+
|
192
|
+
return result
|
@@ -0,0 +1,62 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import BmmArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class BatchMatmulVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.bmm.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
args = BmmArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
42
|
+
input = args.input
|
43
|
+
mat2 = args.mat2
|
44
|
+
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
47
|
+
)
|
48
|
+
|
49
|
+
inputs = [input, mat2]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
# Op-specific option
|
55
|
+
operator.builtinOptionsType = (
|
56
|
+
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
57
|
+
)
|
58
|
+
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
59
|
+
option.adjointLhs, option.adjointRhs = False, False
|
60
|
+
operator.builtinOptions = option
|
61
|
+
|
62
|
+
return operator
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
24
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
|
+
from tico.utils.validate_args_kwargs import CatArgs
|
27
|
+
|
28
|
+
|
29
|
+
@register_node_visitor
|
30
|
+
class CatVisitor(NodeVisitor):
|
31
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.cat.default]
|
32
|
+
|
33
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
34
|
+
super().__init__(op_codes, graph)
|
35
|
+
|
36
|
+
def define_node(
|
37
|
+
self,
|
38
|
+
node: torch.fx.Node,
|
39
|
+
) -> circle.Operator.OperatorT:
|
40
|
+
args = CatArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
41
|
+
tensors = args.tensors
|
42
|
+
dim = args.dim
|
43
|
+
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.CONCATENATION, self._op_codes
|
46
|
+
)
|
47
|
+
|
48
|
+
inputs = tensors
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
# Op-specific option
|
54
|
+
operator.builtinOptionsType = (
|
55
|
+
circle.BuiltinOptions.BuiltinOptions.ConcatenationOptions
|
56
|
+
)
|
57
|
+
option = circle.ConcatenationOptions.ConcatenationOptionsT()
|
58
|
+
|
59
|
+
option.axis = dim
|
60
|
+
|
61
|
+
option.fusedActivationFunction = (
|
62
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
63
|
+
)
|
64
|
+
operator.builtinOptions = option
|
65
|
+
|
66
|
+
return operator
|