tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,113 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.passes import ops
23
+
24
+ from tico.utils import logging
25
+ from tico.utils.passes import PassBase, PassResult
26
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
27
+ from tico.utils.validate_args_kwargs import AddTensorArgs
28
+
29
+
30
+ @trace_graph_diff_on_pass
31
+ class LegalizeCausalMaskValue(PassBase):
32
+ """
33
+ This pass replaces occurrences of -inf in attention masks with a large negative finite value (e.g., -120) to ensure numerical stability in computations, particularly in softmax operations.
34
+
35
+ This pass can be turned enable only when
36
+ 1. The model will be quantized later (e.g., by circle-quantizer).
37
+ 2. Softmax kernel of our backend does not support masking.
38
+ 3. `Add with -inf` is used only for masking.
39
+ """
40
+
41
+ def __init__(self, enabled: bool = False):
42
+ super().__init__()
43
+ self.enabled = enabled
44
+
45
+ def call(self, exported_program: ExportedProgram) -> PassResult:
46
+ if not self.enabled:
47
+ return PassResult(False)
48
+
49
+ new_mask = -120 # Make it configurable
50
+ logger = logging.getLogger(__name__)
51
+
52
+ graph_module = exported_program.graph_module
53
+ graph = graph_module.graph
54
+ modified = False
55
+ for node in graph.nodes:
56
+ if not node.op == "call_function":
57
+ continue
58
+
59
+ if not node.target in ops.aten.add:
60
+ continue
61
+
62
+ assert len(node.args) == 2
63
+
64
+ args = AddTensorArgs(*node.args, **node.kwargs)
65
+ input = args.input
66
+ other = args.other
67
+
68
+ if (
69
+ isinstance(input, torch.fx.Node)
70
+ and input.name
71
+ in exported_program.graph_signature.lifted_tensor_constants
72
+ ):
73
+ mask_node = input
74
+ elif (
75
+ isinstance(other, torch.fx.Node)
76
+ and other.name
77
+ in exported_program.graph_signature.lifted_tensor_constants
78
+ ):
79
+ mask_node = other
80
+ else:
81
+ continue
82
+
83
+ mask_node_name = (
84
+ exported_program.graph_signature.inputs_to_lifted_tensor_constants[
85
+ mask_node.name
86
+ ]
87
+ )
88
+ mask_data = exported_program.constants[mask_node_name]
89
+
90
+ # WHY Use -1.e+38, not -float('inf') or torch.finfo(torch.float32).min?
91
+ #
92
+ # torch.finfo(torch.float32).min is -3.4028234663852886e+38 but it changes while processed in const prop or other passes.
93
+ # Therefore, use a rounded value and compare to know it's very large negative number.
94
+ fp32_minus_inf_rounded = -1.0e38
95
+ if torch.all(
96
+ torch.logical_or(mask_data == 0, mask_data < fp32_minus_inf_rounded)
97
+ ):
98
+ exported_program.constants[mask_node_name] = torch.where(
99
+ mask_data < fp32_minus_inf_rounded,
100
+ torch.tensor(new_mask, dtype=mask_data.dtype),
101
+ mask_data,
102
+ )
103
+
104
+ modified = False # To run only once
105
+ logger.debug(
106
+ f"{mask_node.name}'s mask data are changed from '-inf' to {new_mask}"
107
+ )
108
+
109
+ graph.eliminate_dead_code()
110
+ graph.lint()
111
+ graph_module.recompile()
112
+
113
+ return PassResult(modified)
@@ -0,0 +1,383 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import NoneType
16
+ from typing import Optional, TYPE_CHECKING
17
+
18
+ if TYPE_CHECKING:
19
+ import torch.fx
20
+ import torch
21
+ from torch.export import ExportedProgram
22
+
23
+ from tico.serialize.circle_graph import extract_shape
24
+ from tico.utils import logging
25
+ from tico.utils.errors import NotYetSupportedError
26
+ from tico.utils.passes import PassBase, PassResult
27
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
28
+ from tico.utils.validate_args_kwargs import (
29
+ AvgPool2dArgs,
30
+ Conv2DArgs,
31
+ DequantizePerChannelArgs,
32
+ DequantizePerTensorArgs,
33
+ InstanceNormArgs,
34
+ MaxPool2dWithIndicesArgs,
35
+ )
36
+
37
+
38
+ def get_permute_weight_input(conv_args: Conv2DArgs) -> torch.fx.Node:
39
+ """
40
+ Retrieves the weight input for the permute operation.
41
+
42
+ This function extracts the weight tensor from the given convolution arguments.
43
+
44
+ If the weight is in floating point format, it is returned directly.
45
+ If the weight is quantized and followed by a Dequantize operation, the function
46
+ returns the input of the Dequantize node (i.e., the original quantized weight)
47
+ """
48
+ weight = conv_args.weight
49
+
50
+ dq_args: Optional[DequantizePerChannelArgs | DequantizePerTensorArgs] = None
51
+ if weight.target == torch.ops.quantized_decomposed.dequantize_per_channel.default:
52
+ dq_args = DequantizePerChannelArgs(*weight.args, *weight.kwargs) # type: ignore[arg-type]
53
+ elif weight.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
54
+ dq_args = DequantizePerTensorArgs(*weight.args, *weight.kwargs) # type: ignore[arg-type]
55
+
56
+ return getattr(dq_args, "input", weight)
57
+
58
+
59
+ @trace_graph_diff_on_pass
60
+ class LegalizePreDefinedLayoutOperators(PassBase):
61
+ """
62
+ Pytorch basically assumes NCHW memory format. But, Circle assumes NHWC. Specifcally, some operators have kernels only for NHWC memory format.
63
+ So, we need to permute the dimensions accordingly.
64
+
65
+ NOTE. This pass DOES NOT CHANGE node.kwargs["memory_format"]. It changes memory formats by inserting `aten.permute` operators.
66
+
67
+ [1] aten.conv2d with group = 1 (circle_custom.conv2d)
68
+
69
+ [BEFORE PASS]
70
+ Input[NCHW] ------------------- aten.conv2d[NCHW] ---- OUTPUT[NCHW]
71
+ Weight[NCHW] - (aten.dequantize) ---/
72
+ Bias --------- (aten.dequantize) --/
73
+
74
+ [AFTER PASS]
75
+ Input[NCHW] ---- aten.permute(NCHW_to_NHWC) ---------- circle_cumstom.conv2d[NHWC] ---- aten.permute(NHWC_to_NCHW) ---- OUTPUT[NCHW]
76
+ Weight[NCHW] - (aten.dequantize) - aten.permute(NCHW_to_NHWC) ---/
77
+ Bias --------- (aten.dequantize) -------------------------------/
78
+
79
+ [2] aten.conv2d with group == Input[C] (circle_custom.depthwise_conv2d)
80
+
81
+ NOTE: Weight layout is CNHW (IOHW)
82
+
83
+ [BEFORE PASS]
84
+ Input[NCHW] -------------- aten.conv2d[NCHW] ---- OUTPUT[NCHW]
85
+ Weight[CNHW] - (aten.dequantize) --/
86
+ Bias ----------(aten.dequantize) -/
87
+
88
+ [AFTER PASS]
89
+ Input[NCHW] ---- aten.permute(NCHW_to_NHWC) ---- circle_cumstom.depthwise_conv2d[NHWC] ---- aten.permute(NHWC_to_NCHW) ---- OUTPUT[NCHW]
90
+ Weight[CNHW] - (aten.dequantize) - aten.permute(CNHW_to_NHWC) ---/
91
+ Bias ----------(aten.dequantize) -------------------------------/
92
+ """
93
+
94
+ def __init__(self):
95
+ super().__init__()
96
+
97
+ def legalize_conv2d(self, exported_program, node) -> bool:
98
+ logger = logging.getLogger(__name__)
99
+ modified = False
100
+
101
+ graph_module = exported_program.graph_module
102
+ graph = graph_module.graph
103
+
104
+ # conv2d (Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
105
+ # conv2d.padding (Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
106
+ args = Conv2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
107
+ input = args.input
108
+ padding = args.padding
109
+ groups = args.groups
110
+
111
+ input_shape = extract_shape(input)
112
+ if not (len(input_shape) == 4):
113
+ raise NotYetSupportedError(
114
+ f"Only support 4D input tensor: node's input shape: {input_shape}"
115
+ )
116
+
117
+ if not (groups == 1 or groups == input_shape[1]):
118
+ raise NotYetSupportedError(
119
+ f"Only support groups=1 or groups=input_channels: node's groups: {groups}, input channels: {input_shape[1]}"
120
+ )
121
+
122
+ NCHW_to_NHWC = [0, 2, 3, 1]
123
+ # TODO Introduce a method that inserts permute op.
124
+ # input permute
125
+ with graph.inserting_after(input):
126
+ input_permute = graph_module.graph.call_function(
127
+ torch.ops.aten.permute.default,
128
+ args=(input, NCHW_to_NHWC),
129
+ )
130
+ node.update_arg(node.args.index(input), input_permute)
131
+
132
+ # weight permute
133
+ weight = get_permute_weight_input(args)
134
+ with graph.inserting_after(weight):
135
+ if groups == 1:
136
+ # circle_custom.conv2d
137
+ perm = [0, 2, 3, 1] # OIHW_to_OHWI
138
+ elif groups == input_shape[1]:
139
+ # circle_custom.depthwise_conv2d
140
+ perm = [1, 2, 3, 0] # O1HW_to_1HWO
141
+ else:
142
+ assert groups == 1 or groups == input_shape[1] # Cannot reach here
143
+
144
+ weight_permute = graph_module.graph.call_function(
145
+ torch.ops.aten.permute.default,
146
+ args=(weight, perm),
147
+ )
148
+ if args.weight.target in [
149
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
150
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
151
+ ]:
152
+ dq = args.weight
153
+ dq.update_arg(dq.args.index(weight), weight_permute)
154
+ # Need to update dq.meta["val"] in FillMetaVal pass.
155
+ del dq.meta["val"]
156
+ else:
157
+ node.update_arg(node.args.index(weight), weight_permute)
158
+
159
+ with graph.inserting_before(node):
160
+ if groups == 1:
161
+ if isinstance(padding, list):
162
+ legalized_op = torch.ops.circle_custom.conv2d
163
+ elif isinstance(padding, str):
164
+ legalized_op = torch.ops.circle_custom.conv2d.padding
165
+ elif groups == input_shape[1]:
166
+ if isinstance(padding, list):
167
+ legalized_op = torch.ops.circle_custom.depthwise_conv2d
168
+ elif isinstance(padding, str):
169
+ legalized_op = torch.ops.circle_custom.depthwise_conv2d.padding
170
+ else:
171
+ assert groups == 1 or groups == input_shape[1] # Cannot reach here
172
+
173
+ circle_op = graph_module.graph.call_function(
174
+ legalized_op,
175
+ args=node.args,
176
+ kwargs=node.kwargs,
177
+ )
178
+ # output permute
179
+ NHWC_to_NCHW = [0, 3, 1, 2]
180
+ conv_out_permute = graph_module.graph.call_function(
181
+ torch.ops.aten.permute.default,
182
+ args=(circle_op, NHWC_to_NCHW),
183
+ )
184
+ # Not set meta for propagating replacing node's meta.
185
+ node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
186
+
187
+ logger.debug(f"{node.name} is replaced with {circle_op.name}")
188
+ modified = True
189
+ return modified
190
+
191
+ def legalize_instance_norm(self, exported_program, node) -> bool:
192
+ logger = logging.getLogger(__name__)
193
+ modified = False
194
+
195
+ graph_module = exported_program.graph_module
196
+ graph = graph_module.graph
197
+
198
+ # instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
199
+ args = InstanceNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
200
+ input = args.input
201
+ weight = args.weight
202
+ bias = args.bias
203
+ eps = args.eps
204
+
205
+ running_mean = args.running_mean
206
+ running_var = args.running_var
207
+ use_input_stats = args.use_input_stats
208
+
209
+ if not (use_input_stats == True):
210
+ raise NotYetSupportedError("Only support use_input_stats is True.")
211
+ if not isinstance(running_mean, NoneType):
212
+ raise NotYetSupportedError("Only support running_mean=None")
213
+ if not isinstance(running_var, NoneType):
214
+ raise NotYetSupportedError("Only support running_var=None")
215
+
216
+ if weight is None:
217
+ # TODO Support weight=None
218
+ raise NotYetSupportedError("Only support weight is not None.")
219
+ if bias is None:
220
+ # TODO Support bias=None
221
+ raise NotYetSupportedError("Only support bias is not None.")
222
+
223
+ with graph.inserting_after(input):
224
+ # input permute
225
+ NCHW_to_NHWC = [0, 2, 3, 1]
226
+ input_permute = graph_module.graph.call_function(
227
+ torch.ops.aten.permute.default,
228
+ args=(input, NCHW_to_NHWC),
229
+ )
230
+ node.update_arg(node.args.index(input), input_permute)
231
+ with graph.inserting_before(node):
232
+ # circle instnorm
233
+ circle_instnorm = graph_module.graph.call_function(
234
+ torch.ops.circle_custom.instance_norm,
235
+ args=node.args,
236
+ kwargs=node.kwargs,
237
+ )
238
+ # output permute
239
+ NHWC_to_NCHW = [0, 3, 1, 2]
240
+ instnorm_out_permute = graph_module.graph.call_function(
241
+ torch.ops.aten.permute.default,
242
+ args=(circle_instnorm, NHWC_to_NCHW),
243
+ )
244
+ # Not set meta for propagating replacing node's meta.
245
+ node.replace_all_uses_with(instnorm_out_permute, propagate_meta=True)
246
+
247
+ logger.debug(f"{node.name} is replaced with {circle_instnorm.name}")
248
+ modified = True
249
+ return modified
250
+
251
+ def legalize_max_pool2d_with_indices(self, exported_program, node) -> bool:
252
+ logger = logging.getLogger(__name__)
253
+ modified = False
254
+
255
+ graph_module = exported_program.graph_module
256
+ graph = graph_module.graph
257
+
258
+ # max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
259
+ args = MaxPool2dWithIndicesArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
260
+ input_ = args.input
261
+ kernel_size = args.kernel_size
262
+ stride = args.stride
263
+ padding = args.padding
264
+ dilation = args.dilation
265
+ ceil_mode = args.ceil_mode
266
+ if ceil_mode:
267
+ raise NotYetSupportedError("Only support non-ceil model.")
268
+ if len(node.users.keys()) != 1:
269
+ raise NotYetSupportedError(
270
+ "Only support maxpool2d with 'return_indices=False'."
271
+ )
272
+
273
+ NCHW_to_NHWC = [0, 2, 3, 1]
274
+ # TODO Introduce a method that inserts permute op.
275
+ # input permute
276
+ with graph.inserting_after(input_):
277
+ input_permute = graph_module.graph.call_function(
278
+ torch.ops.aten.permute.default,
279
+ args=(input_, NCHW_to_NHWC),
280
+ )
281
+ node.update_arg(node.args.index(input_), input_permute)
282
+ with graph.inserting_before(node):
283
+ legalized_op = torch.ops.circle_custom.maxpool2d
284
+ circle_maxpool2d = graph_module.graph.call_function(
285
+ legalized_op,
286
+ args=node.args,
287
+ kwargs=node.kwargs,
288
+ )
289
+ # output permute
290
+ NHWC_to_NCHW = [0, 3, 1, 2]
291
+ maxpool_out_permute = graph_module.graph.call_function(
292
+ torch.ops.aten.permute.default,
293
+ args=(circle_maxpool2d, NHWC_to_NCHW),
294
+ )
295
+ # Not set meta for propagating replacing get_item's meta.
296
+ get_item, *_ = node.users.keys()
297
+ get_item.replace_all_uses_with(maxpool_out_permute, propagate_meta=True)
298
+
299
+ logger.debug(f"{node.name} is replaced with {circle_maxpool2d.name}")
300
+ modified = True
301
+ return modified
302
+
303
+ def legalize_avg_pool2d(self, exported_program, node) -> bool:
304
+ logger = logging.getLogger(__name__)
305
+ modified = False
306
+
307
+ graph_module = exported_program.graph_module
308
+ graph = graph_module.graph
309
+
310
+ # avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
311
+ args = AvgPool2dArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
312
+ input_ = args.input
313
+ kernel_size = args.kernel_size
314
+ stride = args.stride
315
+ padding = args.padding
316
+ ceil_mode = args.ceil_mode
317
+ if ceil_mode:
318
+ raise NotYetSupportedError("Only support non-ceil model.")
319
+ count_include_pad = args.count_include_pad
320
+ if not count_include_pad:
321
+ # NOTE count_include_pad = False can be partially supported with SAME padding in circle.
322
+ raise NotYetSupportedError(
323
+ "For the case that the count_include_pad is False is not yet supported."
324
+ )
325
+ divisor_override = args.divisor_override
326
+ if divisor_override is not None:
327
+ raise NotYetSupportedError(
328
+ "For the case that the divisor_override is not None is not yet supported."
329
+ )
330
+
331
+ NCHW_to_NHWC = [0, 2, 3, 1]
332
+ # TODO Introduce a method that inserts permute op.
333
+ # input permute
334
+ with graph.inserting_after(input_):
335
+ input_permute = graph_module.graph.call_function(
336
+ torch.ops.aten.permute.default,
337
+ args=(input_, NCHW_to_NHWC),
338
+ )
339
+ node.update_arg(node.args.index(input_), input_permute)
340
+ with graph.inserting_before(node):
341
+ legalized_op = torch.ops.circle_custom.avgpool2d
342
+ circle_avgpool2d = graph_module.graph.call_function(
343
+ legalized_op,
344
+ args=node.args,
345
+ kwargs=node.kwargs,
346
+ )
347
+ # output permute
348
+ NHWC_to_NCHW = [0, 3, 1, 2]
349
+ avgpool_out_permute = graph_module.graph.call_function(
350
+ torch.ops.aten.permute.default,
351
+ args=(circle_avgpool2d, NHWC_to_NCHW),
352
+ )
353
+ node.replace_all_uses_with(avgpool_out_permute, propagate_meta=True)
354
+
355
+ logger.debug(f"{node.name} is replaced with {circle_avgpool2d.name}")
356
+ modified = True
357
+ return modified
358
+
359
+ def call(self, exported_program: ExportedProgram) -> PassResult:
360
+ target_to_legalize_func = {
361
+ torch.ops.aten.conv2d.default: self.legalize_conv2d,
362
+ torch.ops.aten.conv2d.padding: self.legalize_conv2d,
363
+ torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
364
+ torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
365
+ torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
366
+ }
367
+
368
+ graph_module = exported_program.graph_module
369
+ graph = graph_module.graph
370
+ modified = False
371
+ for node in graph.nodes:
372
+ if not node.op == "call_function":
373
+ continue
374
+
375
+ if node.target not in target_to_legalize_func:
376
+ continue
377
+ modified |= target_to_legalize_func[node.target](exported_program, node)
378
+
379
+ graph.eliminate_dead_code()
380
+ graph.lint()
381
+ graph_module.recompile()
382
+
383
+ return PassResult(modified)
@@ -0,0 +1,75 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch.fx
19
+ import torch
20
+ from torch.export import ExportedProgram
21
+
22
+ from tico.utils import logging
23
+ from tico.utils.passes import PassBase, PassResult
24
+ from tico.utils.trace_decorators import trace_graph_diff_on_pass
25
+
26
+
27
+ @trace_graph_diff_on_pass
28
+ class LowerPow2ToMul(PassBase):
29
+ """
30
+ This pass lowers pow operator whose exponent is 2 to mul.
31
+
32
+ E.g. `Pow(in_, 2)` -> `Mul(in_, in_)`
33
+ """
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+
38
+ def call(self, exported_program: ExportedProgram) -> PassResult:
39
+ logger = logging.getLogger(__name__)
40
+
41
+ graph_module = exported_program.graph_module
42
+ graph = graph_module.graph
43
+ modified = False
44
+ for node in graph.nodes:
45
+ if not node.op == "call_function":
46
+ continue
47
+
48
+ if node.target != torch.ops.aten.pow.Tensor_Scalar:
49
+ continue
50
+
51
+ assert len(node.args) == 2, len(node.args)
52
+ in_, exp = node.args
53
+ assert isinstance(in_, torch.fx.Node), type(in_)
54
+
55
+ if exp != 2:
56
+ continue
57
+
58
+ lhs = rhs = in_
59
+ with graph.inserting_after(node):
60
+ new_mul = graph.call_function(
61
+ torch.ops.aten.mul.Tensor,
62
+ args=(lhs, rhs),
63
+ kwargs={},
64
+ )
65
+
66
+ node.replace_all_uses_with(new_mul, propagate_meta=True)
67
+
68
+ modified = True
69
+ logger.debug(f"{node.name} is replaced with {new_mul.name}")
70
+
71
+ graph.eliminate_dead_code()
72
+ graph.lint()
73
+ graph_module.recompile()
74
+
75
+ return PassResult(modified)