tico 0.1.0.dev250629__py3-none-any.whl → 0.1.0.dev250701__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250629"
24
+ __version__ = "0.1.0.dev250701"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -29,10 +29,12 @@ from tico.utils.passes import PassBase, PassResult
29
29
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
30
30
  from tico.utils.utils import quant_min_max, set_new_meta_val
31
31
  from tico.utils.validate_args_kwargs import (
32
+ AddTensorArgs,
32
33
  BmmArgs,
33
34
  LinearArgs,
34
35
  MulTensorArgs,
35
36
  PermuteArgs,
37
+ ReluArgs,
36
38
  ReshapeArgs,
37
39
  )
38
40
 
@@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
77
79
  max_ = u8_scale * (255 - u8_zerop)
78
80
  min_ = u8_scale * (-u8_zerop)
79
81
 
80
- abs_max = max([max_, min_], key=abs)
82
+ abs_max = abs(max([max_, min_], key=abs))
81
83
  s16_scale = abs_max / 32767
82
84
  s16_zerop = 0
83
85
 
@@ -210,6 +212,42 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
210
212
  logger.debug(
211
213
  f"quantize_per_tensor.default is inserted after {node.name}."
212
214
  )
215
+ else:
216
+ raise NotYetSupportedError(
217
+ f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
218
+ )
219
+
220
+ elif node.target == torch.ops.aten.add.Tensor:
221
+ add_args = AddTensorArgs(*node.args, **node.kwargs)
222
+ x = add_args.input
223
+ y = add_args.other
224
+
225
+ if not isinstance(x, torch.fx.Node):
226
+ continue
227
+ if not isinstance(y, torch.fx.Node):
228
+ continue
229
+
230
+ if QPARAM_KEY not in x.meta:
231
+ continue
232
+ if QPARAM_KEY not in y.meta:
233
+ continue
234
+ if QPARAM_KEY not in node.meta:
235
+ continue
236
+
237
+ if qparam_dtype(x) == qparam_dtype(node):
238
+ continue
239
+
240
+ if qparam_dtype(x) != qparam_dtype(y):
241
+ continue
242
+
243
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
244
+ quantize = _insert_quantize_op_after(node)
245
+
246
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
247
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
248
+ logger.debug(
249
+ f"quantize_per_tensor.default is inserted after {node.name}."
250
+ )
213
251
  else:
214
252
  raise NotYetSupportedError("Unsupported dtype")
215
253
 
@@ -335,6 +373,30 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
335
373
  else:
336
374
  raise NotYetSupportedError("Unsupported dtype")
337
375
 
376
+ elif node.target == torch.ops.aten.relu.default:
377
+ relu_args = ReluArgs(*node.args, **node.kwargs)
378
+ inp = relu_args.input
379
+
380
+ if QPARAM_KEY not in inp.meta:
381
+ continue
382
+
383
+ if QPARAM_KEY not in node.meta:
384
+ continue
385
+
386
+ if qparam_dtype(inp) == qparam_dtype(node):
387
+ continue
388
+
389
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
390
+ quantize = _insert_quantize_op_after(node)
391
+
392
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
393
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
394
+ logger.debug(
395
+ f"quantize_per_tensor.default is inserted after {node.name}."
396
+ )
397
+ else:
398
+ raise NotYetSupportedError("Unsupported dtype")
399
+
338
400
  # TODO Support more ops.
339
401
 
340
402
  graph.eliminate_dead_code()
