onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +5 -2
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/log_helper.py +65 -12
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +55 -37
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +15 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
- onnx_diagnostic/torch_models/validate.py +337 -113
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
|
@@ -39,6 +39,9 @@ def evaluation(
|
|
|
39
39
|
"export-strict",
|
|
40
40
|
"export-nostrict",
|
|
41
41
|
"export-nostrict-decall",
|
|
42
|
+
"export-strict-oblivious",
|
|
43
|
+
"export-nostrict-oblivious",
|
|
44
|
+
"export-nostrict-decall-oblivious",
|
|
42
45
|
),
|
|
43
46
|
dynamic: Tuple[bool] = (False, True),
|
|
44
47
|
cases: Optional[Union[str, Dict[str, type]]] = None,
|
|
@@ -105,9 +108,7 @@ def evaluation(
|
|
|
105
108
|
|
|
106
109
|
|
|
107
110
|
def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
|
|
108
|
-
"""
|
|
109
|
-
Flatten inputs.
|
|
110
|
-
"""
|
|
111
|
+
"""Flatten inputs."""
|
|
111
112
|
if x is None:
|
|
112
113
|
return x
|
|
113
114
|
import torch
|
|
@@ -173,6 +174,15 @@ def _clone(x):
|
|
|
173
174
|
raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
|
|
174
175
|
|
|
175
176
|
|
|
177
|
+
def _wrap_torch_export(*args, backed_size_oblivious=False, **kwargs):
|
|
178
|
+
import torch
|
|
179
|
+
|
|
180
|
+
if backed_size_oblivious:
|
|
181
|
+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
|
182
|
+
return torch.export.export(*args, **kwargs)
|
|
183
|
+
return torch.export.export(*args, **kwargs)
|
|
184
|
+
|
|
185
|
+
|
|
176
186
|
def _make_exporter_export(
|
|
177
187
|
exporter: str,
|
|
178
188
|
model: "torch.nn.Module", # noqa: F821
|
|
@@ -183,76 +193,36 @@ def _make_exporter_export(
|
|
|
183
193
|
) -> Union[Dict, Callable]:
|
|
184
194
|
import torch
|
|
185
195
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
exported = torch.export.export(
|
|
197
|
-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
198
|
-
)
|
|
199
|
-
except Exception as e:
|
|
200
|
-
if not quiet:
|
|
201
|
-
raise
|
|
202
|
-
return dict(error=str(e), success=0, error_step="export")
|
|
203
|
-
if verbose >= 9:
|
|
204
|
-
print("-- graph")
|
|
205
|
-
print(exported.graph)
|
|
206
|
-
return exported.module()
|
|
207
|
-
if exporter in ("export-strict-dec", "export-strict-decall"):
|
|
208
|
-
try:
|
|
209
|
-
if verbose >= 2:
|
|
210
|
-
exported = torch.export.export(
|
|
211
|
-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
212
|
-
)
|
|
213
|
-
if verbose >= 9:
|
|
214
|
-
print("-- graph before decomposition")
|
|
215
|
-
print(exported.graph)
|
|
216
|
-
exported = (
|
|
217
|
-
exported.run_decompositions()
|
|
218
|
-
if "decall" in exporter
|
|
219
|
-
else exported.run_decompositions({})
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
223
|
-
io.StringIO()
|
|
224
|
-
):
|
|
225
|
-
exported = torch.export.export(
|
|
226
|
-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
227
|
-
)
|
|
228
|
-
if verbose >= 9:
|
|
229
|
-
print("-- graph before decomposition")
|
|
230
|
-
print(exported.graph)
|
|
231
|
-
exported = (
|
|
232
|
-
exported.run_decompositions()
|
|
233
|
-
if "decall" in exporter
|
|
234
|
-
else exported.run_decompositions({})
|
|
235
|
-
)
|
|
236
|
-
except Exception as e:
|
|
237
|
-
if not quiet:
|
|
238
|
-
raise
|
|
239
|
-
return dict(error=str(e), success=0, error_step="export")
|
|
240
|
-
if verbose >= 9:
|
|
241
|
-
print("-- graph after decomposition")
|
|
242
|
-
print(exported.graph)
|
|
243
|
-
return exported.module()
|
|
244
|
-
if exporter == "export-nostrict":
|
|
196
|
+
backed_size_oblivious = "-oblivious" in exporter
|
|
197
|
+
strict = "-nostrict" not in exporter
|
|
198
|
+
|
|
199
|
+
if exporter in (
|
|
200
|
+
"export-strict",
|
|
201
|
+
"export-strict-oblivious",
|
|
202
|
+
"export-nostrict",
|
|
203
|
+
"export-nostrict-oblivious",
|
|
204
|
+
"export-oblivious",
|
|
205
|
+
):
|
|
245
206
|
try:
|
|
246
207
|
if verbose >= 2:
|
|
247
|
-
exported =
|
|
248
|
-
model,
|
|
208
|
+
exported = _wrap_torch_export(
|
|
209
|
+
model,
|
|
210
|
+
inputs,
|
|
211
|
+
dynamic_shapes=dynamic_shapes,
|
|
212
|
+
strict=strict,
|
|
213
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
249
214
|
)
|
|
250
215
|
else:
|
|
251
|
-
with
|
|
252
|
-
io.StringIO()
|
|
216
|
+
with (
|
|
217
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
218
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
253
219
|
):
|
|
254
|
-
exported =
|
|
255
|
-
model,
|
|
220
|
+
exported = _wrap_torch_export(
|
|
221
|
+
model,
|
|
222
|
+
inputs,
|
|
223
|
+
dynamic_shapes=dynamic_shapes,
|
|
224
|
+
strict=strict,
|
|
225
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
256
226
|
)
|
|
257
227
|
except Exception as e:
|
|
258
228
|
if not quiet:
|
|
@@ -262,11 +232,25 @@ def _make_exporter_export(
|
|
|
262
232
|
print("-- graph")
|
|
263
233
|
print(exported.graph)
|
|
264
234
|
return exported.module()
|
|
265
|
-
|
|
235
|
+
|
|
236
|
+
if exporter in (
|
|
237
|
+
"export-strict-dec",
|
|
238
|
+
"export-strict-decall",
|
|
239
|
+
"export-strict-dec-oblivious",
|
|
240
|
+
"export-strict-decall-oblivious",
|
|
241
|
+
"export-nostrict-dec",
|
|
242
|
+
"export-nostrict-decall",
|
|
243
|
+
"export-nostrict-dec-oblivious",
|
|
244
|
+
"export-nostrict-decall-oblivious",
|
|
245
|
+
):
|
|
266
246
|
try:
|
|
267
247
|
if verbose >= 2:
|
|
268
|
-
exported =
|
|
269
|
-
model,
|
|
248
|
+
exported = _wrap_torch_export(
|
|
249
|
+
model,
|
|
250
|
+
inputs,
|
|
251
|
+
dynamic_shapes=dynamic_shapes,
|
|
252
|
+
strict=strict,
|
|
253
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
270
254
|
)
|
|
271
255
|
if verbose >= 9:
|
|
272
256
|
print("-- graph before decomposition")
|
|
@@ -277,11 +261,16 @@ def _make_exporter_export(
|
|
|
277
261
|
else exported.run_decompositions({})
|
|
278
262
|
)
|
|
279
263
|
else:
|
|
280
|
-
with
|
|
281
|
-
io.StringIO()
|
|
264
|
+
with (
|
|
265
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
266
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
282
267
|
):
|
|
283
|
-
exported =
|
|
284
|
-
model,
|
|
268
|
+
exported = _wrap_torch_export(
|
|
269
|
+
model,
|
|
270
|
+
inputs,
|
|
271
|
+
dynamic_shapes=dynamic_shapes,
|
|
272
|
+
strict=strict,
|
|
273
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
285
274
|
)
|
|
286
275
|
if verbose >= 9:
|
|
287
276
|
print("-- graph before decomposition")
|
|
@@ -299,6 +288,7 @@ def _make_exporter_export(
|
|
|
299
288
|
print("-- graph after decomposition")
|
|
300
289
|
print(exported.graph)
|
|
301
290
|
return exported.module()
|
|
291
|
+
|
|
302
292
|
if exporter == "export-tracing":
|
|
303
293
|
from experimental_experiment.torch_interpreter.tracing import CustomTracer
|
|
304
294
|
|
|
@@ -307,8 +297,9 @@ def _make_exporter_export(
|
|
|
307
297
|
graph = CustomTracer().trace(model)
|
|
308
298
|
mod = torch.fx.GraphModule(model, graph)
|
|
309
299
|
else:
|
|
310
|
-
with
|
|
311
|
-
io.StringIO()
|
|
300
|
+
with (
|
|
301
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
302
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
312
303
|
):
|
|
313
304
|
graph = CustomTracer().trace(model)
|
|
314
305
|
mod = torch.fx.GraphModule(model, graph)
|
|
@@ -353,8 +344,9 @@ def _make_exporter_onnx(
|
|
|
353
344
|
return_builder=True,
|
|
354
345
|
)
|
|
355
346
|
else:
|
|
356
|
-
with
|
|
357
|
-
io.StringIO()
|
|
347
|
+
with (
|
|
348
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
349
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
358
350
|
):
|
|
359
351
|
onx, builder = to_onnx(
|
|
360
352
|
model,
|
|
@@ -387,8 +379,9 @@ def _make_exporter_onnx(
|
|
|
387
379
|
report=True,
|
|
388
380
|
).model_proto
|
|
389
381
|
else:
|
|
390
|
-
with
|
|
391
|
-
io.StringIO()
|
|
382
|
+
with (
|
|
383
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
384
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
392
385
|
):
|
|
393
386
|
onx = torch.onnx.export(
|
|
394
387
|
model,
|
|
@@ -422,8 +415,9 @@ def _make_exporter_onnx(
|
|
|
422
415
|
ep.optimize()
|
|
423
416
|
onx = ep.model_proto
|
|
424
417
|
else:
|
|
425
|
-
with
|
|
426
|
-
io.StringIO()
|
|
418
|
+
with (
|
|
419
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
420
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
427
421
|
):
|
|
428
422
|
ep = torch.onnx.export(
|
|
429
423
|
model,
|
|
@@ -446,6 +440,74 @@ def _make_exporter_onnx(
|
|
|
446
440
|
raise AssertionError(f"Unexpected exporter={exporter!r}")
|
|
447
441
|
|
|
448
442
|
|
|
443
|
+
def _compares_on_one_example(
|
|
444
|
+
model: Callable, inputs: Tuple[Any, ...], mod: Callable, verbose: int, quiet: bool
|
|
445
|
+
) -> Tuple[Any, Any, Dict]:
|
|
446
|
+
from onnx_diagnostic.helpers import max_diff, string_type
|
|
447
|
+
|
|
448
|
+
try:
|
|
449
|
+
expected = model(*_clone(inputs))
|
|
450
|
+
except Exception as e:
|
|
451
|
+
if not quiet:
|
|
452
|
+
raise RuntimeError(
|
|
453
|
+
f"eager mode failed=\n{string_type(inputs, with_shape=True)} "
|
|
454
|
+
f"\nmodel=\n{type(model)}"
|
|
455
|
+
) from e
|
|
456
|
+
res = dict(error=str(e), success=0, error_step="eager")
|
|
457
|
+
return None, None, res
|
|
458
|
+
try:
|
|
459
|
+
got = mod(*inputs)
|
|
460
|
+
except Exception as e:
|
|
461
|
+
if not quiet:
|
|
462
|
+
raise RuntimeError(
|
|
463
|
+
f"onnxruntime failed, feeds=\n{string_type(inputs, with_shape=True)}"
|
|
464
|
+
) from e
|
|
465
|
+
res = dict(error=str(e), success=0, error_step="run.0")
|
|
466
|
+
return expected, None, res
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
disc = max_diff(expected, got)
|
|
470
|
+
except Exception as e:
|
|
471
|
+
if not quiet:
|
|
472
|
+
raise
|
|
473
|
+
res = dict(error=str(e), success=0, error_step="discrepancy")
|
|
474
|
+
return expected, got, res
|
|
475
|
+
|
|
476
|
+
if verbose >= 5 and np.isinf(disc["abs"]):
|
|
477
|
+
print("[run_exporter] comparison issues with")
|
|
478
|
+
print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
|
|
479
|
+
print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
|
|
480
|
+
print(f"-- got={string_type(got, with_shape=True, limit=20)}")
|
|
481
|
+
elif verbose >= 9:
|
|
482
|
+
print("[run_exporter] inputs and outputs")
|
|
483
|
+
print(
|
|
484
|
+
f"-- inputs="
|
|
485
|
+
f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
|
|
486
|
+
)
|
|
487
|
+
print(
|
|
488
|
+
f"-- expected="
|
|
489
|
+
f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
|
|
490
|
+
)
|
|
491
|
+
print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
|
|
492
|
+
del disc["n"]
|
|
493
|
+
del disc["sum"]
|
|
494
|
+
disc.update(
|
|
495
|
+
dict(
|
|
496
|
+
success=1 if disc["abs"] < 0.1 else 0,
|
|
497
|
+
model_cls=model.__class__, # type: ignore[dict-item]
|
|
498
|
+
exported=mod, # type: ignore[dict-item]
|
|
499
|
+
)
|
|
500
|
+
)
|
|
501
|
+
if disc["abs"] >= 0.1:
|
|
502
|
+
disc["error"] = "diff.0"
|
|
503
|
+
disc["error_step"] = "diff.0"
|
|
504
|
+
if verbose >= 9:
|
|
505
|
+
max_diff(expected, got, verbose=verbose)
|
|
506
|
+
else:
|
|
507
|
+
disc["success"] = 1
|
|
508
|
+
return expected, got, disc
|
|
509
|
+
|
|
510
|
+
|
|
449
511
|
def run_exporter(
|
|
450
512
|
exporter: str,
|
|
451
513
|
cls_model: type,
|
|
@@ -473,6 +535,7 @@ def run_exporter(
|
|
|
473
535
|
|
|
474
536
|
model = cls_model()
|
|
475
537
|
inputs = cls_model._inputs
|
|
538
|
+
valid = getattr(cls_model, "_valid", None)
|
|
476
539
|
if isinstance(inputs, tuple):
|
|
477
540
|
inputs = [inputs]
|
|
478
541
|
if dynamic:
|
|
@@ -566,74 +629,38 @@ def run_exporter(
|
|
|
566
629
|
mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
|
|
567
630
|
|
|
568
631
|
# we need to clone for models modifying the inputs
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
) from e
|
|
577
|
-
res = dict(error=str(e), success=0, error_step="eager")
|
|
578
|
-
res.update(base)
|
|
579
|
-
return res
|
|
580
|
-
try:
|
|
581
|
-
got = mod(*inputs[0])
|
|
582
|
-
except Exception as e:
|
|
583
|
-
if not quiet:
|
|
584
|
-
raise RuntimeError(
|
|
585
|
-
f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
|
|
586
|
-
f"\nmodel=\n{pretty_onnx(onx)}"
|
|
587
|
-
) from e
|
|
588
|
-
res = dict(error=str(e), success=0, error_step="run.0")
|
|
589
|
-
res.update(base)
|
|
590
|
-
return res
|
|
591
|
-
|
|
592
|
-
base["expected"] = expected
|
|
593
|
-
base["obtained"] = got
|
|
594
|
-
|
|
595
|
-
try:
|
|
596
|
-
disc = max_diff(expected, got)
|
|
597
|
-
except Exception as e:
|
|
598
|
-
if not quiet:
|
|
599
|
-
raise
|
|
600
|
-
res = dict(error=str(e), success=0, error_step="discrepancy")
|
|
601
|
-
res.update(base)
|
|
602
|
-
return res
|
|
632
|
+
expected, got, disc = _compares_on_one_example(model, inputs[0], mod, verbose, quiet)
|
|
633
|
+
if expected is not None:
|
|
634
|
+
base["expected"] = expected
|
|
635
|
+
if got is not None:
|
|
636
|
+
base["obtained"] = got
|
|
637
|
+
disc.update(base)
|
|
638
|
+
disc["onnx"] = onx # type: ignore[dict-item]
|
|
603
639
|
|
|
604
|
-
if
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
)
|
|
629
|
-
)
|
|
630
|
-
if disc["abs"] >= 0.1:
|
|
631
|
-
disc["error"] = "diff.0"
|
|
632
|
-
disc["error_step"] = "diff.0"
|
|
633
|
-
if verbose >= 9:
|
|
634
|
-
max_diff(expected, got, verbose=verbose)
|
|
635
|
-
else:
|
|
636
|
-
disc["success"] = 1
|
|
640
|
+
if valid is not None:
|
|
641
|
+
for valid_inputs in valid:
|
|
642
|
+
expected, got, _disc = _compares_on_one_example(
|
|
643
|
+
model, valid_inputs, mod, verbose, quiet
|
|
644
|
+
)
|
|
645
|
+
if "abs" not in disc and (np.isnan(disc["abs"]) or disc["abs"] > 1e-3):
|
|
646
|
+
_disc["issue-abs"] = disc["abs"]
|
|
647
|
+
_disc["issue-rel"] = disc["rel"]
|
|
648
|
+
_disc["issue-inputs"] = string_type(
|
|
649
|
+
valid_inputs, with_shape=True, with_min_max=True
|
|
650
|
+
)
|
|
651
|
+
_disc["issue-expected"] = string_type(
|
|
652
|
+
expected, with_shape=True, with_min_max=True
|
|
653
|
+
)
|
|
654
|
+
_disc["issue-obtained"] = string_type(got, with_shape=True, with_min_max=True)
|
|
655
|
+
if not quiet:
|
|
656
|
+
raise RuntimeError(
|
|
657
|
+
f"validation failed,"
|
|
658
|
+
f"\n-- inputs=\n{string_type(_disc['issue-inputs'])} "
|
|
659
|
+
f"\n-- exporter={exporter!r}\n-- dynamic_shapes={dynamic_shapes}, "
|
|
660
|
+
f"\n-- expected={_disc['issue-expected']}"
|
|
661
|
+
f"\n-- obtained={_disc['issue-obtained']}"
|
|
662
|
+
)
|
|
663
|
+
break
|
|
637
664
|
|
|
638
665
|
if dynamic and onx is not None:
|
|
639
666
|
ds = []
|
|
@@ -861,7 +861,7 @@ class CreateFromShapeThroughFunction(torch.nn.Module):
|
|
|
861
861
|
y = torch.ones((x.shape[0], dy1))
|
|
862
862
|
return y
|
|
863
863
|
|
|
864
|
-
_inputs = [(torch.rand((4, 4)),)
|
|
864
|
+
_inputs = [(torch.rand((4, 4)),)]
|
|
865
865
|
_dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
|
|
866
866
|
|
|
867
867
|
|
|
@@ -881,3 +881,21 @@ class VmapPython(torch.nn.Module):
|
|
|
881
881
|
|
|
882
882
|
_inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
|
|
883
883
|
_dynamic = {"x": {0: DYN}, "y": {0: DYN}}
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class ExportWithDimension0(torch.nn.Module):
|
|
887
|
+
def forward(self, x):
|
|
888
|
+
return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
|
|
889
|
+
|
|
890
|
+
_inputs = [(torch.empty((0, 3), dtype=torch.float32),)]
|
|
891
|
+
_dynamic = {"x": {0: DYN, 1: DYN}}
|
|
892
|
+
_valid = [(torch.rand((2, 3), dtype=torch.float32),)]
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
class ExportWithDimension1(torch.nn.Module):
|
|
896
|
+
def forward(self, x):
|
|
897
|
+
return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
|
|
898
|
+
|
|
899
|
+
_inputs = [(torch.zeros((1, 3), dtype=torch.float32),)]
|
|
900
|
+
_dynamic = {"x": {0: DYN, 1: DYN}}
|
|
901
|
+
_valid = [(torch.rand((2, 3), dtype=torch.float32),)]
|
|
@@ -83,7 +83,7 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
83
83
|
continue
|
|
84
84
|
|
|
85
85
|
original = cls._PATCHED_CLASS_
|
|
86
|
-
methods = cls._PATCHES_
|
|
86
|
+
methods = [_ for _ in cls._PATCHES_ if _ is not None]
|
|
87
87
|
if verbose:
|
|
88
88
|
print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
|
|
89
89
|
|
|
@@ -254,22 +254,36 @@ def torch_export_patches(
|
|
|
254
254
|
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
|
|
255
255
|
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
|
|
256
256
|
"""
|
|
257
|
+
if verbose:
|
|
258
|
+
print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
|
|
259
|
+
print(f" . patch_torch={patch_torch!r}")
|
|
260
|
+
print(f" . patch_transformers={patch_transformers!r}")
|
|
261
|
+
print(f" . patch_diffusers={patch_diffusers!r}")
|
|
262
|
+
print(f" . catch_constraints={catch_constraints!r}")
|
|
263
|
+
print(f" . stop_if_static={stop_if_static!r}")
|
|
264
|
+
print(f" . patch={patch!r}")
|
|
265
|
+
print(f" . custom_patches={custom_patches!r}")
|
|
266
|
+
print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
|
|
267
|
+
|
|
257
268
|
if rewrite:
|
|
258
269
|
from .patch_module import torch_export_rewrite
|
|
259
270
|
|
|
260
|
-
with
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
271
|
+
with (
|
|
272
|
+
torch_export_rewrite(
|
|
273
|
+
rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
|
|
274
|
+
),
|
|
275
|
+
torch_export_patches( # type: ignore[var-annotated]
|
|
276
|
+
patch_sympy=patch_sympy,
|
|
277
|
+
patch_torch=patch_torch,
|
|
278
|
+
patch_transformers=patch_transformers,
|
|
279
|
+
patch_diffusers=patch_diffusers,
|
|
280
|
+
catch_constraints=catch_constraints,
|
|
281
|
+
stop_if_static=stop_if_static,
|
|
282
|
+
verbose=verbose,
|
|
283
|
+
patch=patch,
|
|
284
|
+
custom_patches=custom_patches,
|
|
285
|
+
) as f,
|
|
286
|
+
):
|
|
273
287
|
try:
|
|
274
288
|
yield f
|
|
275
289
|
finally:
|
|
@@ -330,6 +344,7 @@ def torch_export_patches(
|
|
|
330
344
|
patched_infer_size,
|
|
331
345
|
patched_vmap,
|
|
332
346
|
patched__broadcast_shapes,
|
|
347
|
+
patched__constrain_user_specified_dimhint_range,
|
|
333
348
|
_catch_produce_guards_and_solve_constraints,
|
|
334
349
|
patch__check_input_constraints_for_graph,
|
|
335
350
|
)
|
|
@@ -360,6 +375,14 @@ def torch_export_patches(
|
|
|
360
375
|
torch._refs._broadcast_shapes = patched__broadcast_shapes
|
|
361
376
|
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
|
|
362
377
|
|
|
378
|
+
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
379
|
+
f___constrain_user_specified_dimhint_range = (
|
|
380
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
381
|
+
)
|
|
382
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
383
|
+
patched__constrain_user_specified_dimhint_range
|
|
384
|
+
)
|
|
385
|
+
|
|
363
386
|
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
364
387
|
if patch_torch and catch_constraints:
|
|
365
388
|
if verbose:
|
|
@@ -558,6 +581,9 @@ def torch_export_patches(
|
|
|
558
581
|
torch._subclasses.fake_impls.infer_size = f_infer_size
|
|
559
582
|
torch._refs._broadcast_shapes = f__broadcast_shapes
|
|
560
583
|
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
|
|
584
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
585
|
+
f___constrain_user_specified_dimhint_range
|
|
586
|
+
)
|
|
561
587
|
|
|
562
588
|
if verbose:
|
|
563
589
|
print("[torch_export_patches] restored pytorch functions")
|
|
@@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
189
189
|
return (), updated_kwargs, dynamic_shapes
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
def use_dyn_not_str(dynamic_shapes: Any) -> Any:
|
|
192
|
+
def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
|
|
193
193
|
"""
|
|
194
194
|
Some functions returns dynamic shapes as string.
|
|
195
195
|
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
|
|
196
|
+
``default_value=torch.export.Dim.AUTO`` changes the default value.
|
|
196
197
|
"""
|
|
197
198
|
if isinstance(dynamic_shapes, list):
|
|
198
|
-
return [use_dyn_not_str(a) for a in dynamic_shapes]
|
|
199
|
+
return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
|
|
199
200
|
if isinstance(dynamic_shapes, tuple):
|
|
200
|
-
return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
|
|
201
|
+
return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
|
|
201
202
|
if isinstance(dynamic_shapes, dict):
|
|
202
|
-
return {
|
|
203
|
+
return {
|
|
204
|
+
k: use_dyn_not_str(v, default_value=default_value)
|
|
205
|
+
for k, v in dynamic_shapes.items()
|
|
206
|
+
}
|
|
203
207
|
if isinstance(dynamic_shapes, set):
|
|
204
|
-
return {use_dyn_not_str(a) for a in dynamic_shapes}
|
|
208
|
+
return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
|
|
205
209
|
if isinstance(dynamic_shapes, str):
|
|
206
|
-
return torch.export.Dim.DYNAMIC
|
|
210
|
+
return torch.export.Dim.DYNAMIC if default_value is None else default_value
|
|
207
211
|
return dynamic_shapes
|