onnx-diagnostic 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +411 -0
  4. onnx_diagnostic/doc.py +4 -4
  5. onnx_diagnostic/export/__init__.py +1 -1
  6. onnx_diagnostic/export/dynamic_shapes.py +433 -22
  7. onnx_diagnostic/ext_test_case.py +86 -29
  8. onnx_diagnostic/helpers/__init__.py +1 -0
  9. onnx_diagnostic/helpers/bench_run.py +450 -0
  10. onnx_diagnostic/{cache_helpers.py → helpers/cache_helper.py} +41 -5
  11. onnx_diagnostic/{helpers.py → helpers/helper.py} +136 -659
  12. onnx_diagnostic/helpers/memory_peak.py +249 -0
  13. onnx_diagnostic/helpers/onnx_helper.py +921 -0
  14. onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +42 -3
  15. onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +138 -55
  16. onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
  17. onnx_diagnostic/reference/ort_evaluator.py +7 -2
  18. onnx_diagnostic/torch_export_patches/__init__.py +107 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +137 -33
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +13 -2
  21. onnx_diagnostic/torch_export_patches/patch_inputs.py +174 -0
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -2
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +4 -4
  24. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  25. onnx_diagnostic/torch_models/hghub/hub_api.py +234 -0
  26. onnx_diagnostic/torch_models/hghub/hub_data.py +195 -0
  27. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3259 -0
  28. onnx_diagnostic/torch_models/hghub/model_inputs.py +727 -0
  29. onnx_diagnostic/torch_models/test_helper.py +827 -0
  30. onnx_diagnostic/torch_models/untrained/llm_phi2.py +3 -4
  31. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +3 -4
  32. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  33. onnx_diagnostic/torch_onnx/sbs.py +439 -0
  34. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/METADATA +2 -2
  35. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/RECORD +39 -25
  36. onnx_diagnostic/onnx_tools.py +0 -260
  37. /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
  38. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/WHEEL +0 -0
  39. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/licenses/LICENSE.txt +0 -0
  40. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import contextlib
2
2
  import pprint