@@ -30,6 +30,7 @@ from tico.utils.utils import is_target_node
30
30
  from tico.utils.validate_args_kwargs import (
31
31
  AvgPool2dArgs,
32
32
  Conv2DArgs,
33
+ ConvTranspose2DArgs,
33
34
  DequantizePerChannelArgs,
34
35
  DequantizePerTensorArgs,
35
36
  InstanceNormArgs,
@@ -37,7 +38,9 @@ from tico.utils.validate_args_kwargs import (
37
38
  )
38
39
 
39
40
 
40
- def get_permute_weight_input(conv_args: Conv2DArgs) -> torch.fx.Node:
41
+ def get_permute_weight_input(
42
+ conv_args: Conv2DArgs | ConvTranspose2DArgs,
43
+ ) -> torch.fx.Node:
41
44
  """
42
45
  Retrieves the weight input for the permute operation.
43
46
 
@@ -194,6 +197,85 @@ class LegalizePreDefinedLayoutOperators(PassBase):
194
197
  modified = True
195
198
  return modified
196
199
 
200
+ def legalize_conv_transpose2d(self, exported_program, node) -> bool:
201
+ logger = logging.getLogger(__name__)
202
+ modified = False
203
+
204
+ graph_module = exported_program.graph_module
205
+ graph = graph_module.graph
206
+
207
+ args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
208
+ input = args.input
209
+ padding = args.padding
210
+ groups = args.groups
211
+ dilation = args.dilation
212
+
213
+ input_shape = extract_shape(input)
214
+ if not (len(input_shape) == 4):
215
+ raise NotYetSupportedError(
216
+ f"Only support 4D input tensor: node's input shape: {input_shape}"
217
+ )
218
+
219
+ if groups != 1:
220
+ raise NotYetSupportedError(
221
+ f"Only support groups=1: node's groups: {groups}"
222
+ )
223
+
224
+ if dilation != [1, 1]:
225
+ raise NotYetSupportedError(
226
+ f"Only support dilation=[1, 1]: node's groups: {dilation}"
227
+ )
228
+
229
+ NCHW_to_NHWC = [0, 2, 3, 1]
230
+ # input permute
231
+ with graph.inserting_after(input):
232
+ input_permute = create_node(
233
+ graph,
234
+ torch.ops.aten.permute.default,
235
+ args=(input, NCHW_to_NHWC),
236
+ origin=input,
237
+ )
238
+ node.update_arg(node.args.index(input), input_permute)
239
+
240
+ # weight permute
241
+ weight = get_permute_weight_input(args)
242
+ with graph.inserting_after(weight):
243
+ perm = [1, 2, 3, 0] # IOHW_to_OHWI
244
+ weight_permute = create_node(
245
+ graph,
246
+ torch.ops.aten.permute.default,
247
+ args=(weight, perm),
248
+ origin=weight,
249
+ )
250
+ if args.weight.target in [
251
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
252
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
253
+ ]:
254
+ dq = args.weight
255
+ dq.update_arg(dq.args.index(weight), weight_permute)
256
+ # Need to update dq.meta["val"] in FillMetaVal pass.
257
+ del dq.meta["val"]
258
+ else:
259
+ node.update_arg(node.args.index(weight), weight_permute)
260
+
261
+ with graph.inserting_before(node):
262
+ legalized_op = torch.ops.circle_custom.transpose_conv
263
+ circle_op = create_node(
264
+ graph, legalized_op, args=node.args, kwargs=node.kwargs, origin=node
265
+ )
266
+ # output permute
267
+ NHWC_to_NCHW = [0, 3, 1, 2]
268
+ conv_out_permute = create_node(
269
+ graph,
270
+ torch.ops.aten.permute.default,
271
+ args=(circle_op, NHWC_to_NCHW),
272
+ )
273
+ node.replace_all_uses_with(conv_out_permute, propagate_meta=True)
274
+
275
+ logger.debug(f"{node.name} is replaced with {circle_op.name}")
276
+ modified = True
277
+ return modified
278
+
197
279
  def legalize_instance_norm(self, exported_program, node) -> bool:
198
280
  logger = logging.getLogger(__name__)
199
281
  modified = False
@@ -365,6 +447,7 @@ class LegalizePreDefinedLayoutOperators(PassBase):
365
447
  target_to_legalize_func = {
366
448
  torch.ops.aten.conv2d.default: self.legalize_conv2d,
367
449
  torch.ops.aten.conv2d.padding: self.legalize_conv2d,
450
+ torch.ops.aten.conv_transpose2d.input: self.legalize_conv_transpose2d,
368
451
  torch.ops.aten.max_pool2d_with_indices.default: self.legalize_max_pool2d_with_indices,
369
452
  torch.ops.aten.avg_pool2d.default: self.legalize_avg_pool2d,
370
453
  torch.ops.aten.instance_norm.default: self.legalize_instance_norm,
@@ -122,7 +122,7 @@ class Conv2dVisitor(NodeVisitor):
122
122
 
123
123
  if is_valid_padding(padding):
124
124
  conv2d_padding_type = VALID
125
- elif is_same_padding(padding, input_shape, output_shape):
125
+ elif is_same_padding(padding, input_shape, output_shape) and stride == [1, 1]:
126
126
  conv2d_padding_type = SAME
127
127
  else:
128
128
  assert isinstance(padding, list) and len(padding) == 2
@@ -0,0 +1,165 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, Optional, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_mapping import (
24
+ circle_legalize_dtype_to,
25
+ extract_circle_dtype,
26
+ extract_shape,
27
+ )
28
+ from tico.serialize.operators.hashable_opcode import OpCode
29
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
30
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
31
+ from tico.serialize.quant_param import QPARAM_KEY, QuantParam
32
+ from tico.utils.define import define_pad_node
33
+ from tico.utils.padding import is_same_padding, is_valid_padding, SAME, VALID
34
+ from tico.utils.validate_args_kwargs import ConvTranspose2DArgs
35
+
36
+
37
+ @register_node_visitor
38
+ class TransposeConvVisitor(NodeVisitor):
39
+ target: List[torch._ops.OpOverload] = [
40
+ torch.ops.circle_custom.transpose_conv,
41
+ ]
42
+
43
+ def define_transpose_conv_node(
44
+ self, padding: int, stride: List, inputs: List, outputs: List
45
+ ) -> circle.Operator.OperatorT:
46
+ def set_transpose_conv_option(operator, stride):
47
+ operator.builtinOptionsType = (
48
+ circle.BuiltinOptions.BuiltinOptions.TransposeConvOptions
49
+ )
50
+ option = circle.TransposeConvOptions.TransposeConvOptionsT()
51
+ option.padding = padding
52
+ option.strideH = stride[0]
53
+ option.strideW = stride[1]
54
+ option.fusedActivationFunction = (
55
+ circle.ActivationFunctionType.ActivationFunctionType.NONE
56
+ )
57
+ operator.builtinOptions = option
58
+
59
+ transpose_conv_op_index = get_op_index(
60
+ circle.BuiltinOperator.BuiltinOperator.TRANSPOSE_CONV, self._op_codes
61
+ )
62
+ operator = create_builtin_operator(
63
+ self.graph, transpose_conv_op_index, inputs, outputs
64
+ )
65
+ set_transpose_conv_option(operator, stride)
66
+ return operator
67
+
68
+ def __init__(self, op_codes: Dict[OpCode, int], graph):
69
+ super().__init__(op_codes, graph)
70
+
71
+ def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT:
72
+ args = ConvTranspose2DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
73
+
74
+ input_ = args.input
75
+ weight = args.weight
76
+ bias = args.bias
77
+ stride = args.stride
78
+ padding = args.padding
79
+ output_padding = args.output_padding
80
+ groups = args.groups
81
+ dilation = args.dilation
82
+
83
+ assert groups == 1, "Only support group 1"
84
+
85
+ input_dtype: int = extract_circle_dtype(input_)
86
+ input_shape = list(extract_shape(input_))
87
+ assert len(input_shape) == 4, len(input_shape)
88
+ output_shape = extract_shape(node)
89
+ assert len(output_shape) == 4, len(output_shape)
90
+
91
+ conv_input: torch.fx.Node | circle.Tensor.TensorT = input_
92
+ weight_shape = list(extract_shape(weight))
93
+
94
+ if is_valid_padding(padding):
95
+ tconv2d_padding_type = VALID
96
+ elif is_same_padding(padding, input_shape, output_shape) and stride == [1, 1]:
97
+ tconv2d_padding_type = SAME
98
+ else:
99
+ assert isinstance(padding, list) and len(padding) == 2
100
+
101
+ tconv2d_padding_type = VALID
102
+
103
+ # Padding is not valid or same, so we use valid padding and add padding operator before tconv2d operator.
104
+ # when data_foramt is "NHWC", padding should be [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
105
+ paddings = torch.tensor(
106
+ [
107
+ [0, 0],
108
+ [padding[0], padding[0]],
109
+ [padding[1], padding[1]],
110
+ [0, 0],
111
+ ],
112
+ dtype=torch.int32,
113
+ )
114
+ pad_output_shape = [
115
+ input_shape[0],
116
+ input_shape[1],
117
+ input_shape[2],
118
+ input_shape[3],
119
+ ]
120
+ # Add (pad_top+pad_bottom) to pad_output_shape_h
121
+ pad_output_shape[1] += padding[0] * 2
122
+ # Add (pad_left+pad_Right) to pad_output_shape_w
123
+ pad_output_shape[2] += padding[1] * 2
124
+ # create padded output tensor
125
+ input_qparam: Optional[QuantParam] = (
126
+ input_.meta[QPARAM_KEY] if QPARAM_KEY in input_.meta else None
127
+ )
128
+ pad_output = self.graph.add_tensor_from_scratch(
129
+ prefix=f"{node.name}_input_pad_output",
130
+ shape=pad_output_shape,
131
+ dtype=input_dtype,
132
+ qparam=input_qparam,
133
+ source_node=node,
134
+ )
135
+ # CirclePad
136
+ pad_operator = define_pad_node(
137
+ self.graph, self._op_codes, [input_, paddings], [pad_output]
138
+ )
139
+ self.graph.add_operator(pad_operator)
140
+ conv_input = pad_output
141
+
142
+ if bias is None:
143
+ # luci-interpreter can't run no bias conv. Let's add zero vector for bias.
144
+ assert len(weight_shape) == 4
145
+ out_channel = weight_shape[0]
146
+ bias = [0.0] * out_channel # type: ignore[assignment]
147
+
148
+ # First arguemnt is output shape of tconv.
149
+ assert output_shape[0] == input_shape[0]
150
+ assert output_shape[3] == weight_shape[0]
151
+ tconv_output = circle_legalize_dtype_to(output_shape, dtype=torch.int32)
152
+
153
+ tconv_output_tensor = self.graph.add_const_tensor(
154
+ tconv_output, source_node=node
155
+ )
156
+
157
+ # TConv2D
158
+ tconv2d_operator = self.define_transpose_conv_node(
159
+ tconv2d_padding_type, # 'SAME'(0) or 'VALID'(1)
160
+ stride,
161
+ [tconv_output_tensor, weight, conv_input, bias],
162
+ [node],
163
+ )
164
+
165
+ return tconv2d_operator
tico/utils/convert.py CHANGED
@@ -100,6 +100,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
100
100
  torch.ops.aten.conv2d.padding,
101
101
  torch.ops.aten.conv1d.default,
102
102
  torch.ops.aten.conv1d.padding,
103
+ torch.ops.aten.conv_transpose2d.input,
103
104
  torch.ops.aten.instance_norm.default,
104
105
  torch.ops.aten._safe_softmax.default,
105
106
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
@@ -116,6 +117,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
116
117
  torch.ops.aten.conv2d.padding,
117
118
  torch.ops.aten.conv1d.default,
118
119
  torch.ops.aten.conv1d.padding,
120
+ torch.ops.aten.conv_transpose2d.input,
119
121
  torch.ops.aten.instance_norm.default,
120
122
  torch.ops.aten._safe_softmax.default,
121
123
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
@@ -174,7 +176,7 @@ def check_training_ops(exported_program: ExportedProgram):
174
176
 
175
177
  if found:
176
178
  raise RuntimeError(
177
- f"Detected training-mode ops {sorted(found)}. Call `model.eval()` before export."
179
+ f"Detected training-mode ops {found}. Call `model.eval()` before export."
178
180
  )
179
181
 
180
182
 
@@ -186,7 +188,6 @@ def convert_exported_module_to_circle(
186
188
  logger.debug("Input ExportedProgram (must be core aten)")
187
189
  logger.debug(exported_program)
188
190
 
189
- check_training_ops(exported_program)
190
191
  # PRE-EDGE PASSES
191
192
  #
192
193
  # Here are the passes that run before to_edge() conversion.
@@ -275,6 +276,7 @@ def convert_exported_module_to_circle(
275
276
  quantize_graph.run(exported_program)
276
277
 
277
278
  check_unsupported_target(exported_program)
279
+ check_training_ops(exported_program)
278
280
  circle_program = build_circle(exported_program)
279
281
 
280
282
  return circle_program
@@ -287,6 +289,12 @@ def convert(
287
289
  strict: bool = True,
288
290
  config: CompileConfigBase = get_default_config(),
289
291
  ) -> CircleModel:
292
+ if hasattr(mod, "training") and mod.training:
293
+ logger = logging.getLogger(__name__)
294
+ logger.fatal(
295
+ "Your model is in TRAINING MODE. PLEASE CHECK IF YOU FORGOT `model.eval()`."
296
+ )
297
+
290
298
  with torch.no_grad():
291
299
  exported_program = export(mod, args, kwargs, strict=strict)
292
300
 
tico/utils/padding.py CHANGED
@@ -40,8 +40,8 @@ def is_same_padding(
40
40
  if isinstance(padding, list):
41
41
  assert len(padding) == 2, "Padding should be a list of length 2."
42
42
 
43
- input_HW = input_shape[1:2] # N H W C
44
- output_HW = output_shape[1:2] # N H W C
43
+ input_HW = tuple(input_shape[1:3]) # N H W C
44
+ output_HW = tuple(output_shape[1:3]) # N H W C
45
45
  return input_HW == output_HW
46
46
 
47
47
  raise InvalidArgumentError("Invalid padding.")
@@ -371,6 +371,113 @@ def CircleDepthwiseConv2dPadding():
371
371
  return NHWC_output
372
372
 
373
373
 
374
+ def CircleTransposeConv():
375
+ """
376
+ Note that this op follows the input spec of `aten.conv_transpose2d.input` whose number
377
+ of arguments meets (2 <= node.args <= 8) condition.
378
+ [RESTRICTION]
379
+ Therefore, I tried to define a spec of it as transpose_conv(input, weight, *args).
380
+ But, custom operators in torch do not support positional-only args. So, I set it
381
+ them as None by default.
382
+ """
383
+
384
+ @custom_op("circle_custom::transpose_conv", mutates_args=())
385
+ def transpose_conv(
386
+ input_: torch.Tensor,
387
+ weight: torch.Tensor,
388
+ bias: Optional[torch.Tensor] = None,
389
+ stride: Optional[List[int]] = None,
390
+ padding: Optional[List[int]] = None,
391
+ output_padding: Optional[List[int]] = None,
392
+ groups: Optional[int] = None,
393
+ dilation: Optional[List[int]] = None,
394
+ ) -> torch.Tensor:
395
+ """
396
+ Set default values.
397
+ Custom operators have limited types when it comes to default values.
398
+ So, let's set them by None in input specs, and then, set it by default values.
399
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
400
+ """
401
+ stride = [1, 1] if stride is None else stride
402
+ padding = [0, 0] if padding is None else padding
403
+ output_padding = [0, 0] if output_padding is None else output_padding
404
+ groups = 1 if groups is None else groups
405
+ dilation = [1, 1] if dilation is None else dilation
406
+ if groups != 1:
407
+ raise RuntimeError(
408
+ f"CircleTransposeConv only supports 1 'groups'. the node's groups: {groups}"
409
+ )
410
+
411
+ NHWC_to_NCHW = [0, 3, 1, 2]
412
+ OHWI_to_IOHW = [3, 0, 1, 2]
413
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
414
+ OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_IOHW)
415
+
416
+ args = [
417
+ NCHW_input,
418
+ OIHW_weight,
419
+ bias,
420
+ stride,
421
+ padding,
422
+ output_padding,
423
+ groups,
424
+ dilation,
425
+ ]
426
+ NCHW_output = torch.ops.aten.conv_transpose2d.input(*args)
427
+ NCHW_to_NHWC = [0, 2, 3, 1]
428
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
429
+
430
+ return NHWC_output
431
+
432
+ @register_fake("circle_custom::transpose_conv")
433
+ def _(
434
+ input_: torch.Tensor,
435
+ weight: torch.Tensor,
436
+ bias: Optional[torch.Tensor] = None,
437
+ stride: Optional[List[int]] = None,
438
+ padding: Optional[List[int]] = None,
439
+ output_padding: Optional[List[int]] = None,
440
+ groups: Optional[int] = None,
441
+ dilation: Optional[List[int]] = None,
442
+ ):
443
+ """
444
+ Set default values.
445
+ Custom operators have limited types when it comes to default values.
446
+ So, let's set them by None in input specs, and then, set it by default values.
447
+ https://github.com/pytorch/pytorch/blob/6b05aafc/torch/_library/infer_schema.py#L131-L144
448
+ """
449
+ stride = [1, 1] if stride is None else stride
450
+ padding = [0, 0] if padding is None else padding
451
+ output_padding = [0, 0] if output_padding is None else output_padding
452
+ groups = 1 if groups is None else groups
453
+ dilation = [1, 1] if dilation is None else dilation
454
+ if groups != 1:
455
+ raise RuntimeError(
456
+ f"CircleConv2d only supports 1 'groups'. the node's groups: {groups}"
457
+ )
458
+
459
+ NHWC_to_NCHW = [0, 3, 1, 2]
460
+ OHWI_to_IOHW = [3, 0, 1, 2]
461
+ NCHW_input = torch.ops.aten.permute.default(input_, NHWC_to_NCHW)
462
+ OIHW_weight = torch.ops.aten.permute.default(weight, OHWI_to_IOHW)
463
+
464
+ args = [
465
+ NCHW_input,
466
+ OIHW_weight,
467
+ bias,
468
+ stride,
469
+ padding,
470
+ output_padding,
471
+ groups,
472
+ dilation,
473
+ ]
474
+ NCHW_output = torch.ops.aten.conv_transpose2d.input(*args)
475
+ NCHW_to_NHWC = [0, 2, 3, 1]
476
+ NHWC_output = torch.ops.aten.permute.default(NCHW_output, NCHW_to_NHWC)
477
+
478
+ return NHWC_output
479
+
480
+
374
481
  def CircleMaxPool2D():
375
482
  """
376
483
  Note that this op follows the input spec of `aten.max_pool2d_with_indices.default` whose number
@@ -603,6 +710,7 @@ def RegisterOps():
603
710
  CircleDepthwiseConv2dPadding()
604
711
  CircleConv2d()
605
712
  CircleConv2dPadding()
713
+ CircleTransposeConv()
606
714
  CircleMaxPool2D()
607
715
  CircleAvgPool2D()
608
716
  CircleInstanceNorm()
@@ -208,6 +208,45 @@ class ConstantPadNdArgs:
208
208
  value: int | float
209
209
 
210
210
 
211
+ @enforce_type
212
+ @dataclass
213
+ class ConvArgs:
214
+ """
215
+ convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
216
+ """
217
+
218
+ input: torch.fx.Node
219
+ weight: torch.fx.Node
220
+ bias: Union[torch.fx.Node, None]
221
+ stride: List[int]
222
+ padding: List[int]
223
+ dilation: List[int]
224
+ transposed: bool
225
+ output_padding: List[int]
226
+ groups: int
227
+
228
+
229
+ @enforce_type
230
+ @dataclass
231
+ class ConvTranspose2DArgs:
232
+ """
233
+ conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor
234
+ """
235
+
236
+ input: torch.fx.Node
237
+ weight: torch.fx.Node
238
+ bias: Union[torch.fx.Node, None] = None
239
+ stride: List[int] = field(default_factory=lambda: [1, 1])
240
+ padding: List[int] = field(default_factory=lambda: [0, 0])
241
+ output_padding: List[int] = field(default_factory=lambda: [0, 0])
242
+ groups: int = 1
243
+ dilation: List[int] = field(default_factory=lambda: [1, 1])
244
+
245
+ def __post_init__(self):
246
+ assert len(self.stride) == 2, len(self.stride)
247
+ assert len(self.dilation) == 2, len(self.dilation)
248
+
249
+
211
250
  @enforce_type
212
251
  @dataclass
213
252
  class Conv2DArgs:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250629
3
+ Version: 0.1.0.dev250701
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=RfkSpTE90FwwM4kJohc9prmZ7kQXSUmZaZiUKxrn-D0,1743
1
+ tico/__init__.py,sha256=na5S7s-9ToqvW2qHQsFkHg9ST_gMeKGTyCnymdNnbo0,1743
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
51
51
  tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
52
52
  tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
53
53
  tico/experimental/quantization/passes/fold_quant_ops.py,sha256=iaBMyO49CwVkhebMz3rjkHWfWE2LhwH6fORe7n4S6XQ,7040
54
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=AbNcI7rfIwHsQna_rFuwqFdOzFAU2lIB3sMK-vns8Dc,13072
54
+ tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=lQeN6VJNYQyDeucB4KpyyWvIiuhGRq7wjIeaCKdM7ck,15462
55
55
  tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
56
56
  tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
57
57
  tico/experimental/quantization/passes/quantize_bias.py,sha256=ZQ3rETYStpW28JUbODRixbq5sDEOiIOB_qWA-Jzuu-Y,4337
@@ -79,7 +79,7 @@ tico/passes/fill_meta_val.py,sha256=Xbam6Aq90ZfWItZw1dgLIwH_q8RCiU5JodKNqkj-ink,
79
79
  tico/passes/fuse_leading_unsqueeze_reshape.py,sha256=88jwTP35yRyXOk9xdO6YW2OEfdKAws3KFRT16WQz0RI,4291
80
80
  tico/passes/fuse_redundant_reshape_to_mean.py,sha256=GhJS1ZKB6Ns4AhwcW3uUQ6q-0N-AzlD32B2EwusUJHg,3761
81
81
  tico/passes/legalize_causal_mask_value.py,sha256=xKdFwwMaSFCSQpSk8xISOAqFpZ1jIhgbBIqf7KTSGuk,4017
82
- tico/passes/legalize_predefined_layout_operators.py,sha256=6jd_FmXX5rbBxqp3H5MQoCnL3vY3qoAdXaXkVdfXEjI,15902
82
+ tico/passes/legalize_predefined_layout_operators.py,sha256=MNx7L2dAlsxSazb-F7c0onPqHleI17zAc7AzQAa9aJ4,18934
83
83
  tico/passes/lower_pow2_to_mul.py,sha256=nfJXa9ZTZMiLg6ownSyvkM4KF2z9tZW34Q3CCWI_vmQ,2402
84
84
  tico/passes/lower_to_resize_nearest_neighbor.py,sha256=N6F56Of8Aiv-KIiYLHnh33WX72W60ZVQSBEYWHdYqNQ,9005
85
85
  tico/passes/lower_to_slice.py,sha256=0qAX3WzZdyMFDW4DiO9b5JFXd4rL1-0doBT6lJvaw_I,7260
@@ -115,7 +115,7 @@ tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-h
115
115
  tico/serialize/operators/op_clamp.py,sha256=ZRAsXLGsZqJEh4wXxESEpRJkRtUuJWTDgAem6lr9_5I,4298
116
116
  tico/serialize/operators/op_clone.py,sha256=vzDYJ8TS3tc2BAyd_z8nt5VqT1inpymSseMEhd9dva0,2394
117
117
  tico/serialize/operators/op_constant_pad_nd.py,sha256=OpP4AP-d1IFcWZolNa-o9ZxzXJQkMdG9WQ66soX3s-E,2675
118
- tico/serialize/operators/op_conv2d.py,sha256=nC_jqzjlrUJ0L_lux_wXBqxDfq67jyroXSgrl5WoNfk,7317
118
+ tico/serialize/operators/op_conv2d.py,sha256=BmSCunhziD9EhXEkWwFrWkaQ_t3cIhrJJQSRLbgqmxI,7338
119
119
  tico/serialize/operators/op_copy.py,sha256=vaianLQ19-2ZQZ-MdQ07YuOPeFeo_HAx2a0Qfn7I5Kk,6122
120
120
  tico/serialize/operators/op_cos.py,sha256=N12bNyuTQIxRnD0eHRPdFVzRQPMy1NFM4iM8oQ4lYzw,2034
121
121
  tico/serialize/operators/op_cumsum.py,sha256=3fmOf1mIeCX1uhTBcSJmRGXejzLtO8UwaI1eEQDC6nA,3798
@@ -175,32 +175,33 @@ tico/serialize/operators/op_sub.py,sha256=yZskQJF0ylXVk02Uid8djPNIWDJ-0uHJar4UYh
175
175
  tico/serialize/operators/op_sum.py,sha256=B5aSwQMhyoBe2JYdE5nVQ3QeVDSzL-yuZZujsG08OdQ,2294
176
176
  tico/serialize/operators/op_tanh.py,sha256=rs7FsbQeUQ7Ak8RoQV9ymNGXHXRObojfY_SiqJiyqdA,1846
177
177
  tico/serialize/operators/op_to_copy.py,sha256=a8T0uPMavMO_md1a-4_0dlvDHyZS_xew0qB6xjf69rI,3934
178
+ tico/serialize/operators/op_transpose_conv.py,sha256=ZpVNWH58FqbQahExeZIpW51w3VtfGg49JlHCDSOv7wg,6370
178
179
  tico/serialize/operators/op_unsqueeze.py,sha256=ZHhfVXSWEiwb2VDYX5uhxbGQyzZjKT7CrbBpVGxVHBU,2310
179
180
  tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FOrlm9_o,2546
180
181
  tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
181
182
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
182
183
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
183
- tico/utils/convert.py,sha256=fJRltUkBALia6N7lKcMMvmNBU-5DYyztnNZQjcgjvXU,12452
184
+ tico/utils/convert.py,sha256=5C8Z2ia2XN4k3XgtJrFZYJSEejoeMllyr8YW6gwu9mw,12763
184
185
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
185
186
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
186
187
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
187
188
  tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
188
189
  tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
189
190
  tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
190
- tico/utils/padding.py,sha256=GGO27VbaOvtaMYLDrSaKv7uxjeet566aMJD0PyYeMvQ,1484
191
+ tico/utils/padding.py,sha256=jNMX2KFoZ3c6HTlMU8BAwG3Fyrqpq4F3ytKP13Pg4ps,1498
191
192
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
192
- tico/utils/register_custom_op.py,sha256=qheG1WqtkUaG1SnHrrKQ7-fE4IZRETApCsfMkjDKcfs,23240
193
+ tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
193
194
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
194
195
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
195
196
  tico/utils/utils.py,sha256=fnbZ2RLH6-J-wqb32O4qsR1ce4BJU0wYNrk84QXa6_E,13158
196
- tico/utils/validate_args_kwargs.py,sha256=ifzO4ikubDPU2iXRBPF8KeyubW23cjxBThOslLAcTrg,25368
197
+ tico/utils/validate_args_kwargs.py,sha256=cJAK6aqdzK3_Xccu6K1FQ32WGdmwWA_SqJ--TPavIuk,26614
197
198
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
198
199
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
199
200
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
200
201
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
201
- tico-0.1.0.dev250629.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
202
- tico-0.1.0.dev250629.dist-info/METADATA,sha256=DSkGeB1pQ2OlIkRm5lh4m4GpX5e5tuEk9rsU0joBLWE,8846
203
- tico-0.1.0.dev250629.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
204
- tico-0.1.0.dev250629.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
205
- tico-0.1.0.dev250629.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
206
- tico-0.1.0.dev250629.dist-info/RECORD,,
202
+ tico-0.1.0.dev250701.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
203
+ tico-0.1.0.dev250701.dist-info/METADATA,sha256=FKpHfi_Mwf8vsk5athXrwRHTOulyZP0H45y_zUJ6uqM,8846
204
+ tico-0.1.0.dev250701.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
205
+ tico-0.1.0.dev250701.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
206
+ tico-0.1.0.dev250701.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
207
+ tico-0.1.0.dev250701.dist-info/RECORD,,