tico 0.1.0.dev250626__py3-none-any.whl → 0.1.0.dev250630__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250626"
24
+ __version__ = "0.1.0.dev250630"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -29,10 +29,12 @@ from tico.utils.passes import PassBase, PassResult
29
29
  from tico.utils.trace_decorators import trace_graph_diff_on_pass
30
30
  from tico.utils.utils import quant_min_max, set_new_meta_val
31
31
  from tico.utils.validate_args_kwargs import (
32
+ AddTensorArgs,
32
33
  BmmArgs,
33
34
  LinearArgs,
34
35
  MulTensorArgs,
35
36
  PermuteArgs,
37
+ ReluArgs,
36
38
  ReshapeArgs,
37
39
  )
38
40
 
@@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
77
79
  max_ = u8_scale * (255 - u8_zerop)
78
80
  min_ = u8_scale * (-u8_zerop)
79
81
 
80
- abs_max = max([max_, min_], key=abs)
82
+ abs_max = abs(max([max_, min_], key=abs))
81
83
  s16_scale = abs_max / 32767
82
84
  s16_zerop = 0
83
85
 
@@ -210,6 +212,42 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
210
212
  logger.debug(
211
213
  f"quantize_per_tensor.default is inserted after {node.name}."
212
214
  )
215
+ else:
216
+ raise NotYetSupportedError(
217
+ f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
218
+ )
219
+
220
+ elif node.target == torch.ops.aten.add.Tensor:
221
+ add_args = AddTensorArgs(*node.args, **node.kwargs)
222
+ x = add_args.input
223
+ y = add_args.other
224
+
225
+ if not isinstance(x, torch.fx.Node):
226
+ continue
227
+ if not isinstance(y, torch.fx.Node):
228
+ continue
229
+
230
+ if QPARAM_KEY not in x.meta:
231
+ continue
232
+ if QPARAM_KEY not in y.meta:
233
+ continue
234
+ if QPARAM_KEY not in node.meta:
235
+ continue
236
+
237
+ if qparam_dtype(x) == qparam_dtype(node):
238
+ continue
239
+
240
+ if qparam_dtype(x) != qparam_dtype(y):
241
+ continue
242
+
243
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
244
+ quantize = _insert_quantize_op_after(node)
245
+
246
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
247
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
248
+ logger.debug(
249
+ f"quantize_per_tensor.default is inserted after {node.name}."
250
+ )
213
251
  else:
214
252
  raise NotYetSupportedError("Unsupported dtype")
215
253
 
@@ -335,6 +373,30 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
335
373
  else:
336
374
  raise NotYetSupportedError("Unsupported dtype")
337
375
 
376
+ elif node.target == torch.ops.aten.relu.default:
377
+ relu_args = ReluArgs(*node.args, **node.kwargs)
378
+ inp = relu_args.input
379
+
380
+ if QPARAM_KEY not in inp.meta:
381
+ continue
382
+
383
+ if QPARAM_KEY not in node.meta:
384
+ continue
385
+
386
+ if qparam_dtype(inp) == qparam_dtype(node):
387
+ continue
388
+
389
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
390
+ quantize = _insert_quantize_op_after(node)
391
+
392
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
393
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
394
+ logger.debug(
395
+ f"quantize_per_tensor.default is inserted after {node.name}."
396
+ )
397
+ else:
398
+ raise NotYetSupportedError("Unsupported dtype")
399
+
338
400
  # TODO Support more ops.
339
401
 
340
402
  graph.eliminate_dead_code()
@@ -122,7 +122,7 @@ class Conv2dVisitor(NodeVisitor):
122
122
 
123
123
  if is_valid_padding(padding):
124
124
  conv2d_padding_type = VALID
125
- elif is_same_padding(padding, input_shape, output_shape):
125
+ elif is_same_padding(padding, input_shape, output_shape) and stride == [1, 1]:
126
126
  conv2d_padding_type = SAME
127
127
  else:
128
128
  assert isinstance(padding, list) and len(padding) == 2
tico/utils/convert.py CHANGED
@@ -162,6 +162,22 @@ def check_unsupported_target(exported_program: ExportedProgram):
162
162
  raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
163
163
 
164
164
 
