tico 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +42 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/quantize_bias.py +123 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +191 -0
- tico/passes/cast_mixed_type_args.py +187 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +160 -0
- tico/passes/convert_layout_op_to_reshape.py +85 -0
- tico/passes/convert_repeat_to_expand_copy.py +89 -0
- tico/passes/convert_to_relu6.py +181 -0
- tico/passes/decompose_addmm.py +124 -0
- tico/passes/decompose_batch_norm.py +192 -0
- tico/passes/decompose_fake_quantize.py +134 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
- tico/passes/decompose_group_norm.py +275 -0
- tico/passes/decompose_grouped_conv2d.py +209 -0
- tico/passes/decompose_slice_scatter.py +169 -0
- tico/passes/extract_dtype_kwargs.py +122 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +108 -0
- tico/passes/legalize_predefined_layout_operators.py +386 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
- tico/passes/lower_to_slice.py +230 -0
- tico/passes/merge_consecutive_cat.py +80 -0
- tico/passes/ops.py +78 -0
- tico/passes/remove_nop.py +84 -0
- tico/passes/remove_redundant_assert_nodes.py +51 -0
- tico/passes/remove_redundant_expand.py +66 -0
- tico/passes/remove_redundant_permute.py +122 -0
- tico/passes/remove_redundant_reshape.py +436 -0
- tico/passes/remove_redundant_slice.py +62 -0
- tico/passes/remove_redundant_to_copy.py +86 -0
- tico/passes/restore_linear.py +115 -0
- tico/passes/segment_index_select.py +145 -0
- tico/pt2_to_circle.py +105 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +319 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +240 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_abs.py +53 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +150 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +192 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +126 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +186 -0
- tico/serialize/operators/op_copy.py +164 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +95 -0
- tico/serialize/operators/op_depthwise_conv2d.py +199 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_leaky_relu.py +60 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +86 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_dim.py +70 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +177 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +141 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +100 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +99 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +94 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +296 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +282 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/mx/__init__.py +1 -0
- tico/utils/mx/elemwise_ops.py +267 -0
- tico/utils/mx/formats.py +125 -0
- tico/utils/mx/mx_ops.py +270 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +609 -0
- tico/utils/serialize.py +42 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +406 -0
- tico/utils/validate_args_kwargs.py +1149 -0
- tico-0.1.0.dist-info/LICENSE +241 -0
- tico-0.1.0.dist-info/METADATA +354 -0
- tico-0.1.0.dist-info/RECORD +206 -0
- tico-0.1.0.dist-info/WHEEL +5 -0
- tico-0.1.0.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import MeanDimArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class MeanVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.mean.dim]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.MEAN, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = MeanDimArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
dim = args.dim
|
48
|
+
keep_dims = args.keepdim
|
49
|
+
|
50
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
51
|
+
inputs = [input, dim_i32]
|
52
|
+
outputs = [node]
|
53
|
+
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
55
|
+
|
56
|
+
# Op-specific option
|
57
|
+
operator.builtinOptionsType = (
|
58
|
+
circle.BuiltinOptions.BuiltinOptions.ReducerOptions
|
59
|
+
)
|
60
|
+
option = circle.ReducerOptions.ReducerOptionsT()
|
61
|
+
if keep_dims:
|
62
|
+
option.keepDims = keep_dims
|
63
|
+
|
64
|
+
operator.builtinOptions = option
|
65
|
+
|
66
|
+
return operator
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import MinimumArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class MinimumVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.minimum.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
40
|
+
args = MinimumArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
41
|
+
input = args.input
|
42
|
+
other = args.other
|
43
|
+
|
44
|
+
op_index = get_op_index(
|
45
|
+
circle.BuiltinOperator.BuiltinOperator.MINIMUM, self._op_codes
|
46
|
+
)
|
47
|
+
|
48
|
+
inputs = [input, other]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
return operator
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph, is_const
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import MatmulArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class MatmulDefaultVisitor(NodeVisitor):
|
32
|
+
"""
|
33
|
+
Convert matmul to equavalent BatchMatMul or FullyConnected with Transpose.
|
34
|
+
"""
|
35
|
+
|
36
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.mm.default]
|
37
|
+
|
38
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
39
|
+
super().__init__(op_codes, graph)
|
40
|
+
|
41
|
+
# NOTE: Matmul is equivalent to Batch MatMul (batch=1)
|
42
|
+
def define_bmm_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
43
|
+
def set_bmm_option(operator):
|
44
|
+
operator.builtinOptionsType = (
|
45
|
+
circle.BuiltinOptions.BuiltinOptions.BatchMatMulOptions
|
46
|
+
)
|
47
|
+
option = circle.BatchMatMulOptions.BatchMatMulOptionsT()
|
48
|
+
option.adjointLhs, option.adjointRhs = False, False
|
49
|
+
option.asymmetricQuantizeInputs = False
|
50
|
+
operator.builtinOptions = option
|
51
|
+
|
52
|
+
op_index = get_op_index(
|
53
|
+
circle.BuiltinOperator.BuiltinOperator.BATCH_MATMUL, self._op_codes
|
54
|
+
)
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
set_bmm_option(operator)
|
57
|
+
|
58
|
+
return operator
|
59
|
+
|
60
|
+
def define_transpose_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
61
|
+
def set_transpose_option(operator):
|
62
|
+
operator.builtinOptionsType = (
|
63
|
+
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
64
|
+
)
|
65
|
+
option = circle.TransposeOptions.TransposeOptionsT()
|
66
|
+
operator.builtinOptions = option
|
67
|
+
|
68
|
+
transpose_op_index = get_op_index(
|
69
|
+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE, self._op_codes
|
70
|
+
)
|
71
|
+
operator = create_builtin_operator(
|
72
|
+
self.graph, transpose_op_index, inputs, outputs
|
73
|
+
)
|
74
|
+
set_transpose_option(operator)
|
75
|
+
return operator
|
76
|
+
|
77
|
+
def define_fc_node(self, inputs, outputs) -> circle.Operator.OperatorT:
|
78
|
+
def set_fc_option(operator):
|
79
|
+
operator.builtinOptionsType = (
|
80
|
+
circle.BuiltinOptions.BuiltinOptions.FullyConnectedOptions
|
81
|
+
)
|
82
|
+
option = circle.FullyConnectedOptions.FullyConnectedOptionsT()
|
83
|
+
|
84
|
+
option.fusedActivationFunction = (
|
85
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
86
|
+
)
|
87
|
+
option.weightsFormat = (
|
88
|
+
circle.FullyConnectedOptionsWeightsFormat.FullyConnectedOptionsWeightsFormat.DEFAULT
|
89
|
+
)
|
90
|
+
option.keepNumDims = False
|
91
|
+
option.asymmetricQuantizeInputs = False
|
92
|
+
option.quantizedBiasType = circle.TensorType.TensorType.FLOAT32
|
93
|
+
|
94
|
+
operator.builtinOptions = option
|
95
|
+
|
96
|
+
fc_op_index = get_op_index(
|
97
|
+
circle.BuiltinOperator.BuiltinOperator.FULLY_CONNECTED, self._op_codes
|
98
|
+
)
|
99
|
+
operator = create_builtin_operator(self.graph, fc_op_index, inputs, outputs)
|
100
|
+
set_fc_option(operator)
|
101
|
+
return operator
|
102
|
+
|
103
|
+
"""
|
104
|
+
Define FullyConnnected with Tranpose operator.
|
105
|
+
Note that those sets of operators are equivalent.
|
106
|
+
(1) Matmul
|
107
|
+
matmul( lhs[H, K], rhs[K, W'] ) -> output(H, W')
|
108
|
+
|
109
|
+
(2) Transpose + FullyConneccted
|
110
|
+
transpose( rhs[K, W'] ) -> trs_output[W', K]
|
111
|
+
fullyconnected( lhs[H, K], trs_output[W', K] ) -> output(H, W')
|
112
|
+
"""
|
113
|
+
|
114
|
+
def define_fc_with_transpose(
|
115
|
+
self, node, inputs, outputs
|
116
|
+
) -> circle.Operator.OperatorT:
|
117
|
+
lhs, rhs = inputs
|
118
|
+
|
119
|
+
# get transpose shape
|
120
|
+
rhs_tid: int = self.graph.get_tid_registered(rhs)
|
121
|
+
rhs_tensor: circle.Tensor.TensorT = self.graph.tensors[rhs_tid]
|
122
|
+
rhs_name: str = rhs.name
|
123
|
+
rhs_type: int = rhs_tensor.type
|
124
|
+
rhs_shape: List[int] = rhs_tensor.shape
|
125
|
+
assert len(rhs_shape) == 2, len(rhs_shape)
|
126
|
+
rhs_shape_transpose = [rhs_shape[1], rhs_shape[0]]
|
127
|
+
|
128
|
+
# create transpose output tensor
|
129
|
+
trs_output = self.graph.add_tensor_from_scratch(
|
130
|
+
prefix=f"{rhs_name}_transposed_output",
|
131
|
+
shape=rhs_shape_transpose,
|
132
|
+
dtype=rhs_type,
|
133
|
+
source_node=node,
|
134
|
+
)
|
135
|
+
trs_perm = self.graph.add_const_tensor(data=[1, 0], source_node=node)
|
136
|
+
trs_operator = self.define_transpose_node([rhs, trs_perm], [trs_output])
|
137
|
+
self.graph.add_operator(trs_operator)
|
138
|
+
|
139
|
+
# define fc node
|
140
|
+
fc_input = lhs
|
141
|
+
fc_weight = trs_output
|
142
|
+
fc_shape = [fc_weight.shape[0]]
|
143
|
+
fc_bias = self.graph.add_const_tensor(
|
144
|
+
data=[0.0] * fc_shape[0], source_node=node
|
145
|
+
)
|
146
|
+
|
147
|
+
operator = self.define_fc_node([fc_input, fc_weight, fc_bias], outputs)
|
148
|
+
|
149
|
+
return operator
|
150
|
+
|
151
|
+
def define_node(
|
152
|
+
self, node: torch.fx.Node, prior_latency=True
|
153
|
+
) -> circle.Operator.OperatorT:
|
154
|
+
"""
|
155
|
+
NOTE: Possibility of accuracy-latency trade-off
|
156
|
+
From ONE compiler's perspective:
|
157
|
+
- BMM uses per-tensor quantization for both rhs and lhs.
|
158
|
+
- FC uses per-channel quantization for weight and per-tensor for input.
|
159
|
+
Thus, FC is better in terms of accuracy.
|
160
|
+
FC necessarily involves an additional transpose operation to be identical with mm.
|
161
|
+
If transposed operand is const, it can be optimized by constant folding.
|
162
|
+
Thus, convert FC only if tranpose can be folded.
|
163
|
+
TODO set prior_latency outside
|
164
|
+
"""
|
165
|
+
args = MatmulArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
166
|
+
input = args.input
|
167
|
+
other = args.other
|
168
|
+
|
169
|
+
inputs = [input, other]
|
170
|
+
outputs = [node]
|
171
|
+
|
172
|
+
if not is_const(other) and prior_latency:
|
173
|
+
operator = self.define_bmm_node(inputs, outputs)
|
174
|
+
else:
|
175
|
+
operator = self.define_fc_with_transpose(node, inputs, outputs)
|
176
|
+
|
177
|
+
return operator
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import MulScalarArgs, MulTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
class BaseMulVisitor(NodeVisitor):
|
31
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
32
|
+
super().__init__(op_codes, graph)
|
33
|
+
|
34
|
+
def define_node(
|
35
|
+
self,
|
36
|
+
node: torch.fx.node.Node,
|
37
|
+
) -> circle.Operator.OperatorT:
|
38
|
+
op_index = get_op_index(
|
39
|
+
circle.BuiltinOperator.BuiltinOperator.MUL, self._op_codes
|
40
|
+
)
|
41
|
+
|
42
|
+
inputs = list(node.args)
|
43
|
+
outputs = [node]
|
44
|
+
|
45
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
46
|
+
|
47
|
+
# Op-specific option
|
48
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.MulOptions
|
49
|
+
option = circle.MulOptions.MulOptionsT()
|
50
|
+
option.fusedActivationFunction = (
|
51
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
52
|
+
)
|
53
|
+
operator.builtinOptions = option
|
54
|
+
|
55
|
+
return operator
|
56
|
+
|
57
|
+
|
58
|
+
@register_node_visitor
|
59
|
+
class MulTensorVisitor(BaseMulVisitor):
|
60
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.mul.Tensor]
|
61
|
+
|
62
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
63
|
+
super().__init__(op_codes, graph)
|
64
|
+
|
65
|
+
def define_node(
|
66
|
+
self,
|
67
|
+
node: torch.fx.Node,
|
68
|
+
) -> circle.Operator.OperatorT:
|
69
|
+
args = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
70
|
+
input = args.input
|
71
|
+
other = args.other
|
72
|
+
|
73
|
+
operator = super().define_node(
|
74
|
+
node,
|
75
|
+
)
|
76
|
+
|
77
|
+
return operator
|
78
|
+
|
79
|
+
|
80
|
+
@register_node_visitor
|
81
|
+
class MulScalarVisitor(BaseMulVisitor):
|
82
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.mul.Scalar]
|
83
|
+
|
84
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
85
|
+
super().__init__(op_codes, graph)
|
86
|
+
|
87
|
+
def define_node(
|
88
|
+
self,
|
89
|
+
node: torch.fx.Node,
|
90
|
+
) -> circle.Operator.OperatorT:
|
91
|
+
args = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
92
|
+
input = args.input
|
93
|
+
other = args.other
|
94
|
+
|
95
|
+
operator = super().define_node(
|
96
|
+
node,
|
97
|
+
)
|
98
|
+
|
99
|
+
return operator
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import NeTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class NeVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.ne.Scalar,
|
34
|
+
torch.ops.aten.ne.Tensor,
|
35
|
+
]
|
36
|
+
|
37
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
38
|
+
super().__init__(op_codes, graph)
|
39
|
+
|
40
|
+
def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.NOT_EQUAL, self._op_codes
|
43
|
+
)
|
44
|
+
|
45
|
+
args = NeTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
other = args.other
|
48
|
+
|
49
|
+
inputs = [input, other]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
return operator
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import NegArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class NegVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.neg.default]
|
33
|
+
|
34
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
35
|
+
super().__init__(op_codes, graph)
|
36
|
+
|
37
|
+
def define_node(
|
38
|
+
self,
|
39
|
+
node: torch.fx.Node,
|
40
|
+
) -> circle.Operator.OperatorT:
|
41
|
+
op_index = get_op_index(
|
42
|
+
circle.BuiltinOperator.BuiltinOperator.NEG,
|
43
|
+
self._op_codes,
|
44
|
+
)
|
45
|
+
args = NegArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
46
|
+
input = args.input
|
47
|
+
|
48
|
+
inputs = [input]
|
49
|
+
outputs = [node]
|
50
|
+
|
51
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
52
|
+
|
53
|
+
# Op-specific option
|
54
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.NegOptions
|
55
|
+
option = circle.NegOptions.NegOptionsT()
|
56
|
+
|
57
|
+
operator.builtinOptions = option
|
58
|
+
|
59
|
+
return operator
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import PermuteArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class PermuteVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.permute.default]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
op_index = get_op_index(
|
43
|
+
circle.BuiltinOperator.BuiltinOperator.TRANSPOSE,
|
44
|
+
self._op_codes,
|
45
|
+
)
|
46
|
+
|
47
|
+
args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
48
|
+
input = args.input
|
49
|
+
dims = args.dims
|
50
|
+
|
51
|
+
dims_i32 = circle_legalize_dtype_to(dims, dtype=torch.int32)
|
52
|
+
inputs = [input, dims_i32]
|
53
|
+
outputs = [node]
|
54
|
+
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
|
57
|
+
# Op-specific option
|
58
|
+
operator.builtinOptionsType = (
|
59
|
+
circle.BuiltinOptions.BuiltinOptions.TransposeOptions
|
60
|
+
)
|
61
|
+
option = circle.TransposeOptions.TransposeOptionsT()
|
62
|
+
|
63
|
+
operator.builtinOptions = option
|
64
|
+
|
65
|
+
return operator
|