onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +5 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +65 -12
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +55 -37
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +15 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
  24. onnx_diagnostic/torch_models/validate.py +337 -113
  25. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  26. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
  27. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,9 @@ def evaluation(
39
39
  "export-strict",
40
40
  "export-nostrict",
41
41
  "export-nostrict-decall",
42
+ "export-strict-oblivious",
43
+ "export-nostrict-oblivious",
44
+ "export-nostrict-decall-oblivious",
42
45
  ),
43
46
  dynamic: Tuple[bool] = (False, True),
44
47
  cases: Optional[Union[str, Dict[str, type]]] = None,
@@ -105,9 +108,7 @@ def evaluation(
105
108
 
106
109
 
107
110
  def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821
108
- """
109
- Flatten inputs.
110
- """
111
+ """Flatten inputs."""
111
112
  if x is None:
112
113
  return x
113
114
  import torch
@@ -173,6 +174,15 @@ def _clone(x):
173
174
  raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy")
174
175
 
175
176
 
177
+ def _wrap_torch_export(*args, backed_size_oblivious=False, **kwargs):
178
+ import torch
179
+
180
+ if backed_size_oblivious:
181
+ with torch.fx.experimental._config.patch(backed_size_oblivious=True):
182
+ return torch.export.export(*args, **kwargs)
183
+ return torch.export.export(*args, **kwargs)
184
+
185
+
176
186
  def _make_exporter_export(
177
187
  exporter: str,
178
188
  model: "torch.nn.Module", # noqa: F821
@@ -183,76 +193,36 @@ def _make_exporter_export(
183
193
  ) -> Union[Dict, Callable]:
184
194
  import torch
185
195
 
186
- if exporter == "export-strict":
187
- try:
188
- if verbose >= 2:
189
- exported = torch.export.export(
190
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
191
- )
192
- else:
193
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
194
- io.StringIO()
195
- ):
196
- exported = torch.export.export(
197
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
198
- )
199
- except Exception as e:
200
- if not quiet:
201
- raise
202
- return dict(error=str(e), success=0, error_step="export")
203
- if verbose >= 9:
204
- print("-- graph")
205
- print(exported.graph)
206
- return exported.module()
207
- if exporter in ("export-strict-dec", "export-strict-decall"):
208
- try:
209
- if verbose >= 2:
210
- exported = torch.export.export(
211
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
212
- )
213
- if verbose >= 9:
214
- print("-- graph before decomposition")
215
- print(exported.graph)
216
- exported = (
217
- exported.run_decompositions()
218
- if "decall" in exporter
219
- else exported.run_decompositions({})
220
- )
221
- else:
222
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
223
- io.StringIO()
224
- ):
225
- exported = torch.export.export(
226
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
227
- )
228
- if verbose >= 9:
229
- print("-- graph before decomposition")
230
- print(exported.graph)
231
- exported = (
232
- exported.run_decompositions()
233
- if "decall" in exporter
234
- else exported.run_decompositions({})
235
- )
236
- except Exception as e:
237
- if not quiet:
238
- raise
239
- return dict(error=str(e), success=0, error_step="export")
240
- if verbose >= 9:
241
- print("-- graph after decomposition")
242
- print(exported.graph)
243
- return exported.module()
244
- if exporter == "export-nostrict":
196
+ backed_size_oblivious = "-oblivious" in exporter
197
+ strict = "-nostrict" not in exporter
198
+
199
+ if exporter in (
200
+ "export-strict",
201
+ "export-strict-oblivious",
202
+ "export-nostrict",
203
+ "export-nostrict-oblivious",
204
+ "export-oblivious",
205
+ ):
245
206
  try:
246
207
  if verbose >= 2:
