onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +7 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +53 -17
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +2 -1
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +30 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +184 -151
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +20 -5
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +52 -20
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +540 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -5
  23. onnx_diagnostic/torch_models/validate.py +116 -50
  24. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  25. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/METADATA +11 -31
  26. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,9 @@ def evaluation(
39
39
  "export-strict",
40
40
  "export-nostrict",
41
41
  "export-nostrict-decall",
42
+ "export-strict-oblivious",
43
+ "export-nostrict-oblivious",
44
+ "export-nostrict-decall-oblivious",
42
45
  ),
43
46
  dynamic: Tuple[bool] = (False, True),
44
47
  cases: Optional[Union[str, Dict[str, type]]] = None,
@@ -105,9 +108,7 @@ def evaluation(
105
108
 
106
109
 
107
110
  def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
108
- """
109
- Flatten inputs.
110
- """
111
+ """Flatten inputs."""
111
112
  if x is None:
112
113
  return x
113
114
  import torch
@@ -173,6 +174,15 @@ def _clone(x):
173
174
  raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
174
175
 
175
176
 
177
+ def _wrap_torch_export(*args, backed_size_oblivious=False, **kwargs):
178
+ import torch
179
+
180
+ if backed_size_oblivious:
181
+ with torch.fx.experimental._config.patch(backed_size_oblivious=True):
182
+ return torch.export.export(*args, **kwargs)
183
+ return torch.export.export(*args, **kwargs)
184
+
185
+
176
186
  def _make_exporter_export(
177
187
  exporter: str,
178
188
  model: "torch.nn.Module", # noqa: F821
@@ -183,76 +193,36 @@ def _make_exporter_export(
183
193
  ) -> Union[Dict, Callable]:
184
194
  import torch
185
195
 
186
- if exporter == "export-strict":
187
- try:
188
- if verbose >= 2:
189
- exported = torch.export.export(
190
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
191
- )
192
- else:
193
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
194
- io.StringIO()
195
- ):
196
- exported = torch.export.export(
197
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
198
- )
199
- except Exception as e:
200
- if not quiet:
201
- raise
202
- return dict(error=str(e), success=0, error_step="export")
203
- if verbose >= 9:
204
- print("-- graph")
205
- print(exported.graph)
206
- return exported.module()
207
- if exporter in ("export-strict-dec", "export-strict-decall"):
208
- try:
209
- if verbose >= 2:
210
- exported = torch.export.export(
211
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
212
- )
213
- if verbose >= 9:
214
- print("-- graph before decomposition")
215
- print(exported.graph)
216
- exported = (
217
- exported.run_decompositions()
218
- if "decall" in exporter
219
- else exported.run_decompositions({})
220
- )
221
- else:
222
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
223
- io.StringIO()
224
- ):
225
- exported = torch.export.export(
226
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
227
- )
228
- if verbose >= 9:
229
- print("-- graph before decomposition")
230
- print(exported.graph)
231
- exported = (
232
- exported.run_decompositions()
233
- if "decall" in exporter
234
- else exported.run_decompositions({})
235
- )
236
- except Exception as e:
237
- if not quiet:
238
- raise
239
- return dict(error=str(e), success=0, error_step="export")
240
- if verbose >= 9:
241
- print("-- graph after decomposition")
242
- print(exported.graph)
243
- return exported.module()
244
- if exporter == "export-nostrict":
196
+ backed_size_oblivious = "-oblivious" in exporter
197
+ strict = "-nostrict" not in exporter
198
+
199
+ if exporter in (
200
+ "export-strict",
201
+ "export-strict-oblivious",
202
+ "export-nostrict",
203
+ "export-nostrict-oblivious",
204
+ "export-oblivious",
205
+ ):
245
206
  try:
246
207
  if verbose >= 2:
247
- exported = torch.export.export(
248
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
208
+ exported = _wrap_torch_export(
209
+ model,
210
+ inputs,
211
+ dynamic_shapes=dynamic_shapes,
212
+ strict=strict,
213
+ backed_size_oblivious=backed_size_oblivious,
249
214
  )
250
215
  else:
251
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
252
- io.StringIO()
216
+ with (
217
+ contextlib.redirect_stdout(io.StringIO()),
218
+ contextlib.redirect_stderr(io.StringIO()),
253
219
  ):
254
- exported = torch.export.export(
255
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
220
+ exported = _wrap_torch_export(
221
+ model,
222
+ inputs,
223
+ dynamic_shapes=dynamic_shapes,
224
+ strict=strict,
225
+ backed_size_oblivious=backed_size_oblivious,
256
226
  )
257
227
  except Exception as e:
258
228
  if not quiet:
@@ -262,11 +232,25 @@ def _make_exporter_export(
262
232
  print("-- graph")
263
233
  print(exported.graph)
264
234
  return exported.module()
265
- if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
235
+
236
+ if exporter in (
237
+ "export-strict-dec",
238
+ "export-strict-decall",
239
+ "export-strict-dec-oblivious",
240
+ "export-strict-decall-oblivious",
241
+ "export-nostrict-dec",
242
+ "export-nostrict-decall",
243
+ "export-nostrict-dec-oblivious",
244
+ "export-nostrict-decall-oblivious",
245
+ ):
266
246
  try:
267
247
  if verbose >= 2:
268
- exported = torch.export.export(
269
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
248
+ exported = _wrap_torch_export(
249
+ model,
250
+ inputs,
251
+ dynamic_shapes=dynamic_shapes,
252
+ strict=strict,
253
+ backed_size_oblivious=backed_size_oblivious,
270
254
  )
271
255
  if verbose >= 9:
272
256
  print("-- graph before decomposition")
@@ -277,11 +261,16 @@ def _make_exporter_export(
277
261
  else exported.run_decompositions({})
278
262
  )
279
263
  else:
280
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
281
- io.StringIO()
264
+ with (
265
+ contextlib.redirect_stdout(io.StringIO()),
266
+ contextlib.redirect_stderr(io.StringIO()),
282
267
  ):
283
- exported = torch.export.export(
284
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
268
+ exported = _wrap_torch_export(
269
+ model,
270
+ inputs,
271
+ dynamic_shapes=dynamic_shapes,
272
+ strict=strict,
273
+ backed_size_oblivious=backed_size_oblivious,
285
274
  )
286
275
  if verbose >= 9:
287
276
  print("-- graph before decomposition")
@@ -299,6 +288,7 @@ def _make_exporter_export(
299
288
  print("-- graph after decomposition")
300
289
  print(exported.graph)
301
290
  return exported.module()
291
+
302
292
  if exporter == "export-tracing":
303
293
  from experimental_experiment.torch_interpreter.tracing import CustomTracer
304
294
 
@@ -307,8 +297,9 @@ def _make_exporter_export(
307
297
  graph = CustomTracer().trace(model)
308
298
  mod = torch.fx.GraphModule(model, graph)
309
299
  else:
310
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
311
- io.StringIO()
300
+ with (
301
+ contextlib.redirect_stdout(io.StringIO()),
302
+ contextlib.redirect_stderr(io.StringIO()),
312
303
  ):
313
304
  graph = CustomTracer().trace(model)
314
305
  mod = torch.fx.GraphModule(model, graph)
@@ -353,8 +344,9 @@ def _make_exporter_onnx(
353
344
  return_builder=True,
354
345
  )
355
346
  else:
356
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357
- io.StringIO()
347
+ with (
348
+ contextlib.redirect_stdout(io.StringIO()),
349
+ contextlib.redirect_stderr(io.StringIO()),
358
350
  ):
359
351
  onx, builder = to_onnx(
360
352
  model,
@@ -387,8 +379,9 @@ def _make_exporter_onnx(
387
379
  report=True,
388
380
  ).model_proto
389
381
  else:
390
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
391
- io.StringIO()
382
+ with (
383
+ contextlib.redirect_stdout(io.StringIO()),
384
+ contextlib.redirect_stderr(io.StringIO()),
392
385
  ):
393
386
  onx = torch.onnx.export(
394
387
  model,
@@ -422,8 +415,9 @@ def _make_exporter_onnx(
422
415
  ep.optimize()
423
416
  onx = ep.model_proto
424
417
  else:
425
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
426
- io.StringIO()
418
+ with (
419
+ contextlib.redirect_stdout(io.StringIO()),
420
+ contextlib.redirect_stderr(io.StringIO()),
427
421
  ):
428
422
  ep = torch.onnx.export(
429
423
  model,
@@ -446,6 +440,74 @@ def _make_exporter_onnx(
446
440
  raise AssertionError(f"Unexpected exporter={exporter!r}")
447
441
 
448
442
 
443
+ def _compares_on_one_example(
444
+ model: Callable, inputs: Tuple[Any, ...], mod: Callable, verbose: int, quiet: bool
445
+ ) -> Tuple[Any, Any, Dict]:
446
+ from onnx_diagnostic.helpers import max_diff, string_type
447
+
448
+ try:
449
+ expected = model(*_clone(inputs))
450
+ except Exception as e:
451
+ if not quiet:
452
+ raise RuntimeError(
453
+ f"eager mode failed=\n{string_type(inputs, with_shape=True)} "
454
+ f"\nmodel=\n{type(model)}"
455
+ ) from e
456
+ res = dict(error=str(e), success=0, error_step="eager")
457
+ return None, None, res
458
+ try:
459
+ got = mod(*inputs)
460
+ except Exception as e:
461
+ if not quiet:
462
+ raise RuntimeError(
463
+ f"onnxruntime failed, feeds=\n{string_type(inputs, with_shape=True)}"
464
+ ) from e
465
+ res = dict(error=str(e), success=0, error_step="run.0")
466
+ return expected, None, res
467
+
468
+ try:
469
+ disc = max_diff(expected, got)
470
+ except Exception as e:
471
+ if not quiet:
472
+ raise
473
+ res = dict(error=str(e), success=0, error_step="discrepancy")
474
+ return expected, got, res
475
+
476
+ if verbose >= 5 and np.isinf(disc["abs"]):
477
+ print("[run_exporter] comparison issues with")
478
+ print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
479
+ print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
480
+ print(f"-- got={string_type(got, with_shape=True, limit=20)}")
481
+ elif verbose >= 9:
482
+ print("[run_exporter] inputs and outputs")
483
+ print(
484
+ f"-- inputs="
485
+ f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
486
+ )
487
+ print(
488
+ f"-- expected="
489
+ f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
490
+ )
491
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
492
+ del disc["n"]
493
+ del disc["sum"]
494
+ disc.update(
495
+ dict(
496
+ success=1 if disc["abs"] < 0.1 else 0,
497
+ model_cls=model.__class__, # type: ignore[dict-item]
498
+ exported=mod, # type: ignore[dict-item]
499
+ )
500
+ )
501
+ if disc["abs"] >= 0.1:
502
+ disc["error"] = "diff.0"
503
+ disc["error_step"] = "diff.0"
504
+ if verbose >= 9:
505
+ max_diff(expected, got, verbose=verbose)
506
+ else:
507
+ disc["success"] = 1
508
+ return expected, got, disc
509
+
510
+
449
511
  def run_exporter(
450
512
  exporter: str,
451
513
  cls_model: type,
@@ -473,6 +535,7 @@ def run_exporter(
473
535
 
474
536
  model = cls_model()
475
537
  inputs = cls_model._inputs
538
+ valid = getattr(cls_model, "_valid", None)
476
539
  if isinstance(inputs, tuple):
477
540
  inputs = [inputs]
478
541
  if dynamic:
@@ -566,74 +629,38 @@ def run_exporter(
566
629
  mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
567
630
 
568
631
  # we need to clone for models modifying the inputs
569
- try:
570
- expected = model(*_clone(inputs[0]))
571
- except Exception as e:
572
- if not quiet:
573
- raise RuntimeError(
574
- f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} "
575
- f"\nmodel=\n{type(model)}"
576
- ) from e
577
- res = dict(error=str(e), success=0, error_step="eager")
578
- res.update(base)
579
- return res
580
- try:
581
- got = mod(*inputs[0])
582
- except Exception as e:
583
- if not quiet:
584
- raise RuntimeError(
585
- f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
586
- f"\nmodel=\n{pretty_onnx(onx)}"
587
- ) from e
588
- res = dict(error=str(e), success=0, error_step="run.0")
589
- res.update(base)
590
- return res
591
-
592
- base["expected"] = expected
593
- base["obtained"] = got
594
-
595
- try:
596
- disc = max_diff(expected, got)
597
- except Exception as e:
598
- if not quiet:
599
- raise
600
- res = dict(error=str(e), success=0, error_step="discrepancy")
601
- res.update(base)
602
- return res
632
+ expected, got, disc = _compares_on_one_example(model, inputs[0], mod, verbose, quiet)
633
+ if expected is not None:
634
+ base["expected"] = expected
635
+ if got is not None:
636
+ base["obtained"] = got
637
+ disc.update(base)
638
+ disc["onnx"] = onx # type: ignore[dict-item]
603
639
 
604
- if verbose >= 5 and np.isinf(disc["abs"]):
605
- print("[run_exporter] comparison issues with")
606
- print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
607
- print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
608
- print(f"-- got={string_type(got, with_shape=True, limit=20)}")
609
- elif verbose >= 9:
610
- print("[run_exporter] inputs and outputs")
611
- print(
612
- f"-- inputs="
613
- f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
614
- )
615
- print(
616
- f"-- expected="
617
- f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
618
- )
619
- print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
620
- del disc["n"]
621
- del disc["sum"]
622
- disc.update(
623
- dict(
624
- success=1 if disc["abs"] < 0.1 else 0,
625
- model_cls=model.__class__,
626
- exported=mod, # type: ignore[dict-item]
627
- onnx=onx, # type: ignore[dict-item]
628
- )
629
- )
630
- if disc["abs"] >= 0.1:
631
- disc["error"] = "diff.0"
632
- disc["error_step"] = "diff.0"
633
- if verbose >= 9:
634
- max_diff(expected, got, verbose=verbose)
635
- else:
636
- disc["success"] = 1
640
+ if valid is not None:
641
+ for valid_inputs in valid:
642
+ expected, got, _disc = _compares_on_one_example(
643
+ model, valid_inputs, mod, verbose, quiet
644
+ )
645
+ if "abs" not in disc and (np.isnan(disc["abs"]) or disc["abs"] > 1e-3):
646
+ _disc["issue-abs"] = disc["abs"]
647
+ _disc["issue-rel"] = disc["rel"]
648
+ _disc["issue-inputs"] = string_type(
649
+ valid_inputs, with_shape=True, with_min_max=True
650
+ )
651
+ _disc["issue-expected"] = string_type(
652
+ expected, with_shape=True, with_min_max=True
653
+ )
654
+ _disc["issue-obtained"] = string_type(got, with_shape=True, with_min_max=True)
655
+ if not quiet:
656
+ raise RuntimeError(
657
+ f"validation failed,"
658
+ f"\n-- inputs=\n{string_type(_disc['issue-inputs'])} "
659
+ f"\n-- exporter={exporter!r}\n-- dynamic_shapes={dynamic_shapes}, "
660
+ f"\n-- expected={_disc['issue-expected']}"
661
+ f"\n-- obtained={_disc['issue-obtained']}"
662
+ )
663
+ break
637
664
 
638
665
  if dynamic and onx is not None:
639
666
  ds = []
@@ -649,7 +676,13 @@ def run_exporter(
649
676
 
650
677
  if dynamic and len(inputs) > 1:
651
678
  for index, i in enumerate(inputs):
652
- expected = model(*_clone(i))
679
+ if quiet:
680
+ try:
681
+ expected = model(*_clone(i))
682
+ except Exception as e:
683
+ return dict(error=str(e), success=0, error_step=f"run0.{index}")
684
+ else:
685
+ expected = model(*_clone(i))
653
686
  try:
654
687
  got = mod(*i)
655
688
  except Exception as e:
@@ -353,12 +353,9 @@ class ControlFlowCondNonZero(torch.nn.Module):
353
353
 
354
354
 
355
355
  class ControlFlowCondIdentity_153832(torch.nn.Module):
356
- """
357
- `#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
358
- """
356
+ """`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
359
357
 
360
358
  def forward(self, x, y):
361
-
362
359
  def branch_cond_then_1(x):
363
360
  x = torch.abs(x) + 1
364
361
  return x
@@ -861,7 +858,7 @@ class CreateFromShapeThroughFunction(torch.nn.Module):
861
858
  y = torch.ones((x.shape[0], dy1))
862
859
  return y
863
860
 
864
- _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
861
+ _inputs = [(torch.rand((4, 4)),)]
865
862
  _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
866
863
 
867
864
 
@@ -881,3 +878,21 @@ class VmapPython(torch.nn.Module):
881
878
 
882
879
  _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
883
880
  _dynamic = {"x": {0: DYN}, "y": {0: DYN}}
881
+
882
+
883
+ class ExportWithDimension0(torch.nn.Module):
884
+ def forward(self, x):
885
+ return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
886
+
887
+ _inputs = [(torch.empty((0, 3), dtype=torch.float32),)]
888
+ _dynamic = {"x": {0: DYN, 1: DYN}}
889
+ _valid = [(torch.rand((2, 3), dtype=torch.float32),)]
890
+
891
+
892
+ class ExportWithDimension1(torch.nn.Module):
893
+ def forward(self, x):
894
+ return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
895
+
896
+ _inputs = [(torch.zeros((1, 3), dtype=torch.float32),)]
897
+ _dynamic = {"x": {0: DYN, 1: DYN}}
898
+ _valid = [(torch.rand((2, 3), dtype=torch.float32),)]
@@ -83,7 +83,7 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
83
83
  continue
84
84
 
85
85
  original = cls._PATCHED_CLASS_
86
- methods = cls._PATCHES_
86
+ methods = [_ for _ in cls._PATCHES_ if _ is not None]
87
87
  if verbose:
88
88
  print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
89
89
 
@@ -268,19 +268,22 @@ def torch_export_patches(
268
268
  if rewrite:
269
269
  from .patch_module import torch_export_rewrite
270
270
 
271
- with torch_export_rewrite(
272
- rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
273
- ), torch_export_patches( # type: ignore[var-annotated]
274
- patch_sympy=patch_sympy,
275
- patch_torch=patch_torch,
276
- patch_transformers=patch_transformers,
277
- patch_diffusers=patch_diffusers,
278
- catch_constraints=catch_constraints,
279
- stop_if_static=stop_if_static,
280
- verbose=verbose,
281
- patch=patch,
282
- custom_patches=custom_patches,
283
- ) as f:
271
+ with (
272
+ torch_export_rewrite(
273
+ rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
274
+ ),
275
+ torch_export_patches( # type: ignore[var-annotated]
276
+ patch_sympy=patch_sympy,
277
+ patch_torch=patch_torch,
278
+ patch_transformers=patch_transformers,
279
+ patch_diffusers=patch_diffusers,
280
+ catch_constraints=catch_constraints,
281
+ stop_if_static=stop_if_static,
282
+ verbose=verbose,
283
+ patch=patch,
284
+ custom_patches=custom_patches,
285
+ ) as f,
286
+ ):
284
287
  try:
285
288
  yield f
286
289
  finally:
@@ -337,12 +340,17 @@ def torch_export_patches(
337
340
  ###############
338
341
 
339
342
  if patch_torch:
343
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
340
344
  from .patches.patch_torch import (
341
345
  patched_infer_size,
342
346
  patched_vmap,
343
347
  patched__broadcast_shapes,
348
+ patched__constrain_user_specified_dimhint_range,
344
349
  _catch_produce_guards_and_solve_constraints,
345
350
  patch__check_input_constraints_for_graph,
351
+ patched__broadcast_in_dim_meta,
352
+ patched__maybe_broadcast,
353
+ patched_ShapeEnv,
346
354
  )
347
355
 
348
356
  if verbose:
@@ -371,6 +379,28 @@ def torch_export_patches(
371
379
  torch._refs._broadcast_shapes = patched__broadcast_shapes
372
380
  torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
373
381
 
382
+ # torch._export.non_strict_utils._constrain_user_specified_dimhint_range
383
+ f___constrain_user_specified_dimhint_range = (
384
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range
385
+ )
386
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
387
+ patched__constrain_user_specified_dimhint_range
388
+ )
389
+
390
+ # torch._prims._broadcast_in_dim_meta
391
+ f_broadcast_in_dim = torch._prims.broadcast_in_dim
392
+ f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
393
+ torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
394
+ torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
395
+
396
+ # torch._refs._maybe_broadcast
397
+ f__maybe_broadcast = torch._refs._maybe_broadcast
398
+ torch._refs._maybe_broadcast = patched__maybe_broadcast
399
+
400
+ # ShapeEnv
401
+ f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
402
+ ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
403
+
374
404
  # torch._export.non_strict_utils.produce_guards_and_solve_constraints
375
405
  if patch_torch and catch_constraints:
376
406
  if verbose:
@@ -393,9 +423,6 @@ def torch_export_patches(
393
423
  )
394
424
 
395
425
  if stop_if_static:
396
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
397
- from .patches.patch_torch import patched_ShapeEnv
398
-
399
426
  ShapeEnv._log_guard_remember = ShapeEnv._log_guard
400
427
 
401
428
  if verbose:
@@ -569,6 +596,13 @@ def torch_export_patches(
569
596
  torch._subclasses.fake_impls.infer_size = f_infer_size
570
597
  torch._refs._broadcast_shapes = f__broadcast_shapes
571
598
  torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
599
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
600
+ f___constrain_user_specified_dimhint_range
601
+ )
602
+ torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
603
+ torch._prims.broadcast_in_dim = f_broadcast_in_dim
604
+ torch._refs._maybe_broadcast = f__maybe_broadcast
605
+ ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
572
606
 
573
607
  if verbose:
574
608
  print("[torch_export_patches] restored pytorch functions")
@@ -708,9 +742,7 @@ def torch_export_patches(
708
742
 
709
743
 
710
744
  def replacement_before_exporting(args: Any) -> Any:
711
- """
712
- Does replacements on the given inputs if needed.
713
- """
745
+ """Does replacements on the given inputs if needed."""
714
746
  if args is None:
715
747
  return None
716
748
  if isinstance(args, (int, float)):
@@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
189
189
  return (), updated_kwargs, dynamic_shapes
190
190
 
191
191
 
192
- def use_dyn_not_str(dynamic_shapes: Any) -> Any:
192
+ def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
193
193
  """
194
194
  Some functions returns dynamic shapes as string.
195
195
  This functions replaces them with ``torch.export.Dim.DYNAMIC``.
196
+ ``default_value=torch.export.Dim.AUTO`` changes the default value.
196
197
  """
197
198
  if isinstance(dynamic_shapes, list):
198
- return [use_dyn_not_str(a) for a in dynamic_shapes]
199
+ return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
199
200
  if isinstance(dynamic_shapes, tuple):
200
- return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
201
+ return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
201
202
  if isinstance(dynamic_shapes, dict):
202
- return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()}
203
+ return {
204
+ k: use_dyn_not_str(v, default_value=default_value)
205
+ for k, v in dynamic_shapes.items()
206
+ }
203
207
  if isinstance(dynamic_shapes, set):
204
- return {use_dyn_not_str(a) for a in dynamic_shapes}
208
+ return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
205
209
  if isinstance(dynamic_shapes, str):
206
- return torch.export.Dim.DYNAMIC
210
+ return torch.export.Dim.DYNAMIC if default_value is None else default_value
207
211
  return dynamic_shapes