3
- from typing import Any, Callable, Dict
3
+ from typing import Any, Callable, Dict, List, Optional, Set
4
4
  from .onnx_export_serialization import (
5
5
  flatten_with_keys_dynamic_cache,
6
6
  flatten_dynamic_cache,
@@ -12,27 +12,36 @@ from .onnx_export_serialization import (
12
12
  from .patches import patch_transformers as patch_transformers_list
13
13
 
14
14
 
15
- def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
15
+ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
16
16
  """
17
17
  Applies all patches defined in classes prefixed by ``patched_``
18
18
  ``cls._PATCHED_CLASS_`` defines the class to patch,
19
19
  ``cls._PATCHES_`` defines the method to patch.
20
- The returns information needs to be sent to :func:`unpatch_module`
20
+ The returns information needs to be sent to :func:`unpatch_module_or_classes`
21
21
  to revert the changes.
22
+
23
+ :param mod: module of list of clsses to patch
24
+ :param verbose: verbosity
25
+ :return: patch info
22
26
  """
23
- to_patch = []
24
- for k in dir(mod):
25
- if k.startswith("patched_"):
26
- v = getattr(mod, k)
27
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
28
- to_patch.append(v)
27
+ if isinstance(mod, list):
28
+ to_patch = mod
29
+ name = "list"
30
+ else:
31
+ to_patch = []
32
+ for k in dir(mod):
33
+ if k.startswith("patched_"):
34
+ v = getattr(mod, k)
35
+ if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
36
+ to_patch.append(v)
37
+ name = mod.__name__
29
38
 
30
39
  res = {}
31
40
  for cls in to_patch:
32
41
  original = cls._PATCHED_CLASS_
33
42
  methods = cls._PATCHES_
34
43
  if verbose:
35
- print(f"[patch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}")
44
+ print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
36
45
 
37
46
  keep = {n: getattr(original, n, None) for n in methods}
38
47
  for n in methods:
@@ -42,20 +51,30 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
42
51
  return res
43
52
 
44
53
 
45
- def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
46
- """Reverts modification made by :func:`patch_module`."""
47
- to_patch = []
48
- for k in dir(mod):
49
- if k.startswith("patched_"):
50
- v = getattr(mod, k)
51
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
52
- to_patch.append(v)
54
+ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
55
+ """
56
+ Reverts modification made by :func:`patch_module_or_classes`.
57
+
58
+ :param mod: module of list of clsses to patch
59
+ :param verbose: verbosity
60
+ """
61
+ if isinstance(mod, list):
62
+ to_patch = mod
63
+ name = "list"
64
+ else:
65
+ to_patch = []
66
+ for k in dir(mod):
67
+ if k.startswith("patched_"):
68
+ v = getattr(mod, k)
69
+ if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
70
+ to_patch.append(v)
71
+ name = mod.__name__
53
72
  set_patch = set(to_patch)
54
73
 
55
74
  for cls, methods in info.items():
56
75
  assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
57
76
  if verbose:
58
- print(f"[unpatch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}")
77
+ print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
59
78
  original = cls._PATCHED_CLASS_
60
79
  for n, v in methods.items():
61
80
  if v is None:
@@ -65,9 +84,14 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
65
84
  setattr(original, n, v)
66
85
 
67
86
 
87
+ PATCH_OF_PATCHES: Set[Any] = set()
88
+
89
+
68
90
  def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
69
91
  # Cache serialization: to be moved into appropriate packages
70
92
  import torch
93
+ import transformers
94
+ import packaging.version as pv
71
95
 
72
96
  try:
73
97
  from transformers.cache_utils import DynamicCache
@@ -100,7 +124,40 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
100
124
  flatten_with_keys_fn=flatten_with_keys_mamba_cache,
101
125
  )
102
126
 
103
- # DynamicCache
127
+ # DynamicCache serialization is different in transformers and does not
128
+ # play way with torch.export.export.
129
+ # see test test_export_dynamic_cache_cat with NOBYPASS=1
130
+ # :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c
131
+ # This is caused by this line:
132
+ # torch.fx._pytree.register_pytree_flatten_spec(
133
+ # DynamicCache, _flatten_dynamic_cache_for_fx)
134
+ # so we remove it anyway
135
+ if (
136
+ DynamicCache in torch.fx._pytree.SUPPORTED_NODES
137
+ and not PATCH_OF_PATCHES
138
+ # and pv.Version(torch.__version__) < pv.Version("2.7")
139
+ and pv.Version(transformers.__version__) >= pv.Version("4.50")
140
+ ):
141
+ if verbose:
142
+ print(
143
+ "[_register_cache_serialization] DynamicCache "
144
+ "is unregistered and registered first."
145
+ )
146
+ _unregister(DynamicCache)
147
+ torch.utils._pytree.register_pytree_node(
148
+ DynamicCache,
149
+ flatten_dynamic_cache,
150
+ unflatten_dynamic_cache,
151
+ serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
152
+ flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
153
+ )
154
+ if pv.Version(torch.__version__) < pv.Version("2.7"):
155
+ torch.fx._pytree.register_pytree_flatten_spec(
156
+ DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
157
+ )
158
+ # To avoid doing it multiple times.
159
+ PATCH_OF_PATCHES.add(DynamicCache)
160
+
104
161
  unregistered_dynamic_cache = True
105
162
  if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
106
163
  if verbose > 1:
@@ -116,12 +173,13 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
116
173
  serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
117
174
  flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
118
175
  )
119
- torch.fx._pytree.register_pytree_flatten_spec(
120
- DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
121
- )
176
+ if pv.Version(torch.__version__) < pv.Version("2.7"):
177
+ torch.fx._pytree.register_pytree_flatten_spec(
178
+ DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
179
+ )
122
180
 
123
181
  # check
124
- from ..cache_helpers import make_dynamic_cache
182
+ from ..helpers.cache_helper import make_dynamic_cache
125
183
 
126
184
  cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
127
185
  values, spec = torch.utils._pytree.tree_flatten(cache)
@@ -180,7 +238,7 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
180
238
  def register_additional_serialization_functions(
181
239
  patch_transformers: bool = False, verbose: int = 0
182
240
  ) -> Callable:
