tico 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. tico/__init__.py +42 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +154 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +345 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/quantize_bias.py +123 -0
  55. tico/experimental/quantization/passes/remove_weight_dequant_op.py +177 -0
  56. tico/experimental/quantization/public_interface.py +108 -0
  57. tico/experimental/quantization/quantizer.py +71 -0
  58. tico/interpreter/__init__.py +1 -0
  59. tico/interpreter/infer.py +116 -0
  60. tico/interpreter/interpreter.py +93 -0
  61. tico/passes/__init__.py +1 -0
  62. tico/passes/cast_aten_where_arg_type.py +191 -0
  63. tico/passes/cast_mixed_type_args.py +187 -0
  64. tico/passes/const_prop_pass.py +307 -0
  65. tico/passes/convert_conv1d_to_conv2d.py +160 -0
  66. tico/passes/convert_layout_op_to_reshape.py +85 -0
  67. tico/passes/convert_repeat_to_expand_copy.py +89 -0
  68. tico/passes/convert_to_relu6.py +181 -0
  69. tico/passes/decompose_addmm.py +124 -0
  70. tico/passes/decompose_batch_norm.py +192 -0
  71. tico/passes/decompose_fake_quantize.py +134 -0
  72. tico/passes/decompose_fake_quantize_tensor_qparams.py +294 -0
  73. tico/passes/decompose_group_norm.py +275 -0
  74. tico/passes/decompose_grouped_conv2d.py +209 -0
  75. tico/passes/decompose_slice_scatter.py +169 -0
  76. tico/passes/extract_dtype_kwargs.py +122 -0
  77. tico/passes/fill_meta_val.py +57 -0
  78. tico/passes/fuse_leading_unsqueeze_reshape.py +112 -0
  79. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  80. tico/passes/legalize_causal_mask_value.py +108 -0
  81. tico/passes/legalize_predefined_layout_operators.py +386 -0
  82. tico/passes/lower_pow2_to_mul.py +75 -0
  83. tico/passes/lower_to_resize_nearest_neighbor.py +235 -0
  84. tico/passes/lower_to_slice.py +230 -0
  85. tico/passes/merge_consecutive_cat.py +80 -0
  86. tico/passes/ops.py +78 -0
  87. tico/passes/remove_nop.py +84 -0
  88. tico/passes/remove_redundant_assert_nodes.py +51 -0
  89. tico/passes/remove_redundant_expand.py +66 -0
  90. tico/passes/remove_redundant_permute.py +122 -0
  91. tico/passes/remove_redundant_reshape.py +436 -0
  92. tico/passes/remove_redundant_slice.py +62 -0
  93. tico/passes/remove_redundant_to_copy.py +86 -0
  94. tico/passes/restore_linear.py +115 -0
  95. tico/passes/segment_index_select.py +145 -0
  96. tico/pt2_to_circle.py +105 -0
  97. tico/serialize/__init__.py +1 -0
  98. tico/serialize/circle_graph.py +319 -0
  99. tico/serialize/circle_mapping.py +177 -0
  100. tico/serialize/circle_serializer.py +240 -0
  101. tico/serialize/operators/__init__.py +28 -0
  102. tico/serialize/operators/hashable_opcode.py +43 -0
  103. tico/serialize/operators/node_visitor.py +80 -0
  104. tico/serialize/operators/op_abs.py +53 -0
  105. tico/serialize/operators/op_add.py +69 -0
  106. tico/serialize/operators/op_alias_copy.py +64 -0
  107. tico/serialize/operators/op_any.py +150 -0
  108. tico/serialize/operators/op_arange_start_step.py +61 -0
  109. tico/serialize/operators/op_argmax.py +62 -0
  110. tico/serialize/operators/op_avg_pool2d.py +192 -0
  111. tico/serialize/operators/op_bmm.py +62 -0
  112. tico/serialize/operators/op_cat.py +66 -0
  113. tico/serialize/operators/op_clamp.py +126 -0
  114. tico/serialize/operators/op_clone.py +71 -0
  115. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  116. tico/serialize/operators/op_conv2d.py +186 -0
  117. tico/serialize/operators/op_copy.py +164 -0
  118. tico/serialize/operators/op_cos.py +59 -0
  119. tico/serialize/operators/op_cumsum.py +95 -0
  120. tico/serialize/operators/op_depthwise_conv2d.py +199 -0
  121. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  122. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  123. tico/serialize/operators/op_div.py +62 -0
  124. tico/serialize/operators/op_embedding.py +60 -0
  125. tico/serialize/operators/op_eq.py +64 -0
  126. tico/serialize/operators/op_exp.py +60 -0
  127. tico/serialize/operators/op_expand.py +91 -0
  128. tico/serialize/operators/op_full.py +48 -0
  129. tico/serialize/operators/op_full_like.py +55 -0
  130. tico/serialize/operators/op_ge.py +54 -0
  131. tico/serialize/operators/op_gelu.py +59 -0
  132. tico/serialize/operators/op_gt.py +54 -0
  133. tico/serialize/operators/op_index.py +82 -0
  134. tico/serialize/operators/op_index_select.py +64 -0
  135. tico/serialize/operators/op_instance_norm.py +91 -0
  136. tico/serialize/operators/op_leaky_relu.py +60 -0
  137. tico/serialize/operators/op_linear.py +70 -0
  138. tico/serialize/operators/op_log.py +53 -0
  139. tico/serialize/operators/op_log1p.py +86 -0
  140. tico/serialize/operators/op_logical_and.py +63 -0
  141. tico/serialize/operators/op_logical_not.py +62 -0
  142. tico/serialize/operators/op_lt.py +61 -0
  143. tico/serialize/operators/op_max_dim.py +70 -0
  144. tico/serialize/operators/op_max_pool2d_with_indices.py +155 -0
  145. tico/serialize/operators/op_maximum.py +53 -0
  146. tico/serialize/operators/op_mean.py +66 -0
  147. tico/serialize/operators/op_minimum.py +53 -0
  148. tico/serialize/operators/op_mm.py +177 -0
  149. tico/serialize/operators/op_mul.py +99 -0
  150. tico/serialize/operators/op_ne.py +54 -0
  151. tico/serialize/operators/op_neg.py +59 -0
  152. tico/serialize/operators/op_permute.py +65 -0
  153. tico/serialize/operators/op_pow.py +141 -0
  154. tico/serialize/operators/op_prelu.py +54 -0
  155. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  156. tico/serialize/operators/op_reciprocal.py +64 -0
  157. tico/serialize/operators/op_relu.py +53 -0
  158. tico/serialize/operators/op_relu6.py +52 -0
  159. tico/serialize/operators/op_repeat.py +100 -0
  160. tico/serialize/operators/op_reshape.py +73 -0
  161. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  162. tico/serialize/operators/op_rsqrt.py +53 -0
  163. tico/serialize/operators/op_scalar_tensor.py +51 -0
  164. tico/serialize/operators/op_select_copy.py +65 -0
  165. tico/serialize/operators/op_sigmoid.py +56 -0
  166. tico/serialize/operators/op_sin.py +53 -0
  167. tico/serialize/operators/op_slice.py +155 -0
  168. tico/serialize/operators/op_softmax.py +100 -0
  169. tico/serialize/operators/op_split_with_sizes.py +99 -0
  170. tico/serialize/operators/op_sqrt.py +55 -0
  171. tico/serialize/operators/op_squeeze.py +73 -0
  172. tico/serialize/operators/op_sub.py +71 -0
  173. tico/serialize/operators/op_sum.py +63 -0
  174. tico/serialize/operators/op_tanh.py +54 -0
  175. tico/serialize/operators/op_to_copy.py +105 -0
  176. tico/serialize/operators/op_unsqueeze.py +66 -0
  177. tico/serialize/operators/op_view.py +74 -0
  178. tico/serialize/operators/op_where.py +82 -0
  179. tico/serialize/operators/utils.py +94 -0
  180. tico/serialize/pack.py +35 -0
  181. tico/serialize/quant_param.py +42 -0
  182. tico/utils/__init__.py +1 -0
  183. tico/utils/convert.py +296 -0
  184. tico/utils/define.py +35 -0
  185. tico/utils/diff_graph.py +181 -0
  186. tico/utils/errors.py +35 -0
  187. tico/utils/graph.py +282 -0
  188. tico/utils/logging.py +45 -0
  189. tico/utils/model.py +37 -0
  190. tico/utils/mx/__init__.py +1 -0
  191. tico/utils/mx/elemwise_ops.py +267 -0
  192. tico/utils/mx/formats.py +125 -0
  193. tico/utils/mx/mx_ops.py +270 -0
  194. tico/utils/padding.py +47 -0
  195. tico/utils/passes.py +76 -0
  196. tico/utils/register_custom_op.py +609 -0
  197. tico/utils/serialize.py +42 -0
  198. tico/utils/trace_decorators.py +101 -0
  199. tico/utils/utils.py +406 -0
  200. tico/utils/validate_args_kwargs.py +1149 -0
  201. tico-0.1.0.dist-info/LICENSE +241 -0
  202. tico-0.1.0.dist-info/METADATA +354 -0
  203. tico-0.1.0.dist-info/RECORD +206 -0
  204. tico-0.1.0.dist-info/WHEEL +5 -0
  205. tico-0.1.0.dist-info/entry_points.txt +3 -0
  206. tico-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,102 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+ from torch.utils import _pytree as pytree
