onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +56 -3
  3. onnx_diagnostic/export/dynamic_shapes.py +24 -10
  4. onnx_diagnostic/export/shape_helper.py +6 -2
  5. onnx_diagnostic/ext_test_case.py +2 -0
  6. onnx_diagnostic/helpers/_log_helper.py +6 -6
  7. onnx_diagnostic/helpers/cache_helper.py +326 -18
  8. onnx_diagnostic/helpers/config_helper.py +10 -0
  9. onnx_diagnostic/helpers/helper.py +152 -11
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  11. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  12. onnx_diagnostic/helpers/torch_helper.py +33 -11
  13. onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
  14. onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
  15. onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
  16. onnx_diagnostic/tasks/__init__.py +2 -0
  17. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  18. onnx_diagnostic/tasks/feature_extraction.py +7 -3
  19. onnx_diagnostic/tasks/fill_mask.py +6 -2
  20. onnx_diagnostic/tasks/image_classification.py +6 -2
  21. onnx_diagnostic/tasks/image_text_to_text.py +289 -62
  22. onnx_diagnostic/tasks/mask_generation.py +143 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
  24. onnx_diagnostic/tasks/object_detection.py +6 -2
  25. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  26. onnx_diagnostic/tasks/summarization.py +7 -2
  27. onnx_diagnostic/tasks/text2text_generation.py +7 -2
  28. onnx_diagnostic/tasks/text_classification.py +6 -2
  29. onnx_diagnostic/tasks/text_generation.py +14 -16
  30. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
  31. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  32. onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
  33. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
  35. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
  36. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
  37. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  38. onnx_diagnostic/torch_models/validate.py +1 -0
  39. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
  40. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
  41. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,12 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4
- from ..helpers.config_helper import update_config, check_hasattr, _pick
4
+ from ..helpers.config_helper import (
5
+ update_config,
6
+ check_hasattr,
7
+ _pick,
8
+ default_num_hidden_layers as nhl,
9
+ )
5
10
 
6
11
  __TASK__ = "text2text-generation"
7
12
 
@@ -12,7 +17,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
12
17
  if hasattr(config, "num_decoder_layers"):
13
18
  config.num_decoder_layers = min(config.num_decoder_layers, 2)
14
19
  if hasattr(config, "num_hidden_layers"):
15
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
20
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
16
21
  update_config(config, kwargs)
17
22
  return kwargs
18
23
 
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "text-classification"
6
10
 
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
13
  """Reduces a model size."""
10
14
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
15
  kwargs = dict(
12
- num_hidden_layers=min(config.num_hidden_layers, 2),
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
13
17
  num_attention_heads=min(config.num_attention_heads, 4),
14
18
  )
15
19
  update_config(config, kwargs)
@@ -1,13 +1,17 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple, Union
2
2
  import torch
3
- import transformers
4
3
  from ..helpers.cache_helper import (
5
4
  make_dynamic_cache,
6
5
  make_mamba_cache,
7
6
  make_sliding_window_cache,
8
7
  make_static_cache,
9
8
  )
10
- from ..helpers.config_helper import update_config, check_hasattr, _pick
9
+ from ..helpers.config_helper import (
10
+ update_config,
11
+ check_hasattr,
12
+ _pick,
13
+ default_num_hidden_layers as nhl,
14
+ )
11
15
 
12
16
  __TASK__ = "text-generation"
13
17
 
@@ -26,7 +30,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
26
30
  if config.__class__.__name__ == "FalconMambaConfig":
27
31
  check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