183
- """The necessary modification to run the fx Graph."""
241
+ """The necessary modifications to run the fx Graph."""
184
242
  fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
185
243
  done = _register_cache_serialization(verbose=verbose)
186
244
  try:
@@ -195,9 +253,10 @@ def bypass_export_some_errors(
195
253
  patch_torch: bool = True,
196
254
  patch_transformers: bool = False,
197
255
  catch_constraints: bool = True,
198
- stop_if_static: bool = False,
256
+ stop_if_static: int = 0,
199
257
  verbose: int = 0,
200
258
  patch: bool = True,
259
+ custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
201
260
  ) -> Callable:
202
261
  """
203
262
  Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -211,9 +270,14 @@ def bypass_export_some_errors(
211
270
  can be put to stop at that stage.
212
271
  :param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
213
272
  to stop the export as soon as an issue is detected with dynamic shapes
214
- and show a stack trace indicating the exact location of the issue
273
+ and show a stack trace indicating the exact location of the issue,
274
+ ``if stop_if_static > 1``, more methods are replace to catch more
275
+ issues
215
276
  :param patch: if False, disable all patches except the registration of
216
277
  serialization function
278
+ :param custom_patches: to apply custom patches,
279
+ every patched class must define static attributes
280
+ ``_PATCHES_``, ``_PATCHED_CLASS_``
217
281
  :param verbose: to show which patches is applied
218
282
 
219
283
  The list of available patches.
@@ -301,6 +365,7 @@ def bypass_export_some_errors(
301
365
  f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
302
366
 
303
367
  if verbose:
368
+ print(f"[bypass_export_some_errors] sympy.__version__={sympy.__version__!r}")
304
369
  print("[bypass_export_some_errors] patch sympy")
305
370
 
306
371
  sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
@@ -318,6 +383,8 @@ def bypass_export_some_errors(
318
383
  )
319
384
 
320
385
  if verbose:
386
+ print(f"[bypass_export_some_errors] torch.__version__={torch.__version__!r}")
387
+ print(f"[bypass_export_some_errors] stop_if_static={stop_if_static!r}")
321
388
  print("[bypass_export_some_errors] patch pytorch")
322
389
 
323
390
  # torch.jit.isinstance
@@ -359,23 +426,46 @@ def bypass_export_some_errors(
359
426
  )
360
427
 
361
428
  if stop_if_static:
429
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
430
+ from .patches.patch_torch import patched_ShapeEnv
431
+
362
432
  if verbose:
363
433
  print(
364
434
  "[bypass_export_some_errors] assert when a dynamic dimension turns static"
365
435
  )
366
-
367
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
368
- from .patches.patch_torch import patched_ShapeEnv
436
+ print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
369
437
 
370
438
  f_shape_env__set_replacement = ShapeEnv._set_replacement
371
439
  ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
372
440
 
441
+ if stop_if_static > 1:
442
+ if verbose:
443
+ print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
444
+ f_shape_env__check_frozen = ShapeEnv._check_frozen
445
+ ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
446
+
373
447
  ####################
374
448
  # patch transformers
375
449
  ####################
376
450
 
377
451
  if patch_transformers:
378
- revert_patches_info = patch_module(patch_transformers_list, verbose=verbose)
452
+ if verbose:
453
+ import transformers
454
+
455
+ print(
456
+ f"[bypass_export_some_errors] transformers.__version__="
457
+ f"{transformers.__version__!r}"
458
+ )
459
+ revert_patches_info = patch_module_or_classes(
460
+ patch_transformers_list, verbose=verbose
461
+ )
462
+
463
+ if custom_patches:
464
+ if verbose:
465
+ print("[bypass_export_some_errors] applies custom patches")
466
+ revert_custom_patches_info = patch_module_or_classes(
467
+ custom_patches, verbose=verbose
468
+ )
379
469
 
380
470
  ########
381
471
  # export
@@ -397,7 +487,6 @@ def bypass_export_some_errors(
397
487
  print("[bypass_export_some_errors] remove patches")
398
488
 
399
489
  if patch_sympy:
400
-
401
490
  # tracked by https://github.com/pytorch/pytorch/issues/143494
402
491
  if f_sympy_name:
403
492
  sympy.core.numbers.IntegerConstant.name = f_sympy_name
@@ -428,6 +517,10 @@ def bypass_export_some_errors(
428
517
  print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
429
518
 
430
519
  ShapeEnv._set_replacement = f_shape_env__set_replacement
520
+ if stop_if_static > 1:
521
+ if verbose:
522
+ print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
523
+ ShapeEnv._check_frozen = f_shape_env__check_frozen
431
524
 
432
525
  if catch_constraints:
433
526
  # to catch or skip dynamic_shapes issues
@@ -440,12 +533,23 @@ def bypass_export_some_errors(
440
533
  if verbose:
441
534
  print("[bypass_export_some_errors] restored shape constraints")
442
535
 
536
+ if custom_patches:
537
+ if verbose:
538
+ print("[bypass_export_some_errors] unpatch custom patches")
539
+ unpatch_module_or_classes(
540
+ custom_patches, revert_custom_patches_info, verbose=verbose
541
+ )
542
+
443
543
  ##############
444
544
  # transformers
445
545
  ##############
446
546
 
447
547
  if patch_transformers:
448
- unpatch_module(patch_transformers_list, revert_patches_info, verbose=verbose)
548
+ if verbose:
549
+ print("[bypass_export_some_errors] unpatch transformers")
550
+ unpatch_module_or_classes(
551
+ patch_transformers_list, revert_patches_info, verbose=verbose
552
+ )
449
553
 
450
554
  ########
451
555
  # caches
@@ -97,6 +97,10 @@ def flatten_dynamic_cache(
97
97
  dynamic_cache: transformers.cache_utils.DynamicCache,
98
98
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
99
99
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
100
+ import transformers.cache_utils
101
+
102
+ if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
103
+ return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
100
104
  flat = [
101
105
  (k, getattr(dynamic_cache, k))
102
106
  for k in ["key_cache", "value_cache"]
@@ -111,7 +115,10 @@ def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[
111
115
  ]:
112
116
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
113
117
  import torch
118
+ import transformers.cache_utils
114
119
 
120
+ if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
121
+ return transformers.cache_utils._flatten_with_keys_dynamic_cache(d)
115
122
  values, context = flatten_dynamic_cache(d)
116
123
  return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
117
124
 
@@ -122,9 +129,13 @@ def unflatten_dynamic_cache(
122
129
  output_type=None,
123
130
  ) -> transformers.cache_utils.DynamicCache:
124
131
  """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
