tico 0.1.0.dev250717__py3-none-any.whl → 0.1.0.dev250721__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +1 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +303 -287
- tico/serialize/circle_graph.py +18 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/utils/convert.py +7 -1
- tico/utils/dtype.py +20 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +0 -4
- {tico-0.1.0.dev250717.dist-info → tico-0.1.0.dev250721.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250717.dist-info → tico-0.1.0.dev250721.dist-info}/RECORD +14 -12
- {tico-0.1.0.dev250717.dist-info → tico-0.1.0.dev250721.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250717.dist-info → tico-0.1.0.dev250721.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250717.dist-info → tico-0.1.0.dev250721.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250717.dist-info → tico-0.1.0.dev250721.dist-info}/top_level.txt +0 -0
tico/__init__.py
CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
|
|
21
21
|
from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
|
22
22
|
|
23
23
|
# THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
|
24
|
-
__version__ = "0.1.0.
|
24
|
+
__version__ = "0.1.0.dev250721"
|
25
25
|
|
26
26
|
MINIMUM_SUPPORTED_VERSION = "2.5.0"
|
27
27
|
SECURE_TORCH_VERSION = "2.6.0"
|
@@ -18,6 +18,9 @@ if TYPE_CHECKING:
|
|
18
18
|
import torch.fx
|
19
19
|
import copy
|
20
20
|
|
21
|
+
from collections import defaultdict
|
22
|
+
from typing import Any, Callable
|
23
|
+
|
21
24
|
import torch
|
22
25
|
from torch.export import ExportedProgram
|
23
26
|
|
@@ -92,6 +95,302 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
|
|
92
95
|
return new_qparam
|
93
96
|
|
94
97
|
|
98
|
+
def _insert_quantize_op_before(node, inp):
|
99
|
+
graph = node.graph
|
100
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
101
|
+
assert qparam.scale is not None
|
102
|
+
assert qparam.zero_point is not None
|
103
|
+
scale = qparam.scale[0]
|
104
|
+
zerop = qparam.zero_point[0]
|
105
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
106
|
+
dtype = getattr(torch, qparam.dtype)
|
107
|
+
|
108
|
+
with graph.inserting_before(node):
|
109
|
+
q_args = (inp, scale, zerop, min_, max_, dtype)
|
110
|
+
quantize = create_node(
|
111
|
+
graph,
|
112
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
113
|
+
args=q_args,
|
114
|
+
origin=node,
|
115
|
+
)
|
116
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
117
|
+
set_new_meta_val(quantize)
|
118
|
+
|
119
|
+
node.replace_input_with(inp, quantize)
|
120
|
+
|
121
|
+
return quantize
|
122
|
+
|
123
|
+
|
124
|
+
def _insert_quantize_op_after(node):
|
125
|
+
graph = node.graph
|
126
|
+
qparam: QuantParam = node.meta[QPARAM_KEY]
|
127
|
+
assert qparam.scale is not None
|
128
|
+
assert qparam.zero_point is not None
|
129
|
+
scale = qparam.scale[0]
|
130
|
+
zerop = qparam.zero_point[0]
|
131
|
+
min_, max_ = quant_min_max(qparam.dtype)
|
132
|
+
dtype = getattr(torch, qparam.dtype)
|
133
|
+
with graph.inserting_after(node):
|
134
|
+
q_args = (node, scale, zerop, min_, max_, dtype)
|
135
|
+
quantize = create_node(
|
136
|
+
graph,
|
137
|
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
138
|
+
args=q_args,
|
139
|
+
)
|
140
|
+
|
141
|
+
node.replace_all_uses_with(quantize, propagate_meta=True)
|
142
|
+
quantize.replace_input_with(quantize, node)
|
143
|
+
|
144
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
145
|
+
|
146
|
+
return quantize
|
147
|
+
|
148
|
+
|
149
|
+
def _linear_handler(node, logger):
|
150
|
+
lin_args = LinearArgs(*node.args, **node.kwargs)
|
151
|
+
inp = lin_args.input
|
152
|
+
|
153
|
+
if QPARAM_KEY not in inp.meta:
|
154
|
+
return
|
155
|
+
|
156
|
+
if QPARAM_KEY not in node.meta:
|
157
|
+
return
|
158
|
+
|
159
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
160
|
+
return
|
161
|
+
|
162
|
+
if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
163
|
+
quantize = _insert_quantize_op_after(node)
|
164
|
+
|
165
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
166
|
+
|
167
|
+
# Update node's qparam from i16 to u8
|
168
|
+
# NOTE This would severely degrade accuracy. It is
|
169
|
+
# important to mitigate this accuracy drop in backend.
|
170
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
171
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
172
|
+
else:
|
173
|
+
raise NotYetSupportedError(
|
174
|
+
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
def _add_handler(node, logger):
|
179
|
+
add_args = AddTensorArgs(*node.args, **node.kwargs)
|
180
|
+
x = add_args.input
|
181
|
+
y = add_args.other
|
182
|
+
|
183
|
+
if not isinstance(x, torch.fx.Node):
|
184
|
+
return
|
185
|
+
if not isinstance(y, torch.fx.Node):
|
186
|
+
return
|
187
|
+
|
188
|
+
if QPARAM_KEY not in x.meta:
|
189
|
+
return
|
190
|
+
if QPARAM_KEY not in y.meta:
|
191
|
+
return
|
192
|
+
if QPARAM_KEY not in node.meta:
|
193
|
+
return
|
194
|
+
|
195
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
196
|
+
return
|
197
|
+
|
198
|
+
if qparam_dtype(x) != qparam_dtype(y):
|
199
|
+
return
|
200
|
+
|
201
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
202
|
+
quantize = _insert_quantize_op_after(node)
|
203
|
+
|
204
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
205
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
206
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
207
|
+
else:
|
208
|
+
raise NotYetSupportedError("Unsupported dtype")
|
209
|
+
|
210
|
+
|
211
|
+
def _mul_handler(node, logger):
|
212
|
+
mul_args = MulTensorArgs(*node.args, **node.kwargs)
|
213
|
+
x = mul_args.input
|
214
|
+
y = mul_args.other
|
215
|
+
|
216
|
+
if not isinstance(x, torch.fx.Node):
|
217
|
+
return
|
218
|
+
if not isinstance(y, torch.fx.Node):
|
219
|
+
return
|
220
|
+
|
221
|
+
if QPARAM_KEY not in x.meta:
|
222
|
+
return
|
223
|
+
if QPARAM_KEY not in y.meta:
|
224
|
+
return
|
225
|
+
if QPARAM_KEY not in node.meta:
|
226
|
+
return
|
227
|
+
|
228
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
229
|
+
return
|
230
|
+
|
231
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
232
|
+
quantize = _insert_quantize_op_after(node)
|
233
|
+
|
234
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
235
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
236
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
237
|
+
else:
|
238
|
+
raise NotYetSupportedError("Unsupported dtype")
|
239
|
+
|
240
|
+
|
241
|
+
def _cat_handler(node, logger):
|
242
|
+
cat_args = CatArgs(*node.args, **node.kwargs)
|
243
|
+
tensors = cat_args.tensors
|
244
|
+
|
245
|
+
if any(QPARAM_KEY not in x.meta for x in tensors):
|
246
|
+
return
|
247
|
+
|
248
|
+
if QPARAM_KEY not in node.meta:
|
249
|
+
return
|
250
|
+
|
251
|
+
assert len(tensors) > 0
|
252
|
+
in_dtype = qparam_dtype(tensors[0])
|
253
|
+
if in_dtype == qparam_dtype(node):
|
254
|
+
return
|
255
|
+
|
256
|
+
if any(qparam_dtype(x) != in_dtype for x in tensors):
|
257
|
+
return
|
258
|
+
|
259
|
+
if in_dtype == "int16" and qparam_dtype(node) == "uint8":
|
260
|
+
quantize = _insert_quantize_op_after(node)
|
261
|
+
|
262
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
263
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
264
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
265
|
+
else:
|
266
|
+
raise NotYetSupportedError("Unsupported dtype")
|
267
|
+
|
268
|
+
|
269
|
+
def _bmm_handler(node, logger):
|
270
|
+
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
271
|
+
x = bmm_args.input
|
272
|
+
y = bmm_args.mat2
|
273
|
+
|
274
|
+
if QPARAM_KEY not in x.meta:
|
275
|
+
return
|
276
|
+
if QPARAM_KEY not in y.meta:
|
277
|
+
return
|
278
|
+
if QPARAM_KEY not in node.meta:
|
279
|
+
return
|
280
|
+
|
281
|
+
if qparam_dtype(x) == qparam_dtype(node):
|
282
|
+
return
|
283
|
+
|
284
|
+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
285
|
+
quantize = _insert_quantize_op_after(node)
|
286
|
+
|
287
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
288
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
289
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
290
|
+
elif qparam_dtype(x) == "uint8" and qparam_dtype(node) == "int16":
|
291
|
+
quantize = _insert_quantize_op_after(node)
|
292
|
+
|
293
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
294
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
295
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
296
|
+
else:
|
297
|
+
raise NotYetSupportedError("Unsupported dtype")
|
298
|
+
|
299
|
+
|
300
|
+
def _permute_handler(node, logger):
|
301
|
+
per_args = PermuteArgs(*node.args, **node.kwargs)
|
302
|
+
inp = per_args.input
|
303
|
+
|
304
|
+
if QPARAM_KEY not in inp.meta:
|
305
|
+
return
|
306
|
+
|
307
|
+
if QPARAM_KEY not in node.meta:
|
308
|
+
return
|
309
|
+
|
310
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
311
|
+
return
|
312
|
+
|
313
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
314
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
315
|
+
# permute Op to reduce tensor size ealier
|
316
|
+
quantize = _insert_quantize_op_before(node, inp)
|
317
|
+
|
318
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
319
|
+
logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
|
320
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
321
|
+
quantize = _insert_quantize_op_after(node)
|
322
|
+
|
323
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
324
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
325
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
326
|
+
else:
|
327
|
+
raise NotYetSupportedError("Unsupported dtype")
|
328
|
+
|
329
|
+
|
330
|
+
def _reshape_handler(node, logger):
|
331
|
+
reshape_args = ReshapeArgs(*node.args, **node.kwargs)
|
332
|
+
inp = reshape_args.input
|
333
|
+
|
334
|
+
if QPARAM_KEY not in inp.meta:
|
335
|
+
return
|
336
|
+
|
337
|
+
if QPARAM_KEY not in node.meta:
|
338
|
+
return
|
339
|
+
|
340
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
341
|
+
return
|
342
|
+
|
343
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
344
|
+
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
345
|
+
# reshape Op to reduce tensor size ealier
|
346
|
+
quantize = _insert_quantize_op_before(node, inp)
|
347
|
+
|
348
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
349
|
+
logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
|
350
|
+
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
351
|
+
quantize = _insert_quantize_op_after(node)
|
352
|
+
|
353
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
354
|
+
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
355
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
356
|
+
else:
|
357
|
+
raise NotYetSupportedError("Unsupported dtype")
|
358
|
+
|
359
|
+
|
360
|
+
def _relu_handler(node, logger):
|
361
|
+
relu_args = ReluArgs(*node.args, **node.kwargs)
|
362
|
+
inp = relu_args.input
|
363
|
+
|
364
|
+
if QPARAM_KEY not in inp.meta:
|
365
|
+
return
|
366
|
+
|
367
|
+
if QPARAM_KEY not in node.meta:
|
368
|
+
return
|
369
|
+
|
370
|
+
if qparam_dtype(inp) == qparam_dtype(node):
|
371
|
+
return
|
372
|
+
|
373
|
+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
374
|
+
quantize = _insert_quantize_op_after(node)
|
375
|
+
|
376
|
+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
377
|
+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
378
|
+
logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
|
379
|
+
else:
|
380
|
+
raise NotYetSupportedError("Unsupported dtype")
|
381
|
+
|
382
|
+
|
383
|
+
_op_handler: defaultdict[Any, Any | None] = defaultdict(lambda: None)
|
384
|
+
_op_handler[torch.ops.aten.linear.default] = _linear_handler
|
385
|
+
_op_handler[torch.ops.aten.add.Tensor] = _add_handler
|
386
|
+
_op_handler[torch.ops.aten.mul.Tensor] = _mul_handler
|
387
|
+
_op_handler[torch.ops.aten.cat.default] = _cat_handler
|
388
|
+
_op_handler[torch.ops.aten.bmm.default] = _bmm_handler
|
389
|
+
_op_handler[torch.ops.aten.permute.default] = _permute_handler
|
390
|
+
_op_handler[torch.ops.aten.reshape.default] = _reshape_handler
|
391
|
+
_op_handler[torch.ops.aten.relu.default] = _relu_handler
|
392
|
+
|
393
|
+
|
95
394
|
@trace_graph_diff_on_pass
|
96
395
|
class InsertQuantizeOnDtypeMismatch(PassBase):
|
97
396
|
"""
|
@@ -138,296 +437,13 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
|
|
138
437
|
graph_module = exported_program.graph_module
|
139
438
|
graph: torch.fx.Graph = graph_module.graph
|
140
439
|
|
141
|
-
def _insert_quantize_op_before(node, inp):
|
142
|
-
qparam: QuantParam = node.meta[QPARAM_KEY]
|
143
|
-
assert qparam.scale is not None
|
144
|
-
assert qparam.zero_point is not None
|
145
|
-
scale = qparam.scale[0]
|
146
|
-
zerop = qparam.zero_point[0]
|
147
|
-
min_, max_ = quant_min_max(qparam.dtype)
|
148
|
-
dtype = getattr(torch, qparam.dtype)
|
149
|
-
|
150
|
-
with graph.inserting_before(node):
|
151
|
-
q_args = (inp, scale, zerop, min_, max_, dtype)
|
152
|
-
quantize = create_node(
|
153
|
-
graph,
|
154
|
-
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
155
|
-
args=q_args,
|
156
|
-
origin=node,
|
157
|
-
)
|
158
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
159
|
-
set_new_meta_val(quantize)
|
160
|
-
|
161
|
-
node.replace_input_with(inp, quantize)
|
162
|
-
|
163
|
-
return quantize
|
164
|
-
|
165
|
-
def _insert_quantize_op_after(node):
|
166
|
-
qparam: QuantParam = node.meta[QPARAM_KEY]
|
167
|
-
assert qparam.scale is not None
|
168
|
-
assert qparam.zero_point is not None
|
169
|
-
scale = qparam.scale[0]
|
170
|
-
zerop = qparam.zero_point[0]
|
171
|
-
min_, max_ = quant_min_max(qparam.dtype)
|
172
|
-
dtype = getattr(torch, qparam.dtype)
|
173
|
-
with graph.inserting_after(node):
|
174
|
-
q_args = (node, scale, zerop, min_, max_, dtype)
|
175
|
-
quantize = create_node(
|
176
|
-
graph,
|
177
|
-
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
178
|
-
args=q_args,
|
179
|
-
)
|
180
|
-
|
181
|
-
node.replace_all_uses_with(quantize, propagate_meta=True)
|
182
|
-
quantize.replace_input_with(quantize, node)
|
183
|
-
|
184
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
|
185
|
-
|
186
|
-
return quantize
|
187
|
-
|
188
440
|
for node in graph.nodes:
|
189
441
|
if node.op != "call_function":
|
190
442
|
continue
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
if QPARAM_KEY not in inp.meta:
|
196
|
-
continue
|
197
|
-
|
198
|
-
if QPARAM_KEY not in node.meta:
|
199
|
-
continue
|
200
|
-
|
201
|
-
if qparam_dtype(inp) == qparam_dtype(node):
|
202
|
-
continue
|
203
|
-
|
204
|
-
if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
205
|
-
quantize = _insert_quantize_op_after(node)
|
206
|
-
|
207
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
208
|
-
|
209
|
-
# Update node's qparam from i16 to u8
|
210
|
-
# NOTE This would severely degrade accuracy. It is
|
211
|
-
# important to mitigate this accuracy drop in backend.
|
212
|
-
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
213
|
-
logger.debug(
|
214
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
215
|
-
)
|
216
|
-
else:
|
217
|
-
raise NotYetSupportedError(
|
218
|
-
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
|
219
|
-
)
|
220
|
-
|
221
|
-
elif node.target == torch.ops.aten.add.Tensor:
|
222
|
-
add_args = AddTensorArgs(*node.args, **node.kwargs)
|
223
|
-
x = add_args.input
|
224
|
-
y = add_args.other
|
225
|
-
|
226
|
-
if not isinstance(x, torch.fx.Node):
|
227
|
-
continue
|
228
|
-
if not isinstance(y, torch.fx.Node):
|
229
|
-
continue
|
230
|
-
|
231
|
-
if QPARAM_KEY not in x.meta:
|
232
|
-
continue
|
233
|
-
if QPARAM_KEY not in y.meta:
|
234
|
-
continue
|
235
|
-
if QPARAM_KEY not in node.meta:
|
236
|
-
continue
|
237
|
-
|
238
|
-
if qparam_dtype(x) == qparam_dtype(node):
|
239
|
-
continue
|
240
|
-
|
241
|
-
if qparam_dtype(x) != qparam_dtype(y):
|
242
|
-
continue
|
243
|
-
|
244
|
-
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
245
|
-
quantize = _insert_quantize_op_after(node)
|
246
|
-
|
247
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
248
|
-
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
249
|
-
logger.debug(
|
250
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
251
|
-
)
|
252
|
-
else:
|
253
|
-
raise NotYetSupportedError("Unsupported dtype")
|
254
|
-
|
255
|
-
elif node.target == torch.ops.aten.mul.Tensor:
|
256
|
-
mul_args = MulTensorArgs(*node.args, **node.kwargs)
|
257
|
-
x = mul_args.input
|
258
|
-
y = mul_args.other
|
259
|
-
|
260
|
-
if not isinstance(x, torch.fx.Node):
|
261
|
-
continue
|
262
|
-
if not isinstance(y, torch.fx.Node):
|
263
|
-
continue
|
264
|
-
|
265
|
-
if QPARAM_KEY not in x.meta:
|
266
|
-
continue
|
267
|
-
if QPARAM_KEY not in y.meta:
|
268
|
-
continue
|
269
|
-
if QPARAM_KEY not in node.meta:
|
270
|
-
continue
|
271
|
-
|
272
|
-
if qparam_dtype(x) == qparam_dtype(node):
|
273
|
-
continue
|
274
|
-
|
275
|
-
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
276
|
-
quantize = _insert_quantize_op_after(node)
|
277
|
-
|
278
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
279
|
-
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
280
|
-
logger.debug(
|
281
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
282
|
-
)
|
283
|
-
else:
|
284
|
-
raise NotYetSupportedError("Unsupported dtype")
|
285
|
-
|
286
|
-
elif node.target == torch.ops.aten.cat.default:
|
287
|
-
cat_args = CatArgs(*node.args, **node.kwargs)
|
288
|
-
tensors = cat_args.tensors
|
289
|
-
|
290
|
-
if any(QPARAM_KEY not in x.meta for x in tensors):
|
291
|
-
continue
|
292
|
-
|
293
|
-
if QPARAM_KEY not in node.meta:
|
294
|
-
continue
|
295
|
-
|
296
|
-
assert len(tensors) > 0
|
297
|
-
in_dtype = qparam_dtype(tensors[0])
|
298
|
-
if in_dtype == qparam_dtype(node):
|
299
|
-
continue
|
300
|
-
|
301
|
-
if any(qparam_dtype(x) != in_dtype for x in tensors):
|
302
|
-
continue
|
303
|
-
|
304
|
-
if in_dtype == "int16" and qparam_dtype(node) == "uint8":
|
305
|
-
quantize = _insert_quantize_op_after(node)
|
306
|
-
|
307
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
308
|
-
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
309
|
-
logger.debug(
|
310
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
311
|
-
)
|
312
|
-
else:
|
313
|
-
raise NotYetSupportedError("Unsupported dtype")
|
314
|
-
|
315
|
-
elif node.target == torch.ops.aten.bmm.default:
|
316
|
-
bmm_args = BmmArgs(*node.args, **node.kwargs)
|
317
|
-
x = bmm_args.input
|
318
|
-
y = bmm_args.mat2
|
319
|
-
|
320
|
-
if QPARAM_KEY not in x.meta:
|
321
|
-
continue
|
322
|
-
if QPARAM_KEY not in y.meta:
|
323
|
-
continue
|
324
|
-
if QPARAM_KEY not in node.meta:
|
325
|
-
continue
|
326
|
-
|
327
|
-
if qparam_dtype(x) == qparam_dtype(node):
|
328
|
-
continue
|
329
|
-
|
330
|
-
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
|
331
|
-
quantize = _insert_quantize_op_after(node)
|
332
|
-
|
333
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
334
|
-
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
335
|
-
logger.debug(
|
336
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
337
|
-
)
|
338
|
-
else:
|
339
|
-
raise NotYetSupportedError("Unsupported dtype")
|
340
|
-
|
341
|
-
elif node.target == torch.ops.aten.permute.default:
|
342
|
-
per_args = PermuteArgs(*node.args, **node.kwargs)
|
343
|
-
inp = per_args.input
|
344
|
-
|
345
|
-
if QPARAM_KEY not in inp.meta:
|
346
|
-
continue
|
347
|
-
|
348
|
-
if QPARAM_KEY not in node.meta:
|
349
|
-
continue
|
350
|
-
|
351
|
-
if qparam_dtype(inp) == qparam_dtype(node):
|
352
|
-
continue
|
353
|
-
|
354
|
-
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
355
|
-
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
356
|
-
# permute Op to reduce tensor size ealier
|
357
|
-
quantize = _insert_quantize_op_before(node, inp)
|
358
|
-
|
359
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
360
|
-
logger.debug(
|
361
|
-
f"quantize_per_tensor.default is inserted before {node.name}."
|
362
|
-
)
|
363
|
-
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
364
|
-
quantize = _insert_quantize_op_after(node)
|
365
|
-
|
366
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
367
|
-
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
368
|
-
logger.debug(
|
369
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
370
|
-
)
|
371
|
-
else:
|
372
|
-
raise NotYetSupportedError("Unsupported dtype")
|
373
|
-
elif node.target == torch.ops.aten.reshape.default:
|
374
|
-
reshape_args = ReshapeArgs(*node.args, **node.kwargs)
|
375
|
-
inp = reshape_args.input
|
376
|
-
|
377
|
-
if QPARAM_KEY not in inp.meta:
|
378
|
-
continue
|
379
|
-
|
380
|
-
if QPARAM_KEY not in node.meta:
|
381
|
-
continue
|
382
|
-
|
383
|
-
if qparam_dtype(inp) == qparam_dtype(node):
|
384
|
-
continue
|
385
|
-
|
386
|
-
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
387
|
-
# A new Quantize Op (s16 to u8) is inserted before (not after)
|
388
|
-
# reshape Op to reduce tensor size ealier
|
389
|
-
quantize = _insert_quantize_op_before(node, inp)
|
390
|
-
|
391
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
392
|
-
logger.debug(
|
393
|
-
f"quantize_per_tensor.default is inserted before {node.name}."
|
394
|
-
)
|
395
|
-
elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
|
396
|
-
quantize = _insert_quantize_op_after(node)
|
397
|
-
|
398
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
399
|
-
node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
|
400
|
-
logger.debug(
|
401
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
402
|
-
)
|
403
|
-
else:
|
404
|
-
raise NotYetSupportedError("Unsupported dtype")
|
405
|
-
|
406
|
-
elif node.target == torch.ops.aten.relu.default:
|
407
|
-
relu_args = ReluArgs(*node.args, **node.kwargs)
|
408
|
-
inp = relu_args.input
|
409
|
-
|
410
|
-
if QPARAM_KEY not in inp.meta:
|
411
|
-
continue
|
412
|
-
|
413
|
-
if QPARAM_KEY not in node.meta:
|
414
|
-
continue
|
415
|
-
|
416
|
-
if qparam_dtype(inp) == qparam_dtype(node):
|
417
|
-
continue
|
418
|
-
|
419
|
-
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
|
420
|
-
quantize = _insert_quantize_op_after(node)
|
421
|
-
|
422
|
-
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
|
423
|
-
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
|
424
|
-
logger.debug(
|
425
|
-
f"quantize_per_tensor.default is inserted after {node.name}."
|
426
|
-
)
|
427
|
-
else:
|
428
|
-
raise NotYetSupportedError("Unsupported dtype")
|
429
|
-
|
430
|
-
# TODO Support more ops.
|
443
|
+
|
444
|
+
handler = _op_handler[node.target]
|
445
|
+
if handler is not None:
|
446
|
+
handler(node, logger)
|
431
447
|
|
432
448
|
graph.eliminate_dead_code()
|
433
449
|
graph.lint()
|
tico/serialize/circle_graph.py
CHANGED
@@ -152,6 +152,15 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
|
|
152
152
|
assert node.meta.get("val") is not None
|
153
153
|
tensor.type = extract_circle_dtype(node)
|
154
154
|
tensor.shape = list(extract_shape(node))
|
155
|
+
|
156
|
+
# Handle dynamic shape
|
157
|
+
if any(isinstance(s, torch.SymInt) for s in tensor.shape):
|
158
|
+
tensor.shapeSignature = tensor.shape.copy()
|
159
|
+
for idx, s in enumerate(tensor.shape):
|
160
|
+
if isinstance(s, torch.SymInt):
|
161
|
+
tensor.shape[idx] = 1
|
162
|
+
tensor.shapeSignature[idx] = -1
|
163
|
+
|
155
164
|
if QPARAM_KEY in node.meta:
|
156
165
|
tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
|
157
166
|
tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
|
@@ -241,6 +250,15 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
|
|
241
250
|
if source_node is not None:
|
242
251
|
self.name_to_node[tensor.name] = source_node
|
243
252
|
tensor.shape = shape
|
253
|
+
|
254
|
+
# Handle dynamic shape
|
255
|
+
if any(isinstance(s, torch.SymInt) for s in tensor.shape):
|
256
|
+
tensor.shapeSignature = tensor.shape.copy()
|
257
|
+
for idx, s in enumerate(tensor.shape):
|
258
|
+
if isinstance(s, torch.SymInt):
|
259
|
+
tensor.shape[idx] = 1
|
260
|
+
tensor.shapeSignature[idx] = -1
|
261
|
+
|
244
262
|
if qparam is not None:
|
245
263
|
tensor.quantization = to_circle_qparam(qparam)
|
246
264
|
tensor.type = str_to_circle_dtype(qparam.dtype)
|
@@ -24,25 +24,18 @@ from tico.serialize.operators.hashable_opcode import OpCode
|
|
24
24
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
25
25
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
26
26
|
from tico.utils.errors import NotYetSupportedError
|
27
|
-
from tico.utils.utils import HAS_TORCH_OVER_25
|
28
27
|
from tico.utils.validate_args_kwargs import SafeSoftmaxArgs, SoftmaxArgs
|
29
28
|
|
30
29
|
|
31
30
|
@register_node_visitor
|
32
31
|
class SoftMaxVisitor(NodeVisitor):
|
33
|
-
target: List[torch._ops.OpOverload] =
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
]
|
41
|
-
if HAS_TORCH_OVER_25
|
42
|
-
else [
|
43
|
-
torch.ops.aten._softmax.default,
|
44
|
-
]
|
45
|
-
)
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
33
|
+
torch.ops.aten._softmax.default,
|
34
|
+
# NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
|
35
|
+
# In order for optimization during inference, it can be replaced to softmax.
|
36
|
+
# ref: https://github.com/pytorch/pytorch/pull/133882
|
37
|
+
torch.ops.aten._safe_softmax.default,
|
38
|
+
]
|
46
39
|
|
47
40
|
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
48
41
|
super().__init__(op_codes, graph)
|
tico/utils/convert.py
CHANGED
@@ -106,6 +106,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
106
106
|
torch.ops.aten._safe_softmax.default,
|
107
107
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
108
108
|
torch.ops.aten.linear.default,
|
109
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
109
110
|
)
|
110
111
|
ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
|
111
112
|
|
@@ -124,6 +125,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
124
125
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
125
126
|
torch.ops.aten.prelu.default,
|
126
127
|
torch.ops.aten.linear.default,
|
128
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
127
129
|
)
|
128
130
|
for op in _preserve_ops:
|
129
131
|
if op in _decomp_table:
|
@@ -138,6 +140,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
138
140
|
torch.__version__.startswith("2.6")
|
139
141
|
or torch.__version__.startswith("2.7")
|
140
142
|
or torch.__version__.startswith("2.8")
|
143
|
+
or torch.__version__.startswith("2.9")
|
141
144
|
):
|
142
145
|
return run_decompositions(exported_program)
|
143
146
|
else:
|
@@ -293,6 +296,7 @@ def convert(
|
|
293
296
|
mod: torch.nn.Module,
|
294
297
|
args: Tuple[Any, ...],
|
295
298
|
kwargs: Optional[Dict[str, Any]] = None,
|
299
|
+
dynamic_shapes: Optional[dict] = None,
|
296
300
|
strict: bool = True,
|
297
301
|
config: CompileConfigBase = get_default_config(),
|
298
302
|
) -> CircleModel:
|
@@ -303,7 +307,9 @@ def convert(
|
|
303
307
|
)
|
304
308
|
|
305
309
|
with torch.no_grad():
|
306
|
-
exported_program = export(
|
310
|
+
exported_program = export(
|
311
|
+
mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
|
312
|
+
)
|
307
313
|
|
308
314
|
circle_binary = convert_exported_module_to_circle(exported_program, config=config)
|
309
315
|
|
tico/utils/dtype.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
|
4
|
+
NUMPY_TO_TORCH_DTYPE_DICT = {
|
5
|
+
np.dtype("float32"): torch.float32,
|
6
|
+
np.dtype("float64"): torch.float64,
|
7
|
+
np.dtype("float16"): torch.float16,
|
8
|
+
np.dtype("complex64"): torch.complex64,
|
9
|
+
np.dtype("complex128"): torch.complex128,
|
10
|
+
np.dtype("int64"): torch.int64,
|
11
|
+
np.dtype("int32"): torch.int32,
|
12
|
+
np.dtype("int16"): torch.int16,
|
13
|
+
np.dtype("int8"): torch.int8,
|
14
|
+
np.dtype("uint8"): torch.uint8,
|
15
|
+
np.dtype("bool"): torch.bool,
|
16
|
+
}
|
17
|
+
|
18
|
+
|
19
|
+
def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
|
20
|
+
return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Runtime **capability-detection helpers** for the `torch.export` stack.
|
17
|
+
|
18
|
+
Instead of sprinkling version checks like `torch.__version__ >= "2.9"` throughout
|
19
|
+
the codebase, import these helpers once and branch on the feature you need.
|
20
|
+
|
21
|
+
Each probe executes only **once per process** thanks to `functools.lru_cache`,
|
22
|
+
so the overhead is negligible.
|
23
|
+
"""
|
24
|
+
|
25
|
+
import functools
|
26
|
+
|
27
|
+
import torch
|
28
|
+
|
29
|
+
|
30
|
+
@functools.lru_cache(maxsize=None)
|
31
|
+
def export_produces_slice() -> bool:
|
32
|
+
"""
|
33
|
+
Compile a minimal model with `torch.export.export` and inspect its FX graph
|
34
|
+
to see whether an `aten.slice.Tensor` node appears.
|
35
|
+
|
36
|
+
Returns
|
37
|
+
-------
|
38
|
+
bool
|
39
|
+
* ``True`` — downstream passes should expect redundant **slice** nodes.
|
40
|
+
* ``False`` — downstream passes should expect only a **select** node.
|
41
|
+
"""
|
42
|
+
|
43
|
+
class _Probe(torch.nn.Module):
|
44
|
+
def forward(self, x): # simple slice: keep all dims except 3rd
|
45
|
+
return x[:, :, 1]
|
46
|
+
|
47
|
+
def get_example_inputs(self):
|
48
|
+
return (torch.randn(1, 4, 4),)
|
49
|
+
|
50
|
+
m = _Probe()
|
51
|
+
ep = torch.export.export(m, m.get_example_inputs())
|
52
|
+
return any(n.target == torch.ops.aten.slice.Tensor for n in ep.graph.nodes)
|
tico/utils/utils.py
CHANGED
@@ -29,10 +29,6 @@ from torch.utils import _pytree as pytree
|
|
29
29
|
from tico.serialize.quant_param import QuantParam
|
30
30
|
|
31
31
|
|
32
|
-
HAS_TORCH_OVER_25 = Version(torch.__version__) >= Version("2.5.0")
|
33
|
-
HAS_TORCH_OVER_28_DEV = Version(torch.__version__) >= Version("2.8.0.dev")
|
34
|
-
|
35
|
-
|
36
32
|
def get_fake_mode(exported_program: ExportedProgram):
|
37
33
|
fake_mode = detect_fake_mode(
|
38
34
|
tuple(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
tico/__init__.py,sha256=
|
1
|
+
tico/__init__.py,sha256=QY54qph93oHLtUUYMg8T_e_Bpn4Zmxfif9POfAGssHI,1743
|
2
2
|
tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
|
3
3
|
tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
|
4
4
|
tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
|
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
|
|
51
51
|
tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
|
52
52
|
tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
53
53
|
tico/experimental/quantization/passes/fold_quant_ops.py,sha256=iaBMyO49CwVkhebMz3rjkHWfWE2LhwH6fORe7n4S6XQ,7040
|
54
|
-
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=
|
54
|
+
tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=t3bnNY9Abm8CZfSWzsbvx59luylXVxqmUvCKPBVPAIE,14731
|
55
55
|
tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
|
56
56
|
tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
|
57
57
|
tico/experimental/quantization/passes/quantize_bias.py,sha256=ZQ3rETYStpW28JUbODRixbq5sDEOiIOB_qWA-Jzuu-Y,4337
|
@@ -96,7 +96,7 @@ tico/passes/remove_redundant_to_copy.py,sha256=tKy4XKkO2l33fMxVPQ_iFkUeFvP15kbPv
|
|
96
96
|
tico/passes/restore_linear.py,sha256=xGJdNb-1CrkOKS9BnLbcblkZc6P2vVjKIi-7lRcs7Bk,4111
|
97
97
|
tico/passes/segment_index_select.py,sha256=jn0M2sdUcDyjrvxfvM40wt5644iPQMY_ud0uvptXN84,5187
|
98
98
|
tico/serialize/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
99
|
-
tico/serialize/circle_graph.py,sha256=
|
99
|
+
tico/serialize/circle_graph.py,sha256=3t78g5eKzhHKvIBJqQ-CcwbqoV-2QwAdd_8wm4W1yXw,12317
|
100
100
|
tico/serialize/circle_mapping.py,sha256=C9C3ORACQOdvBdnt5KRzlT8zao_TvzQklIxH794OhP0,5719
|
101
101
|
tico/serialize/circle_serializer.py,sha256=KRx_Azx2Je9XNYe-pZuuiSMvbXEddd8M8qDATIt7XXk,8981
|
102
102
|
tico/serialize/pack.py,sha256=5HZ9kX3x6C6CyT_FWS6FRmvx_P7Dx21orjUNQxJ2xlo,1297
|
@@ -169,7 +169,7 @@ tico/serialize/operators/op_select_copy.py,sha256=GPLN7QZmwSlA4WRbjfU6pLer3KVWzg
|
|
169
169
|
tico/serialize/operators/op_sigmoid.py,sha256=ZubbGG1yU5uvNkEmOmbjj3eq7d9mwEaJdChRgL0OjDU,2045
|
170
170
|
tico/serialize/operators/op_sin.py,sha256=MbttmHTVKhwKK6gH9Vbcbn5aAaxnQ71NdpmQAlTcojU,1827
|
171
171
|
tico/serialize/operators/op_slice.py,sha256=g0r8lj5CIxpT6ixOKqUzwKiNhoiuIFwWjbpaiCoOg6w,5259
|
172
|
-
tico/serialize/operators/op_softmax.py,sha256=
|
172
|
+
tico/serialize/operators/op_softmax.py,sha256=qwYke5zfhnSL89DZbzdr5Fc9SsJf0vI-LDZjT_NFpbc,3669
|
173
173
|
tico/serialize/operators/op_split_with_sizes.py,sha256=TgYg1cu-3BSz9SsXfAhoJbo4q5ZzFaoFArkH_obsYlU,4274
|
174
174
|
tico/serialize/operators/op_sqrt.py,sha256=9Q5jkuEPrim11XfSQHGDGVTMYk1TQhOfVqMVYD_eIrI,1871
|
175
175
|
tico/serialize/operators/op_squeeze.py,sha256=QnNwfAdTC1xBm04C9DkVs8VB5YRN-4fCsIWn189QaPg,2416
|
@@ -183,9 +183,10 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
|
|
183
183
|
tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
|
184
184
|
tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
|
185
185
|
tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
186
|
-
tico/utils/convert.py,sha256=
|
186
|
+
tico/utils/convert.py,sha256=lNbbNswbKyCTK5E5i5CXkJBykWWfF5HDChhM3DLscWo,13222
|
187
187
|
tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
|
188
188
|
tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
|
189
|
+
tico/utils/dtype.py,sha256=4-k1iUaHivFFXAQuDs7up6fXt5y4FqldGNokAPa3kic,603
|
189
190
|
tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
|
190
191
|
tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
|
191
192
|
tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
|
@@ -196,16 +197,17 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
|
|
196
197
|
tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
|
197
198
|
tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
|
198
199
|
tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
|
200
|
+
tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
|
199
201
|
tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
|
200
|
-
tico/utils/utils.py,sha256=
|
202
|
+
tico/utils/utils.py,sha256=kg4in3E0eH2nURpnt8XRPUS2Iam3HjJRadDZdyhUy0w,13014
|
201
203
|
tico/utils/validate_args_kwargs.py,sha256=3dXkNll9E9eZq-p0HjYaV4YltQESqdEHBU34k-tIg1k,26733
|
202
204
|
tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
|
203
205
|
tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
|
204
206
|
tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
|
205
207
|
tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
|
206
|
-
tico-0.1.0.
|
207
|
-
tico-0.1.0.
|
208
|
-
tico-0.1.0.
|
209
|
-
tico-0.1.0.
|
210
|
-
tico-0.1.0.
|
211
|
-
tico-0.1.0.
|
208
|
+
tico-0.1.0.dev250721.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
|
209
|
+
tico-0.1.0.dev250721.dist-info/METADATA,sha256=OwOUs1qxfuulEmUn99ocFFSGB9L0XmUmCqNhZugmsCU,8430
|
210
|
+
tico-0.1.0.dev250721.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
211
|
+
tico-0.1.0.dev250721.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
|
212
|
+
tico-0.1.0.dev250721.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
|
213
|
+
tico-0.1.0.dev250721.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|