onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +87 -77
  3. onnx_diagnostic/doc.py +22 -0
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +59 -0
  6. onnx_diagnostic/helpers/config_helper.py +8 -4
  7. onnx_diagnostic/helpers/helper.py +30 -3
  8. onnx_diagnostic/helpers/log_helper.py +585 -0
  9. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  10. onnx_diagnostic/helpers/model_builder_helper.py +54 -73
  11. onnx_diagnostic/helpers/torch_helper.py +18 -2
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  14. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  15. onnx_diagnostic/reference/torch_evaluator.py +21 -0
  16. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  17. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  18. onnx_diagnostic/tasks/fill_mask.py +3 -0
  19. onnx_diagnostic/tasks/image_classification.py +7 -1
  20. onnx_diagnostic/tasks/image_text_to_text.py +3 -0
  21. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  22. onnx_diagnostic/tasks/object_detection.py +3 -0
  23. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  24. onnx_diagnostic/tasks/summarization.py +3 -0
  25. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  26. onnx_diagnostic/tasks/text_classification.py +3 -0
  27. onnx_diagnostic/tasks/text_generation.py +90 -43
  28. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  29. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  30. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  32. onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
  33. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  34. onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
  35. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  36. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +158 -103
  37. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
  38. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +41 -39
  39. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
  40. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from ..helpers.cache_helper import (
5
5
  make_dynamic_cache,
6
6
  make_mamba_cache,
7
7
  make_sliding_window_cache,
8
+ make_static_cache,
8
9
  )
9
10
  from ..helpers.config_helper import update_config, check_hasattr, _pick
10
11
 
@@ -151,52 +152,98 @@ def get_inputs(
151
152
  assert config, "head_dim is None, the value cannot be set without a configuration"
152
153
  head_dim = config.hidden_size // config.num_attention_heads
153
154
 
154
- shapes = {
155
- "input_ids": {0: batch, 1: seq_length},
156
- "attention_mask": {
157
- 0: batch,
158
- 1: "cache+seq", # cache_length + seq_length
159
- },
160
- "position_ids": {
161
- 0: batch,
162
- 1: "cache+seq", # cache_length + seq_length
163
- },
164
- "past_key_values": [
165
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
166
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
167
- ],
155
+ cache_name = (
156
+ cls_cache
157
+ if cls_cache is None or isinstance(cls_cache, str)
158
+ else cls_cache.__name__
159
+ )
160
+ make_caches = {
161
+ "DynamicCache": make_dynamic_cache,
162
+ "SlidingWindowCache": make_sliding_window_cache,
163
+ "StaticCache": make_static_cache,
168
164
  }
169
-
170
- make_cache = (
171
- make_sliding_window_cache
172
- if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
173
- else make_dynamic_cache
165
+ assert cache_name is None or cache_name in make_caches, (
166
+ f"Unable to handle cls_cache={cache_name!r}, it should be in "
167
+ f"{sorted(make_caches)}"
174
168
  )
169
+ make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
170
+ is_static = cache_name == "StaticCache"
175
171
 
176
- inputs = dict(
177
- input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
178
- torch.int64
179
- ),
180
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
181
- torch.int64
182
- ),
183
- position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
184
- .to(torch.int64)
185
- .expand((batch_size, -1)),
186
- past_key_values=make_cache(
187
- [
188
- (
189
- torch.randn(
190
- batch_size, num_key_value_heads, sequence_length, head_dim
191
- ),
192
- torch.randn(
193
- batch_size, num_key_value_heads, sequence_length, head_dim
194
- ),
195
- )
196
- for i in range(num_hidden_layers)
197
- ]
198
- ),
199
- )
172
+ if is_static:
173
+ # static
174
+ shapes = {
175
+ "input_ids": {0: batch, 1: seq_length},
176
+ "attention_mask": {0: batch, 2: "seq"},
177
+ "cache_position": {0: "seq"},
178
+ "past_key_values": [
179
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181
+ ],
182
+ }
183
+ inputs = dict(
184
+ input_ids=torch.randint(
185
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
186
+ ).to(torch.int64),
187
+ attention_mask=torch.ones(
188
+ (batch_size, num_key_value_heads, sequence_length2, head_dim)
189
+ ).to(torch.bool),
190
+ cache_position=torch.arange(sequence_length2).to(torch.int64),
191
+ past_key_values=make_cache(
192
+ [
193
+ (
194
+ torch.randn(
195
+ batch_size, num_key_value_heads, sequence_length, head_dim
196
+ ),
197
+ torch.randn(
198
+ batch_size, num_key_value_heads, sequence_length, head_dim
199
+ ),
200
+ )
201
+ for i in range(num_hidden_layers)
202
+ ]
203
+ ),
204
+ )
205
+ else:
206
+ # dynamic
207
+ shapes = {
208
+ "input_ids": {0: batch, 1: seq_length},
209
+ "attention_mask": {
210
+ 0: batch,
211
+ 1: "cache+seq", # cache_length + seq_length
212
+ },
213
+ "position_ids": {
214
+ 0: batch,
215
+ 1: "cache+seq", # cache_length + seq_length
216
+ },
217
+ "past_key_values": [
218
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
219
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
220
+ ],
221
+ }
222
+
223
+ inputs = dict(
224
+ input_ids=torch.randint(
225
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
226
+ ).to(torch.int64),
227
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
228
+ torch.int64
229
+ ),
230
+ position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
231
+ .to(torch.int64)
232
+ .expand((batch_size, -1)),
233
+ past_key_values=make_cache(
234
+ [
235
+ (
236
+ torch.randn(
237
+ batch_size, num_key_value_heads, sequence_length, head_dim
238
+ ),
239
+ torch.randn(
240
+ batch_size, num_key_value_heads, sequence_length, head_dim
241
+ ),
242
+ )
243
+ for i in range(num_hidden_layers)
244
+ ]
245
+ ),
246
+ )
200
247
  res = dict(inputs=inputs, dynamic_shapes=shapes)