125
- from transformers.cache_utils import DynamicCache
132
+ import transformers.cache_utils
133
+
134
+ if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
135
+ assert output_type is None, f"output_type={output_type} not supported"
136
+ return transformers.cache_utils._unflatten_dynamic_cache(values, context)
126
137
 
127
- cache = DynamicCache()
138
+ cache = transformers.cache_utils.DynamicCache()
128
139
  values = dict(zip(context, values))
129
140
  for k, v in values.items():
130
141
  setattr(cache, k, v)
@@ -0,0 +1,174 @@
1
+ import inspect
2
+ from typing import Any, Dict, Optional, Tuple
3
+ import torch
4
+ import transformers
5
+ from ..helpers import string_type
6
+ from ..helpers.cache_helper import make_dynamic_cache
7
+
8
+
9
+ def _process_cache(k: str, v):
10
+ assert k != "position_ids" or isinstance(
11
+ k, torch.Tensor
12
+ ), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}"
13
+ if (
14
+ isinstance(v, list)
15
+ and all(isinstance(i, tuple) for i in v)
16
+ and set(len(t) for t in v) == {2}
17
+ ):
18
+ # A dynamicCache
19
+ cache = make_dynamic_cache(v)
20
+ return cache
21
+ if isinstance(v, torch.Tensor):
22
+ return v
23
+ raise NotImplementedError(
24
+ f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}"
25
+ )
26
+
27
+
28
+ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
29
+ if cls is transformers.cache_utils.DynamicCache:
30
+ assert subset, "DynamicCache cannot be empty"
31
+ values = set(map(str, subset.values()))
32
+ assert len(values) == 1, (
33
+ f"Inconsistencies in subset={subset}, found={values}, "
34
+ f"it cannot be a {cls}, value={string_type(value)}"
35
+ )
36
+ cache_length = len(value.key_cache)
37
+ for v in subset.values():
38
+ axes = v
39
+ break
40
+ new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
41
+ return new_shape
42
+ if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
43
+ raise NotImplementedError(
44
+ f"_make_shape not implemented for registered class={cls}, "
45
+ f"subset={subset}, value={string_type(value)}"
46
+ )
47
+ raise NotImplementedError(
48
+ f"_make_shape not implemented for cls={cls}, "
49
+ f"subset={subset}, value={string_type(value)}"
50
+ )
51
+
52
+
53
+ def convert_dynamic_axes_into_dynamic_shapes(
54
+ model: torch.nn.Module,
55
+ args: Optional[Tuple[Any, ...]] = None,
56
+ kwargs: Optional[Dict[str, Any]] = None,
57
+ dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
58
+ prefix_mapping: Optional[Dict[str, str]] = None,
59
+ verbose: int = 0,
60
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]:
61
+ """
62
+ Converts the input from an export to something :func:`torch.export.export` can handle.
63
+
64
+ :param model: model to convert (used to extract the signature)
65
+ :param args: positional arguments
66
+ :param kwargs: named arguments
67
+ :param dynamic_axes: dynamic axes
68
+ :param prefix_mapping: prefix mapping
69
+ :param verbose: verbosity
70
+ :return: (args, kwargs, dynamic shapes)
71
+ """
72
+ new_kwargs = {}
73
+ if args:
74
+ assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
75
+ plus = 0 if isinstance(model, torch.nn.Module) else 1
76
+ print(
77
+ f"[convert_dynamic_axes_into_dynamic_shapes] "
78
+ f"mapping args to kwargs for model="
79
+ f"{model if plus else model.__class__.__name__}"
80
+ )
81
+ pars = inspect.signature(model.forward).parameters
82
+ assert len(pars) >= len(
83
+ args
84
+ ), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}"
85
+
86
+ for i, p in enumerate(pars):
87
+ if i < plus:
88
+ continue
89
+ if i - plus >= len(args):
90
+ break
91
+ if verbose:
92
+ print(
93
+ f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] "
94
+ f"to {p!r} ({string_type(args[i-plus])})"
95
+ )
96
+ new_kwargs[p] = args[i - plus]
97
+
98
+ if kwargs:
99
+ for k, v in kwargs.items():
100
+ assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args."
101
+ new_kwargs[k] = v
102
+
103
+ # process
104
+ updated_kwargs = {}
105
+ changes = {}
106
+ for k, v in new_kwargs.items():
107
+ if isinstance(v, torch.Tensor):
108
+ updated_kwargs[k] = v
109
+ continue
110
+ if isinstance(v, list):
111
+ # cache?
112
+ updated_kwargs[k] = _process_cache(k, v)
113
+ if type(updated_kwargs[k]) is not type(v):
114
+ # A cache was introduced.
115
+ if verbose:
116
+ print(
117
+ f"[convert_dynamic_axes_into_dynamic_shapes] parameter "
118
+ f"{k!r} was changed into {type(updated_kwargs[k])}"
119
+ )
120
+ changes[k] = type(updated_kwargs[k])
121
+ continue
122
+ raise NotImplementedError(
123
+ f"Unexpected type {type(v)} for parameter {k!r} "
124
+ f"({string_type(v, with_shape=True)})"
125
+ )
126
+
127
+ # process dynamic axes
128
+ if changes:
129
+ dynamic_shapes = {}
130
+ done = set()
131
+ for k, v in dynamic_axes.items():
132
+ if k not in changes and k in updated_kwargs and isinstance(v, dict):
133
+ dynamic_shapes[k] = v
134
+ continue
135
+ if "." in k:
136
+ # something like present.0.key
137
+ prefix = k.split(".")[0]
138
+ if prefix in done:
139
+ continue
140
+ args_prefix = (
141
+ prefix_mapping[prefix]
142
+ if prefix_mapping and prefix in prefix_mapping
143
+ else prefix
144
+ )
145
+ if args_prefix in updated_kwargs and args_prefix in changes:
146
+ # A cache.
147
+ cls = changes[args_prefix]
148
+ dynamic_shapes[args_prefix] = _make_shape(
149
+ {
150
+ _: __
151
+ for _, __ in dynamic_axes.items()
152
+ if _.startswith(f"{prefix}.")
153
+ },
154
+ cls,
155
+ updated_kwargs[args_prefix],
156
+ )
157
+ done.add(prefix)
158
+ continue
159
+ if k not in updated_kwargs:
160
+ # dynamic axes not in the given inputs, should be raise an exception?
161
+ if verbose:
162
+ print(
163
+ f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes "
164
+ f"{k!r}-{v!r}, not found in {set(updated_kwargs)}"
165
+ )
166
+ continue
167
+ raise NotImplementedError(
168
+ f"Unable to process dynamic axes {k!r}, axes={v}, "
169
+ f"value={string_type(updated_kwargs[k], with_shape=True)}, "
170
+ f"dynamic axes={dynamic_axes}, "
171
+ f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}"
172
+ )
173
+
174
+ return (), updated_kwargs, dynamic_shapes
@@ -131,7 +131,7 @@ def patched__broadcast_shapes(*_shapes):
131
131
  assert isinstance(shape, Sequence)
