onnx-diagnostic 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +35 -5
  4. onnx_diagnostic/export/control_flow.py +511 -0
  5. onnx_diagnostic/export/control_flow_research.py +135 -0
  6. onnx_diagnostic/ext_test_case.py +33 -9
  7. onnx_diagnostic/helpers/cache_helper.py +217 -203
  8. onnx_diagnostic/helpers/helper.py +6 -2
  9. onnx_diagnostic/helpers/log_helper.py +39 -5
  10. onnx_diagnostic/helpers/memory_peak.py +2 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +55 -3
  12. onnx_diagnostic/helpers/onnx_helper.py +13 -16
  13. onnx_diagnostic/helpers/rt_helper.py +579 -15
  14. onnx_diagnostic/helpers/torch_helper.py +5 -0
  15. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  16. onnx_diagnostic/tasks/text2text_generation.py +1 -0
  17. onnx_diagnostic/tasks/text_generation.py +84 -54
  18. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
  22. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +563 -61
  23. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  25. onnx_diagnostic/torch_models/validate.py +620 -213
  26. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.8.0.dist-info → onnx_diagnostic-0.8.2.dist-info}/top_level.txt +0 -0
@@ -56,6 +56,74 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
56
56
  return kwargs
57
57
 
58
58
 
59
+ def _get_input_falcon_mamba(
60
+ model: torch.nn.Module,
61
+ config: Optional[Any],
62
+ dummy_max_token_id: int,
63
+ num_hidden_layers: int,
64
+ batch_size: int = 2,
65
+ sequence_length: int = 30,
66
+ sequence_length2: int = 3,
67
+ dynamic_rope: bool = False,
68
+ num_key_value_heads: Optional[int] = None,
69
+ head_dim: Optional[int] = None,
70
+ cls_cache: Optional[Union[type, str]] = None,
71
+ **kwargs, # unused
72
+ ):
73
+ try:
74
+ from transformers.models.mamba.modeling_mamba import MambaCache
75
+ except ImportError:
76
+ from transformers.cache_utils import MambaCache
77
+
78
+ assert cls_cache in (
79
+ "MambaCache",
80
+ MambaCache,
81
+ ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
82
+
83
+ batch = "batch"
84
+ seq_length_multiple = 8
85
+ sequence_length = (
86
+ (sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple
87
+ )
88
+ # sequence_inc = seq_length_multiple
89
+ sequence_length2 = seq_length_multiple
90
+
91
+ shapes = {
92
+ "input_ids": {0: batch, 1: "sequence_length"},
93
+ "attention_mask": {
94
+ 0: batch,
95
+ 1: "cache+seq", # cache_length + seq_length
96
+ },
97
+ "cache_position": {
98
+ 0: batch,
99
+ 1: "cache+seq", # cache_length + seq_length
100
+ },
101
+ "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
102
+ }
103
+ inputs = dict(
104
+ input_ids=torch.randint(
105
+ 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
106
+ ).to(torch.int64),
107
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
108
+ torch.int64
109
+ ),
110
+ cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
111
+ # .expand((batch_size, -1))
112
+ cache_params=make_mamba_cache(
113
+ [
114
+ (
115
+ torch.randn(
116
+ batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
117
+ ),
118
+ torch.randn(batch_size, kwargs["intermediate_size"], kwargs["state_size"]),
119
+ )
120
+ for i in range(num_hidden_layers)
121
+ ]
122
+ ),
123
+ )
124
+ return dict(inputs=inputs, dynamic_shapes=shapes)
125
+
126
+
59
127
  def get_inputs(
60
128
  model: torch.nn.Module,
61
129
  config: Optional[Any],
@@ -68,7 +136,7 @@ def get_inputs(
68
136
  num_key_value_heads: Optional[int] = None,
69
137
  head_dim: Optional[int] = None,
70
138
  cls_cache: Optional[Union[type, str]] = None,
71
- add_second_input: int = 1,
139
+ add_second_input: Optional[int] = None,
72
140
  **kwargs, # unused
73
141
  ):
74
142
  """
@@ -84,6 +152,7 @@ def get_inputs(
84
152
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
85
153
  :param cls_cache: cache class, by default it is
86
154
  :class:`transformers.cache_utils.DynamicCache`
155
+ :param add_second_input: adds other kinds of inputs
87
156
  :return: dictionary
88
157
  """
89
158
  batch = "batch"
@@ -91,60 +160,20 @@ def get_inputs(
91
160
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
92
161
 
93
162
  if config is not None and config.__class__.__name__ == "FalconMambaConfig":
94
- try:
95
- from transformers.models.mamba.modeling_mamba import MambaCache
96
- except ImportError:
97
- from transformers.cache_utils import MambaCache
98
-
99
- assert cls_cache in (
100
- "MambaCache",
101
- MambaCache,
102
- ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
103
- seq_length_multiple = 8
104
- sequence_length = (
105
- (sequence_length + seq_length_multiple)
106
- // seq_length_multiple
107
- * seq_length_multiple
108
- )
109
- # sequence_inc = seq_length_multiple
110
- sequence_length2 = seq_length_multiple
111
-
112
- shapes = {
113
- "input_ids": {0: batch, 1: "sequence_length"},
114
- "attention_mask": {
115
- 0: batch,
116
- 1: "cache+seq", # cache_length + seq_length
117
- },
118
- "cache_position": {
119
- 0: batch,
120
- 1: "cache+seq", # cache_length + seq_length
121
- },
122
- "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
123
- }
124
- inputs = dict(
125
- input_ids=torch.randint(
126
- 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
127
- ).to(torch.int64),
128
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
129
- torch.int64
130
- ),
131
- cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
132
- # .expand((batch_size, -1))
133
- cache_params=make_mamba_cache(
134
- [
135
- (
136
- torch.randn(
137
- batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
138
- ),
139
- torch.randn(
140
- batch_size, kwargs["intermediate_size"], kwargs["state_size"]
141
- ),
142
- )
143
- for i in range(num_hidden_layers)
144
- ]
145
- ),
163
+ res = _get_input_falcon_mamba(
164
+ model=model,
165
+ config=config,
166
+ dummy_max_token_id=dummy_max_token_id,
167
+ num_hidden_layers=num_hidden_layers,
168
+ batch_size=batch_size,
169
+ sequence_length=sequence_length,
170
+ sequence_length2=sequence_length2,
171
+ dynamic_rope=dynamic_rope,
172
+ num_key_value_heads=num_key_value_heads,
173
+ head_dim=head_dim,
174
+ cls_cache=cls_cache,
175
+ **kwargs, # unused
146
176
  )
147
- res = dict(inputs=inputs, dynamic_shapes=shapes)
148
177
  else:
149
178
  if head_dim is None:
150
179
  assert config, "head_dim is None, the value cannot be set without a configuration"
@@ -244,6 +273,7 @@ def get_inputs(
244
273
  )
245
274
  res = dict(inputs=inputs, dynamic_shapes=shapes)
246
275
  if add_second_input:
276
+ res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
247
277
  res["inputs2"] = get_inputs(
248
278
  model=model,
249
279
  config=config,
@@ -570,6 +570,34 @@ class ControlFlowScanDecomposition_151564(torch.nn.Module):
570
570
  _dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
571
571
 
572
572
 
573
+ class ControlFlowWhileDec(torch.nn.Module):
574
+ def forward(self, ci, a, b):
575
+ def cond_fn(i, x, y):
576
+ return i > 0
577
+
578
+ def body_fn(i, x, y):
579
+ return i - 1, x + y, y - x
580
+
581
+ return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
582
+
583
+ _inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
584
+ _dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
585
+
586
+
587
+ class ControlFlowWhileInc(torch.nn.Module):
588
+ def forward(self, ci, a, b):
589
+ def cond_fn(i, x, y):
590
+ return i < x.size(0)
591
+
592
+ def body_fn(i, x, y):
593
+ return i + 1, x + y, y - x
594
+
595
+ return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
596
+
597
+ _inputs = [(torch.tensor(1), torch.randn(2, 3), torch.randn(2, 3))]
598
+ _dynamic = {}, {0: DYN, 1: DYN}, {0: DYN}
599
+
600
+
573
601
  class SignatureInt1(torch.nn.Module):
574
602
  def __init__(self, n_dims: int = 3, n_targets: int = 1):
575
603
  super().__init__()
@@ -32,7 +32,7 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
32
32
  v = getattr(mod, k)
33
33
  if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
34
34
  to_patch.append(v)
35
- else:
35
+ elif v.__doc__:
36
36
  # a function
37
37
  doc = v.__doc__.lstrip()
38
38
  if doc.startswith("manual patch"):
@@ -4,14 +4,18 @@ import packaging.version as pv
4
4
  import optree
5
5
  import torch
6
6
  import transformers
7
- from transformers.cache_utils import (
8
- DynamicCache,
9
- EncoderDecoderCache,
10
- HybridCache,
11
- SlidingWindowCache,
12
- StaticCache,
13
- )
7
+ from transformers.cache_utils import DynamicCache, StaticCache
14
8
 
9
+ try:
10
+ from transformers.cache_utils import (
11
+ EncoderDecoderCache,
12
+ HybridCache,
13
+ SlidingWindowCache,
14
+ )
15
+ except ImportError:
16
+ EncoderDecoderCache = None
17
+ HybridCache = None
18
+ SlidingWindowCache = None
15
19
  from ..helpers import string_type
16
20
  from .serialization import _lower_name_with_
17
21
 
@@ -195,9 +195,12 @@ class patched_ShapeEnv:
195
195
  if self.frozen:
196
196
  self.counter["ignored_backward_guard"] += 1
197
197
  # PATCHED: raised an exception instead of logging.
198
+ import transformers
199
+
198
200
  raise AssertionError(
199
201
  f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
200
- f"this could result in accuracy problems"
202
+ f"this could result in accuracy problems, transformers.__version__="
203
+ f"{transformers.__version__!r}"
201
204
  )
202
205
 
203
206
  def _set_replacement(