201
248
  if add_second_input:
202
249
  res["inputs2"] = get_inputs(
@@ -55,6 +55,9 @@ def get_inputs(
55
55
  # attention_mask:T7s2x7
56
56
  # pixel_values:T1s2x3x224x224
57
57
  """
58
+ assert (
59
+ "cls_cache" not in kwargs
60
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
58
61
  assert isinstance(
59
62
  input_width, int
60
63
  ), f"Unexpected type for input_width {type(input_width)}{config}"
@@ -1,5 +1,8 @@
1
+ import functools
2
+ import importlib
1
3
  import contextlib
2
- from typing import Any, Callable, Dict, List, Optional
4
+ import re
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
3
6
  from .onnx_export_serialization import (
4
7
  register_cache_serialization,
5
8
  unregister_cache_serialization,
@@ -7,6 +10,41 @@ from .onnx_export_serialization import (
7
10
  from .patches import patch_transformers as patch_transformers_list
8
11
 
9
12
 
13
+ def get_function(name: str) -> Tuple[type, Callable]:
14
+ """Returns the module and the function based on its name."""
15
+ spl = name.split(".")
16
+ module_name = ".".join(spl[:-1])
17
+ fname = spl[-1]
18
+ mod = importlib.import_module(module_name)
19
+ return mod, getattr(mod, fname)
20
+
21
+
22
+ @functools.lru_cache
23
+ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
24
+ """Returns the list of patches to make for a specific module."""
25
+ to_patch = []
26
+ for k in dir(mod):
27
+ if k.startswith("patched_"):
28
+ v = getattr(mod, k)
29
+ if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
30
+ to_patch.append(v)
31
+ else:
32
+ # a function
33
+ doc = v.__doc__.lstrip()
34
+ if doc.startswith("manual patch"):
35
+ continue
36
+ reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
37
+ fall = reg.findall(doc)
38
+ assert (
39
+ len(fall) == 1
40
+ ), f"Unable to find patching information for {v} in \n{doc}"
41
+ fmod, f = get_function(fall[0])
42
+ to_patch.append({"module": fmod, "function": f, "patch": v})
43
+
44
+ name = mod.__name__
45
+ return name, to_patch
46
+
47
+
10
48
  def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
11
49
  """
12
50
  Applies all patches defined in classes prefixed by ``patched_``
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
23
61
  to_patch = mod
24
62
  name = "list"
25
63
  else:
26
- to_patch = []
27
- for k in dir(mod):
28
- if k.startswith("patched_"):
29
- v = getattr(mod, k)
30
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
31
- to_patch.append(v)
32
- name = mod.__name__
64
+ name, to_patch = get_patches(mod, verbose)
33
65
 
34
66
  res = {}
35
67
  for cls in to_patch:
68
+ if isinstance(cls, dict):
69
+ # a function
70
+ keep = {}
71
+ original = cls["module"]
72
+ f = cls["function"]
73
+ res[f] = f
74
+ if verbose:
75
+ print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
76
+ setattr(original, f.__name__, cls["patch"])
77
+ continue
78
+
36
79
  original = cls._PATCHED_CLASS_
37
80
  methods = cls._PATCHES_
38
81
  if verbose:
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
57
100
  to_patch = mod
58
101
  name = "list"
59
102
  else:
60
- to_patch = []
61
- for k in dir(mod):
62
- if k.startswith("patched_"):
63
- v = getattr(mod, k)
64
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
65
- to_patch.append(v)
66
- name = mod.__name__
67
- set_patch = set(to_patch)
103
+ name, to_patch = get_patches(mod, verbose)
104
+
105
+ set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
106
+ dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
68
107
 
69
108
  for cls, methods in info.items():
70
- assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
109
+ if cls in set_patch_cls:
110
+ if verbose:
111
+ print(
112
+ f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
113
+ )
114
+ original = cls._PATCHED_CLASS_
115
+ for n, v in methods.items():
116
+ if v is None:
117
+ # The method did not exist. We remove it.
118
+ delattr(original, n)
119
+ else:
120
+ setattr(original, n, v)
121
+ continue
122
+ assert cls in dict_patch_fct, (
123
+ f"No patch registered for {cls} in {mod} "
124
+ f"(found {set_patch_cls} and {set(dict_patch_fct)})"
125
+ )
126
+ patch = dict_patch_fct[cls]
71
127
  if verbose:
72
- print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
73
- original = cls._PATCHED_CLASS_
74
- for n, v in methods.items():
75
- if v is None:
76
- # The method did not exist. We remove it.
77
- delattr(original, n)
78
- else:
79
- setattr(original, n, v)
128
+ print(
129
+ f"[unpatch_module_or_classes] function "
130
+ f"{patch['module'].__name__}.{cls.__name__}"
131
+ )
132
+ setattr(patch["module"], cls.__name__, patch["function"])
80
133
 
81
134
 
82
135
  @contextlib.contextmanager
@@ -9,9 +9,11 @@ from transformers.cache_utils import (
9
9
  MambaCache,
10
10
  EncoderDecoderCache,
11
11
  SlidingWindowCache,
12
+ StaticCache,
12
13
  )
13
14
  from transformers.modeling_outputs import BaseModelOutput
14
15
  from ..helpers import string_type
16
+ from ..helpers.cache_helper import make_static_cache
15
17
 
16
18
 
17
19
  PATCH_OF_PATCHES: Set[Any] = set()
@@ -175,6 +177,13 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]
175
177
  flatten_with_keys_sliding_window_cache,
176
178
  verbose=verbose,
177
179
  ),
180
+ StaticCache=register_class_serialization(
181
+ StaticCache,
182
+ flatten_static_cache,
183
+ unflatten_static_cache,
184
+ flatten_with_keys_static_cache,
185
+ verbose=verbose,
186
+ ),
178
187
  )
179
188
 
180
189
 
@@ -309,6 +318,34 @@ def unflatten_dynamic_cache(
309
318
  return cache
310
319
 
311
320
 
321
+ ##############
322
+ # DynamicCache
323
+ ##############
324
+
325
+
326
+ def flatten_static_cache(
327
+ cache: StaticCache,
328
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
329
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
330
+ flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
331
+ return [f[1] for f in flat], [f[0] for f in flat]
332
+
333
+
334
+ def flatten_with_keys_static_cache(
335
+ cache: StaticCache,
336
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
337
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
338
+ values, context = flatten_static_cache(cache)
339
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
340
+
341
+
342
+ def unflatten_static_cache(
343
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
344
+ ) -> StaticCache:
345
+ """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
346
+ return make_static_cache(list(zip(values[0], values[1])))
347
+
348
+
312
349
  ####################
313
350
  # SlidingWindowCache
314
351
  ####################