165
+ def check_training_ops(exported_program: ExportedProgram):
166
+ TRAINING_OPS = {
167
+ torch.ops.aten.dropout.default,
168
+ torch.ops.aten.native_dropout.default,
169
+ }
170
+ found = set()
171
+ for node in exported_program.graph.nodes:
172
+ if node.op == "call_function" and node.target in TRAINING_OPS:
173
+ found.add(node.target)
174
+
175
+ if found:
176
+ raise RuntimeError(
177
+ f"Detected training-mode ops {found}. Call `model.eval()` before export."
178
+ )
179
+
180
+
165
181
  def convert_exported_module_to_circle(
166
182
  exported_program: ExportedProgram,
167
183
  config: CompileConfigBase = get_default_config(),
@@ -258,6 +274,7 @@ def convert_exported_module_to_circle(
258
274
  quantize_graph.run(exported_program)
259
275
 
260
276
  check_unsupported_target(exported_program)
277
+ check_training_ops(exported_program)
261
278
  circle_program = build_circle(exported_program)
262
279
 
263
280
  return circle_program
tico/utils/padding.py CHANGED
@@ -40,8 +40,8 @@ def is_same_padding(
40
40
  if isinstance(padding, list):
41
41
  assert len(padding) == 2, "Padding should be a list of length 2."
42
42
 
43
- input_HW = input_shape[1:2] # N H W C
44
- output_HW = output_shape[1:2] # N H W C
43
+ input_HW = tuple(input_shape[1:3]) # N H W C
44
+ output_HW = tuple(output_shape[1:3]) # N H W C
45
45
  return input_HW == output_HW
46
46
 
47
47
  raise InvalidArgumentError("Invalid padding.")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250626
3
+ Version: 0.1.0.dev250630
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=V3kK40Sfe76REBha3SgBJMRTeeZYJFErkkr3yo2nCE0,1743
1
+ tico/__init__.py,sha256=gNG4gqEHE73hnLBf47wz6Gj7RXqTV5DoQ8Bpss9Az84,1743
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
51
51
  tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
52
52
  tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
53
53
  tico/experimental/quantization/passes/fold_quant_ops.py,sha256=iaBMyO49CwVkhebMz3rjkHWfWE2LhwH6fORe7n4S6XQ,7040
54
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=AbNcI7rfIwHsQna_rFuwqFdOzFAU2lIB3sMK-vns8Dc,13072
54
+ tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=lQeN6VJNYQyDeucB4KpyyWvIiuhGRq7wjIeaCKdM7ck,15462
55
55
  tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
56
56
  tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
57
57
  tico/experimental/quantization/passes/quantize_bias.py,sha256=ZQ3rETYStpW28JUbODRixbq5sDEOiIOB_qWA-Jzuu-Y,4337
@@ -115,7 +115,7 @@ tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-h
115
115
  tico/serialize/operators/op_clamp.py,sha256=ZRAsXLGsZqJEh4wXxESEpRJkRtUuJWTDgAem6lr9_5I,4298
116
116
  tico/serialize/operators/op_clone.py,sha256=vzDYJ8TS3tc2BAyd_z8nt5VqT1inpymSseMEhd9dva0,2394
117
117
  tico/serialize/operators/op_constant_pad_nd.py,sha256=OpP4AP-d1IFcWZolNa-o9ZxzXJQkMdG9WQ66soX3s-E,2675
118
- tico/serialize/operators/op_conv2d.py,sha256=nC_jqzjlrUJ0L_lux_wXBqxDfq67jyroXSgrl5WoNfk,7317
118
+ tico/serialize/operators/op_conv2d.py,sha256=BmSCunhziD9EhXEkWwFrWkaQ_t3cIhrJJQSRLbgqmxI,7338
119
119
  tico/serialize/operators/op_copy.py,sha256=vaianLQ19-2ZQZ-MdQ07YuOPeFeo_HAx2a0Qfn7I5Kk,6122
120
120
  tico/serialize/operators/op_cos.py,sha256=N12bNyuTQIxRnD0eHRPdFVzRQPMy1NFM4iM8oQ4lYzw,2034
121
121
  tico/serialize/operators/op_cumsum.py,sha256=3fmOf1mIeCX1uhTBcSJmRGXejzLtO8UwaI1eEQDC6nA,3798
@@ -180,14 +180,14 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
180
180
  tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
181
181
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
182
182
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
183
- tico/utils/convert.py,sha256=gh_EcWwjZZ5IyQKtazfiZ1_FvAMXe0leu5nPN-0OcHg,11919
183
+ tico/utils/convert.py,sha256=crjCVDLoGBSSpI0EAq-cr_oL68CcntOBy361XVcaPzU,12444
184
184
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
185
185
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
186
186
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
187
187
  tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
188
188
  tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
189
189
  tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
190
- tico/utils/padding.py,sha256=GGO27VbaOvtaMYLDrSaKv7uxjeet566aMJD0PyYeMvQ,1484
190
+ tico/utils/padding.py,sha256=jNMX2KFoZ3c6HTlMU8BAwG3Fyrqpq4F3ytKP13Pg4ps,1498
191
191
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
192
192
  tico/utils/register_custom_op.py,sha256=qheG1WqtkUaG1SnHrrKQ7-fE4IZRETApCsfMkjDKcfs,23240
193
193
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
@@ -198,9 +198,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
198
198
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
199
199
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
200
200
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
201
- tico-0.1.0.dev250626.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
202
- tico-0.1.0.dev250626.dist-info/METADATA,sha256=WnpZWIQvH2xOtEbyhhWKZNI56zg_fJ3ublFR7M-t128,8846
203
- tico-0.1.0.dev250626.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
204
- tico-0.1.0.dev250626.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
205
- tico-0.1.0.dev250626.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
206
- tico-0.1.0.dev250626.dist-info/RECORD,,
201
+ tico-0.1.0.dev250630.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
202
+ tico-0.1.0.dev250630.dist-info/METADATA,sha256=SM3Z2qgcIkj7qUL2DKzlp4F47pwbr-3dR1ZTL-gtdMc,8846
203
+ tico-0.1.0.dev250630.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
204
+ tico-0.1.0.dev250630.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
205
+ tico-0.1.0.dev250630.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
206
+ tico-0.1.0.dev250630.dist-info/RECORD,,