onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +7 -2
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/log_helper.py +53 -17
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +2 -1
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +30 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +184 -151
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +20 -5
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +52 -20
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +540 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -5
- onnx_diagnostic/torch_models/validate.py +116 -50
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/top_level.txt +0 -0
|
@@ -39,6 +39,9 @@ def evaluation(
|
|
|
39
39
|
"export-strict",
|
|
40
40
|
"export-nostrict",
|
|
41
41
|
"export-nostrict-decall",
|
|
42
|
+
"export-strict-oblivious",
|
|
43
|
+
"export-nostrict-oblivious",
|
|
44
|
+
"export-nostrict-decall-oblivious",
|
|
42
45
|
),
|
|
43
46
|
dynamic: Tuple[bool] = (False, True),
|
|
44
47
|
cases: Optional[Union[str, Dict[str, type]]] = None,
|
|
@@ -105,9 +108,7 @@ def evaluation(
|
|
|
105
108
|
|
|
106
109
|
|
|
107
110
|
def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
|
|
108
|
-
"""
|
|
109
|
-
Flatten inputs.
|
|
110
|
-
"""
|
|
111
|
+
"""Flatten inputs."""
|
|
111
112
|
if x is None:
|
|
112
113
|
return x
|
|
113
114
|
import torch
|
|
@@ -173,6 +174,15 @@ def _clone(x):
|
|
|
173
174
|
raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
|
|
174
175
|
|
|
175
176
|
|
|
177
|
+
def _wrap_torch_export(*args, backed_size_oblivious=False, **kwargs):
|
|
178
|
+
import torch
|
|
179
|
+
|
|
180
|
+
if backed_size_oblivious:
|
|
181
|
+
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
|
182
|
+
return torch.export.export(*args, **kwargs)
|
|
183
|
+
return torch.export.export(*args, **kwargs)
|
|
184
|
+
|
|
185
|
+
|
|
176
186
|
def _make_exporter_export(
|
|
177
187
|
exporter: str,
|
|
178
188
|
model: "torch.nn.Module", # noqa: F821
|
|
@@ -183,76 +193,36 @@ def _make_exporter_export(
|
|
|
183
193
|
) -> Union[Dict, Callable]:
|
|
184
194
|
import torch
|
|
185
195
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
exported = torch.export.export(
|
|
197
|
-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
198
|
-
)
|
|
199
|
-
except Exception as e:
|
|
200
|
-
if not quiet:
|
|
201
|
-
raise
|
|
202
|
-
return dict(error=str(e), success=0, error_step="export")
|
|
203
|
-
if verbose >= 9:
|
|
204
|
-
print("-- graph")
|
|
205
|
-
print(exported.graph)
|
|
206
|
-
return exported.module()
|
|
207
|
-
if exporter in ("export-strict-dec", "export-strict-decall"):
|
|
208
|
-
try:
|
|
209
|
-
if verbose >= 2:
|
|
210
|
-
exported = torch.export.export(
|
|
211
|
-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
212
|
-
)
|
|
213
|
-
if verbose >= 9:
|
|
214
|
-
print("-- graph before decomposition")
|
|
215
|
-
print(exported.graph)
|
|
216
|
-
exported = (
|
|
217
|
-
exported.run_decompositions()
|
|
218
|
-
if "decall" in exporter
|
|
219
|
-
else exported.run_decompositions({})
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
223
|
-
io.StringIO()
|
|
224
|
-
):
|
|
225
|
-
exported = torch.export.export(
|
|
226
|
-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
227
|
-
)
|
|
228
|
-
if verbose >= 9:
|
|
229
|
-
print("-- graph before decomposition")
|
|
230
|
-
print(exported.graph)
|
|
231
|
-
exported = (
|
|
232
|
-
exported.run_decompositions()
|
|
233
|
-
if "decall" in exporter
|
|
234
|
-
else exported.run_decompositions({})
|
|
235
|
-
)
|
|
236
|
-
except Exception as e:
|
|
237
|
-
if not quiet:
|
|
238
|
-
raise
|
|
239
|
-
return dict(error=str(e), success=0, error_step="export")
|
|
240
|
-
if verbose >= 9:
|
|
241
|
-
print("-- graph after decomposition")
|
|
242
|
-
print(exported.graph)
|
|
243
|
-
return exported.module()
|
|
244
|
-
if exporter == "export-nostrict":
|
|
196
|
+
backed_size_oblivious = "-oblivious" in exporter
|
|
197
|
+
strict = "-nostrict" not in exporter
|
|
198
|
+
|
|
199
|
+
if exporter in (
|
|
200
|
+
"export-strict",
|
|
201
|
+
"export-strict-oblivious",
|
|
202
|
+
"export-nostrict",
|
|
203
|
+
"export-nostrict-oblivious",
|
|
204
|
+
"export-oblivious",
|
|
205
|
+
):
|
|
245
206
|
try:
|
|
246
207
|
if verbose >= 2:
|
|
247
|
-
exported =
|
|
248
|
-
model,
|
|
208
|
+
exported = _wrap_torch_export(
|
|
209
|
+
model,
|
|
210
|
+
inputs,
|
|
211
|
+
dynamic_shapes=dynamic_shapes,
|
|
212
|
+
strict=strict,
|
|
213
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
249
214
|
)
|
|
250
215
|
else:
|
|
251
|
-
with
|
|
252
|
-
io.StringIO()
|
|
216
|
+
with (
|
|
217
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
218
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
253
219
|
):
|
|
254
|
-
exported =
|
|
255
|
-
model,
|
|
220
|
+
exported = _wrap_torch_export(
|
|
221
|
+
model,
|
|
222
|
+
inputs,
|
|
223
|
+
dynamic_shapes=dynamic_shapes,
|
|
224
|
+
strict=strict,
|
|
225
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
256
226
|
)
|
|
257
227
|
except Exception as e:
|
|
258
228
|
if not quiet:
|
|
@@ -262,11 +232,25 @@ def _make_exporter_export(
|
|
|
262
232
|
print("-- graph")
|
|
263
233
|
print(exported.graph)
|
|
264
234
|
return exported.module()
|
|
265
|
-
|
|
235
|
+
|
|
236
|
+
if exporter in (
|
|
237
|
+
"export-strict-dec",
|
|
238
|
+
"export-strict-decall",
|
|
239
|
+
"export-strict-dec-oblivious",
|
|
240
|
+
"export-strict-decall-oblivious",
|
|
241
|
+
"export-nostrict-dec",
|
|
242
|
+
"export-nostrict-decall",
|
|
243
|
+
"export-nostrict-dec-oblivious",
|
|
244
|
+
"export-nostrict-decall-oblivious",
|
|
245
|
+
):
|
|
266
246
|
try:
|
|
267
247
|
if verbose >= 2:
|
|
268
|
-
exported =
|
|
269
|
-
model,
|
|
248
|
+
exported = _wrap_torch_export(
|
|
249
|
+
model,
|
|
250
|
+
inputs,
|
|
251
|
+
dynamic_shapes=dynamic_shapes,
|
|
252
|
+
strict=strict,
|
|
253
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
270
254
|
)
|
|
271
255
|
if verbose >= 9:
|
|
272
256
|
print("-- graph before decomposition")
|
|
@@ -277,11 +261,16 @@ def _make_exporter_export(
|
|
|
277
261
|
else exported.run_decompositions({})
|
|
278
262
|
)
|
|
279
263
|
else:
|
|
280
|
-
with
|
|
281
|
-
io.StringIO()
|
|
264
|
+
with (
|
|
265
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
266
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
282
267
|
):
|
|
283
|
-
exported =
|
|
284
|
-
model,
|
|
268
|
+
exported = _wrap_torch_export(
|
|
269
|
+
model,
|
|
270
|
+
inputs,
|
|
271
|
+
dynamic_shapes=dynamic_shapes,
|
|
272
|
+
strict=strict,
|
|
273
|
+
backed_size_oblivious=backed_size_oblivious,
|
|
285
274
|
)
|
|
286
275
|
if verbose >= 9:
|
|
287
276
|
print("-- graph before decomposition")
|
|
@@ -299,6 +288,7 @@ def _make_exporter_export(
|
|
|
299
288
|
print("-- graph after decomposition")
|
|
300
289
|
print(exported.graph)
|
|
301
290
|
return exported.module()
|
|
291
|
+
|
|
302
292
|
if exporter == "export-tracing":
|
|
303
293
|
from experimental_experiment.torch_interpreter.tracing import CustomTracer
|
|
304
294
|
|
|
@@ -307,8 +297,9 @@ def _make_exporter_export(
|
|
|
307
297
|
graph = CustomTracer().trace(model)
|
|
308
298
|
mod = torch.fx.GraphModule(model, graph)
|
|
309
299
|
else:
|
|
310
|
-
with
|
|
311
|
-
io.StringIO()
|
|
300
|
+
with (
|
|
301
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
302
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
312
303
|
):
|
|
313
304
|
graph = CustomTracer().trace(model)
|
|
314
305
|
mod = torch.fx.GraphModule(model, graph)
|
|
@@ -353,8 +344,9 @@ def _make_exporter_onnx(
|
|
|
353
344
|
return_builder=True,
|
|
354
345
|
)
|
|
355
346
|
else:
|
|
356
|
-
with
|
|
357
|
-
io.StringIO()
|
|
347
|
+
with (
|
|
348
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
349
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
358
350
|
):
|
|
359
351
|
onx, builder = to_onnx(
|
|
360
352
|
model,
|
|
@@ -387,8 +379,9 @@ def _make_exporter_onnx(
|
|
|
387
379
|
report=True,
|
|
388
380
|
).model_proto
|
|
389
381
|
else:
|
|
390
|
-
with
|
|
391
|
-
io.StringIO()
|
|
382
|
+
with (
|
|
383
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
384
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
392
385
|
):
|
|
393
386
|
onx = torch.onnx.export(
|
|
394
387
|
model,
|
|
@@ -422,8 +415,9 @@ def _make_exporter_onnx(
|
|
|
422
415
|
ep.optimize()
|
|
423
416
|
onx = ep.model_proto
|
|
424
417
|
else:
|
|
425
|
-
with
|
|
426
|
-
io.StringIO()
|
|
418
|
+
with (
|
|
419
|
+
contextlib.redirect_stdout(io.StringIO()),
|
|
420
|
+
contextlib.redirect_stderr(io.StringIO()),
|
|
427
421
|
):
|
|
428
422
|
ep = torch.onnx.export(
|
|
429
423
|
model,
|
|
@@ -446,6 +440,74 @@ def _make_exporter_onnx(
|
|
|
446
440
|
raise AssertionError(f"Unexpected exporter={exporter!r}")
|
|
447
441
|
|
|
448
442
|
|
|
443
|
+
def _compares_on_one_example(
|
|
444
|
+
model: Callable, inputs: Tuple[Any, ...], mod: Callable, verbose: int, quiet: bool
|
|
445
|
+
) -> Tuple[Any, Any, Dict]:
|
|
446
|
+
from onnx_diagnostic.helpers import max_diff, string_type
|
|
447
|
+
|
|
448
|
+
try:
|
|
449
|
+
expected = model(*_clone(inputs))
|
|
450
|
+
except Exception as e:
|
|
451
|
+
if not quiet:
|
|
452
|
+
raise RuntimeError(
|
|
453
|
+
f"eager mode failed=\n{string_type(inputs, with_shape=True)} "
|
|
454
|
+
f"\nmodel=\n{type(model)}"
|
|
455
|
+
) from e
|
|
456
|
+
res = dict(error=str(e), success=0, error_step="eager")
|
|
457
|
+
return None, None, res
|
|
458
|
+
try:
|
|
459
|
+
got = mod(*inputs)
|
|
460
|
+
except Exception as e:
|
|
461
|
+
if not quiet:
|
|
462
|
+
raise RuntimeError(
|
|
463
|
+
f"onnxruntime failed, feeds=\n{string_type(inputs, with_shape=True)}"
|
|
464
|
+
) from e
|
|
465
|
+
res = dict(error=str(e), success=0, error_step="run.0")
|
|
466
|
+
return expected, None, res
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
disc = max_diff(expected, got)
|
|
470
|
+
except Exception as e:
|
|
471
|
+
if not quiet:
|
|
472
|
+
raise
|
|
473
|
+
res = dict(error=str(e), success=0, error_step="discrepancy")
|
|
474
|
+
return expected, got, res
|
|
475
|
+
|
|
476
|
+
if verbose >= 5 and np.isinf(disc["abs"]):
|
|
477
|
+
print("[run_exporter] comparison issues with")
|
|
478
|
+
print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
|
|
479
|
+
print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
|
|
480
|
+
print(f"-- got={string_type(got, with_shape=True, limit=20)}")
|
|
481
|
+
elif verbose >= 9:
|
|
482
|
+
print("[run_exporter] inputs and outputs")
|
|
483
|
+
print(
|
|
484
|
+
f"-- inputs="
|
|
485
|
+
f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
|
|
486
|
+
)
|
|
487
|
+
print(
|
|
488
|
+
f"-- expected="
|
|
489
|
+
f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
|
|
490
|
+
)
|
|
491
|
+
print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
|
|
492
|
+
del disc["n"]
|
|
493
|
+
del disc["sum"]
|
|
494
|
+
disc.update(
|
|
495
|
+
dict(
|
|
496
|
+
success=1 if disc["abs"] < 0.1 else 0,
|
|
497
|
+
model_cls=model.__class__, # type: ignore[dict-item]
|
|
498
|
+
exported=mod, # type: ignore[dict-item]
|
|
499
|
+
)
|
|
500
|
+
)
|
|
501
|
+
if disc["abs"] >= 0.1:
|
|
502
|
+
disc["error"] = "diff.0"
|
|
503
|
+
disc["error_step"] = "diff.0"
|
|
504
|
+
if verbose >= 9:
|
|
505
|
+
max_diff(expected, got, verbose=verbose)
|
|
506
|
+
else:
|
|
507
|
+
disc["success"] = 1
|
|
508
|
+
return expected, got, disc
|
|
509
|
+
|
|
510
|
+
|
|
449
511
|
def run_exporter(
|
|
450
512
|
exporter: str,
|
|
451
513
|
cls_model: type,
|
|
@@ -473,6 +535,7 @@ def run_exporter(
|
|
|
473
535
|
|
|
474
536
|
model = cls_model()
|
|
475
537
|
inputs = cls_model._inputs
|
|
538
|
+
valid = getattr(cls_model, "_valid", None)
|
|
476
539
|
if isinstance(inputs, tuple):
|
|
477
540
|
inputs = [inputs]
|
|
478
541
|
if dynamic:
|
|
@@ -566,74 +629,38 @@ def run_exporter(
|
|
|
566
629
|
mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
|
|
567
630
|
|
|
568
631
|
# we need to clone for models modifying the inputs
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
) from e
|
|
577
|
-
res = dict(error=str(e), success=0, error_step="eager")
|
|
578
|
-
res.update(base)
|
|
579
|
-
return res
|
|
580
|
-
try:
|
|
581
|
-
got = mod(*inputs[0])
|
|
582
|
-
except Exception as e:
|
|
583
|
-
if not quiet:
|
|
584
|
-
raise RuntimeError(
|
|
585
|
-
f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
|
|
586
|
-
f"\nmodel=\n{pretty_onnx(onx)}"
|
|
587
|
-
) from e
|
|
588
|
-
res = dict(error=str(e), success=0, error_step="run.0")
|
|
589
|
-
res.update(base)
|
|
590
|
-
return res
|
|
591
|
-
|
|
592
|
-
base["expected"] = expected
|
|
593
|
-
base["obtained"] = got
|
|
594
|
-
|
|
595
|
-
try:
|
|
596
|
-
disc = max_diff(expected, got)
|
|
597
|
-
except Exception as e:
|
|
598
|
-
if not quiet:
|
|
599
|
-
raise
|
|
600
|
-
res = dict(error=str(e), success=0, error_step="discrepancy")
|
|
601
|
-
res.update(base)
|
|
602
|
-
return res
|
|
632
|
+
expected, got, disc = _compares_on_one_example(model, inputs[0], mod, verbose, quiet)
|
|
633
|
+
if expected is not None:
|
|
634
|
+
base["expected"] = expected
|
|
635
|
+
if got is not None:
|
|
636
|
+
base["obtained"] = got
|
|
637
|
+
disc.update(base)
|
|
638
|
+
disc["onnx"] = onx # type: ignore[dict-item]
|
|
603
639
|
|
|
604
|
-
if
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
)
|
|
629
|
-
)
|
|
630
|
-
if disc["abs"] >= 0.1:
|
|
631
|
-
disc["error"] = "diff.0"
|
|
632
|
-
disc["error_step"] = "diff.0"
|
|
633
|
-
if verbose >= 9:
|
|
634
|
-
max_diff(expected, got, verbose=verbose)
|
|
635
|
-
else:
|
|
636
|
-
disc["success"] = 1
|
|
640
|
+
if valid is not None:
|
|
641
|
+
for valid_inputs in valid:
|
|
642
|
+
expected, got, _disc = _compares_on_one_example(
|
|
643
|
+
model, valid_inputs, mod, verbose, quiet
|
|
644
|
+
)
|
|
645
|
+
if "abs" not in disc and (np.isnan(disc["abs"]) or disc["abs"] > 1e-3):
|
|
646
|
+
_disc["issue-abs"] = disc["abs"]
|
|
647
|
+
_disc["issue-rel"] = disc["rel"]
|
|
648
|
+
_disc["issue-inputs"] = string_type(
|
|
649
|
+
valid_inputs, with_shape=True, with_min_max=True
|
|
650
|
+
)
|
|
651
|
+
_disc["issue-expected"] = string_type(
|
|
652
|
+
expected, with_shape=True, with_min_max=True
|
|
653
|
+
)
|
|
654
|
+
_disc["issue-obtained"] = string_type(got, with_shape=True, with_min_max=True)
|
|
655
|
+
if not quiet:
|
|
656
|
+
raise RuntimeError(
|
|
657
|
+
f"validation failed,"
|
|
658
|
+
f"\n-- inputs=\n{string_type(_disc['issue-inputs'])} "
|
|
659
|
+
f"\n-- exporter={exporter!r}\n-- dynamic_shapes={dynamic_shapes}, "
|
|
660
|
+
f"\n-- expected={_disc['issue-expected']}"
|
|
661
|
+
f"\n-- obtained={_disc['issue-obtained']}"
|
|
662
|
+
)
|
|
663
|
+
break
|
|
637
664
|
|
|
638
665
|
if dynamic and onx is not None:
|
|
639
666
|
ds = []
|
|
@@ -649,7 +676,13 @@ def run_exporter(
|
|
|
649
676
|
|
|
650
677
|
if dynamic and len(inputs) > 1:
|
|
651
678
|
for index, i in enumerate(inputs):
|
|
652
|
-
|
|
679
|
+
if quiet:
|
|
680
|
+
try:
|
|
681
|
+
expected = model(*_clone(i))
|
|
682
|
+
except Exception as e:
|
|
683
|
+
return dict(error=str(e), success=0, error_step=f"run0.{index}")
|
|
684
|
+
else:
|
|
685
|
+
expected = model(*_clone(i))
|
|
653
686
|
try:
|
|
654
687
|
got = mod(*i)
|
|
655
688
|
except Exception as e:
|
|
@@ -353,12 +353,9 @@ class ControlFlowCondNonZero(torch.nn.Module):
|
|
|
353
353
|
|
|
354
354
|
|
|
355
355
|
class ControlFlowCondIdentity_153832(torch.nn.Module):
|
|
356
|
-
"""
|
|
357
|
-
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
|
|
358
|
-
"""
|
|
356
|
+
"""`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
|
|
359
357
|
|
|
360
358
|
def forward(self, x, y):
|
|
361
|
-
|
|
362
359
|
def branch_cond_then_1(x):
|
|
363
360
|
x = torch.abs(x) + 1
|
|
364
361
|
return x
|
|
@@ -861,7 +858,7 @@ class CreateFromShapeThroughFunction(torch.nn.Module):
|
|
|
861
858
|
y = torch.ones((x.shape[0], dy1))
|
|
862
859
|
return y
|
|
863
860
|
|
|
864
|
-
_inputs = [(torch.rand((4, 4)),)
|
|
861
|
+
_inputs = [(torch.rand((4, 4)),)]
|
|
865
862
|
_dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
|
|
866
863
|
|
|
867
864
|
|
|
@@ -881,3 +878,21 @@ class VmapPython(torch.nn.Module):
|
|
|
881
878
|
|
|
882
879
|
_inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
|
|
883
880
|
_dynamic = {"x": {0: DYN}, "y": {0: DYN}}
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
class ExportWithDimension0(torch.nn.Module):
|
|
884
|
+
def forward(self, x):
|
|
885
|
+
return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
|
|
886
|
+
|
|
887
|
+
_inputs = [(torch.empty((0, 3), dtype=torch.float32),)]
|
|
888
|
+
_dynamic = {"x": {0: DYN, 1: DYN}}
|
|
889
|
+
_valid = [(torch.rand((2, 3), dtype=torch.float32),)]
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
class ExportWithDimension1(torch.nn.Module):
|
|
893
|
+
def forward(self, x):
|
|
894
|
+
return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
|
|
895
|
+
|
|
896
|
+
_inputs = [(torch.zeros((1, 3), dtype=torch.float32),)]
|
|
897
|
+
_dynamic = {"x": {0: DYN, 1: DYN}}
|
|
898
|
+
_valid = [(torch.rand((2, 3), dtype=torch.float32),)]
|
|
@@ -83,7 +83,7 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
83
83
|
continue
|
|
84
84
|
|
|
85
85
|
original = cls._PATCHED_CLASS_
|
|
86
|
-
methods = cls._PATCHES_
|
|
86
|
+
methods = [_ for _ in cls._PATCHES_ if _ is not None]
|
|
87
87
|
if verbose:
|
|
88
88
|
print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
|
|
89
89
|
|
|
@@ -268,19 +268,22 @@ def torch_export_patches(
|
|
|
268
268
|
if rewrite:
|
|
269
269
|
from .patch_module import torch_export_rewrite
|
|
270
270
|
|
|
271
|
-
with
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
271
|
+
with (
|
|
272
|
+
torch_export_rewrite(
|
|
273
|
+
rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
|
|
274
|
+
),
|
|
275
|
+
torch_export_patches( # type: ignore[var-annotated]
|
|
276
|
+
patch_sympy=patch_sympy,
|
|
277
|
+
patch_torch=patch_torch,
|
|
278
|
+
patch_transformers=patch_transformers,
|
|
279
|
+
patch_diffusers=patch_diffusers,
|
|
280
|
+
catch_constraints=catch_constraints,
|
|
281
|
+
stop_if_static=stop_if_static,
|
|
282
|
+
verbose=verbose,
|
|
283
|
+
patch=patch,
|
|
284
|
+
custom_patches=custom_patches,
|
|
285
|
+
) as f,
|
|
286
|
+
):
|
|
284
287
|
try:
|
|
285
288
|
yield f
|
|
286
289
|
finally:
|
|
@@ -337,12 +340,17 @@ def torch_export_patches(
|
|
|
337
340
|
###############
|
|
338
341
|
|
|
339
342
|
if patch_torch:
|
|
343
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
340
344
|
from .patches.patch_torch import (
|
|
341
345
|
patched_infer_size,
|
|
342
346
|
patched_vmap,
|
|
343
347
|
patched__broadcast_shapes,
|
|
348
|
+
patched__constrain_user_specified_dimhint_range,
|
|
344
349
|
_catch_produce_guards_and_solve_constraints,
|
|
345
350
|
patch__check_input_constraints_for_graph,
|
|
351
|
+
patched__broadcast_in_dim_meta,
|
|
352
|
+
patched__maybe_broadcast,
|
|
353
|
+
patched_ShapeEnv,
|
|
346
354
|
)
|
|
347
355
|
|
|
348
356
|
if verbose:
|
|
@@ -371,6 +379,28 @@ def torch_export_patches(
|
|
|
371
379
|
torch._refs._broadcast_shapes = patched__broadcast_shapes
|
|
372
380
|
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
|
|
373
381
|
|
|
382
|
+
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
383
|
+
f___constrain_user_specified_dimhint_range = (
|
|
384
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
385
|
+
)
|
|
386
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
387
|
+
patched__constrain_user_specified_dimhint_range
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# torch._prims._broadcast_in_dim_meta
|
|
391
|
+
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
392
|
+
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
393
|
+
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
|
|
394
|
+
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
|
|
395
|
+
|
|
396
|
+
# torch._refs._maybe_broadcast
|
|
397
|
+
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
398
|
+
torch._refs._maybe_broadcast = patched__maybe_broadcast
|
|
399
|
+
|
|
400
|
+
# ShapeEnv
|
|
401
|
+
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
|
|
402
|
+
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
|
|
403
|
+
|
|
374
404
|
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
375
405
|
if patch_torch and catch_constraints:
|
|
376
406
|
if verbose:
|
|
@@ -393,9 +423,6 @@ def torch_export_patches(
|
|
|
393
423
|
)
|
|
394
424
|
|
|
395
425
|
if stop_if_static:
|
|
396
|
-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
397
|
-
from .patches.patch_torch import patched_ShapeEnv
|
|
398
|
-
|
|
399
426
|
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
400
427
|
|
|
401
428
|
if verbose:
|
|
@@ -569,6 +596,13 @@ def torch_export_patches(
|
|
|
569
596
|
torch._subclasses.fake_impls.infer_size = f_infer_size
|
|
570
597
|
torch._refs._broadcast_shapes = f__broadcast_shapes
|
|
571
598
|
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
|
|
599
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
600
|
+
f___constrain_user_specified_dimhint_range
|
|
601
|
+
)
|
|
602
|
+
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
|
|
603
|
+
torch._prims.broadcast_in_dim = f_broadcast_in_dim
|
|
604
|
+
torch._refs._maybe_broadcast = f__maybe_broadcast
|
|
605
|
+
ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
|
|
572
606
|
|
|
573
607
|
if verbose:
|
|
574
608
|
print("[torch_export_patches] restored pytorch functions")
|
|
@@ -708,9 +742,7 @@ def torch_export_patches(
|
|
|
708
742
|
|
|
709
743
|
|
|
710
744
|
def replacement_before_exporting(args: Any) -> Any:
|
|
711
|
-
"""
|
|
712
|
-
Does replacements on the given inputs if needed.
|
|
713
|
-
"""
|
|
745
|
+
"""Does replacements on the given inputs if needed."""
|
|
714
746
|
if args is None:
|
|
715
747
|
return None
|
|
716
748
|
if isinstance(args, (int, float)):
|
|
@@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
189
189
|
return (), updated_kwargs, dynamic_shapes
|
|
190
190
|
|
|
191
191
|
|
|
192
|
-
def use_dyn_not_str(dynamic_shapes: Any) -> Any:
|
|
192
|
+
def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
|
|
193
193
|
"""
|
|
194
194
|
Some functions returns dynamic shapes as string.
|
|
195
195
|
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
|
|
196
|
+
``default_value=torch.export.Dim.AUTO`` changes the default value.
|
|
196
197
|
"""
|
|
197
198
|
if isinstance(dynamic_shapes, list):
|
|
198
|
-
return [use_dyn_not_str(a) for a in dynamic_shapes]
|
|
199
|
+
return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
|
|
199
200
|
if isinstance(dynamic_shapes, tuple):
|
|
200
|
-
return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
|
|
201
|
+
return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
|
|
201
202
|
if isinstance(dynamic_shapes, dict):
|
|
202
|
-
return {
|
|
203
|
+
return {
|
|
204
|
+
k: use_dyn_not_str(v, default_value=default_value)
|
|
205
|
+
for k, v in dynamic_shapes.items()
|
|
206
|
+
}
|
|
203
207
|
if isinstance(dynamic_shapes, set):
|
|
204
|
-
return {use_dyn_not_str(a) for a in dynamic_shapes}
|
|
208
|
+
return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
|
|
205
209
|
if isinstance(dynamic_shapes, str):
|
|
206
|
-
return torch.export.Dim.DYNAMIC
|
|
210
|
+
return torch.export.Dim.DYNAMIC if default_value is None else default_value
|
|
207
211
|
return dynamic_shapes
|