247
- exported = torch.export.export(
248
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
208
+ exported = _wrap_torch_export(
209
+ model,
210
+ inputs,
211
+ dynamic_shapes=dynamic_shapes,
212
+ strict=strict,
213
+ backed_size_oblivious=backed_size_oblivious,
249
214
  )
250
215
  else:
251
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
252
- io.StringIO()
216
+ with (
217
+ contextlib.redirect_stdout(io.StringIO()),
218
+ contextlib.redirect_stderr(io.StringIO()),
253
219
  ):
254
- exported = torch.export.export(
255
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
220
+ exported = _wrap_torch_export(
221
+ model,
222
+ inputs,
223
+ dynamic_shapes=dynamic_shapes,
224
+ strict=strict,
225
+ backed_size_oblivious=backed_size_oblivious,
256
226
  )
257
227
  except Exception as e:
258
228
  if not quiet:
@@ -262,11 +232,25 @@ def _make_exporter_export(
262
232
  print("-- graph")
263
233
  print(exported.graph)
264
234
  return exported.module()
265
- if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
235
+
236
+ if exporter in (
237
+ "export-strict-dec",
238
+ "export-strict-decall",
239
+ "export-strict-dec-oblivious",
240
+ "export-strict-decall-oblivious",
241
+ "export-nostrict-dec",
242
+ "export-nostrict-decall",
243
+ "export-nostrict-dec-oblivious",
244
+ "export-nostrict-decall-oblivious",
245
+ ):
266
246
  try:
267
247
  if verbose >= 2:
268
- exported = torch.export.export(
269
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
248
+ exported = _wrap_torch_export(
249
+ model,
250
+ inputs,
251
+ dynamic_shapes=dynamic_shapes,
252
+ strict=strict,
253
+ backed_size_oblivious=backed_size_oblivious,
270
254
  )
271
255
  if verbose >= 9:
272
256
  print("-- graph before decomposition")
@@ -277,11 +261,16 @@ def _make_exporter_export(
277
261
  else exported.run_decompositions({})
278
262
  )
279
263
  else:
280
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
281
- io.StringIO()
264
+ with (
265
+ contextlib.redirect_stdout(io.StringIO()),
266
+ contextlib.redirect_stderr(io.StringIO()),
282
267
  ):
283
- exported = torch.export.export(
284
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
268
+ exported = _wrap_torch_export(
269
+ model,
270
+ inputs,
271
+ dynamic_shapes=dynamic_shapes,
272
+ strict=strict,
273
+ backed_size_oblivious=backed_size_oblivious,
285
274
  )
286
275
  if verbose >= 9:
287
276
  print("-- graph before decomposition")
@@ -299,6 +288,7 @@ def _make_exporter_export(
299
288
  print("-- graph after decomposition")
300
289
  print(exported.graph)
301
290
  return exported.module()
291
+
302
292
  if exporter == "export-tracing":
303
293
  from experimental_experiment.torch_interpreter.tracing import CustomTracer
304
294
 
@@ -307,8 +297,9 @@ def _make_exporter_export(
307
297
  graph = CustomTracer().trace(model)
308
298
  mod = torch.fx.GraphModule(model, graph)
309
299
  else:
310
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
311
- io.StringIO()
300
+ with (
301
+ contextlib.redirect_stdout(io.StringIO()),
302
+ contextlib.redirect_stderr(io.StringIO()),
312
303
  ):
313
304
  graph = CustomTracer().trace(model)
314
305
  mod = torch.fx.GraphModule(model, graph)
@@ -353,8 +344,9 @@ def _make_exporter_onnx(
353
344
  return_builder=True,
354
345
  )
355
346
  else:
356
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357
- io.StringIO()
347
+ with (
348
+ contextlib.redirect_stdout(io.StringIO()),
349
+ contextlib.redirect_stderr(io.StringIO()),
358
350
  ):
359
351
  onx, builder = to_onnx(
360
352
  model,
@@ -387,8 +379,9 @@ def _make_exporter_onnx(
387
379
  report=True,
388
380
  ).model_proto
389
381
  else:
390
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
391
- io.StringIO()
382
+ with (
383
+ contextlib.redirect_stdout(io.StringIO()),
384
+ contextlib.redirect_stderr(io.StringIO()),
392
385
  ):
393
386
  onx = torch.onnx.export(
394
387
  model,
@@ -422,8 +415,9 @@ def _make_exporter_onnx(
422
415
  ep.optimize()
423
416
  onx = ep.model_proto
424
417
  else:
425
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
426
- io.StringIO()
418
+ with (
419
+ contextlib.redirect_stdout(io.StringIO()),
420
+ contextlib.redirect_stderr(io.StringIO()),
427
421
  ):
428
422
  ep = torch.onnx.export(
429
423
  model,
@@ -446,6 +440,74 @@ def _make_exporter_onnx(
446
440
  raise AssertionError(f"Unexpected exporter={exporter!r}")
447
441
 
448
442
 
443
+ def _compares_on_one_example(
444
+ model: Callable, inputs: Tuple[Any, ...], mod: Callable, verbose: int, quiet: bool
445
+ ) -> Tuple[Any, Any, Dict]:
446
+ from onnx_diagnostic.helpers import max_diff, string_type
447
+
448
+ try:
449
+ expected = model(*_clone(inputs))
450
+ except Exception as e:
451
+ if not quiet:
452
+ raise RuntimeError(
453
+ f"eager mode failed=\n{string_type(inputs, with_shape=True)} "
454
+ f"\nmodel=\n{type(model)}"
455
+ ) from e
456
+ res = dict(error=str(e), success=0, error_step="eager")
457
+ return None, None, res
458
+ try:
459
+ got = mod(*inputs)
460
+ except Exception as e:
461
+ if not quiet:
462
+ raise RuntimeError(
463
+ f"onnxruntime failed, feeds=\n{string_type(inputs, with_shape=True)}"
464
+ ) from e
465
+ res = dict(error=str(e), success=0, error_step="run.0")
466
+ return expected, None, res
467
+
468
+ try:
469
+ disc = max_diff(expected, got)
470
+ except Exception as e:
471
+ if not quiet:
472
+ raise
473
+ res = dict(error=str(e), success=0, error_step="discrepancy")
474
+ return expected, got, res
475
+
476
+ if verbose >= 5 and np.isinf(disc["abs"]):
477
+ print("[run_exporter] comparison issues with")
478
+ print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
479
+ print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
480
+ print(f"-- got={string_type(got, with_shape=True, limit=20)}")
481
+ elif verbose >= 9:
482
+ print("[run_exporter] inputs and outputs")
483
+ print(
484
+ f"-- inputs="
485
+ f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
486
+ )
487
+ print(
488
+ f"-- expected="
489
+ f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
490
+ )
491
+ print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
492
+ del disc["n"]
493
+ del disc["sum"]
494
+ disc.update(
495
+ dict(
496
+ success=1 if disc["abs"] < 0.1 else 0,
497
+ model_cls=model.__class__, # type: ignore[dict-item]
498
+ exported=mod, # type: ignore[dict-item]
499
+ )
500
+ )
501
+ if disc["abs"] >= 0.1:
502
+ disc["error"] = "diff.0"
503
+ disc["error_step"] = "diff.0"
504
+ if verbose >= 9:
505
+ max_diff(expected, got, verbose=verbose)
506
+ else:
507
+ disc["success"] = 1
508
+ return expected, got, disc
509
+
510
+
449
511
  def run_exporter(
450
512
  exporter: str,
451
513
  cls_model: type,
@@ -473,6 +535,7 @@ def run_exporter(
473
535
 
474
536
  model = cls_model()
475
537
  inputs = cls_model._inputs
538
+ valid = getattr(cls_model, "_valid", None)
476
539
  if isinstance(inputs, tuple):
477
540
  inputs = [inputs]
478
541
  if dynamic:
@@ -566,74 +629,38 @@ def run_exporter(
566
629
  mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731
567
630
 
568
631
  # we need to clone for models modifying the inputs
569
- try:
570
- expected = model(*_clone(inputs[0]))
571
- except Exception as e:
572
- if not quiet:
573
- raise RuntimeError(
574
- f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} "
575
- f"\nmodel=\n{type(model)}"
576
- ) from e
577
- res = dict(error=str(e), success=0, error_step="eager")
578
- res.update(base)
579
- return res
580
- try:
581
- got = mod(*inputs[0])
582
- except Exception as e:
583
- if not quiet:
584
- raise RuntimeError(
585
- f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} "
586
- f"\nmodel=\n{pretty_onnx(onx)}"
587
- ) from e
588
- res = dict(error=str(e), success=0, error_step="run.0")
589
- res.update(base)
590
- return res
591
-
592
- base["expected"] = expected
593
- base["obtained"] = got
594
-
595
- try:
596
- disc = max_diff(expected, got)
597
- except Exception as e:
598
- if not quiet:
599
- raise
600
- res = dict(error=str(e), success=0, error_step="discrepancy")
601
- res.update(base)
602
- return res
632
+ expected, got, disc = _compares_on_one_example(model, inputs[0], mod, verbose, quiet)
633
+ if expected is not None:
634
+ base["expected"] = expected
635
+ if got is not None:
636
+ base["obtained"] = got
637
+ disc.update(base)
638
+ disc["onnx"] = onx # type: ignore[dict-item]
603
639
 
604
- if verbose >= 5 and np.isinf(disc["abs"]):
605
- print("[run_exporter] comparison issues with")
606
- print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}")
607
- print(f"-- expected={string_type(expected, with_shape=True, limit=20)}")
608
- print(f"-- got={string_type(got, with_shape=True, limit=20)}")
609
- elif verbose >= 9:
610
- print("[run_exporter] inputs and outputs")
611
- print(
612
- f"-- inputs="
613
- f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}"
614
- )
615
- print(
616
- f"-- expected="
617
- f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}"
618
- )
619
- print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}")
620
- del disc["n"]
621
- del disc["sum"]
622
- disc.update(
623
- dict(
624
- success=1 if disc["abs"] < 0.1 else 0,
625
- model_cls=model.__class__,
626
- exported=mod, # type: ignore[dict-item]
627
- onnx=onx, # type: ignore[dict-item]
628
- )
629
- )
630
- if disc["abs"] >= 0.1:
631
- disc["error"] = "diff.0"
632
- disc["error_step"] = "diff.0"
633
- if verbose >= 9:
634
- max_diff(expected, got, verbose=verbose)
635
- else:
636
- disc["success"] = 1
640
+ if valid is not None:
641
+ for valid_inputs in valid:
642
+ expected, got, _disc = _compares_on_one_example(
643
+ model, valid_inputs, mod, verbose, quiet
644
+ )
645
+ if "abs" not in disc and (np.isnan(disc["abs"]) or disc["abs"] > 1e-3):
646
+ _disc["issue-abs"] = disc["abs"]
647
+ _disc["issue-rel"] = disc["rel"]
648
+ _disc["issue-inputs"] = string_type(
649
+ valid_inputs, with_shape=True, with_min_max=True
650
+ )
651
+ _disc["issue-expected"] = string_type(
652
+ expected, with_shape=True, with_min_max=True
653
+ )
654
+ _disc["issue-obtained"] = string_type(got, with_shape=True, with_min_max=True)
655
+ if not quiet:
656
+ raise RuntimeError(
657
+ f"validation failed,"
658
+ f"\n-- inputs=\n{string_type(_disc['issue-inputs'])} "
659
+ f"\n-- exporter={exporter!r}\n-- dynamic_shapes={dynamic_shapes}, "
660
+ f"\n-- expected={_disc['issue-expected']}"
661
+ f"\n-- obtained={_disc['issue-obtained']}"
662
+ )
663
+ break
637
664
 
638
665
  if dynamic and onx is not None:
639
666
  ds = []
@@ -861,7 +861,7 @@ class CreateFromShapeThroughFunction(torch.nn.Module):
861
861
  y = torch.ones((x.shape[0], dy1))
862
862
  return y
863
863
 
864
- _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
864
+ _inputs = [(torch.rand((4, 4)),)]
865
865
  _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
866
866
 
867
867
 
@@ -881,3 +881,21 @@ class VmapPython(torch.nn.Module):
881
881
 
882
882
  _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
883
883
  _dynamic = {"x": {0: DYN}, "y": {0: DYN}}
884
+
885
+
886
+ class ExportWithDimension0(torch.nn.Module):
887
+ def forward(self, x):
888
+ return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
889
+
890
+ _inputs = [(torch.empty((0, 3), dtype=torch.float32),)]
891
+ _dynamic = {"x": {0: DYN, 1: DYN}}
892
+ _valid = [(torch.rand((2, 3), dtype=torch.float32),)]
893
+
894
+
895
+ class ExportWithDimension1(torch.nn.Module):
896
+ def forward(self, x):
897
+ return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
898
+
899
+ _inputs = [(torch.zeros((1, 3), dtype=torch.float32),)]
900
+ _dynamic = {"x": {0: DYN, 1: DYN}}
901
+ _valid = [(torch.rand((2, 3), dtype=torch.float32),)]
@@ -83,7 +83,7 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
83
83
  continue
84
84
 
85
85
  original = cls._PATCHED_CLASS_
86
- methods = cls._PATCHES_
86
+ methods = [_ for _ in cls._PATCHES_ if _ is not None]
87
87
  if verbose:
88
88
  print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
89
89
 
@@ -254,22 +254,36 @@ def torch_export_patches(
254
254
  may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
255
255
  It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
256
256
  """
257
+ if verbose:
258
+ print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
259
+ print(f" . patch_torch={patch_torch!r}")
260
+ print(f" . patch_transformers={patch_transformers!r}")
261
+ print(f" . patch_diffusers={patch_diffusers!r}")
262
+ print(f" . catch_constraints={catch_constraints!r}")
263
+ print(f" . stop_if_static={stop_if_static!r}")
264
+ print(f" . patch={patch!r}")
265
+ print(f" . custom_patches={custom_patches!r}")
266
+ print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
267
+
257
268
  if rewrite:
258
269
  from .patch_module import torch_export_rewrite
259
270
 
260
- with torch_export_rewrite(
261
- rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
262
- ), torch_export_patches( # type: ignore[var-annotated]
263
- patch_sympy=patch_sympy,
264
- patch_torch=patch_torch,
265
- patch_transformers=patch_transformers,
266
- patch_diffusers=patch_diffusers,
267
- catch_constraints=catch_constraints,
268
- stop_if_static=stop_if_static,
269
- verbose=verbose,
270
- patch=patch,
271
- custom_patches=custom_patches,
272
- ) as f:
271
+ with (
272
+ torch_export_rewrite(
273
+ rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
274
+ ),
275
+ torch_export_patches( # type: ignore[var-annotated]
276
+ patch_sympy=patch_sympy,
277
+ patch_torch=patch_torch,
278
+ patch_transformers=patch_transformers,
279
+ patch_diffusers=patch_diffusers,
280
+ catch_constraints=catch_constraints,
281
+ stop_if_static=stop_if_static,
282
+ verbose=verbose,
283
+ patch=patch,
284
+ custom_patches=custom_patches,
285
+ ) as f,
286
+ ):
273
287
  try:
274
288
  yield f
275
289
  finally:
@@ -330,6 +344,7 @@ def torch_export_patches(
330
344
  patched_infer_size,
331
345
  patched_vmap,
332
346
  patched__broadcast_shapes,
347
+ patched__constrain_user_specified_dimhint_range,
333
348
  _catch_produce_guards_and_solve_constraints,
334
349
  patch__check_input_constraints_for_graph,
335
350
  )
@@ -360,6 +375,14 @@ def torch_export_patches(
360
375
  torch._refs._broadcast_shapes = patched__broadcast_shapes
361
376
  torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
362
377
 
378
+ # torch._export.non_strict_utils._constrain_user_specified_dimhint_range
379
+ f___constrain_user_specified_dimhint_range = (
380
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range
381
+ )
382
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
383
+ patched__constrain_user_specified_dimhint_range
384
+ )
385
+
363
386
  # torch._export.non_strict_utils.produce_guards_and_solve_constraints
364
387
  if patch_torch and catch_constraints:
365
388
  if verbose:
@@ -558,6 +581,9 @@ def torch_export_patches(
558
581
  torch._subclasses.fake_impls.infer_size = f_infer_size
559
582
  torch._refs._broadcast_shapes = f__broadcast_shapes
560
583
  torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
584
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
585
+ f___constrain_user_specified_dimhint_range
586
+ )
561
587
 
562
588
  if verbose:
563
589
  print("[torch_export_patches] restored pytorch functions")
@@ -189,19 +189,23 @@ def convert_dynamic_axes_into_dynamic_shapes(
189
189
  return (), updated_kwargs, dynamic_shapes
190
190
 
191
191
 
192
- def use_dyn_not_str(dynamic_shapes: Any) -> Any:
192
+ def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
193
193
  """
194
194
  Some functions returns dynamic shapes as string.
195
195
  This functions replaces them with ``torch.export.Dim.DYNAMIC``.
196
+ ``default_value=torch.export.Dim.AUTO`` changes the default value.
196
197
  """
197
198
  if isinstance(dynamic_shapes, list):
198
- return [use_dyn_not_str(a) for a in dynamic_shapes]
199
+ return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
199
200
  if isinstance(dynamic_shapes, tuple):
200
- return tuple(use_dyn_not_str(a) for a in dynamic_shapes)
201
+ return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
201
202
  if isinstance(dynamic_shapes, dict):
202
- return {k: use_dyn_not_str(v) for k, v in dynamic_shapes.items()}
203
+ return {
204
+ k: use_dyn_not_str(v, default_value=default_value)
205
+ for k, v in dynamic_shapes.items()
206
+ }
203
207
  if isinstance(dynamic_shapes, set):
204
- return {use_dyn_not_str(a) for a in dynamic_shapes}
208
+ return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
205
209
  if isinstance(dynamic_shapes, str):
206
- return torch.export.Dim.DYNAMIC
210
+ return torch.export.Dim.DYNAMIC if default_value is None else default_value
207
211
  return dynamic_shapes