tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.errors import NotYetSupportedError
|
28
|
+
from tico.utils.validate_args_kwargs import SubTensorArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class SubVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.sub.Tensor]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
op_index = get_op_index(
|
43
|
+
circle.BuiltinOperator.BuiltinOperator.SUB, self._op_codes
|
44
|
+
)
|
45
|
+
|
46
|
+
args = SubTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg]
|
47
|
+
|
48
|
+
input = args.input
|
49
|
+
other = args.other
|
50
|
+
alpha = args.alpha
|
51
|
+
|
52
|
+
if alpha is not None:
|
53
|
+
raise NotYetSupportedError(
|
54
|
+
"'alpha' of aten::sub.Tensor is not supported yet"
|
55
|
+
)
|
56
|
+
|
57
|
+
inputs = [input, other]
|
58
|
+
outputs = [node]
|
59
|
+
|
60
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
61
|
+
|
62
|
+
# Op-specific option
|
63
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.SubOptions
|
64
|
+
option = circle.SubOptions.SubOptionsT()
|
65
|
+
option.fusedActivationFunction = (
|
66
|
+
circle.ActivationFunctionType.ActivationFunctionType.NONE
|
67
|
+
)
|
68
|
+
option.potScaleInt16 = False
|
69
|
+
operator.builtinOptions = option
|
70
|
+
|
71
|
+
return operator
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import SumDimIntListArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class SumVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.sum.dim_IntList]
|
34
|
+
|
35
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
36
|
+
super().__init__(op_codes, graph)
|
37
|
+
|
38
|
+
def define_node(
|
39
|
+
self,
|
40
|
+
node: torch.fx.Node,
|
41
|
+
) -> circle.Operator.OperatorT:
|
42
|
+
args = SumDimIntListArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
43
|
+
input = args.input
|
44
|
+
dim = args.dim
|
45
|
+
keepdim = args.keepdim
|
46
|
+
|
47
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
48
|
+
|
49
|
+
inputs = [input, dim_i32]
|
50
|
+
outputs = [node]
|
51
|
+
|
52
|
+
op_index = get_op_index(
|
53
|
+
circle.BuiltinOperator.BuiltinOperator.SUM, self._op_codes
|
54
|
+
)
|
55
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
56
|
+
operator.builtinOptionsType = (
|
57
|
+
circle.BuiltinOptions.BuiltinOptions.ReducerOptions
|
58
|
+
)
|
59
|
+
option = circle.ReducerOptions.ReducerOptionsT()
|
60
|
+
option.keepDims = keepdim
|
61
|
+
operator.builtinOptions = option
|
62
|
+
|
63
|
+
return operator
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
27
|
+
from tico.utils.validate_args_kwargs import TanhArgs
|
28
|
+
|
29
|
+
|
30
|
+
@register_node_visitor
|
31
|
+
class TanhVisitor(NodeVisitor):
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten.tanh.default,
|
34
|
+
]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
args = TanhArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
44
|
+
input = args.input
|
45
|
+
|
46
|
+
inputs = [input]
|
47
|
+
outputs = [node]
|
48
|
+
|
49
|
+
op_index = get_op_index(
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.TANH, self._op_codes
|
51
|
+
)
|
52
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
53
|
+
|
54
|
+
return operator
|
@@ -0,0 +1,105 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_mapping import (
|
24
|
+
extract_circle_dtype,
|
25
|
+
extract_torch_dtype,
|
26
|
+
to_circle_dtype,
|
27
|
+
)
|
28
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
29
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
30
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
31
|
+
from tico.utils.errors import NotYetSupportedError
|
32
|
+
from tico.utils.validate_args_kwargs import ToCopyArgs
|
33
|
+
|
34
|
+
|
35
|
+
@register_node_visitor
|
36
|
+
class ToCopyVisitor(NodeVisitor):
|
37
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten._to_copy.default]
|
38
|
+
|
39
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
40
|
+
super().__init__(op_codes, graph)
|
41
|
+
|
42
|
+
def define_cast_node(
|
43
|
+
self,
|
44
|
+
inputs: List[torch.fx.Node],
|
45
|
+
outputs: List[torch.fx.Node],
|
46
|
+
in_type: int,
|
47
|
+
out_type: int,
|
48
|
+
):
|
49
|
+
op_index = get_op_index(
|
50
|
+
circle.BuiltinOperator.BuiltinOperator.CAST, self._op_codes
|
51
|
+
)
|
52
|
+
|
53
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
54
|
+
|
55
|
+
operator.builtinOptionsType = circle.BuiltinOptions.BuiltinOptions.CastOptions
|
56
|
+
option = circle.CastOptions.CastOptionsT()
|
57
|
+
option.inDataType = in_type
|
58
|
+
option.outDataType = out_type
|
59
|
+
operator.builtinOptions = option
|
60
|
+
|
61
|
+
return operator
|
62
|
+
|
63
|
+
def define_node(
|
64
|
+
self,
|
65
|
+
node: torch.fx.Node,
|
66
|
+
) -> circle.Operator.OperatorT:
|
67
|
+
supported_kwargs = ["dtype", "device", "layout"]
|
68
|
+
if not all(k in supported_kwargs for k in node.kwargs):
|
69
|
+
unsupported_node_kargs = list(node.kwargs.keys())
|
70
|
+
for supported_key in supported_kwargs:
|
71
|
+
if supported_key in node.kwargs:
|
72
|
+
unsupported_node_kargs.remove(supported_key)
|
73
|
+
raise NotYetSupportedError(
|
74
|
+
f"Support only {supported_kwargs} kwargs now. Do not support {unsupported_node_kargs}"
|
75
|
+
)
|
76
|
+
|
77
|
+
args = ToCopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type, call-arg]
|
78
|
+
input = args.input
|
79
|
+
dtype = args.dtype
|
80
|
+
|
81
|
+
input_meta = input.meta["val"]
|
82
|
+
# https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout
|
83
|
+
# layout is two types: torch.strided(dense Tensors), torch.sparse_coo(sparse COO Tensors)
|
84
|
+
if "layout" in input.kwargs and input.kwargs["layout"] != input_meta:
|
85
|
+
raise NotYetSupportedError(
|
86
|
+
f"Only support when node and its input have same layout: (input layout: {input_meta}), (node layout: {node.kwargs['layout']})."
|
87
|
+
)
|
88
|
+
|
89
|
+
if dtype is not None:
|
90
|
+
target_type = node.kwargs["dtype"]
|
91
|
+
else:
|
92
|
+
# device and layout are meaningless
|
93
|
+
target_type = extract_torch_dtype(node)
|
94
|
+
assert isinstance(target_type, torch.dtype), type(target_type)
|
95
|
+
|
96
|
+
# define cast node
|
97
|
+
in_type: int = extract_circle_dtype(input)
|
98
|
+
out_type: int = to_circle_dtype(target_type)
|
99
|
+
inputs = [input]
|
100
|
+
outputs = [node]
|
101
|
+
operator = self.define_cast_node(inputs, outputs, in_type, out_type)
|
102
|
+
|
103
|
+
# TODO Support layout conversion
|
104
|
+
|
105
|
+
return operator
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import UnSqueezeArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class UnsqueezeVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [
|
34
|
+
torch.ops.aten.unsqueeze.default,
|
35
|
+
torch.ops.aten.unsqueeze_copy.default,
|
36
|
+
]
|
37
|
+
|
38
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
39
|
+
super().__init__(op_codes, graph)
|
40
|
+
|
41
|
+
def define_node(
|
42
|
+
self,
|
43
|
+
node: torch.fx.Node,
|
44
|
+
) -> circle.Operator.OperatorT:
|
45
|
+
op_index = get_op_index(
|
46
|
+
circle.BuiltinOperator.BuiltinOperator.EXPAND_DIMS, self._op_codes
|
47
|
+
)
|
48
|
+
|
49
|
+
args = UnSqueezeArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
50
|
+
input = args.input
|
51
|
+
dim = args.dim
|
52
|
+
|
53
|
+
dim_i32 = circle_legalize_dtype_to(dim, dtype=torch.int32)
|
54
|
+
inputs = [input, dim_i32]
|
55
|
+
outputs = [node]
|
56
|
+
|
57
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
58
|
+
|
59
|
+
operator.builtinOptionsType = (
|
60
|
+
circle.BuiltinOptions.BuiltinOptions.ExpandDimsOptions
|
61
|
+
)
|
62
|
+
option = circle.ExpandDimsOptions.ExpandDimsOptionsT()
|
63
|
+
|
64
|
+
operator.builtinOptions = option
|
65
|
+
|
66
|
+
return operator
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph, is_const
|
24
|
+
from tico.serialize.circle_mapping import circle_legalize_dtype_to
|
25
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
26
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
|
+
from tico.utils.validate_args_kwargs import ViewArgs
|
29
|
+
|
30
|
+
|
31
|
+
@register_node_visitor
|
32
|
+
class ViewVisitor(NodeVisitor):
|
33
|
+
target: List[torch._ops.OpOverload] = [
|
34
|
+
torch.ops.aten.view,
|
35
|
+
torch.ops.aten.view.default,
|
36
|
+
torch.ops.aten.view_copy.default,
|
37
|
+
]
|
38
|
+
|
39
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
40
|
+
super().__init__(op_codes, graph)
|
41
|
+
|
42
|
+
def define_node(
|
43
|
+
self,
|
44
|
+
node: torch.fx.Node,
|
45
|
+
) -> circle.Operator.OperatorT:
|
46
|
+
op_index = get_op_index(
|
47
|
+
circle.BuiltinOperator.BuiltinOperator.RESHAPE,
|
48
|
+
self._op_codes,
|
49
|
+
)
|
50
|
+
args = ViewArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
51
|
+
input = args.input
|
52
|
+
size = args.size
|
53
|
+
|
54
|
+
assert is_const(size), type(size)
|
55
|
+
|
56
|
+
if isinstance(size, int):
|
57
|
+
raise Exception("scalar size conversion is not supported yet.")
|
58
|
+
|
59
|
+
size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
|
60
|
+
inputs = [input, size_i32]
|
61
|
+
outputs = [node]
|
62
|
+
|
63
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
64
|
+
|
65
|
+
# Op-specific option
|
66
|
+
operator.builtinOptionsType = (
|
67
|
+
circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
|
68
|
+
)
|
69
|
+
option = circle.ReshapeOptions.ReshapeOptionsT()
|
70
|
+
option.newShape = size_i32
|
71
|
+
|
72
|
+
operator.builtinOptions = option
|
73
|
+
|
74
|
+
return operator
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch._ops
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from circle_schema import circle
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
+
from tico.serialize.circle_mapping import extract_torch_dtype
|
25
|
+
|
26
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
27
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
28
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
29
|
+
from tico.utils.validate_args_kwargs import WhereSelfArgs
|
30
|
+
|
31
|
+
|
32
|
+
@register_node_visitor
|
33
|
+
class WhereVisitor(NodeVisitor):
|
34
|
+
target: List[torch._ops.OpOverload] = [torch.ops.aten.where.self]
|
35
|
+
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
37
|
+
super().__init__(op_codes, graph)
|
38
|
+
|
39
|
+
def define_node(
|
40
|
+
self,
|
41
|
+
node: torch.fx.Node,
|
42
|
+
) -> circle.Operator.OperatorT:
|
43
|
+
op_index = get_op_index(
|
44
|
+
circle.BuiltinOperator.BuiltinOperator.SELECT_V2,
|
45
|
+
self._op_codes,
|
46
|
+
)
|
47
|
+
|
48
|
+
args = WhereSelfArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
49
|
+
condition = args.condition
|
50
|
+
input = args.input
|
51
|
+
other = args.other
|
52
|
+
|
53
|
+
result_true_dtype = (
|
54
|
+
extract_torch_dtype(input)
|
55
|
+
if isinstance(input, torch.fx.node.Node)
|
56
|
+
else input.dtype # type: ignore[union-attr]
|
57
|
+
)
|
58
|
+
result_false_dtype = (
|
59
|
+
extract_torch_dtype(other)
|
60
|
+
if isinstance(other, torch.fx.node.Node)
|
61
|
+
else other.dtype # type: ignore[union-attr]
|
62
|
+
)
|
63
|
+
|
64
|
+
if result_true_dtype != result_false_dtype:
|
65
|
+
raise RuntimeError(
|
66
|
+
f"Data type of arguments are not matched. result_true: {result_true_dtype}, result_false: {result_false_dtype}"
|
67
|
+
)
|
68
|
+
|
69
|
+
inputs = [condition, input, other]
|
70
|
+
outputs = [node]
|
71
|
+
|
72
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
73
|
+
|
74
|
+
# Op-specific option
|
75
|
+
operator.builtinOptionsType = (
|
76
|
+
circle.BuiltinOptions.BuiltinOptions.SelectV2Options
|
77
|
+
)
|
78
|
+
option = circle.SelectV2Options.SelectV2OptionsT()
|
79
|
+
|
80
|
+
operator.builtinOptions = option
|
81
|
+
|
82
|
+
return operator
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Dict, List
|
16
|
+
|
17
|
+
from circle_schema import circle
|
18
|
+
|
19
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
20
|
+
|
21
|
+
|
22
|
+
def create_builtin_opcode(opcode: int) -> OpCode:
|
23
|
+
op_code = OpCode()
|
24
|
+
# deprecatedBuiltinCode is int8, so its maximum value is 127
|
25
|
+
# (127 is reserved as a placeholder for greater opcodes)
|
26
|
+
# opcode greater than 127 is saved in builtinCode
|
27
|
+
op_code.deprecatedBuiltinCode = min(127, opcode)
|
28
|
+
op_code.builtinCode = opcode
|
29
|
+
op_code.version = 1
|
30
|
+
return op_code
|
31
|
+
|
32
|
+
|
33
|
+
def get_op_index(opcode: int, opcode_map: Dict[OpCode, int]) -> int:
|
34
|
+
op_code = create_builtin_opcode(opcode)
|
35
|
+
if op_code not in opcode_map:
|
36
|
+
op_index = len(opcode_map)
|
37
|
+
opcode_map[op_code] = op_index
|
38
|
+
else:
|
39
|
+
op_index = opcode_map[op_code]
|
40
|
+
return op_index
|
41
|
+
|
42
|
+
|
43
|
+
# TODO Move this to CircleSubGraph
|
44
|
+
def create_builtin_operator(
|
45
|
+
graph, op_index: int, inputs: List, outputs: List
|
46
|
+
) -> circle.Operator.OperatorT:
|
47
|
+
operator = circle.Operator.OperatorT()
|
48
|
+
operator.opcodeIndex = op_index
|
49
|
+
operator.inputs = [graph.get_tid(input) for input in inputs]
|
50
|
+
operator.outputs = [graph.get_tid(output) for output in outputs]
|
51
|
+
return operator
|
tico/serialize/pack.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
|
17
|
+
|
18
|
+
def pack_buffer(flat_data: np.ndarray, dtype: str) -> np.ndarray:
|
19
|
+
assert len(flat_data.shape) == 1
|
20
|
+
|
21
|
+
if dtype == "uint4":
|
22
|
+
if flat_data.dtype != np.uint8:
|
23
|
+
raise RuntimeError("uint4 data should be saved in uint8.")
|
24
|
+
|
25
|
+
numel = flat_data.shape[0]
|
26
|
+
packed = np.zeros((numel + 1) // 2, dtype=np.uint8)
|
27
|
+
for i in range(numel):
|
28
|
+
assert flat_data[i] >= 0 and flat_data[i] <= 15
|
29
|
+
if i % 2 == 0:
|
30
|
+
packed[i // 2] = flat_data[i]
|
31
|
+
else:
|
32
|
+
packed[i // 2] |= flat_data[i] << 4
|
33
|
+
return packed
|
34
|
+
else:
|
35
|
+
raise NotImplementedError(f"NYI dtype: {dtype}")
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
This is a key for torch.fx.Node's meta dict to save QuantParam
|
17
|
+
|
18
|
+
QuantParam can be retrieved as node.meta[QPARAM_KEY]
|
19
|
+
"""
|
20
|
+
QPARAM_KEY = "_quantization_parameters_"
|
21
|
+
|
22
|
+
from dataclasses import dataclass
|
23
|
+
from typing import List, Optional
|
24
|
+
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class QuantParam:
|
30
|
+
scale: Optional[List[float]] = None
|
31
|
+
zero_point: Optional[List[int]] = None
|
32
|
+
quantized_dimension: Optional[int] = None
|
33
|
+
min: Optional[List[float]] = None
|
34
|
+
max: Optional[List[float]] = None
|
35
|
+
# NOTE We define dtype as a string to easily extend new dtypes (ex: uint4)
|
36
|
+
dtype: str = ""
|
37
|
+
|
38
|
+
|
39
|
+
def to_qparam_dtype(dtype: torch.dtype) -> str:
|
40
|
+
str_type = str(dtype)
|
41
|
+
assert str_type.startswith("torch.")
|
42
|
+
return str_type[6:]
|
tico/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
# DO NOT REMOVE THIS FILE
|