132
132
 
133
133
  # Computes common shape
134
- common_shape: List[Union[int, torch.SymInt]] = [
134
+ common_shape = [ # List[Union[int, torch.SymInt]]
135
135
  1,
136
136
  ] * reduce(max, (len(shape) for shape in shapes))
137
137
  for _arg_idx, shape in enumerate(shapes):
@@ -150,6 +150,16 @@ def patched__broadcast_shapes(*_shapes):
150
150
 
151
151
  class patched_ShapeEnv:
152
152
 
153
+ def _check_frozen(
154
+ self, expr: "sympy.Basic", concrete_val: "sympy.Basic" # noqa: F821
155
+ ) -> None:
156
+ if self.frozen:
157
+ self.counter["ignored_backward_guard"] += 1
158
+ raise AssertionError(
159
+ f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
160
+ f"this could result in accuracy problems."
161
+ )
162
+
153
163
  def _set_replacement(
154
164
  self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
155
165
  ) -> None:
@@ -314,7 +324,7 @@ class patched_ShapeEnv:
314
324
  # )
315
325
  # self.log.debug("SPECIALIZATION", stack_info=True)
316
326
  assert msg != "range_refined_to_singleton", (
317
- f"A dynamic dimension becomes static! "
327
+ f"patched_ShapeEnv: A dynamic dimension becomes static! "
318
328
  f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
319
329
  )
320
330
  # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
@@ -4,9 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import Any, Dict, List, Optional, Tuple
5
5
  import torch
6
6
  import transformers
7
- import transformers.modeling_attn_mask_utils
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
8
  from transformers.cache_utils import StaticCache, Cache, DynamicCache
9
- from ...torch_test_helper import is_torchdynamo_exporting
9
+ from ...helpers.torch_test_helper import is_torchdynamo_exporting
10
10
 
11
11
 
12
12
  def _patch_make_causal_mask(
@@ -54,7 +54,7 @@ if sys.version_info[:2] <= (3, 11):
54
54
  """
55
55
 
56
56
  _PATCHES_ = ["_make_causal_mask"]
57
- _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter
57
+ _PATCHED_CLASS_ = AttentionMaskConverter
58
58
 
59
59
  @staticmethod
60
60
  def _make_causal_mask(
@@ -79,7 +79,7 @@ else:
79
79
  """
80
80
 
81
81
  _PATCHES_ = ["_make_causal_mask"]
82
- _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter
82
+ _PATCHED_CLASS_ = AttentionMaskConverter
83
83
 
84
84
  @staticmethod
85
85
  def _make_causal_mask(
@@ -0,0 +1 @@
1
+ from .model_inputs import get_untrained_model_with_inputs