tico 0.1.0.dev250411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. tico/__init__.py +31 -0
  2. tico/config/__init__.py +4 -0
  3. tico/config/base.py +37 -0
  4. tico/config/factory.py +41 -0
  5. tico/config/v1.py +35 -0
  6. tico/experimental/__init__.py +1 -0
  7. tico/experimental/quantization/__init__.py +1 -0
  8. tico/experimental/quantization/algorithm/__init__.py +1 -0
  9. tico/experimental/quantization/algorithm/gptq/__init__.py +1 -0
  10. tico/experimental/quantization/algorithm/gptq/gptq.py +172 -0
  11. tico/experimental/quantization/algorithm/gptq/quant.py +153 -0
  12. tico/experimental/quantization/algorithm/gptq/quantizer.py +225 -0
  13. tico/experimental/quantization/algorithm/gptq/utils.py +65 -0
  14. tico/experimental/quantization/algorithm/pt2e/__init__.py +1 -0
  15. tico/experimental/quantization/algorithm/pt2e/annotation/__init__.py +1 -0
  16. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +215 -0
  17. tico/experimental/quantization/algorithm/pt2e/annotation/config.py +26 -0
  18. tico/experimental/quantization/algorithm/pt2e/annotation/op/__init__.py +21 -0
  19. tico/experimental/quantization/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +65 -0
  20. tico/experimental/quantization/algorithm/pt2e/annotation/op/add.py +57 -0
  21. tico/experimental/quantization/algorithm/pt2e/annotation/op/conv2d.py +92 -0
  22. tico/experimental/quantization/algorithm/pt2e/annotation/op/div.py +57 -0
  23. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +94 -0
  24. tico/experimental/quantization/algorithm/pt2e/annotation/op/mean.py +53 -0
  25. tico/experimental/quantization/algorithm/pt2e/annotation/op/mul.py +57 -0
  26. tico/experimental/quantization/algorithm/pt2e/annotation/op/relu6.py +53 -0
  27. tico/experimental/quantization/algorithm/pt2e/annotation/op/rsqrt.py +53 -0
  28. tico/experimental/quantization/algorithm/pt2e/annotation/op/sub.py +57 -0
  29. tico/experimental/quantization/algorithm/pt2e/annotation/spec.py +47 -0
  30. tico/experimental/quantization/algorithm/pt2e/annotation/utils.py +88 -0
  31. tico/experimental/quantization/algorithm/pt2e/quantizer.py +78 -0
  32. tico/experimental/quantization/algorithm/pt2e/transformation/__init__.py +1 -0
  33. tico/experimental/quantization/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +58 -0
  34. tico/experimental/quantization/algorithm/pt2e/utils.py +138 -0
  35. tico/experimental/quantization/algorithm/smoothquant/__init__.py +1 -0
  36. tico/experimental/quantization/algorithm/smoothquant/observer.py +78 -0
  37. tico/experimental/quantization/algorithm/smoothquant/quantizer.py +81 -0
  38. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +164 -0
  39. tico/experimental/quantization/config.py +68 -0
  40. tico/experimental/quantization/evaluation/__init__.py +1 -0
  41. tico/experimental/quantization/evaluation/backend.py +20 -0
  42. tico/experimental/quantization/evaluation/evaluate.py +223 -0
  43. tico/experimental/quantization/evaluation/executor/__init__.py +1 -0
  44. tico/experimental/quantization/evaluation/executor/backend_executor.py +54 -0
  45. tico/experimental/quantization/evaluation/executor/circle_executor.py +75 -0
  46. tico/experimental/quantization/evaluation/executor/triv24_executor.py +128 -0
  47. tico/experimental/quantization/evaluation/metric.py +109 -0
  48. tico/experimental/quantization/evaluation/utils.py +185 -0
  49. tico/experimental/quantization/passes/__init__.py +1 -0
  50. tico/experimental/quantization/passes/fold_quant_ops.py +97 -0
  51. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +289 -0
  52. tico/experimental/quantization/passes/propagate_qparam_backward.py +91 -0
  53. tico/experimental/quantization/passes/propagate_qparam_forward.py +141 -0
  54. tico/experimental/quantization/passes/remove_weight_dequant_op.py +168 -0
  55. tico/experimental/quantization/public_interface.py +108 -0
  56. tico/experimental/quantization/quantizer.py +71 -0
  57. tico/interpreter/__init__.py +1 -0
  58. tico/interpreter/infer.py +116 -0
  59. tico/interpreter/interpreter.py +93 -0
  60. tico/passes/__init__.py +1 -0
  61. tico/passes/cast_aten_where_arg_type.py +185 -0
  62. tico/passes/cast_mixed_type_args.py +186 -0
  63. tico/passes/const_prop_pass.py +307 -0
  64. tico/passes/convert_conv1d_to_conv2d.py +151 -0
  65. tico/passes/convert_layout_op_to_reshape.py +84 -0
  66. tico/passes/convert_repeat_to_expand_copy.py +90 -0
  67. tico/passes/convert_to_relu6.py +180 -0
  68. tico/passes/decompose_addmm.py +127 -0
  69. tico/passes/decompose_batch_norm.py +198 -0
  70. tico/passes/decompose_fake_quantize.py +126 -0
  71. tico/passes/decompose_fake_quantize_tensor_qparams.py +270 -0
  72. tico/passes/decompose_group_norm.py +258 -0
  73. tico/passes/decompose_grouped_conv2d.py +202 -0
  74. tico/passes/decompose_slice_scatter.py +167 -0
  75. tico/passes/extract_dtype_kwargs.py +121 -0
  76. tico/passes/fill_meta_val.py +57 -0
  77. tico/passes/fuse_redundant_reshape_to_mean.py +102 -0
  78. tico/passes/legalize_causal_mask_value.py +113 -0
  79. tico/passes/legalize_predefined_layout_operators.py +383 -0
  80. tico/passes/lower_pow2_to_mul.py +75 -0
  81. tico/passes/lower_to_resize_nearest_neighbor.py +249 -0
  82. tico/passes/lower_to_slice.py +112 -0
  83. tico/passes/merge_consecutive_cat.py +82 -0
  84. tico/passes/ops.py +75 -0
  85. tico/passes/remove_nop.py +85 -0
  86. tico/passes/remove_redundant_assert_nodes.py +50 -0
  87. tico/passes/remove_redundant_expand.py +70 -0
  88. tico/passes/remove_redundant_permute.py +102 -0
  89. tico/passes/remove_redundant_reshape.py +431 -0
  90. tico/passes/remove_redundant_slice.py +64 -0
  91. tico/passes/remove_redundant_to_copy.py +84 -0
  92. tico/passes/restore_linear.py +113 -0
  93. tico/passes/segment_index_select.py +143 -0
  94. tico/pt2_to_circle.py +101 -0
  95. tico/serialize/__init__.py +1 -0
  96. tico/serialize/circle_graph.py +264 -0
  97. tico/serialize/circle_mapping.py +177 -0
  98. tico/serialize/circle_serializer.py +232 -0
  99. tico/serialize/operators/__init__.py +28 -0
  100. tico/serialize/operators/hashable_opcode.py +43 -0
  101. tico/serialize/operators/node_visitor.py +80 -0
  102. tico/serialize/operators/op_add.py +69 -0
  103. tico/serialize/operators/op_alias_copy.py +64 -0
  104. tico/serialize/operators/op_any.py +142 -0
  105. tico/serialize/operators/op_arange_start_step.py +61 -0
  106. tico/serialize/operators/op_argmax.py +62 -0
  107. tico/serialize/operators/op_avg_pool2d.py +112 -0
  108. tico/serialize/operators/op_bmm.py +62 -0
  109. tico/serialize/operators/op_cat.py +66 -0
  110. tico/serialize/operators/op_clamp.py +123 -0
  111. tico/serialize/operators/op_clone.py +71 -0
  112. tico/serialize/operators/op_constant_pad_nd.py +72 -0
  113. tico/serialize/operators/op_conv2d.py +181 -0
  114. tico/serialize/operators/op_copy.py +162 -0
  115. tico/serialize/operators/op_cos.py +59 -0
  116. tico/serialize/operators/op_cumsum.py +92 -0
  117. tico/serialize/operators/op_depthwise_conv2d.py +198 -0
  118. tico/serialize/operators/op_dequantize_per_channel.py +82 -0
  119. tico/serialize/operators/op_dequantize_per_tensor.py +64 -0
  120. tico/serialize/operators/op_div.py +62 -0
  121. tico/serialize/operators/op_embedding.py +60 -0
  122. tico/serialize/operators/op_eq.py +64 -0
  123. tico/serialize/operators/op_exp.py +60 -0
  124. tico/serialize/operators/op_expand.py +91 -0
  125. tico/serialize/operators/op_full.py +48 -0
  126. tico/serialize/operators/op_full_like.py +55 -0
  127. tico/serialize/operators/op_ge.py +54 -0
  128. tico/serialize/operators/op_gelu.py +59 -0
  129. tico/serialize/operators/op_gt.py +54 -0
  130. tico/serialize/operators/op_index.py +82 -0
  131. tico/serialize/operators/op_index_select.py +64 -0
  132. tico/serialize/operators/op_instance_norm.py +91 -0
  133. tico/serialize/operators/op_linear.py +70 -0
  134. tico/serialize/operators/op_log.py +53 -0
  135. tico/serialize/operators/op_log1p.py +83 -0
  136. tico/serialize/operators/op_logical_and.py +63 -0
  137. tico/serialize/operators/op_logical_not.py +62 -0
  138. tico/serialize/operators/op_lt.py +61 -0
  139. tico/serialize/operators/op_max_pool2d_with_indices.py +140 -0
  140. tico/serialize/operators/op_maximum.py +53 -0
  141. tico/serialize/operators/op_mean.py +66 -0
  142. tico/serialize/operators/op_minimum.py +53 -0
  143. tico/serialize/operators/op_mm.py +174 -0
  144. tico/serialize/operators/op_mul.py +99 -0
  145. tico/serialize/operators/op_ne.py +54 -0
  146. tico/serialize/operators/op_neg.py +59 -0
  147. tico/serialize/operators/op_permute.py +65 -0
  148. tico/serialize/operators/op_pow.py +138 -0
  149. tico/serialize/operators/op_prelu.py +54 -0
  150. tico/serialize/operators/op_quantize_per_tensor.py +79 -0
  151. tico/serialize/operators/op_reciprocal.py +64 -0
  152. tico/serialize/operators/op_relu.py +53 -0
  153. tico/serialize/operators/op_relu6.py +52 -0
  154. tico/serialize/operators/op_repeat.py +99 -0
  155. tico/serialize/operators/op_reshape.py +73 -0
  156. tico/serialize/operators/op_resize_nearest_neighbor.py +70 -0
  157. tico/serialize/operators/op_rsqrt.py +53 -0
  158. tico/serialize/operators/op_scalar_tensor.py +51 -0
  159. tico/serialize/operators/op_select_copy.py +65 -0
  160. tico/serialize/operators/op_sigmoid.py +56 -0
  161. tico/serialize/operators/op_sin.py +53 -0
  162. tico/serialize/operators/op_slice.py +155 -0
  163. tico/serialize/operators/op_softmax.py +100 -0
  164. tico/serialize/operators/op_split_with_sizes.py +96 -0
  165. tico/serialize/operators/op_sqrt.py +55 -0
  166. tico/serialize/operators/op_squeeze.py +73 -0
  167. tico/serialize/operators/op_sub.py +71 -0
  168. tico/serialize/operators/op_sum.py +63 -0
  169. tico/serialize/operators/op_tanh.py +54 -0
  170. tico/serialize/operators/op_to_copy.py +105 -0
  171. tico/serialize/operators/op_unsqueeze.py +66 -0
  172. tico/serialize/operators/op_view.py +74 -0
  173. tico/serialize/operators/op_where.py +82 -0
  174. tico/serialize/operators/utils.py +51 -0
  175. tico/serialize/pack.py +35 -0
  176. tico/serialize/quant_param.py +42 -0
  177. tico/utils/__init__.py +1 -0
  178. tico/utils/convert.py +292 -0
  179. tico/utils/define.py +35 -0
  180. tico/utils/diff_graph.py +181 -0
  181. tico/utils/errors.py +35 -0
  182. tico/utils/graph.py +200 -0
  183. tico/utils/logging.py +45 -0
  184. tico/utils/model.py +37 -0
  185. tico/utils/padding.py +47 -0
  186. tico/utils/passes.py +76 -0
  187. tico/utils/register_custom_op.py +562 -0
  188. tico/utils/trace_decorators.py +101 -0
  189. tico/utils/utils.py +314 -0
  190. tico/utils/validate_args_kwargs.py +1114 -0
  191. tico-0.1.0.dev250411.dist-info/LICENSE +241 -0
  192. tico-0.1.0.dev250411.dist-info/METADATA +17 -0
  193. tico-0.1.0.dev250411.dist-info/RECORD +196 -0
  194. tico-0.1.0.dev250411.dist-info/WHEEL +5 -0
  195. tico-0.1.0.dev250411.dist-info/entry_points.txt +3 -0
  196. tico-0.1.0.dev250411.dist-info/top_level.txt +1 -0