28
32
  kwargs = dict(
29
- num_hidden_layers=min(config.num_hidden_layers, 2),
33
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
30
34
  intermediate_size=256 if config is None else min(512, config.intermediate_size),
31
35
  hidden_size=512 if config is None else min(512, config.hidden_size),
32
36
  cls_cache="MambaCache",
@@ -38,24 +42,13 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
38
42
  head_dim=getattr(
39
43
  config, "head_dim", config.hidden_size // config.num_attention_heads
40
44
  ),
41
- num_hidden_layers=min(config.num_hidden_layers, 2),
45
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
42
46
  num_key_value_heads=(
43
47
  config.num_key_value_heads
44
48
  if hasattr(config, "num_key_value_heads")
45
49
  else config.num_attention_heads
46
50
  ),
47
- hidden_size=(
48
- min(config.hidden_size, 4096 // 4)
49
- if config.hidden_size % 64 == 0
50
- else config.hidden_size
51
- ),
52
51
  )
53
- if config is None or hasattr(config, "intermediate_size"):
54
- kwargs["intermediate_size"] = (
55
- min(config.intermediate_size, 24576 // 4)
56
- if config.intermediate_size % 4 == 0
57
- else config.intermediate_size
58
- )
59
52
  update_config(config, kwargs)
60
53
  return kwargs
61
54
 
@@ -95,9 +88,14 @@ def get_inputs(
95
88
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
96
89
 
97
90
  if config is not None and config.__class__.__name__ == "FalconMambaConfig":
91
+ try:
92
+ from transformers.models.mamba.modeling_mamba import MambaCache
93
+ except ImportError:
94
+ from transformers.cache_utils import MambaCache
95
+
98
96
  assert cls_cache in (
99
97
  "MambaCache",
100
- transformers.cache_utils.MambaCache,
98
+ MambaCache,
101
99
  ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
102
100
  seq_length_multiple = 8
103
101
  sequence_length = (
@@ -361,7 +361,7 @@ def torch_export_patches(
361
361
  torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
362
362
 
363
363
  # torch._export.non_strict_utils.produce_guards_and_solve_constraints
364
- if catch_constraints:
364
+ if patch_torch and catch_constraints:
365
365
  if verbose:
366
366
  print("[torch_export_patches] modifies shape constraints")
367
367
  f_produce_guards_and_solve_constraints = (
@@ -513,7 +513,7 @@ def torch_export_patches(
513
513
  if verbose:
514
514
  print("[torch_export_patches] restored pytorch functions")
515
515
 
516
- if stop_if_static:
516
+ if patch_torch and stop_if_static:
517
517
  if verbose:
518
518
  print("[torch_export_patches] restored ShapeEnv._set_replacement")
519
519
 
@@ -529,7 +529,7 @@ def torch_export_patches(
529
529
  print("[torch_export_patches] restored ShapeEnv._check_frozen")
530
530
  ShapeEnv._check_frozen = f_shape_env__check_frozen
531
531
 
532
- if catch_constraints:
532
+ if patch_torch and catch_constraints:
533
533
  # to catch or skip dynamic_shapes issues
534
534
  torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
535
535
  f_produce_guards_and_solve_constraints
@@ -6,12 +6,17 @@ import torch
6
6
  import transformers
7
7
  from transformers.cache_utils import (
8
8
  DynamicCache,
9
- MambaCache,
10
9
  EncoderDecoderCache,
10
+ HybridCache,
11
11
  SlidingWindowCache,
12
12
  StaticCache,
13
13
  )
14
14
 
15
+ try:
16
+ from transformers.models.mamba.modeling_mamba import MambaCache
17
+ except ImportError:
18
+ from transformers.cache_utils import MambaCache
19
+
15
20
  from ..helpers import string_type
16
21
  from .serialization import _lower_name_with_
17
22
 
@@ -161,6 +166,9 @@ def serialization_functions(
161
166
  flatten_dynamic_cache,
162
167
  unflatten_dynamic_cache,
163
168
  flatten_with_keys_dynamic_cache,
169
+ flatten_hybrid_cache,
170
+ unflatten_hybrid_cache,
171
+ flatten_with_keys_hybrid_cache,
164
172
  flatten_mamba_cache,
165
173
  unflatten_mamba_cache,
166
174
  flatten_with_keys_mamba_cache,
@@ -187,6 +195,14 @@ def serialization_functions(
187
195
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
188
196
  verbose=verbose,
189
197
  ),
198
+ HybridCache: lambda verbose=verbose: register_class_serialization(
199
+ HybridCache,
200
+ flatten_hybrid_cache,
201
+ unflatten_hybrid_cache,
202
+ flatten_with_keys_hybrid_cache,
203
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
204
+ verbose=verbose,
205
+ ),
190
206
  MambaCache: lambda verbose=verbose: register_class_serialization(
191
207
  MambaCache,
192
208
  flatten_mamba_cache,
@@ -34,7 +34,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
34
34
  f"Inconsistencies in subset={subset}, found={values}, "
35
35
  f"it cannot be a {cls}, value={string_type(value)}"
36
36
  )
37
- cache_length = len(value.key_cache)
37
+ cache_length = len(value.layers if hasattr(value, "layers") else value.key_cache)
38
38
  for v in subset.values():
39
39
  axes = v
40
40
  break
@@ -70,6 +70,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
70
70
  :param verbose: verbosity
71
71
  :return: (args, kwargs, dynamic shapes)
72
72
  """
73
+ from ..helpers.cache_helper import CacheKeyValue
74
+
73
75
  new_kwargs = {}
74
76
  if args:
75
77
  assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
@@ -121,7 +123,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
121
123
  changes[k] = type(updated_kwargs[k])
122
124
  continue
123
125
  if isinstance(v, transformers.cache_utils.DynamicCache):
124
- updated_kwargs[k] = [v.key_cache, v.value_cache]
126
+ ca = CacheKeyValue(v)
127
+ updated_kwargs[k] = [ca.key_cache, ca.value_cache]
125
128
  changes[k] = type(v)
126
129
  continue
127
130
  raise NotImplementedError(
@@ -27,8 +27,8 @@ def _catch_produce_guards_and_solve_constraints(
27
27
  dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
28
28
  equalities_inputs: "EqualityConstraint", # noqa: F821
29
29
  original_signature: inspect.Signature,
30
- _is_torch_jit_trace: bool = False,
31
30
  verbose: int = 0,
31
+ **kwargs,
32
32
  ):
33
33
  try:
34
34
  return previous_function(
@@ -37,7 +37,7 @@ def _catch_produce_guards_and_solve_constraints(
37
37
  dynamic_shapes=dynamic_shapes,
38
38
  equalities_inputs=equalities_inputs,
39
39
  original_signature=original_signature,
40
- _is_torch_jit_trace=_is_torch_jit_trace,
40
+ **kwargs,
41
41
  )
42
42
  except Exception as e:
43
43
  if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
@@ -51,7 +51,7 @@ def _catch_produce_guards_and_solve_constraints(
51
51
  f"dynamic_shapes={dynamic_shapes}\n"
52
52
  f"equalities_inputs={equalities_inputs}\n"
53
53
  f"original_signature={original_signature}\n"
54
- f"_is_torch_jit_trace={_is_torch_jit_trace}\n"
54
+ f"kwargs={kwargs}\n"
55
55
  f"exc={e}\ngm={gm}"
56
56
  )
57
57
  torch._dynamo.reset()
@@ -174,7 +174,7 @@ class patched_ShapeEnv:
174
174
  self.counter["ignored_backward_guard"] += 1
175
175
  raise AssertionError(
176
176
  f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
177
- f"this could result in accuracy problems."
177
+ f"this could result in accuracy problems"
178
178
  )
179
179
 
180
180
  def _set_replacement(