tico 0.1.0.dev250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +31 -0
- tico/config/__init__.py +4 -0
- tico/config/base.py +37 -0
- tico/config/factory.py +41 -0
- tico/config/v1.py +35 -0
- tico/experimental/__init__.py +1 -0
- tico/experimental/quantization/__init__.py +1 -0
- tico/experimental/quantization/algorithm/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
- tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
- tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
- tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
- tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
- tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
- tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
- tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
- tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
- tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
- tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
- tico/experimental/quantization/config.py +68 -0
- tico/experimental/quantization/evaluation/__init__.py +1 -0
- tico/experimental/quantization/evaluation/backend.py +20 -0
- tico/experimental/quantization/evaluation/evaluate.py +223 -0
- tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
- tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
- tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
- tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
- tico/experimental/quantization/evaluation/metric.py +109 -0
- tico/experimental/quantization/evaluation/utils.py +185 -0
- tico/experimental/quantization/passes/__init__.py +1 -0
- tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
- tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
- tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
- tico/experimental/quantization/public_interface.py +108 -0
- tico/experimental/quantization/quantizer.py +71 -0
- tico/interpreter/__init__.py +1 -0
- tico/interpreter/infer.py +116 -0
- tico/interpreter/interpreter.py +93 -0
- tico/passes/__init__.py +1 -0
- tico/passes/cast_aten_where_arg_type.py +185 -0
- tico/passes/cast_mixed_type_args.py +186 -0
- tico/passes/const_prop_pass.py +307 -0
- tico/passes/convert_conv1d_to_conv2d.py +151 -0
- tico/passes/convert_layout_op_to_reshape.py +84 -0
- tico/passes/convert_repeat_to_expand_copy.py +90 -0
- tico/passes/convert_to_relu6.py +180 -0
- tico/passes/decompose_addmm.py +127 -0
- tico/passes/decompose_batch_norm.py +198 -0
- tico/passes/decompose_fake_quantize.py +126 -0
- tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
- tico/passes/decompose_group_norm.py +258 -0
- tico/passes/decompose_grouped_conv2d.py +202 -0
- tico/passes/decompose_slice_scatter.py +167 -0
- tico/passes/extract_dtype_kwargs.py +121 -0
- tico/passes/fill_meta_val.py +57 -0
- tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
- tico/passes/legalize_causal_mask_value.py +113 -0
- tico/passes/legalize_predefined_layout_operators.py +383 -0
- tico/passes/lower_pow2_to_mul.py +75 -0
- tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
- tico/passes/lower_to_slice.py +112 -0
- tico/passes/merge_consecutive_cat.py +82 -0
- tico/passes/ops.py +75 -0
- tico/passes/remove_nop.py +85 -0
- tico/passes/remove_redundant_assert_nodes.py +50 -0
- tico/passes/remove_redundant_expand.py +70 -0
- tico/passes/remove_redundant_permute.py +102 -0
- tico/passes/remove_redundant_reshape.py +431 -0
- tico/passes/remove_redundant_slice.py +64 -0
- tico/passes/remove_redundant_to_copy.py +84 -0
- tico/passes/restore_linear.py +113 -0
- tico/passes/segment_index_select.py +143 -0
- tico/pt2_to_circle.py +101 -0
- tico/serialize/__init__.py +1 -0
- tico/serialize/circle_graph.py +264 -0
- tico/serialize/circle_mapping.py +177 -0
- tico/serialize/circle_serializer.py +232 -0
- tico/serialize/operators/__init__.py +28 -0
- tico/serialize/operators/hashable_opcode.py +43 -0
- tico/serialize/operators/node_visitor.py +80 -0
- tico/serialize/operators/op_add.py +69 -0
- tico/serialize/operators/op_alias_copy.py +64 -0
- tico/serialize/operators/op_any.py +142 -0
- tico/serialize/operators/op_arange_start_step.py +61 -0
- tico/serialize/operators/op_argmax.py +62 -0
- tico/serialize/operators/op_avg_pool2d.py +112 -0
- tico/serialize/operators/op_bmm.py +62 -0
- tico/serialize/operators/op_cat.py +66 -0
- tico/serialize/operators/op_clamp.py +123 -0
- tico/serialize/operators/op_clone.py +71 -0
- tico/serialize/operators/op_constant_pad_nd.py +72 -0
- tico/serialize/operators/op_conv2d.py +181 -0
- tico/serialize/operators/op_copy.py +162 -0
- tico/serialize/operators/op_cos.py +59 -0
- tico/serialize/operators/op_cumsum.py +92 -0
- tico/serialize/operators/op_depthwise_conv2d.py +198 -0
- tico/serialize/operators/op_dequantize_per_channel.py +82 -0
- tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
- tico/serialize/operators/op_div.py +62 -0
- tico/serialize/operators/op_embedding.py +60 -0
- tico/serialize/operators/op_eq.py +64 -0
- tico/serialize/operators/op_exp.py +60 -0
- tico/serialize/operators/op_expand.py +91 -0
- tico/serialize/operators/op_full.py +48 -0
- tico/serialize/operators/op_full_like.py +55 -0
- tico/serialize/operators/op_ge.py +54 -0
- tico/serialize/operators/op_gelu.py +59 -0
- tico/serialize/operators/op_gt.py +54 -0
- tico/serialize/operators/op_index.py +82 -0
- tico/serialize/operators/op_index_select.py +64 -0
- tico/serialize/operators/op_instance_norm.py +91 -0
- tico/serialize/operators/op_linear.py +70 -0
- tico/serialize/operators/op_log.py +53 -0
- tico/serialize/operators/op_log1p.py +83 -0
- tico/serialize/operators/op_logical_and.py +63 -0
- tico/serialize/operators/op_logical_not.py +62 -0
- tico/serialize/operators/op_lt.py +61 -0
- tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
- tico/serialize/operators/op_maximum.py +53 -0
- tico/serialize/operators/op_mean.py +66 -0
- tico/serialize/operators/op_minimum.py +53 -0
- tico/serialize/operators/op_mm.py +174 -0
- tico/serialize/operators/op_mul.py +99 -0
- tico/serialize/operators/op_ne.py +54 -0
- tico/serialize/operators/op_neg.py +59 -0
- tico/serialize/operators/op_permute.py +65 -0
- tico/serialize/operators/op_pow.py +138 -0
- tico/serialize/operators/op_prelu.py +54 -0
- tico/serialize/operators/op_quantize_per_tensor.py +79 -0
- tico/serialize/operators/op_reciprocal.py +64 -0
- tico/serialize/operators/op_relu.py +53 -0
- tico/serialize/operators/op_relu6.py +52 -0
- tico/serialize/operators/op_repeat.py +99 -0
- tico/serialize/operators/op_reshape.py +73 -0
- tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
- tico/serialize/operators/op_rsqrt.py +53 -0
- tico/serialize/operators/op_scalar_tensor.py +51 -0
- tico/serialize/operators/op_select_copy.py +65 -0
- tico/serialize/operators/op_sigmoid.py +56 -0
- tico/serialize/operators/op_sin.py +53 -0
- tico/serialize/operators/op_slice.py +155 -0
- tico/serialize/operators/op_softmax.py +100 -0
- tico/serialize/operators/op_split_with_sizes.py +96 -0
- tico/serialize/operators/op_sqrt.py +55 -0
- tico/serialize/operators/op_squeeze.py +73 -0
- tico/serialize/operators/op_sub.py +71 -0
- tico/serialize/operators/op_sum.py +63 -0
- tico/serialize/operators/op_tanh.py +54 -0
- tico/serialize/operators/op_to_copy.py +105 -0
- tico/serialize/operators/op_unsqueeze.py +66 -0
- tico/serialize/operators/op_view.py +74 -0
- tico/serialize/operators/op_where.py +82 -0
- tico/serialize/operators/utils.py +51 -0
- tico/serialize/pack.py +35 -0
- tico/serialize/quant_param.py +42 -0
- tico/utils/__init__.py +1 -0
- tico/utils/convert.py +292 -0
- tico/utils/define.py +35 -0
- tico/utils/diff_graph.py +181 -0
- tico/utils/errors.py +35 -0
- tico/utils/graph.py +200 -0
- tico/utils/logging.py +45 -0
- tico/utils/model.py +37 -0
- tico/utils/padding.py +47 -0
- tico/utils/passes.py +76 -0
- tico/utils/register_custom_op.py +562 -0
- tico/utils/trace_decorators.py +101 -0
- tico/utils/utils.py +314 -0
- tico/utils/validate_args_kwargs.py +1114 -0
- tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
- tico-0.1.0.dev250411.dist-info/METADATA +17 -0
- tico-0.1.0.dev250411.dist-info/RECORD +196 -0
- tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
- tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
- tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,113 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.passes import ops
|
23
|
+
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.passes import PassBase, PassResult
|
26
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
27
|
+
from tico.utils.validate_args_kwargs import AddTensorArgs
|
28
|
+
|
29
|
+
|
30
|
+
@trace_graph_diff_on_pass
|
31
|
+
class LegalizeCausalMaskValue(PassBase):
|
32
|
+
"""
|
33
|
+
This pass replaces occurrences of -inf in attention masks with a large negative finite value (e.g., -120) to ensure numerical stability in computations, particularly in softmax operations.
|
34
|
+
|
35
|
+
This pass can be turned enable only when
|
36
|
+
1. The model will be quantized later (e.g., by circle-quantizer).
|
37
|
+
2. Softmax kernel of our backend does not support masking.
|
38
|
+
3. `Add with -inf` is used only for masking.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, enabled: bool = False):
|
42
|
+
super().__init__()
|
43
|
+
self.enabled = enabled
|
44
|
+
|
45
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
46
|
+
if not self.enabled:
|
47
|
+
return PassResult(False)
|
48
|
+
|
49
|
+
new_mask = -120 # Make it configurable
|
50
|
+
logger = logging.getLogger(__name__)
|
51
|
+
|
52
|
+
graph_module = exported_program.graph_module
|
53
|
+
graph = graph_module.graph
|
54
|
+
modified = False
|
55
|
+
for node in graph.nodes:
|
56
|
+
if not node.op == "call_function":
|
57
|
+
continue
|
58
|
+
|
59
|
+
if not node.target in ops.aten.add:
|
60
|
+
continue
|
61
|
+
|
62
|
+
assert len(node.args) == 2
|
63
|
+
|
64
|
+
args = AddTensorArgs(*node.args, **node.kwargs)
|
65
|
+
input = args.input
|
66
|
+
other = args.other
|
67
|
+
|
68
|
+
if (
|
69
|
+
isinstance(input, torch.fx.Node)
|
70
|
+
and input.name
|
71
|
+
in exported_program.graph_signature.lifted_tensor_constants
|
72
|
+
):
|
73
|
+
mask_node = input
|
74
|
+
elif (
|
75
|
+
isinstance(other, torch.fx.Node)
|
76
|
+
and other.name
|
77
|
+
in exported_program.graph_signature.lifted_tensor_constants
|
78
|
+
):
|
79
|
+
mask_node = other
|
80
|
+
else:
|
81
|
+
continue
|
82
|
+
|
83
|
+
mask_node_name = (
|
84
|
+
exported_program.graph_signature.inputs_to_lifted_tensor_constants[
|
85
|
+
mask_node.name
|
86
|
+
]
|
87
|
+
)
|
88
|
+
mask_data = exported_program.constants[mask_node_name]
|
89
|
+
|
90
|
+
# WHY Use -1.e+38, not -float('inf') or torch.finfo(torch.float32).min?
|
91
|
+
#
|
92
|
+
# torch.finfo(torch.float32).min is -3.4028234663852886e+38 but it changes while processed in const prop or other passes.
|
93
|
+
# Therefore, use a rounded value and compare to know it's very large negative number.
|
94
|
+
fp32_minus_inf_rounded = -1.0e38
|
95
|
+
if torch.all(
|
96
|
+
torch.logical_or(mask_data == 0, mask_data < fp32_minus_inf_rounded)
|
97
|
+
):
|
98
|
+
exported_program.constants[mask_node_name] = torch.where(
|
99
|
+
mask_data < fp32_minus_inf_rounded,
|
100
|
+
torch.tensor(new_mask, dtype=mask_data.dtype),
|
101
|
+
mask_data,
|
102
|
+
)
|
103
|
+
|
104
|
+
modified = False # To run only once
|
105
|
+
logger.debug(
|
106
|
+
f"{mask_node.name}'s mask data are changed from '-inf' to {new_mask}"
|
107
|
+
)
|
108
|
+
|
109
|
+
graph.eliminate_dead_code()
|
110
|
+
graph.lint()
|
111
|
+
graph_module.recompile()
|
112
|
+
|
113
|
+
return PassResult(modified)
|
@@ -0,0 +1,383 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from types import NoneType
|
16
|
+
from typing import Optional, TYPE_CHECKING
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
import torch.fx
|
20
|
+
import torch
|
21
|
+
from torch.export import ExportedProgram
|
22
|
+
|
23
|
+
from tico.serialize.circle_graph import extract_shape
|
24
|
+
from tico.utils import logging
|
25
|
+
from tico.utils.errors import NotYetSupportedError
|
26
|
+
from tico.utils.passes import PassBase, PassResult
|
27
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
28
|
+
from tico.utils.validate_args_kwargs import (
|
29
|
+
AvgPool2dArgs,
|
30
|
+
Conv2DArgs,
|
31
|
+
DequantizePerChannelArgs,
|
32
|
+
DequantizePerTensorArgs,
|
33
|
+
InstanceNormArgs,
|
34
|
+
MaxPool2dWithIndicesArgs,
|
35
|
+
)
|
36
|
+
|
37
|
+
|
38
|
+
def get_permute_weight_input(conv_args: Conv2DArgs) -> torch.fx.Node:
|
39
|
+
"""
|
40
|
+
Retrieves the weight input for the permute operation.
|
41
|
+
|
42
|
+
This function extracts the weight tensor from the given convolution arguments.
|
43
|
+
|
44
|
+
If the weight is in floating point format, it is returned directly.
|
45
|
+
If the weight is quantized and followed by a Dequantize operation, the function
|
46
|
+
returns the input of the Dequantize node (i.e., the original quantized weight)
|
47
|
+
"""
|
48
|
+
weight = conv_args.weight
|
49
|
+
|
50
|
+
dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
|
51
|
+
if weight.target == torch.ops.quantized_decomposed.dequantize_per_channel.default:
|
52
|
+
dq_args = DequantizePerChannelArgs(*weight.args, *weight.kwargs) # type: ignore[arg-type]
|
53
|
+
elif weight.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
|
54
|
+
dq_args = DequantizePerTensorArgs(*weight.args, *weight.kwargs) # type: ignore[arg-type]
|
55
|
+
|
56
|
+
return getattr(dq_args, "input", weight)
|
57
|
+
|
58
|
+
|
59
|
+
@trace_graph_diff_on_pass
|
60
|
+
class LegalizePreDefinedLayoutOperators(PassBase):
|
61
|
+
"""
|
62
|
+
Pytorch basically assumes NCHW memory format. But, Circle assumes NHWC. Specifcally, some operators have kernels only for NHWC memory format.
|
63
|
+
So, we need to permute the dimensions accordingly.
|
64
|
+
|
65
|
+
NOTE. This pass DOES NOT CHANGE node.kwargs["memory_format"]. It changes memory formats by inserting `aten.permute` operators.
|
66
|
+
|
67
|
+
[1] aten.conv2d with group = 1 (circle_custom.conv2d)
|
68
|
+
|
69
|
+
[BEFORE PASS]
|
70
|
+
Input[NCHW] ------------------- aten.conv2d[NCHW] ---- OUTPUT[NCHW]
|
71
|
+
Weight[NCHW] - (aten.dequantize) ---/
|
72
|
+
Bias --------- (aten.dequantize) --/
|
73
|
+
|
74
|
+
[AFTER PASS]
|
75
|
+
Input[NCHW] ---- aten.permute(NCHW_to_NHWC) ---------- circle_cumstom.conv2d[NHWC] ---- aten.permute(NHWC_to_NCHW) ---- OUTPUT[NCHW]
|
76
|
+
Weight[NCHW] - (aten.dequantize) - aten.permute(NCHW_to_NHWC) ---/
|
77
|
+
Bias --------- (aten.dequantize) -------------------------------/
|
78
|
+
|
79
|
+
[2] aten.conv2d with group == Input[C] (circle_custom.depthwise_conv2d)
|
80
|
+
|
81
|
+
NOTE: Weight layout is CNHW (IOHW)
|
82
|
+
|
83
|
+
[BEFORE PASS]
|
84
|
+
Input[NCHW] -------------- aten.conv2d[NCHW] ---- OUTPUT[NCHW]
|
85
|
+
Weight[CNHW] - (aten.dequantize) --/
|
86
|
+
Bias ----------(aten.dequantize) -/
|
87
|
+
|
88
|
+
[AFTER PASS]
|
89
|
+
Input[NCHW] ---- aten.permute(NCHW_to_NHWC) ---- circle_cumstom.depthwise_conv2d[NHWC] ---- aten.permute(NHWC_to_NCHW) ---- OUTPUT[NCHW]
|
90
|
+
Weight[CNHW] - (aten.dequantize) - aten.permute(CNHW_to_NHWC) ---/
|
91
|
+
Bias ----------(aten.dequantize) -------------------------------/
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(self):
|
95
|
+
super().__init__()
|
96
|
+
|
97
|
+
def legalize_conv2d(self, exported_program, node) -> bool:
|
98
|
+
logger = logging.getLogger(__name__)
|
99
|
+
modified = False
|
100
|
+
|
101
|
+
graph_module = exported_program.graph_module
|
102
|
+
graph = graph_module.graph
|
103
|
+
|
104
|
+
# conv2d (Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
|
105
|
+
# conv2d.padding (Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
|
106
|
+
args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
107
|
+
input = args.input
|
108
|
+
padding = args.padding
|
109
|
+
groups = args.groups
|
110
|
+
|
111
|
+
input_shape = extract_shape(input)
|
112
|
+
if not (len(input_shape) == 4):
|
113
|
+
raise NotYetSupportedError(
|
114
|
+
f"Only support 4D input tensor: node's input shape: {input_shape}"
|
115
|
+
)
|
116
|
+
|
117
|
+
if not (groups == 1 or groups == input_shape[1]):
|
118
|
+
raise NotYetSupportedError(
|
119
|
+
f"Only support groups=1 or groups=input_channels: node's groups: {groups}, input channels: {input_shape[1]}"
|
120
|
+
)
|
121
|
+
|
122
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
123
|
+
# TODO Introduce a method that inserts permute op.
|
124
|
+
# input permute
|
125
|
+
with graph.inserting_after(input):
|
126
|
+
input_permute = graph_module.graph.call_function(
|
127
|
+
torch.ops.aten.permute.default,
|
128
|
+
args=(input, NCHW_to_NHWC),
|
129
|
+
)
|
130
|
+
node.update_arg(node.args.index(input), input_permute)
|
131
|
+
|
132
|
+
# weight permute
|
133
|
+
weight = get_permute_weight_input(args)
|
134
|
+
with graph.inserting_after(weight):
|
135
|
+
if groups == 1:
|
136
|
+
# circle_custom.conv2d
|
137
|
+
perm = [0, 2, 3, 1] # OIHW_to_OHWI
|
138
|
+
elif groups == input_shape[1]:
|
139
|
+
# circle_custom.depthwise_conv2d
|
140
|
+
perm = [1, 2, 3, 0] # O1HW_to_1HWO
|
141
|
+
else:
|
142
|
+
assert groups == 1 or groups == input_shape[1] # Cannot reach here
|
143
|
+
|
144
|
+
weight_permute = graph_module.graph.call_function(
|
145
|
+
torch.ops.aten.permute.default,
|
146
|
+
args=(weight, perm),
|
147
|
+
)
|
148
|
+
if args.weight.target in [
|
149
|
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
150
|
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
151
|
+
]:
|
152
|
+
dq = args.weight
|
153
|
+
dq.update_arg(dq.args.index(weight), weight_permute)
|
154
|
+
# Need to update dq.meta["val"] in FillMetaVal pass.
|
155
|
+
del dq.meta["val"]
|
156
|
+
else:
|
157
|
+
node.update_arg(node.args.index(weight), weight_permute)
|
158
|
+
|
159
|
+
with graph.inserting_before(node):
|
160
|
+
if groups == 1:
|
161
|
+
if isinstance(padding, list):
|
162
|
+
legalized_op = torch.ops.circle_custom.conv2d
|
163
|
+
elif isinstance(padding, str):
|
164
|
+
legalized_op = torch.ops.circle_custom.conv2d.padding
|
165
|
+
elif groups == input_shape[1]:
|
166
|
+
if isinstance(padding, list):
|
167
|
+
legalized_op = torch.ops.circle_custom.depthwise_conv2d
|
168
|
+
elif isinstance(padding, str):
|
169
|
+
legalized_op = torch.ops.circle_custom.depthwise_conv2d.padding
|
170
|
+
else:
|
171
|
+
assert groups == 1 or groups == input_shape[1] # Cannot reach here
|
172
|
+
|
173
|
+
circle_op = graph_module.graph.call_function(
|
174
|
+
legalized_op,
|
175
|
+
args=node.args,
|
176
|
+
kwargs=node.kwargs,
|
177
|
+
)
|
178
|
+
# output permute
|
179
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
180
|
+
conv_out_permute = graph_module.graph.call_function(
|
181
|
+
torch.ops.aten.permute.default,
|
182
|
+
args=(circle_op, NHWC_to_NCHW),
|
183
|
+
)
|
184
|
+
# Not set meta for propagating replacing node's meta.
|
185
|
+
node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
|
186
|
+
|
187
|
+
logger.debug(f"{node.name} is replaced with {circle_op.name}")
|
188
|
+
modified = True
|
189
|
+
return modified
|
190
|
+
|
191
|
+
def legalize_instance_norm(self, exported_program, node) -> bool:
|
192
|
+
logger = logging.getLogger(__name__)
|
193
|
+
modified = False
|
194
|
+
|
195
|
+
graph_module = exported_program.graph_module
|
196
|
+
graph = graph_module.graph
|
197
|
+
|
198
|
+
# instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
199
|
+
args = InstanceNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
200
|
+
input = args.input
|
201
|
+
weight = args.weight
|
202
|
+
bias = args.bias
|
203
|
+
eps = args.eps
|
204
|
+
|
205
|
+
running_mean = args.running_mean
|
206
|
+
running_var = args.running_var
|
207
|
+
use_input_stats = args.use_input_stats
|
208
|
+
|
209
|
+
if not (use_input_stats == True):
|
210
|
+
raise NotYetSupportedError("Only support use_input_stats is True.")
|
211
|
+
if not isinstance(running_mean, NoneType):
|
212
|
+
raise NotYetSupportedError("Only support running_mean=None")
|
213
|
+
if not isinstance(running_var, NoneType):
|
214
|
+
raise NotYetSupportedError("Only support running_var=None")
|
215
|
+
|
216
|
+
if weight is None:
|
217
|
+
# TODO Support weight=None
|
218
|
+
raise NotYetSupportedError("Only support weight is not None.")
|
219
|
+
if bias is None:
|
220
|
+
# TODO Support bias=None
|
221
|
+
raise NotYetSupportedError("Only support bias is not None.")
|
222
|
+
|
223
|
+
with graph.inserting_after(input):
|
224
|
+
# input permute
|
225
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
226
|
+
input_permute = graph_module.graph.call_function(
|
227
|
+
torch.ops.aten.permute.default,
|
228
|
+
args=(input, NCHW_to_NHWC),
|
229
|
+
)
|
230
|
+
node.update_arg(node.args.index(input), input_permute)
|
231
|
+
with graph.inserting_before(node):
|
232
|
+
# circle instnorm
|
233
|
+
circle_instnorm = graph_module.graph.call_function(
|
234
|
+
torch.ops.circle_custom.instance_norm,
|
235
|
+
args=node.args,
|
236
|
+
kwargs=node.kwargs,
|
237
|
+
)
|
238
|
+
# output permute
|
239
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
240
|
+
instnorm_out_permute = graph_module.graph.call_function(
|
241
|
+
torch.ops.aten.permute.default,
|
242
|
+
args=(circle_instnorm, NHWC_to_NCHW),
|
243
|
+
)
|
244
|
+
# Not set meta for propagating replacing node's meta.
|
245
|
+
node.replace_all_uses_with(instnorm_out_permute, propagate_meta=True)
|
246
|
+
|
247
|
+
logger.debug(f"{node.name} is replaced with {circle_instnorm.name}")
|
248
|
+
modified = True
|
249
|
+
return modified
|
250
|
+
|
251
|
+
def legalize_max_pool2d_with_indices(self, exported_program, node) -> bool:
|
252
|
+
logger = logging.getLogger(__name__)
|
253
|
+
modified = False
|
254
|
+
|
255
|
+
graph_module = exported_program.graph_module
|
256
|
+
graph = graph_module.graph
|
257
|
+
|
258
|
+
# max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
|
259
|
+
args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
260
|
+
input_ = args.input
|
261
|
+
kernel_size = args.kernel_size
|
262
|
+
stride = args.stride
|
263
|
+
padding = args.padding
|
264
|
+
dilation = args.dilation
|
265
|
+
ceil_mode = args.ceil_mode
|
266
|
+
if ceil_mode:
|
267
|
+
raise NotYetSupportedError("Only support non-ceil model.")
|
268
|
+
if len(node.users.keys()) != 1:
|
269
|
+
raise NotYetSupportedError(
|
270
|
+
"Only support maxpool2d with 'return_indices=False'."
|
271
|
+
)
|
272
|
+
|
273
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
274
|
+
# TODO Introduce a method that inserts permute op.
|
275
|
+
# input permute
|
276
|
+
with graph.inserting_after(input_):
|
277
|
+
input_permute = graph_module.graph.call_function(
|
278
|
+
torch.ops.aten.permute.default,
|
279
|
+
args=(input_, NCHW_to_NHWC),
|
280
|
+
)
|
281
|
+
node.update_arg(node.args.index(input_), input_permute)
|
282
|
+
with graph.inserting_before(node):
|
283
|
+
legalized_op = torch.ops.circle_custom.maxpool2d
|
284
|
+
circle_maxpool2d = graph_module.graph.call_function(
|
285
|
+
legalized_op,
|
286
|
+
args=node.args,
|
287
|
+
kwargs=node.kwargs,
|
288
|
+
)
|
289
|
+
# output permute
|
290
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
291
|
+
maxpool_out_permute = graph_module.graph.call_function(
|
292
|
+
torch.ops.aten.permute.default,
|
293
|
+
args=(circle_maxpool2d, NHWC_to_NCHW),
|
294
|
+
)
|
295
|
+
# Not set meta for propagating replacing get_item's meta.
|
296
|
+
get_item, *_ = node.users.keys()
|
297
|
+
get_item.replace_all_uses_with(maxpool_out_permute, propagate_meta=True)
|
298
|
+
|
299
|
+
logger.debug(f"{node.name} is replaced with {circle_maxpool2d.name}")
|
300
|
+
modified = True
|
301
|
+
return modified
|
302
|
+
|
303
|
+
def legalize_avg_pool2d(self, exported_program, node) -> bool:
|
304
|
+
logger = logging.getLogger(__name__)
|
305
|
+
modified = False
|
306
|
+
|
307
|
+
graph_module = exported_program.graph_module
|
308
|
+
graph = graph_module.graph
|
309
|
+
|
310
|
+
# avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
|
311
|
+
args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
312
|
+
input_ = args.input
|
313
|
+
kernel_size = args.kernel_size
|
314
|
+
stride = args.stride
|
315
|
+
padding = args.padding
|
316
|
+
ceil_mode = args.ceil_mode
|
317
|
+
if ceil_mode:
|
318
|
+
raise NotYetSupportedError("Only support non-ceil model.")
|
319
|
+
count_include_pad = args.count_include_pad
|
320
|
+
if not count_include_pad:
|
321
|
+
# NOTE count_include_pad = False can be partially supported with SAME padding in circle.
|
322
|
+
raise NotYetSupportedError(
|
323
|
+
"For the case that the count_include_pad is False is not yet supported."
|
324
|
+
)
|
325
|
+
divisor_override = args.divisor_override
|
326
|
+
if divisor_override is not None:
|
327
|
+
raise NotYetSupportedError(
|
328
|
+
"For the case that the divisor_override is not None is not yet supported."
|
329
|
+
)
|
330
|
+
|
331
|
+
NCHW_to_NHWC = [0, 2, 3, 1]
|
332
|
+
# TODO Introduce a method that inserts permute op.
|
333
|
+
# input permute
|
334
|
+
with graph.inserting_after(input_):
|
335
|
+
input_permute = graph_module.graph.call_function(
|
336
|
+
torch.ops.aten.permute.default,
|
337
|
+
args=(input_, NCHW_to_NHWC),
|
338
|
+
)
|
339
|
+
node.update_arg(node.args.index(input_), input_permute)
|
340
|
+
with graph.inserting_before(node):
|
341
|
+
legalized_op = torch.ops.circle_custom.avgpool2d
|
342
|
+
circle_avgpool2d = graph_module.graph.call_function(
|
343
|
+
legalized_op,
|
344
|
+
args=node.args,
|
345
|
+
kwargs=node.kwargs,
|
346
|
+
)
|
347
|
+
# output permute
|
348
|
+
NHWC_to_NCHW = [0, 3, 1, 2]
|
349
|
+
avgpool_out_permute = graph_module.graph.call_function(
|
350
|
+
torch.ops.aten.permute.default,
|
351
|
+
args=(circle_avgpool2d, NHWC_to_NCHW),
|
352
|
+
)
|
353
|
+
node.replace_all_uses_with(avgpool_out_permute, propagate_meta=True)
|
354
|
+
|
355
|
+
logger.debug(f"{node.name} is replaced with {circle_avgpool2d.name}")
|
356
|
+
modified = True
|
357
|
+
return modified
|
358
|
+
|
359
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
360
|
+
target_to_legalize_func = {
|
361
|
+
torch.ops.aten.conv2d.default: self.legalize_conv2d,
|
362
|
+
torch.ops.aten.conv2d.padding: self.legalize_conv2d,
|
363
|
+
torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
|
364
|
+
torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
|
365
|
+
torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
|
366
|
+
}
|
367
|
+
|
368
|
+
graph_module = exported_program.graph_module
|
369
|
+
graph = graph_module.graph
|
370
|
+
modified = False
|
371
|
+
for node in graph.nodes:
|
372
|
+
if not node.op == "call_function":
|
373
|
+
continue
|
374
|
+
|
375
|
+
if node.target not in target_to_legalize_func:
|
376
|
+
continue
|
377
|
+
modified |= target_to_legalize_func[node.target](exported_program, node)
|
378
|
+
|
379
|
+
graph.eliminate_dead_code()
|
380
|
+
graph.lint()
|
381
|
+
graph_module.recompile()
|
382
|
+
|
383
|
+
return PassResult(modified)
|
@@ -0,0 +1,75 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import torch.fx
|
19
|
+
import torch
|
20
|
+
from torch.export import ExportedProgram
|
21
|
+
|
22
|
+
from tico.utils import logging
|
23
|
+
from tico.utils.passes import PassBase, PassResult
|
24
|
+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
25
|
+
|
26
|
+
|
27
|
+
@trace_graph_diff_on_pass
|
28
|
+
class LowerPow2ToMul(PassBase):
|
29
|
+
"""
|
30
|
+
This pass lowers pow operator whose exponent is 2 to mul.
|
31
|
+
|
32
|
+
E.g. `Pow(in_, 2)` -> `Mul(in_, in_)`
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self):
|
36
|
+
super().__init__()
|
37
|
+
|
38
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
graph_module = exported_program.graph_module
|
42
|
+
graph = graph_module.graph
|
43
|
+
modified = False
|
44
|
+
for node in graph.nodes:
|
45
|
+
if not node.op == "call_function":
|
46
|
+
continue
|
47
|
+
|
48
|
+
if node.target != torch.ops.aten.pow.Tensor_Scalar:
|
49
|
+
continue
|
50
|
+
|
51
|
+
assert len(node.args) == 2, len(node.args)
|
52
|
+
in_, exp = node.args
|
53
|
+
assert isinstance(in_, torch.fx.Node), type(in_)
|
54
|
+
|
55
|
+
if exp != 2:
|
56
|
+
continue
|
57
|
+
|
58
|
+
lhs = rhs = in_
|
59
|
+
with graph.inserting_after(node):
|
60
|
+
new_mul = graph.call_function(
|
61
|
+
torch.ops.aten.mul.Tensor,
|
62
|
+
args=(lhs, rhs),
|
63
|
+
kwargs={},
|
64
|
+
)
|
65
|
+
|
66
|
+
node.replace_all_uses_with(new_mul, propagate_meta=True)
|
67
|
+
|
68
|
+
modified = True
|
69
|
+
logger.debug(f"{node.name} is replaced with {new_mul.name}")
|
70
|
+
|
71
|
+
graph.eliminate_dead_code()
|
72
|
+
graph.lint()
|
73
|
+
graph_module.recompile()
|
74
|
+
|
75
|
+
return PassResult(modified)
|