tico 0.1.0.dev250626__py3-none-any.whl → 0.1.0.dev250630__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +63 -1
- tico/serialize/operators/op_conv2d.py +1 -1
- tico/utils/convert.py +17 -0
- tico/utils/padding.py +2 -2
- {tico-0.1.0.dev250626.dist-info → tico-0.1.0.dev250630.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250626.dist-info → tico-0.1.0.dev250630.dist-info}/RECORD +11 -11
- {tico-0.1.0.dev250626.dist-info → tico-0.1.0.dev250630.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250626.dist-info → tico-0.1.0.dev250630.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250626.dist-info → tico-0.1.0.dev250630.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250626.dist-info → tico-0.1.0.dev250630.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250630"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -29,10 +29,12 @@ from tico.utils.passes import PassBase, PassResult
|
|
29
29
|
from tico.utils.trace_decorators import trace_graph_diff_on_pass
|
30
30
|
from tico.utils.utils import quant_min_max, set_new_meta_val
|
31
31
|
from tico.utils.validate_args_kwargs import (
|
32
|
+
AddTensorArgs,
|
32
33
|
BmmArgs,
|
33
34
|
LinearArgs,
|
34
35
|
MulTensorArgs,
|
35
36
|
PermuteArgs,
|
37
|
+
ReluArgs,
|
36
38
|
ReshapeArgs,
|
37
39
|
)
|
38
40
|
|
@@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
|
|
77
79
|
max_ = u8_scale * (255 - u8_zerop)
|
78
80
|
min_ = u8_scale * (-u8_zerop)
|
79
81
|
|
80
|
-
abs_max = max([max_, min_], key=abs)
|
82
|
+
abs_max = abs(max([max_, min_], key=abs))
|
81
83
|
s16_scale = abs_max / 32767
|
82
84
|
s16_zerop = 0
|
83
85
|
|
@@ -210,6 +212,42 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
210
212
|
logger.debug(
|
211
213
|
f"quantize_per_tensor.default is inserted after {node.name}."
|
212
214
|
)
|
215
|
+
else:
|
216
|
+
raise NotYetSupportedError(
|
217
|
+
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
|
218
|
+
)
|
219
|
+
|
220
|
+
elif node.target == torch.ops.aten.add.Tensor:
|
221
|
+
add_args = AddTensorArgs(*node.args, **node.kwargs)
|
222
|
+
x = add_args.input
|
223
|
+
y = add_args.other
|
224
|
+
|
225
|
+
if not isinstance(x, torch.fx.Node):
|
226
|
+
continue
|
227
|
+
if not isinstance(y, torch.fx.Node):
|
228
|
+
continue
|
229
|
+
|
230
|
+
if QPARAM_KEY not in x.meta:
|
231
|
+
continue
|
232
|
+
if QPARAM_KEY not in y.meta:
|
233
|
+
continue
|
234
|
+
if QPARAM_KEY not in node.meta:
|
235
|
+
continue
|
236
|
+
|
237
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
238
|
+
continue
|
239
|
+
|
240
|
+
if qparam_dtype(x) != qparam_dtype(y):
|
241
|
+
continue
|
242
|
+
|
243
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
244
|
+
quantize = _insert_quantize_op_after(node)
|
245
|
+
|
246
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
247
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
248
|
+
logger.debug(
|
249
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
250
|
+
)
|
213
251
|
else:
|
214
252
|
raise NotYetSupportedError("Unsupported dtype")
|
215
253
|
|
@@ -335,6 +373,30 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
335
373
|
else:
|
336
374
|
raise NotYetSupportedError("Unsupported dtype")
|
337
375
|
|
376
|
+
elif node.target == torch.ops.aten.relu.default:
|
377
|
+
relu_args = ReluArgs(*node.args, **node.kwargs)
|
378
|
+
inp = relu_args.input
|
379
|
+
|
380
|
+
if QPARAM_KEY not in inp.meta:
|
381
|
+
continue
|
382
|
+
|
383
|
+
if QPARAM_KEY not in node.meta:
|
384
|
+
continue
|
385
|
+
|
386
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
387
|
+
continue
|
388
|
+
|
389
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
390
|
+
quantize = _insert_quantize_op_after(node)
|
391
|
+
|
392
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
393
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
394
|
+
logger.debug(
|
395
|
+
f"quantize_per_tensor.default is inserted after {node.name}."
|
396
|
+
)
|
397
|
+
else:
|
398
|
+
raise NotYetSupportedError("Unsupported dtype")
|
399
|
+
|
338
400
|
# TODO Support more ops.
|
339
401
|
|
340
402
|
graph.eliminate_dead_code()
|
@@ -122,7 +122,7 @@ class Conv2dVisitor(NodeVisitor):
|
|
122
122
|
|
123
123
|
if is_valid_padding(padding):
|
124
124
|
conv2d_padding_type = VALID
|
125
|
-
elif is_same_padding(padding, input_shape, output_shape):
|
125
|
+
elif is_same_padding(padding, input_shape, output_shape) and stride == [1, 1]:
|
126
126
|
conv2d_padding_type = SAME
|
127
127
|
else:
|
128
128
|
assert isinstance(padding, list) and len(padding) == 2
|
tico/utils/convert.py
CHANGED
@@ -162,6 +162,22 @@ def check_unsupported_target(exported_program: ExportedProgram):
|
|
162
162
|
raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
|
163
163
|
|
164
164
|
|
165
|
+
def check_training_ops(exported_program: ExportedProgram):
|
166
|
+
TRAINING_OPS = {
|
167
|
+
torch.ops.aten.dropout.default,
|
168
|
+
torch.ops.aten.native_dropout.default,
|
169
|
+
}
|
170
|
+
found = set()
|
171
|
+
for node in exported_program.graph.nodes:
|
172
|
+
if node.op == "call_function" and node.target in TRAINING_OPS:
|
173
|
+
found.add(node.target)
|
174
|
+
|
175
|
+
if found:
|
176
|
+
raise RuntimeError(
|
177
|
+
f"Detected training-mode ops {found}. Call `model.eval()` before export."
|
178
|
+
)
|
179
|
+
|
180
|
+
|
165
181
|
def convert_exported_module_to_circle(
|
166
182
|
exported_program: ExportedProgram,
|
167
183
|
config: CompileConfigBase = get_default_config(),
|
@@ -258,6 +274,7 @@ def convert_exported_module_to_circle(
|
|
258
274
|
quantize_graph.run(exported_program)
|
259
275
|
|
260
276
|
check_unsupported_target(exported_program)
|
277
|
+
check_training_ops(exported_program)
|
261
278
|
circle_program = build_circle(exported_program)
|
262
279
|
|
263
280
|
return circle_program
|
tico/utils/padding.py
CHANGED
@@ -40,8 +40,8 @@ def is_same_padding(
|
|
40
40
|
if isinstance(padding, list):
|
41
41
|
assert len(padding) == 2, "Padding should be a list of length 2."
|
42
42
|
|
43
|
-
input_HW = input_shape[1:
|
44
|
-
output_HW = output_shape[1:
|
43
|
+
input_HW = tuple(input_shape[1:3]) # N H W C
|
44
|
+
output_HW = tuple(output_shape[1:3]) # N H W C
|
45
45
|
return input_HW == output_HW
|
46
46
|
|
47
47
|
raise InvalidArgumentError("Invalid padding.")
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=gNG4gqEHE73hnLBf47wz6Gj7RXqTV5DoQ8Bpss9Az84,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
|
|
51
51
|
tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
|
52
52
|
tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
53
53
|
tico/experimental/quantization/passes/fold_quant_ops.py,sha256=iaBMyO49CwVkhebMz3rjkHWfWE2LhwH6fORe7n4S6XQ,7040
|
54
|
-
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=
|
54
|
+
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=lQeN6VJNYQyDeucB4KpyyWvIiuhGRq7wjIeaCKdM7ck,15462
|
55
55
|
tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
|
56
56
|
tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
|
57
57
|
tico/experimental/quantization/passes/quantize_bias.py,sha256=ZQ3rETYStpW28JUbODRixbq5sDEOiIOB_qWA-Jzuu-Y,4337
|
@@ -115,7 +115,7 @@ tico/serialize/operators/op_cat.py,sha256=XDYOh0XAyrM0TlxVm6Sa0OFFGrKk7aSDcGXC-h
|
|
115
115
|
tico/serialize/operators/op_clamp.py,sha256=ZRAsXLGsZqJEh4wXxESEpRJkRtUuJWTDgAem6lr9_5I,4298
|
116
116
|
tico/serialize/operators/op_clone.py,sha256=vzDYJ8TS3tc2BAyd_z8nt5VqT1inpymSseMEhd9dva0,2394
|
117
117
|
tico/serialize/operators/op_constant_pad_nd.py,sha256=OpP4AP-d1IFcWZolNa-o9ZxzXJQkMdG9WQ66soX3s-E,2675
|
118
|
-
tico/serialize/operators/op_conv2d.py,sha256=
|
118
|
+
tico/serialize/operators/op_conv2d.py,sha256=BmSCunhziD9EhXEkWwFrWkaQ_t3cIhrJJQSRLbgqmxI,7338
|
119
119
|
tico/serialize/operators/op_copy.py,sha256=vaianLQ19-2ZQZ-MdQ07YuOPeFeo_HAx2a0Qfn7I5Kk,6122
|
120
120
|
tico/serialize/operators/op_cos.py,sha256=N12bNyuTQIxRnD0eHRPdFVzRQPMy1NFM4iM8oQ4lYzw,2034
|
121
121
|
tico/serialize/operators/op_cumsum.py,sha256=3fmOf1mIeCX1uhTBcSJmRGXejzLtO8UwaI1eEQDC6nA,3798
|
@@ -180,14 +180,14 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
|
|
180
180
|
tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
|
181
181
|
tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
|
182
182
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
183
|
-
tico/utils/convert.py,sha256=
|
183
|
+
tico/utils/convert.py,sha256=crjCVDLoGBSSpI0EAq-cr_oL68CcntOBy361XVcaPzU,12444
|
184
184
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
185
185
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
186
186
|
tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
|
187
187
|
tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
|
188
188
|
tico/utils/logging.py,sha256=IlbBWscsaHidI0dNqro1HEXAbIcbkR3BD5ukLy2m95k,1286
|
189
189
|
tico/utils/model.py,sha256=Uqc92AnJXQ2pbvctS2z2F3Ku3yNrwXZ9O33hZVis7is,1250
|
190
|
-
tico/utils/padding.py,sha256=
|
190
|
+
tico/utils/padding.py,sha256=jNMX2KFoZ3c6HTlMU8BAwG3Fyrqpq4F3ytKP13Pg4ps,1498
|
191
191
|
tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
192
192
|
tico/utils/register_custom_op.py,sha256=qheG1WqtkUaG1SnHrrKQ7-fE4IZRETApCsfMkjDKcfs,23240
|
193
193
|
tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
|
@@ -198,9 +198,9 @@ tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
|
198
198
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
199
199
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
200
200
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
201
|
-
tico-0.1.0.
|
202
|
-
tico-0.1.0.
|
203
|
-
tico-0.1.0.
|
204
|
-
tico-0.1.0.
|
205
|
-
tico-0.1.0.
|
206
|
-
tico-0.1.0.
|
201
|
+
tico-0.1.0.dev250630.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
202
|
+
tico-0.1.0.dev250630.dist-info/METADATA,sha256=SM3Z2qgcIkj7qUL2DKzlp4F47pwbr-3dR1ZTL-gtdMc,8846
|
203
|
+
tico-0.1.0.dev250630.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
204
|
+
tico-0.1.0.dev250630.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
205
|
+
tico-0.1.0.dev250630.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
206
|
+
tico-0.1.0.dev250630.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|