onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -56,6 +56,74 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
56
56
  return kwargs
57
57
 
58
58
 
59
+ def _get_input_falcon_mamba(
60
+ model: torch.nn.Module,
61
+ config: Optional[Any],
62
+ dummy_max_token_id: int,
63
+ num_hidden_layers: int,
64
+ batch_size: int = 2,
65
+ sequence_length: int = 30,
66
+ sequence_length2: int = 3,
67
+ dynamic_rope: bool = False,
68
+ num_key_value_heads: Optional[int] = None,
69
+ head_dim: Optional[int] = None,
70
+ cls_cache: Optional[Union[type, str]] = None,
71
+ **kwargs, # unused
72
+ ):
73
+ try:
74
+ from transformers.models.mamba.modeling_mamba import MambaCache
75
+ except ImportError:
76
+ from transformers.cache_utils import MambaCache
77
+
78
+ assert cls_cache in (
79
+ "MambaCache",
80
+ MambaCache,
81
+ ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
82
+
83
+ batch = "batch"
84
+ seq_length_multiple = 8
85
+ sequence_length = (
86
+ (sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple
87
+ )
88
+ # sequence_inc = seq_length_multiple
89
+ sequence_length2 = seq_length_multiple
90
+
91
+ shapes = {
92
+ "input_ids": {0: batch, 1: "sequence_length"},
93
+ "attention_mask": {
94
+ 0: batch,
95
+ 1: "cache+seq", # cache_length + seq_length
96
+ },
97
+ "cache_position": {
98
+ 0: batch,
99
+ 1: "cache+seq", # cache_length + seq_length
100
+ },
101
+ "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
102
+ }
103
+ inputs = dict(
104
+ input_ids=torch.randint(
105
+ 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
106
+ ).to(torch.int64),
107
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
108
+ torch.int64
109
+ ),
110
+ cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
111
+ # .expand((batch_size, -1))
112
+ cache_params=make_mamba_cache(
113
+ [
114
+ (
115
+ torch.randn(
116
+ batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
117
+ ),
118
+ torch.randn(batch_size, kwargs["intermediate_size"], kwargs["state_size"]),
119
+ )
120
+ for i in range(num_hidden_layers)
121
+ ]
122
+ ),
123
+ )
124
+ return dict(inputs=inputs, dynamic_shapes=shapes)
125
+
126
+
59
127
  def get_inputs(
60
128
  model: torch.nn.Module,
61
129
  config: Optional[Any],
@@ -68,7 +136,7 @@ def get_inputs(
68
136
  num_key_value_heads: Optional[int] = None,
69
137
  head_dim: Optional[int] = None,
70
138
  cls_cache: Optional[Union[type, str]] = None,
71
- add_second_input: int = 1,
139
+ add_second_input: Optional[int] = None,
72
140
  **kwargs, # unused
73
141
  ):
74
142
  """
@@ -84,6 +152,7 @@ def get_inputs(
84
152
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
85
153
  :param cls_cache: cache class, by default it is
86
154
  :class:`transformers.cache_utils.DynamicCache`
155
+ :param add_second_input: adds other kinds of inputs
87
156
  :return: dictionary
88
157
  """
89
158
  batch = "batch"
@@ -91,63 +160,20 @@ def get_inputs(
91
160
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
92
161
 
93
162
  if config is not None and config.__class__.__name__ == "FalconMambaConfig":
94
- try:
95
- from transformers.models.mamba.modeling_mamba import MambaCache
96
- except ImportError:
97
- from transformers.cache_utils import MambaCache
98
-
99
- assert cls_cache in (
100
- "MambaCache",
101
- MambaCache,
102
- ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
103
- seq_length_multiple = 8
104
- sequence_length = (
105
- (sequence_length + seq_length_multiple)
106
- // seq_length_multiple
107
- * seq_length_multiple
108
- )
109
- # sequence_inc = seq_length_multiple
110
- sequence_length2 = seq_length_multiple
111
-
112
- shapes = {
113
- "input_ids": {0: batch, 1: "sequence_length"},
114
- "attention_mask": {
115
- 0: batch,
116
- 1: "cache+seq", # cache_length + seq_length
117
- },
118
- "cache_position": {
119
- 0: batch,
120
- 1: "cache+seq", # cache_length + seq_length
121
- },
122
- "cache_params": [
123
- [{0: batch} for _ in range(num_hidden_layers)],
124
- [{0: batch} for _ in range(num_hidden_layers)],
125
- ],
126
- }
127
- inputs = dict(
128
- input_ids=torch.randint(
129
- 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
130
- ).to(torch.int64),
131
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
132
- torch.int64
133
- ),
134
- cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
135
- # .expand((batch_size, -1))
136
- cache_params=make_mamba_cache(
137
- [
138
- (
139
- torch.randn(
140
- batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
141
- ),
142
- torch.randn(
143
- batch_size, kwargs["intermediate_size"], kwargs["state_size"]
144
- ),
145
- )
146
- for i in range(num_hidden_layers)
147
- ]
148
- ),
163
+ res = _get_input_falcon_mamba(
164
+ model=model,
165
+ config=config,
166
+ dummy_max_token_id=dummy_max_token_id,
167
+ num_hidden_layers=num_hidden_layers,
168
+ batch_size=batch_size,
169
+ sequence_length=sequence_length,
170
+ sequence_length2=sequence_length2,
171
+ dynamic_rope=dynamic_rope,
172
+ num_key_value_heads=num_key_value_heads,
173
+ head_dim=head_dim,
174
+ cls_cache=cls_cache,
175
+ **kwargs, # unused
149
176
  )
150
- res = dict(inputs=inputs, dynamic_shapes=shapes)
151
177
  else:
152
178
  if head_dim is None:
153
179
  assert config, "head_dim is None, the value cannot be set without a configuration"
@@ -176,12 +202,7 @@ def get_inputs(
176
202
  "input_ids": {0: batch, 1: seq_length},
177
203
  "attention_mask": {0: batch, 2: "seq"},
178
204
  "cache_position": {0: "seq"},
179
- "past_key_values": [
180
- # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181
- # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
182
- [{0: batch} for _ in range(num_hidden_layers)],
183
- [{0: batch} for _ in range(num_hidden_layers)],
184
- ],
205
+ "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
185
206
  }
186
207
  inputs = dict(
187
208
  input_ids=torch.randint(
@@ -222,8 +243,7 @@ def get_inputs(
222
243
  },
223
244
  "position_ids": {0: batch, 1: seq_length},
224
245
  "past_key_values": [
225
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
226
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
246
+ {0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)
227
247
  ],
228
248
  }
229
249
 
@@ -253,6 +273,7 @@ def get_inputs(
253
273
  )
254
274
  res = dict(inputs=inputs, dynamic_shapes=shapes)
255
275
  if add_second_input:
276
+ res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
256
277
  res["inputs2"] = get_inputs(
257
278
  model=model,
258
279
  config=config,