@@ -0,0 +1,562 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional
16
+
17
+ import torch
18
+ from torch._subclasses.fake_tensor import FakeTensor
19
+ from torch.library import custom_op, register_fake
20
+
21
+
22
+ # Note that an operator assumes input tensor has NHWC format.
23
+ def CircleResizeNearestNeighbor():
24
+ @custom_op("circle_custom::resize_nearest_neighbor", mutates_args=())
25
+ def resize_nearest_neighbor(input_: torch.Tensor, size: List[int]) -> torch.Tensor:
26
+ input_size = input_.size()
27
+ H = input_size[1]
28
+ W = input_size[2]
29
+ H_scale_factor = size[1] / H
30
+ W_scale_factor = size[2] / W
31
+ if H_scale_factor != W_scale_factor:
32
+ raise RuntimeError("Scale factor of H and W should be same.")
33
+ return torch.nn.functional.interpolate(
34
+ input_, scale_factor=H_scale_factor, mode="nearest"
35
+ )
36
+
37
+ @register_fake("circle_custom::resize_nearest_neighbor")
38
+ def _(input_: torch.Tensor, size: List[int]):
39
+ shape = list(input_.size())
40
+ new_shape = [shape[0]] + list(size) + [shape[3]]
41
+ result = torch.empty(new_shape, dtype=input_.dtype)
42
+ return result
43
+
44
+
45
+ def CircleConv2d():
46
+ """
47
+ Note that this op follows the input spec of `aten.conv2d.default` whose number
48
+ of arguments meets (2 <= node.args <= 7) condition.
49
+
50
+ [RESTRICTION]
51
+ Therefore, I tried to define a spec of conv2d as conv2d(input, weight, *args).
52
+ But, custom operators in torch do not support positional-only args. So, I set it
53
+ them as None by default.
54
+ """
55
+
56
+ @custom_op("circle_custom::conv2d", mutates_args=())
57
+ def conv2d(
58
+ input_: torch.Tensor,
59
+ weight: torch.Tensor,
60
+ bias: Optional[torch.Tensor] = None,
61
+ stride: Optional[List[int]] = None,
62
+ padding: Optional[List[int]] = None,
63
+ dilation: Optional[List[int]] = None,
64
+ groups: Optional[int] = None,
65
+ ) -> torch.Tensor:
66
+ """
67
+ Set default values.
68
+
69
+ Custom operators have limited types when it comes to default values.
70
+ So, let's set them by None in input specs, and then, set it by default values.
71
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
72
+ """
73
+ stride = [1, 1] if stride is None else stride
74
+ padding = [0, 0] if padding is None else padding
75
+ dilation = [1, 1] if dilation is None else dilation
76
+ groups = 1 if groups is None else groups
77
+
78
+ if groups != 1:
79
+ raise RuntimeError(
80
+ f"CircleConv2d only supports 1 'groups'. the node's groups: {groups}"
81
+ )
82
+
83
+ NHWC_to_NCHW = [0, 3, 1, 2]
84
+ OHWI_to_OIHW = [0, 3, 1, 2]
85
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
86
+ OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_OIHW)
87
+
88
+ args = [NCHW_input, OIHW_weight, bias, stride, padding, dilation, groups]
89
+ NCHW_output = torch.ops.aten.conv2d.default(*args)
90
+ NCHW_to_NHWC = [0, 2, 3, 1]
91
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
92
+
93
+ return NHWC_output
94
+
95
+ @register_fake("circle_custom::conv2d")
96
+ def _(
97
+ input_: torch.Tensor,
98
+ weight: torch.Tensor,
99
+ bias: Optional[torch.Tensor] = None,
100
+ stride: Optional[List[int]] = None,
101
+ padding: Optional[List[int]] = None,
102
+ dilation: Optional[List[int]] = None,
103
+ groups: Optional[int] = None,
104
+ ):
105
+ """
106
+ Set default values.
107
+
108
+ Custom operators have limited types when it comes to default values.
109
+ So, let's set them by None in input specs, and then, set it by default values.
110
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
111
+ """
112
+ stride = [1, 1] if stride is None else stride
113
+ padding = [0, 0] if padding is None else padding
114
+ dilation = [1, 1] if dilation is None else dilation
115
+ groups = 1 if groups is None else groups
116
+ if groups != 1:
117
+ raise RuntimeError(
118
+ f"CircleConv2d only supports 1 'groups'. the node's groups: {groups}"
119
+ )
120
+
121
+ NHWC_to_NCHW = [0, 3, 1, 2]
122
+ OHWI_to_OIHW = [0, 3, 1, 2]
123
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
124
+ OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_OIHW)
125
+
126
+ args = [NCHW_input, OIHW_weight, bias, stride, padding, dilation, groups]
127
+ NCHW_output = torch.ops.aten.conv2d.default(*args)
128
+ NCHW_to_NHWC = [0, 2, 3, 1]
129
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
130
+
131
+ return NHWC_output
132
+
133
+
134
+ def CircleConv2dPadding():
135
+ """
136
+ Almost same with `CircleConv2d` except padding arugment is a string type.
137
+
138
+ Q) Why create another custom op rather than make `CircleConv2d` cover multiple padding type?
139
+ A) `padding` with Optional[Union[List[int], str]] type is not allowed in torch.
140
+ """
141
+
142
+ @custom_op("circle_custom::conv2d.padding", mutates_args=())
143
+ def conv2d_padding(
144
+ input_: torch.Tensor,
145
+ weight: torch.Tensor,
146
+ bias: Optional[torch.Tensor] = None,
147
+ stride: Optional[List[int]] = None,
148
+ padding: Optional[str] = None,
149
+ dilation: Optional[List[int]] = None,
150
+ groups: Optional[int] = None,
151
+ ) -> torch.Tensor:
152
+ """
153
+ Set default values.
154
+
155
+ Custom operators have limited types when it comes to default values.
156
+ So, let's set them by None in input specs, and then, set it by default values.
157
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
158
+ """
159
+ stride = [1, 1] if stride is None else stride
160
+ padding = "valid" if padding is None else padding
161
+ dilation = [1, 1] if dilation is None else dilation
162
+ groups = 1 if groups is None else groups
163
+ if groups != 1:
164
+ raise RuntimeError(
165
+ f"CircleConv2d only supports 1 'groups'. the node's groups: {groups}"
166
+ )
167
+
168
+ NHWC_to_NCHW = [0, 3, 1, 2]
169
+ OHWI_to_OIHW = [0, 3, 1, 2]
170
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
171
+ OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_OIHW)
172
+
173
+ args = [NCHW_input, OIHW_weight, bias, stride, padding, dilation, groups]
174
+ NCHW_output = torch.ops.aten.conv2d.padding(*args)
175
+ NCHW_to_NHWC = [0, 2, 3, 1]
176
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
177
+
178
+ return NHWC_output
179
+
180
+ @register_fake("circle_custom::conv2d.padding")
181
+ def _(
182
+ input_: torch.Tensor,
183
+ weight: torch.Tensor,
184
+ bias: Optional[torch.Tensor] = None,
185
+ stride: Optional[List[int]] = None,
186
+ padding: Optional[str] = None,
187
+ dilation: Optional[List[int]] = None,
188
+ groups: Optional[int] = None,
189
+ ):
190
+ """
191
+ Set default values.
192
+
193
+ Custom operators have limited types when it comes to default values.
194
+ So, let's set them by None in input specs, and then, set it by default values.
195
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
196
+ """
197
+ stride = [1, 1] if stride is None else stride
198
+ padding = "valid" if padding is None else padding
199
+ dilation = [1, 1] if dilation is None else dilation
200
+ groups = 1 if groups is None else groups
201
+ if groups != 1:
202
+ raise RuntimeError(
203
+ f"CircleConv2d only supports 1 'groups'. the node's groups: {groups}"
204
+ )
205
+
206
+ NHWC_to_NCHW = [0, 3, 1, 2]
207
+ OHWI_to_OIHW = [0, 3, 1, 2]
208
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
209
+ OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_OIHW)
210
+
211
+ args = [NCHW_input, OIHW_weight, bias, stride, padding, dilation, groups]
212
+ NCHW_output = torch.ops.aten.conv2d.padding(*args)
213
+ NCHW_to_NHWC = [0, 2, 3, 1]
214
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
215
+
216
+ return NHWC_output
217
+
218
+
219
+ def CircleDepthwiseConv2d():
220
+ """
221
+ Note that this op follows the input spec of `aten.conv2d.default` whose number
222
+ of arguments meets (2 <= node.args <= 7) condition.
223
+
224
+ [RESTRICTION]
225
+ Therefore, I tried to define a spec of conv2d as conv2d(input, weight, *args).
226
+ But, custom operators in torch do not support positional-only args. So, I set it
227
+ them as None by default.
228
+ """
229
+
230
+ @custom_op("circle_custom::depthwise_conv2d", mutates_args=())
231
+ def depthwise_conv2d(
232
+ input_: torch.Tensor,
233
+ weight: torch.Tensor,
234
+ bias: Optional[torch.Tensor] = None,
235
+ stride: Optional[List[int]] = None,
236
+ padding: Optional[List[int]] = None,
237
+ dilation: Optional[List[int]] = None,
238
+ groups: Optional[int] = None,
239
+ ) -> torch.Tensor:
240
+ """
241
+ Set default values.
242
+
243
+ Custom operators have limited types when it comes to default values.
244
+ So, let's set them by None in input specs, and then, set it by default values.
245
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
246
+ """
247
+ stride = [1, 1] if stride is None else stride
248
+ padding = [0, 0] if padding is None else padding
249
+ dilation = [1, 1] if dilation is None else dilation
250
+
251
+ assert groups and groups > 1
252
+
253
+ NHWC_to_NCHW = [0, 3, 1, 2]
254
+ OHW1_to_1OHW = [3, 0, 1, 2]
255
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
256
+ _1OHW_weight = torch.ops.aten.permute.default(weight, OHW1_to_1OHW)
257
+
258
+ args = [NCHW_input, _1OHW_weight, bias, stride, padding, dilation, groups]
259
+ NCHW_output = torch.ops.aten.conv2d.default(*args)
260
+ NCHW_to_NHWC = [0, 2, 3, 1]
261
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
262
+
263
+ return NHWC_output
264
+
265
+ @register_fake("circle_custom::depthwise_conv2d")
266
+ def _(
267
+ input_: torch.Tensor,
268
+ weight: torch.Tensor,
269
+ bias: Optional[torch.Tensor] = None,
270
+ stride: Optional[List[int]] = None,
271
+ padding: Optional[List[int]] = None,
272
+ dilation: Optional[List[int]] = None,
273
+ groups: Optional[int] = None,
274
+ ):
275
+ """
276
+ Set default values.
277
+
278
+ Custom operators have limited types when it comes to default values.
279
+ So, let's set them by None in input specs, and then, set it by default values.
280
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
281
+ """
282
+ stride = [1, 1] if stride is None else stride
283
+ padding = [0, 0] if padding is None else padding
284
+ dilation = [1, 1] if dilation is None else dilation
285
+
286
+ assert groups and groups > 1
287
+
288
+ NHWC_to_NCHW = [0, 3, 1, 2]
289
+ OHW1_to_1OHW = [3, 0, 1, 2]
290
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
291
+ _1OHW_weight = torch.ops.aten.permute.default(weight, OHW1_to_1OHW)
292
+
293
+ args = [NCHW_input, _1OHW_weight, bias, stride, padding, dilation, groups]
294
+ NCHW_output = torch.ops.aten.conv2d.default(*args)
295
+ NCHW_to_NHWC = [0, 2, 3, 1]
296
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
297
+
298
+ return NHWC_output
299
+
300
+
301
+ def CircleDepthwiseConv2dPadding():
302
+ @custom_op("circle_custom::depthwise_conv2d.padding", mutates_args=())
303
+ def depthwise_conv2d_padding(
304
+ input_: torch.Tensor,
305
+ weight: torch.Tensor,
306
+ bias: Optional[torch.Tensor] = None,
307
+ stride: Optional[List[int]] = None,
308
+ padding: Optional[str] = None,
309
+ dilation: Optional[List[int]] = None,
310
+ groups: Optional[int] = None,
311
+ ) -> torch.Tensor:
312
+ """
313
+ Set default values.
314
+
315
+ Custom operators have limited types when it comes to default values.
316
+ So, let's set them by None in input specs, and then, set it by default values.
317
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
318
+ """
319
+ stride = [1, 1] if stride is None else stride
320
+ padding = "valid" if padding is None else padding
321
+ dilation = [1, 1] if dilation is None else dilation
322
+
323
+ assert groups and groups > 1
324
+
325
+ NHWC_to_NCHW = [0, 3, 1, 2]
326
+ OHW1_to_1OHW = [3, 0, 1, 2]
327
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
328
+ _1OHW_weight = torch.ops.aten.permute.default(weight, OHW1_to_1OHW)
329
+
330
+ args = [NCHW_input, _1OHW_weight, bias, stride, padding, dilation, groups]
331
+ NCHW_output = torch.ops.aten.conv2d.padding(*args)
332
+ NCHW_to_NHWC = [0, 2, 3, 1]
333
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
334
+
335
+ return NHWC_output
336
+
337
+ @register_fake("circle_custom::depthwise_conv2d.padding")
338
+ def _(
339
+ input_: torch.Tensor,
340
+ weight: torch.Tensor,
341
+ bias: Optional[torch.Tensor] = None,
342
+ stride: Optional[List[int]] = None,
343
+ padding: Optional[str] = None,
344
+ dilation: Optional[List[int]] = None,
345
+ groups: Optional[int] = None,
346
+ ):
347
+ """
348
+ Set default values.
349
+
350
+ Custom operators have limited types when it comes to default values.
351
+ So, let's set them by None in input specs, and then, set it by default values.
352
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
353
+ """
354
+ stride = [1, 1] if stride is None else stride
355
+ padding = "valid" if padding is None else padding
356
+ dilation = [1, 1] if dilation is None else dilation
357
+
358
+ assert groups and groups > 1
359
+
360
+ NHWC_to_NCHW = [0, 3, 1, 2]
361
+ OHW1_to_1OHW = [3, 0, 1, 2]
362
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
363
+ _1OHW_weight = torch.ops.aten.permute.default(weight, OHW1_to_1OHW)
364
+
365
+ args = [NCHW_input, _1OHW_weight, bias, stride, padding, dilation, groups]
366
+ NCHW_output = torch.ops.aten.conv2d.padding(*args)
367
+ NCHW_to_NHWC = [0, 2, 3, 1]
368
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
369
+
370
+ return NHWC_output
371
+
372
+
373
+ def CircleMaxPool2D():
374
+ """
375
+ Note that this op follows the input spec of `aten.max_pool2d_with_indices.default` whose number
376
+ of arguments meets (3 <= node.args <= 6) condition.
377
+
378
+ [RESTRICTION]
379
+ Custom operators in torch do not support positional-only args. So, I set it
380
+ them as None by default.
381
+ """
382
+
383
+ @custom_op("circle_custom::maxpool2d", mutates_args=())
384
+ def maxpool2d(
385
+ input_: torch.Tensor,
386
+ kernel_size: List[int],
387
+ stride: List[int],
388
+ padding: Optional[List[int]] = None,
389
+ dilation: Optional[List[int]] = None,
390
+ ceil_mode: Optional[bool] = None,
391
+ ) -> torch.Tensor:
392
+ """
393
+ Set default values.
394
+
395
+ Custom operators have limited types when it comes to default values.
396
+ So, let's set them by None in input specs, and then, set it by default values.
397
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
398
+ """
399
+ padding = [0, 0] if padding is None else padding
400
+ dilation = [1, 1] if dilation is None else dilation
401
+ ceil_mode = False if ceil_mode is None else ceil_mode
402
+
403
+ NHWC_to_NCHW = [0, 3, 1, 2]
404
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
405
+
406
+ args = [NCHW_input, kernel_size, stride, padding, dilation, ceil_mode]
407
+ NCHW_output = torch.ops.aten.max_pool2d_with_indices.default(*args)
408
+ NCHW_to_NHWC = [0, 2, 3, 1]
409
+ # use first output
410
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output[0], NCHW_to_NHWC)
411
+
412
+ return NHWC_output
413
+
414
+ @register_fake("circle_custom::maxpool2d")
415
+ def _(
416
+ input_: torch.Tensor,
417
+ kernel_size: List[int],
418
+ stride: List[int],
419
+ padding: Optional[List[int]] = None,
420
+ dilation: Optional[List[int]] = None,
421
+ ceil_mode: Optional[bool] = None,
422
+ ):
423
+ """
424
+ Set default values.
425
+
426
+ Custom operators have limited types when it comes to default values.
427
+ So, let's set them by None in input specs, and then, set it by default values.
428
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
429
+ """
430
+ padding = [0, 0] if padding is None else padding
431
+ dilation = [1, 1] if dilation is None else dilation
432
+ ceil_mode = False if ceil_mode is None else ceil_mode
433
+
434
+ NHWC_to_NCHW = [0, 3, 1, 2]
435
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
436
+
437
+ args = [NCHW_input, kernel_size, stride, padding, dilation, ceil_mode]
438
+ NCHW_output = torch.ops.aten.max_pool2d_with_indices.default(*args)
439
+ NCHW_to_NHWC = [0, 2, 3, 1]
440
+ # use first output
441
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output[0], NCHW_to_NHWC)
442
+
443
+ return NHWC_output
444
+
445
+
446
+ def CircleAvgPool2D():
447
+ @custom_op("circle_custom::avgpool2d", mutates_args=())
448
+ def avgpool2d(
449
+ input_: torch.Tensor,
450
+ kernel_size: List[int],
451
+ stride: List[int],
452
+ padding: Optional[List[int]] = None,
453
+ ceil_mode: Optional[bool] = None,
454
+ count_include_pad: Optional[bool] = None,
455
+ divisor_override: Optional[int] = None,
456
+ ) -> torch.Tensor:
457
+ padding = [0, 0] if padding is None else padding
458
+ ceil_mode = False if ceil_mode is None else ceil_mode
459
+ count_include_pad = True if count_include_pad is None else count_include_pad
460
+ divisor_override = None if divisor_override is None else divisor_override
461
+
462
+ NHWC_to_NCHW = [0, 3, 1, 2]
463
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
464
+
465
+ args = [
466
+ NCHW_input,
467
+ kernel_size,
468
+ stride,
469
+ padding,
470
+ ceil_mode,
471
+ count_include_pad,
472
+ divisor_override,
473
+ ]
474
+ NCHW_output = torch.ops.aten.avg_pool2d.default(*args)
475
+ NCHW_to_NHWC = [0, 2, 3, 1]
476
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
477
+
478
+ return NHWC_output
479
+
480
+ @register_fake("circle_custom::avgpool2d")
481
+ def _(
482
+ input_: torch.Tensor,
483
+ kernel_size: List[int],
484
+ stride: List[int],
485
+ padding: Optional[List[int]] = None,
486
+ ceil_mode: Optional[bool] = None,
487
+ count_include_pad: Optional[bool] = None,
488
+ divisor_override: Optional[int] = None,
489
+ ):
490
+ padding = [0, 0] if padding is None else padding
491
+ ceil_mode = False if ceil_mode is None else ceil_mode
492
+ count_include_pad = True if count_include_pad is None else count_include_pad
493
+ divisor_override = None if divisor_override is None else divisor_override
494
+
495
+ NHWC_to_NCHW = [0, 3, 1, 2]
496
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
497
+
498
+ args = [
499
+ NCHW_input,
500
+ kernel_size,
501
+ stride,
502
+ padding,
503
+ ceil_mode,
504
+ count_include_pad,
505
+ divisor_override,
506
+ ]
507
+ NCHW_output = torch.ops.aten.avg_pool2d.default(*args)
508
+ NCHW_to_NHWC = [0, 2, 3, 1]
509
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
510
+
511
+ return NHWC_output
512
+
513
+
514
+ def CircleInstanceNorm():
515
+ @custom_op("circle_custom::instance_norm", mutates_args=())
516
+ def instance_norm(
517
+ input_: torch.Tensor,
518
+ weight: Optional[torch.Tensor] = None,
519
+ bias: Optional[torch.Tensor] = None,
520
+ running_mean: Optional[torch.Tensor] = None,
521
+ running_var: Optional[torch.Tensor] = None,
522
+ use_input_stats: bool = False,
523
+ momentum: float = 0.1,
524
+ eps: float = 1e-05,
525
+ cudnn_enabled: bool = False,
526
+ ) -> torch.Tensor:
527
+ NHWC_to_NCHW = [0, 3, 1, 2]
528
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
529
+
530
+ args = [NCHW_input, weight, bias, None, None, False, momentum, eps, False]
531
+ NCHW_output = torch.ops.aten.instance_norm.default(*args)
532
+ NCHW_to_NHWC = [0, 2, 3, 1]
533
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
534
+
535
+ return NHWC_output
536
+
537
+ @register_fake("circle_custom::instance_norm")
538
+ def _(
539
+ input: FakeTensor,
540
+ weight: Optional[FakeTensor] = None,
541
+ bias: Optional[FakeTensor] = None,
542
+ running_mean: Optional[FakeTensor] = None,
543
+ running_var: Optional[FakeTensor] = None,
544
+ use_input_stats: bool = False,
545
+ momentum: float = 0.1,
546
+ eps: float = 1e-05,
547
+ cudnn_enabled: bool = False,
548
+ ):
549
+ # shape is preserved
550
+ return input.new_empty(input.size())
551
+
552
+
553
+ # Add custom ops to the torch namespace
554
+ def RegisterOps():
555
+ CircleResizeNearestNeighbor()
556
+ CircleDepthwiseConv2d()
557
+ CircleDepthwiseConv2dPadding()
558
+ CircleConv2d()
559
+ CircleConv2dPadding()
560
+ CircleMaxPool2D()
561
+ CircleAvgPool2D()
562
+ CircleInstanceNorm()
@@ -0,0 +1,101 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import wraps
16
+
17
+ import torch
18
+ from torch.export import ExportedProgram
19
+
20
+ from tico.utils.diff_graph import capture, capture_const, log, log_const
21
+ from tico.utils.passes import PassBase
22
+
23
+
24
+ def trace_const_diff_on_pass(cls):
25
+ """Decorator for PassBase to trace const diff"""
26
+
27
+ assert issubclass(cls, PassBase), type(cls)
28
+
29
+ def _call_traced(fn):
30
+ @wraps(fn)
31
+ def wrapped(*args):
32
+ _, exported_program = args
33
+ assert isinstance(exported_program, ExportedProgram)
34
+ graph_module = exported_program.graph_module
35
+ assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module)
36
+ capture_const(exported_program)
37
+ ret = fn(*args)
38
+ log_const(exported_program, title=str(cls.__name__), recapture=False)
39
+ return ret
40
+
41
+ return wrapped
42
+
43
+ # replace call function it with traced version
44
+ for key, val in vars(cls).items():
45
+ if key == "call":
46
+ setattr(cls, key, _call_traced(val))
47
+ return cls
48
+
49
+
50
+ def trace_graph_diff_on_pass(cls):
51
+ """Decorator for PassBase to trace graph diff"""
52
+
53
+ assert issubclass(cls, PassBase), type(cls)
54
+
55
+ def _call_traced(fn):
56
+ @wraps(fn)
57
+ def wrapped(*args):
58
+ _, exported_program = args
59
+ assert isinstance(exported_program, ExportedProgram)
60
+ graph_module = exported_program.graph_module
61
+ assert isinstance(graph_module, torch.fx.GraphModule), type(graph_module)
62
+ capture(graph_module.graph)
63
+ ret = fn(*args)
64
+ log(graph_module.graph, title=str(cls.__name__), recapture=False)
65
+ return ret
66
+
67
+ return wrapped
68
+
69
+ # replace call function it with traced version
70
+ for key, val in vars(cls).items():
71
+ if key == "call":
72
+ setattr(cls, key, _call_traced(val))
73
+ return cls
74
+
75
+
76
+ def trace_const_diff_on_func(fn):
77
+ """Decorator for function to trace const diff"""
78
+
79
+ @wraps(fn)
80
+ def wrapped(ep: torch.export.ExportedProgram):
81
+ assert isinstance(ep, torch.export.ExportedProgram)
82
+ capture_const(ep)
83
+ ret = fn(ep)
84
+ log_const(ret, title=str(fn.__name__), recapture=False)
85
+ return ret
86
+
87
+ return wrapped
88
+
89
+
90
+ def trace_graph_diff_on_func(fn):
91
+ """Decorator for function to trace graph diff"""
92
+
93
+ @wraps(fn)
94
+ def wrapped(ep: torch.export.ExportedProgram):
95
+ assert isinstance(ep, torch.export.ExportedProgram)
96
+ capture(ep.graph)
97
+ ret = fn(ep)
98
+ log(ret.graph, title=str(fn.__name__), recapture=False)
99
+ return ret
100
+
101
+ return wrapped