22
+
23
+ from tico.passes import ops
24
+ from tico.serialize.circle_mapping import extract_shape
25
+ from tico.utils import logging
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.utils import is_target_node
29
+
30
+
31
+ @trace_graph_diff_on_pass
32
+ class FuseRedundantReshapeToMean(PassBase):
33
+ """
34
+ This pass removes redundant `aten.reshape` operators that can be fused to `aten.mean` with `keep_dims`.
35
+
36
+ Shape(aten.reshape(aten.mean(input))) == Shape(aten.mean(input, keep_dims=True))
37
+ """
38
+
39
+ def __init__(self):
40
+ super().__init__()
41
+
42
+ def call(self, exported_program: ExportedProgram) -> PassResult:
43
+ logger = logging.getLogger(__name__)
44
+
45
+ graph_module = exported_program.graph_module
46
+ graph = graph_module.graph
47
+ modified = False
48
+ for node in graph.nodes:
49
+ if not is_target_node(node, torch.ops.aten.mean.dim):
50
+ continue
51
+
52
+ # If mean is being used in other nodes, do not fuse it.
53
+ if len(node.users) != 1:
54
+ continue
55
+
56
+ user_node = next(iter(node.users))
57
+ if not is_target_node(user_node, ops.aten.reshape):
58
+ continue
59
+
60
+ mean_args, mean_kwargs = pytree.tree_map_only(
61
+ torch.fx.Node,
62
+ lambda n: n.meta["val"],
63
+ (node.args, node.kwargs),
64
+ )
65
+ # Signature of aten.mean.dim is as follows.
66
+ # mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
67
+ # `keepdim` in `node.kwargs` is moved to `node.args` in `run_decompositions`.
68
+ # `dtype` in `node.kwargs` is not moved
69
+ assert len(mean_args) == 3 or len(mean_args) == 2 # keepdim exists or not
70
+ assert len(mean_kwargs) <= 1 # dtype exists or not
71
+ fused_mean_args = mean_args
72
+ keep_dims = True
73
+ if len(mean_args) == 2:
74
+ fused_mean_args += (keep_dims,)
75
+
76
+ fused_val = node.target(*fused_mean_args, **mean_kwargs)
77
+
78
+ # Check if both shapes are same
79
+ # 1. Shape(aten.reshape(aten.mean))
80
+ # 2. Shape(aten.mean(keep_dims=True))
81
+ if fused_val.size() != extract_shape(user_node):
82
+ continue
83
+
84
+ # update args
85
+ if len(mean_args) == 2:
86
+ updated_args = node.args + (keep_dims,)
87
+ elif len(mean_args) == 3:
88
+ updated_args = node.args
89
+ else:
90
+ raise RuntimeError("Invalid input")
91
+ node.args = updated_args
92
+ node.meta["val"] = fused_val
93
+ user_node.replace_all_uses_with(node, propagate_meta=False)
94
+
95
+ modified = True
96
+ logger.debug(f"{user_node.name} is replaced with {node.name}")
97
+
98
+ graph.eliminate_dead_code()
99
+ graph.lint()
100
+ graph_module.recompile()
101
+
102
+ return PassResult(modified)
@@ -0,0 +1,108 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+ from tico.utils import logging
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
+ from tico.utils.validate_args_kwargs import AddTensorArgs
28
+
29
+
30
+ @trace_graph_diff_on_pass
31
+ class LegalizeCausalMaskValue(PassBase):
32
+ """
33
+ This pass replaces occurrences of -inf in attention masks with a large negative finite value (e.g., -120) to ensure numerical stability in computations, particularly in softmax operations.
34
+
35
+ This pass can be turned enable only when
36
+ 1. The model will be quantized later (e.g., by circle-quantizer).
37
+ 2. Softmax kernel of our backend does not support masking.
38
+ 3. `Add with -inf` is used only for masking.
39
+ """
40
+
41
+ def __init__(self, enabled: bool = False):
42
+ super().__init__()
43
+ self.enabled = enabled
44
+
45
+ def call(self, exported_program: ExportedProgram) -> PassResult:
46
+ if not self.enabled:
47
+ return PassResult(False)
48
+
49
+ new_mask = -120 # Make it configurable
50
+ logger = logging.getLogger(__name__)
51
+
52
+ graph_module = exported_program.graph_module
53
+ graph = graph_module.graph
54
+ modified = False
55
+ for node in graph.nodes:
56
+ if not is_target_node(node, ops.aten.add):
57
+ continue
58
+
59
+ args = AddTensorArgs(*node.args, **node.kwargs)
60
+ input = args.input
61
+ other = args.other
62
+
63
+ if (
64
+ isinstance(input, torch.fx.Node)
65
+ and input.name
66
+ in exported_program.graph_signature.lifted_tensor_constants
67
+ ):
68
+ mask_node = input
69
+ elif (
70
+ isinstance(other, torch.fx.Node)
71
+ and other.name
72
+ in exported_program.graph_signature.lifted_tensor_constants
73
+ ):
74
+ mask_node = other
75
+ else:
76
+ continue
77
+
78
+ mask_node_name = (
79
+ exported_program.graph_signature.inputs_to_lifted_tensor_constants[
80
+ mask_node.name
81
+ ]
82
+ )
83
+ mask_data = exported_program.constants[mask_node_name]
84
+
85
+ # WHY Use -1.e+38, not -float('inf') or torch.finfo(torch.float32).min?
86
+ #
87
+ # torch.finfo(torch.float32).min is -3.4028234663852886e+38 but it changes while processed in const prop or other passes.
88
+ # Therefore, use a rounded value and compare to know it's very large negative number.
89
+ fp32_minus_inf_rounded = -1.0e38
90
+ if torch.all(
91
+ torch.logical_or(mask_data == 0, mask_data < fp32_minus_inf_rounded)
92
+ ):
93
+ exported_program.constants[mask_node_name] = torch.where(
94
+ mask_data < fp32_minus_inf_rounded,
95
+ torch.tensor(new_mask, dtype=mask_data.dtype),
96
+ mask_data,
97
+ )
98
+
99
+ modified = False # To run only once
100
+ logger.debug(
101
+ f"{mask_node.name}'s mask data are changed from '-inf' to {new_mask}"
102
+ )
103
+
104
+ graph.eliminate_dead_code()
105
+ graph.lint()
106
+ graph_module.recompile()
107
+
108
+ return PassResult(modified)
@@ -0,0 +1,386 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import NoneType
16
+ from typing import Optional, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch.fx
20
+ import torch
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.circle_graph import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.errors import NotYetSupportedError
26
+ from tico.utils.graph import create_node
27
+ from tico.utils.passes import PassBase, PassResult
28
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
29
+ from tico.utils.utils import is_target_node
30
+ from tico.utils.validate_args_kwargs import (
31
+ AvgPool2dArgs,
32
+ Conv2DArgs,
33
+ DequantizePerChannelArgs,
34
+ DequantizePerTensorArgs,
35
+ InstanceNormArgs,
36
+ MaxPool2dWithIndicesArgs,
37
+ )
38
+
39
+
40
+ def get_permute_weight_input(conv_args: Conv2DArgs) -> torch.fx.Node:
41
+ """
42
+ Retrieves the weight input for the permute operation.
43
+
44
+ This function extracts the weight tensor from the given convolution arguments.
45
+
46
+ If the weight is in floating point format, it is returned directly.
47
+ If the weight is quantized and followed by a Dequantize operation, the function
48
+ returns the input of the Dequantize node (i.e., the original quantized weight)
49
+ """
50
+ weight = conv_args.weight
51
+
52
+ dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
53
+ if weight.target == torch.ops.quantized_decomposed.dequantize_per_channel.default:
54
+ dq_args = DequantizePerChannelArgs(*weight.args, *weight.kwargs) # type: ignore[arg-type]
55
+ elif weight.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
56
+ dq_args = DequantizePerTensorArgs(*weight.args, *weight.kwargs) # type: ignore[arg-type]
57
+
58
+ return getattr(dq_args, "input", weight)
59
+
60
+
61
+ @trace_graph_diff_on_pass
62
+ class LegalizePreDefinedLayoutOperators(PassBase):
63
+ """
64
+ Pytorch basically assumes NCHW memory format. But, Circle assumes NHWC. Specifcally, some operators have kernels only for NHWC memory format.
65
+ So, we need to permute the dimensions accordingly.
66
+
67
+ NOTE. This pass DOES NOT CHANGE node.kwargs["memory_format"]. It changes memory formats by inserting `aten.permute` operators.
68
+
69
+ [1] aten.conv2d with group = 1 (circle_custom.conv2d)
70
+
71
+ [BEFORE PASS]
72
+ Input[NCHW] ------------------- aten.conv2d[NCHW] ---- OUTPUT[NCHW]
73
+ Weight[NCHW] - (aten.dequantize) ---/
74
+ Bias --------- (aten.dequantize) --/
75
+
76
+ [AFTER PASS]
77
+ Input[NCHW] ---- aten.permute(NCHW_to_NHWC) ---------- circle_cumstom.conv2d[NHWC] ---- aten.permute(NHWC_to_NCHW) ---- OUTPUT[NCHW]
78
+ Weight[NCHW] - (aten.dequantize) - aten.permute(NCHW_to_NHWC) ---/
79
+ Bias --------- (aten.dequantize) -------------------------------/
80
+
81
+ [2] aten.conv2d with group == Input[C] (circle_custom.depthwise_conv2d)
82
+
83
+ NOTE: Weight layout is CNHW (IOHW)
84
+
85
+ [BEFORE PASS]
86
+ Input[NCHW] -------------- aten.conv2d[NCHW] ---- OUTPUT[NCHW]
87
+ Weight[CNHW] - (aten.dequantize) --/
88
+ Bias ----------(aten.dequantize) -/
89
+
90
+ [AFTER PASS]
91
+ Input[NCHW] ---- aten.permute(NCHW_to_NHWC) ---- circle_cumstom.depthwise_conv2d[NHWC] ---- aten.permute(NHWC_to_NCHW) ---- OUTPUT[NCHW]
92
+ Weight[CNHW] - (aten.dequantize) - aten.permute(CNHW_to_NHWC) ---/
93
+ Bias ----------(aten.dequantize) -------------------------------/
94
+ """
95
+
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def legalize_conv2d(self, exported_program, node) -> bool:
100
+ logger = logging.getLogger(__name__)
101
+ modified = False
102
+
103
+ graph_module = exported_program.graph_module
104
+ graph = graph_module.graph
105
+
106
+ # conv2d (Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
107
+ # conv2d.padding (Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
108
+ args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
109
+ input = args.input
110
+ padding = args.padding
111
+ groups = args.groups
112
+
113
+ input_shape = extract_shape(input)
114
+ if not (len(input_shape) == 4):
115
+ raise NotYetSupportedError(
116
+ f"Only support 4D input tensor: node's input shape: {input_shape}"
117
+ )
118
+
119
+ if not (groups == 1 or groups == input_shape[1]):
120
+ raise NotYetSupportedError(
121
+ f"Only support groups=1 or groups=input_channels: node's groups: {groups}, input channels: {input_shape[1]}"
122
+ )
123
+
124
+ NCHW_to_NHWC = [0, 2, 3, 1]
125
+ # TODO Introduce a method that inserts permute op.
126
+ # input permute
127
+ with graph.inserting_after(input):
128
+ input_permute = create_node(
129
+ graph,
130
+ torch.ops.aten.permute.default,
131
+ args=(input, NCHW_to_NHWC),
132
+ origin=input,
133
+ )
134
+ node.update_arg(node.args.index(input), input_permute)
135
+
136
+ # weight permute
137
+ weight = get_permute_weight_input(args)
138
+ with graph.inserting_after(weight):
139
+ if groups == 1:
140
+ # circle_custom.conv2d
141
+ perm = [0, 2, 3, 1] # OIHW_to_OHWI
142
+ elif groups == input_shape[1]:
143
+ # circle_custom.depthwise_conv2d
144
+ perm = [1, 2, 3, 0] # O1HW_to_1HWO
145
+ else:
146
+ assert groups == 1 or groups == input_shape[1] # Cannot reach here
147
+
148
+ weight_permute = create_node(
149
+ graph,
150
+ torch.ops.aten.permute.default,
151
+ args=(weight, perm),
152
+ origin=weight,
153
+ )
154
+ if args.weight.target in [
155
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
156
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
157
+ ]:
158
+ dq = args.weight
159
+ dq.update_arg(dq.args.index(weight), weight_permute)
160
+ # Need to update dq.meta["val"] in FillMetaVal pass.
161
+ del dq.meta["val"]
162
+ else:
163
+ node.update_arg(node.args.index(weight), weight_permute)
164
+
165
+ with graph.inserting_before(node):
166
+ legalized_op = None
167
+ if groups == 1:
168
+ if isinstance(padding, list):
169
+ legalized_op = torch.ops.circle_custom.conv2d
170
+ elif isinstance(padding, str):
171
+ legalized_op = torch.ops.circle_custom.conv2d.padding
172
+ elif groups == input_shape[1]:
173
+ if isinstance(padding, list):
174
+ legalized_op = torch.ops.circle_custom.depthwise_conv2d
175
+ elif isinstance(padding, str):
176
+ legalized_op = torch.ops.circle_custom.depthwise_conv2d.padding
177
+ else:
178
+ assert groups == 1 or groups == input_shape[1] # Cannot reach here
179
+ assert legalized_op is not None
180
+
181
+ circle_op = create_node(
182
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
183
+ )
184
+ # output permute
185
+ NHWC_to_NCHW = [0, 3, 1, 2]
186
+ conv_out_permute = create_node(
187
+ graph,
188
+ torch.ops.aten.permute.default,
189
+ args=(circle_op, NHWC_to_NCHW),
190
+ )
191
+ node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
192
+
193
+ logger.debug(f"{node.name} is replaced with {circle_op.name}")
194
+ modified = True
195
+ return modified
196
+
197
+ def legalize_instance_norm(self, exported_program, node) -> bool:
198
+ logger = logging.getLogger(__name__)
199
+ modified = False
200
+
201
+ graph_module = exported_program.graph_module
202
+ graph = graph_module.graph
203
+
204
+ # instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
205
+ args = InstanceNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
206
+ input = args.input
207
+ weight = args.weight
208
+ bias = args.bias
209
+ eps = args.eps
210
+
211
+ running_mean = args.running_mean
212
+ running_var = args.running_var
213
+ use_input_stats = args.use_input_stats
214
+
215
+ if not (use_input_stats == True):
216
+ raise NotYetSupportedError("Only support use_input_stats is True.")
217
+ if not isinstance(running_mean, NoneType):
218
+ raise NotYetSupportedError("Only support running_mean=None")
219
+ if not isinstance(running_var, NoneType):
220
+ raise NotYetSupportedError("Only support running_var=None")
221
+
222
+ if weight is None:
223
+ # TODO Support weight=None
224
+ raise NotYetSupportedError("Only support weight is not None.")
225
+ if bias is None:
226
+ # TODO Support bias=None
227
+ raise NotYetSupportedError("Only support bias is not None.")
228
+
229
+ with graph.inserting_after(input):
230
+ # input permute
231
+ NCHW_to_NHWC = [0, 2, 3, 1]
232
+ input_permute = create_node(
233
+ graph,
234
+ torch.ops.aten.permute.default,
235
+ args=(input, NCHW_to_NHWC),
236
+ origin=input,
237
+ )
238
+ node.update_arg(node.args.index(input), input_permute)
239
+ with graph.inserting_before(node):
240
+ # circle instnorm
241
+ circle_instnorm = create_node(
242
+ graph,
243
+ torch.ops.circle_custom.instance_norm,
244
+ args=node.args,
245
+ kwargs=node.kwargs,
246
+ origin=node,
247
+ )
248
+ # output permute
249
+ NHWC_to_NCHW = [0, 3, 1, 2]
250
+ instnorm_out_permute = create_node(
251
+ graph,
252
+ torch.ops.aten.permute.default,
253
+ args=(circle_instnorm, NHWC_to_NCHW),
254
+ )
255
+ node.replace_all_uses_with(instnorm_out_permute, propagate_meta=True)
256
+
257
+ logger.debug(f"{node.name} is replaced with {circle_instnorm.name}")
258
+ modified = True
259
+ return modified
260
+
261
+ def legalize_max_pool2d_with_indices(self, exported_program, node) -> bool:
262
+ logger = logging.getLogger(__name__)
263
+ modified = False
264
+
265
+ graph_module = exported_program.graph_module
266
+ graph = graph_module.graph
267
+
268
+ # max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
269
+ args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
270
+ input_ = args.input
271
+ kernel_size = args.kernel_size
272
+ stride = args.stride
273
+ padding = args.padding
274
+ dilation = args.dilation
275
+ ceil_mode = args.ceil_mode
276
+ if ceil_mode:
277
+ raise NotYetSupportedError("Only support non-ceil model.")
278
+ if len(node.users.keys()) != 1:
279
+ raise NotYetSupportedError(
280
+ "Only support maxpool2d with 'return_indices=False'."
281
+ )
282
+
283
+ NCHW_to_NHWC = [0, 2, 3, 1]
284
+ # TODO Introduce a method that inserts permute op.
285
+ # input permute
286
+ with graph.inserting_after(input_):
287
+ input_permute = create_node(
288
+ graph,
289
+ torch.ops.aten.permute.default,
290
+ args=(input_, NCHW_to_NHWC),
291
+ origin=input_,
292
+ )
293
+ node.update_arg(node.args.index(input_), input_permute)
294
+ with graph.inserting_before(node):
295
+ legalized_op = torch.ops.circle_custom.maxpool2d
296
+ circle_maxpool2d = create_node(
297
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
298
+ )
299
+ # output permute
300
+ NHWC_to_NCHW = [0, 3, 1, 2]
301
+ maxpool_out_permute = create_node(
302
+ graph,
303
+ torch.ops.aten.permute.default,
304
+ args=(circle_maxpool2d, NHWC_to_NCHW),
305
+ )
306
+ get_item, *_ = node.users.keys()
307
+ get_item.replace_all_uses_with(maxpool_out_permute, propagate_meta=True)
308
+
309
+ logger.debug(f"{node.name} is replaced with {circle_maxpool2d.name}")
310
+ modified = True
311
+ return modified
312
+
313
+ def legalize_avg_pool2d(self, exported_program, node) -> bool:
314
+ logger = logging.getLogger(__name__)
315
+ modified = False
316
+
317
+ graph_module = exported_program.graph_module
318
+ graph = graph_module.graph
319
+
320
+ # avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
321
+ args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
322
+ input_ = args.input
323
+ kernel_size = args.kernel_size
324
+ stride = args.stride
325
+ padding = args.padding
326
+ ceil_mode = args.ceil_mode
327
+ if ceil_mode:
328
+ raise NotYetSupportedError("Only support non-ceil model.")
329
+ divisor_override = args.divisor_override
330
+ if divisor_override is not None:
331
+ raise NotYetSupportedError(
332
+ "For the case that the divisor_override is not None is not yet supported."
333
+ )
334
+
335
+ NCHW_to_NHWC = [0, 2, 3, 1]
336
+ # TODO Introduce a method that inserts permute op.
337
+ # input permute
338
+ with graph.inserting_after(input_):
339
+ input_permute = create_node(
340
+ graph,
341
+ torch.ops.aten.permute.default,
342
+ args=(input_, NCHW_to_NHWC),
343
+ origin=input_,
344
+ )
345
+ node.update_arg(node.args.index(input_), input_permute)
346
+ with graph.inserting_before(node):
347
+ legalized_op = torch.ops.circle_custom.avgpool2d
348
+ circle_avgpool2d = create_node(
349
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
350
+ )
351
+ # output permute
352
+ NHWC_to_NCHW = [0, 3, 1, 2]
353
+ avgpool_out_permute = create_node(
354
+ graph,
355
+ torch.ops.aten.permute.default,
356
+ args=(circle_avgpool2d, NHWC_to_NCHW),
357
+ )
358
+ node.replace_all_uses_with(avgpool_out_permute, propagate_meta=True)
359
+
360
+ logger.debug(f"{node.name} is replaced with {circle_avgpool2d.name}")
361
+ modified = True
362
+ return modified
363
+
364
+ def call(self, exported_program: ExportedProgram) -> PassResult:
365
+ target_to_legalize_func = {
366
+ torch.ops.aten.conv2d.default: self.legalize_conv2d,
367
+ torch.ops.aten.conv2d.padding: self.legalize_conv2d,
368
+ torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
369
+ torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
370
+ torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
371
+ }
372
+
373
+ graph_module = exported_program.graph_module
374
+ graph = graph_module.graph
375
+ modified = False
376
+ for node in graph.nodes:
377
+ if not is_target_node(node, list(target_to_legalize_func.keys())):
378
+ continue
379
+
380
+ modified |= target_to_legalize_func[node.target](exported_program, node)
381
+
382
+ graph.eliminate_dead_code()
383
+ graph.lint()
384
+ graph_module.recompile()
385
+
386
+ return PassResult(modified)
@@ -0,0 +1,75 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.utils import logging
23
+ from tico.utils.graph import create_node
24
+ from tico.utils.passes import PassBase, PassResult
25
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
26
+ from tico.utils.utils import is_target_node
27
+ from tico.utils.validate_args_kwargs import PowTensorScalarArgs
28
+
29
+
30
+ @trace_graph_diff_on_pass
31
+ class LowerPow2ToMul(PassBase):
32
+ """
33
+ This pass lowers pow operator whose exponent is 2 to mul.
34
+
35
+ E.g. `Pow(in_, 2)` -> `Mul(in_, in_)`
36
+ """
37
+
38
+ def __init__(self):
39
+ super().__init__()
40
+
41
+ def call(self, exported_program: ExportedProgram) -> PassResult:
42
+ logger = logging.getLogger(__name__)
43
+
44
+ graph_module = exported_program.graph_module
45
+ graph = graph_module.graph
46
+ modified = False
47
+ for node in graph.nodes:
48
+ if not is_target_node(node, torch.ops.aten.pow.Tensor_Scalar):
49
+ continue
50
+
51
+ args = PowTensorScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
52
+ in_, exp = args.input, args.exponent
53
+
54
+ if exp != 2:
55
+ continue
56
+
57
+ lhs = rhs = in_
58
+ with graph.inserting_after(node):
59
+ new_mul = create_node(
60
+ graph,
61
+ torch.ops.aten.mul.Tensor,
62
+ args=(lhs, rhs),
63
+ kwargs={},
64
+ )
65
+
66
+ node.replace_all_uses_with(new_mul, propagate_meta=True)
67
+
68
+ modified = True
69
+ logger.debug(f"{node.name} is replaced with {new_mul.name}")
70
+
71
+ graph.eliminate_dead_code()
72
+ graph.lint()
73
+ graph_module.recompile()
74
+
75
+ return PassResult(modified)