onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +281 -80
  3. onnx_diagnostic/doc.py +22 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +48 -20
  5. onnx_diagnostic/export/shape_helper.py +126 -0
  6. onnx_diagnostic/ext_test_case.py +1 -1
  7. onnx_diagnostic/helpers/cache_helper.py +78 -8
  8. onnx_diagnostic/helpers/config_helper.py +8 -4
  9. onnx_diagnostic/helpers/helper.py +30 -3
  10. onnx_diagnostic/helpers/log_helper.py +1744 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  12. onnx_diagnostic/helpers/model_builder_helper.py +54 -73
  13. onnx_diagnostic/helpers/torch_helper.py +18 -2
  14. onnx_diagnostic/reference/__init__.py +1 -0
  15. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  16. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  17. onnx_diagnostic/reference/torch_evaluator.py +21 -0
  18. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  19. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  20. onnx_diagnostic/tasks/fill_mask.py +3 -0
  21. onnx_diagnostic/tasks/image_classification.py +7 -1
  22. onnx_diagnostic/tasks/image_text_to_text.py +72 -18
  23. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  24. onnx_diagnostic/tasks/object_detection.py +3 -0
  25. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  26. onnx_diagnostic/tasks/summarization.py +3 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  28. onnx_diagnostic/tasks/text_classification.py +3 -0
  29. onnx_diagnostic/tasks/text_generation.py +90 -43
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  33. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  34. onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
  35. onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
  38. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  39. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
  40. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
  41. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
  42. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
  43. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
  44. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
@@ -630,9 +630,12 @@ class ModelInputs:
630
630
  method_name: str = "forward",
631
631
  name: str = "main",
632
632
  ):
633
- assert isinstance(model, torch.nn.Module) or inspect.ismodule(
634
- model
635
- ), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
633
+ assert (
634
+ model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
635
+ ), (
636
+ f"unexpected type for model={type(model)}, "
637
+ f"it must be a torch.nn.Module or None"
638
+ )
636
639
  assert name, (
637
640
  f"name={name!r} cannot be empty this string is used to "
638
641
  f"display meaningful error messages"
@@ -641,26 +644,42 @@ class ModelInputs:
641
644
  self.model = model
642
645
  self.level = level
643
646
  self.method_name = method_name
644
- self.forward = getattr(model, method_name)
645
- self.signature = inspect.signature(self.forward)
647
+ self.forward = getattr(model, method_name) if model is not None else None
648
+ self.signature = inspect.signature(self.forward) if self.forward else None
646
649
 
647
650
  # information about the signature
648
- self.forward_parameter_names = set(
649
- p.name
650
- for p in self.signature.parameters.values()
651
- if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
651
+ self.forward_parameter_names = (
652
+ set(
653
+ p.name
654
+ for p in self.signature.parameters.values()
655
+ if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
656
+ )
657
+ if self.signature
658
+ else None
659
+ )
660
+ self.forward_ordered_parameter_names = (
661
+ list(self.signature.parameters) if self.signature else None
662
+ )
663
+ self.forward_positioned_parameter_names = (
664
+ [
665
+ p.name
666
+ for p in self.signature.parameters.values()
667
+ if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
668
+ ]
669
+ if self.signature
670
+ else None
671
+ )
672
+ names = (
673
+ [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
674
+ if self.signature
675
+ else None
652
676
  )
653
- self.forward_ordered_parameter_names = list(self.signature.parameters)
654
- self.forward_positioned_parameter_names = [
655
- p.name
656
- for p in self.signature.parameters.values()
657
- if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
658
- ]
659
- names = [
660
- p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
661
- ]
662
677
  self.forward_args = names[0] if names else None
663
- names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
678
+ names = (
679
+ [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
680
+ if self.signature
681
+ else None
682
+ )
664
683
  self.forward_kwargs = names[0] if names else None
665
684
  self.forward_custom_op_schema = None
666
685
  self.forward_need_serialization = False
@@ -711,6 +730,7 @@ class ModelInputs:
711
730
  @property
712
731
  def true_model_name(self) -> str:
713
732
  "Returns class name or module name."
733
+ assert self.model is not None, "model was None when the class was initialized."
714
734
  return (
715
735
  self.model.__class__.__name__
716
736
  if isinstance(self.model, torch.nn.Module)
@@ -942,7 +962,7 @@ class ModelInputs:
942
962
  )
943
963
  )
944
964
  names = s2.pop()
945
- for name in names:
965
+ for i, name in enumerate(names):
946
966
  assert name not in {"_diag", "verbose"}, (
947
967
  f"{self.full_name}: unexpected parameter {name!r}, names={names}"
948
968
  f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
@@ -968,6 +988,14 @@ class ModelInputs:
968
988
  with the corresponding dynamic shapes.
969
989
  *kwargs*, *dynamic_shapes* are modified inplace.
970
990
  """
991
+ assert (
992
+ self.signature is not None
993
+ and self.forward_parameter_names is not None
994
+ and self.forward_ordered_parameter_names is not None
995
+ ), (
996
+ "model was None when the class was initialized, "
997
+ "cannot move args to kwargs without the signature."
998
+ )
971
999
  sig = self.signature
972
1000
  arg_dyn, kw_dyn = dynamic_shapes
973
1001
  for i, p in enumerate(sig.parameters):
@@ -0,0 +1,126 @@
1
+ from typing import Any, Dict, List, Set, Tuple, Union
2
+ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3
+ from .dynamic_shapes import ModelInputs
4
+
5
+
6
+ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
7
+ """
8
+ Returns the dynamic shapes for the given inputs.
9
+ All dimensions are considered as dynamic.
10
+ ``dim_prefix`` can be a string (the function uses it as a prefix),
11
+ or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12
+
13
+ .. runpython::
14
+ :showcode:
15
+
16
+ import pprint
17
+ import torch
18
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
20
+
21
+ bsize, nheads, slen, dim = 2, 1, 30, 96
22
+ inputs = dict(
23
+ input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
24
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
25
+ position_ids=torch.arange(3, dtype=torch.int64),
26
+ past_key_values=make_dynamic_cache(
27
+ [(torch.randn(bsize, nheads, slen, dim),
28
+ torch.randn(bsize, nheads, slen, dim))]
29
+ ),
30
+ )
31
+ ds = all_dynamic_shape_from_inputs(inputs)
32
+ pprint.pprint(ds)
33
+ """
34
+ if isinstance(dim_prefix, str):
35
+ prefixes: Set[str] = set()
36
+
37
+ def tensor_to_shape(tensor):
38
+ n = len(prefixes)
39
+ p = f"{dim_prefix}_{n}"
40
+ prefixes.add(p)
41
+ return {i: f"{p}_{i}" for i in range(tensor.ndim)}
42
+
43
+ else:
44
+
45
+ def tensor_to_shape(tensor):
46
+ return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
47
+
48
+ return flatten_unflatten_for_dynamic_shapes(
49
+ inputs, change_function=tensor_to_shape, use_dict=True
50
+ )
51
+
52
+
53
+ def guess_dynamic_shapes_from_inputs(
54
+ inputs: List[Any], auto: Union[bool, str] = False
55
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
56
+ """
57
+ Guesses which dimension is dimension from a set of inputs.
58
+ Every dimension having different values over multiple sets
59
+ of inputs. Every dimension not changing remains static.
60
+
61
+ :param inputs: a list of input sets
62
+ :param auto: True for ``torch.export.Dim.AUTO``,
63
+ False for ``torch.export.Dim.DYNAMIC``,
64
+ a string to get a unique string for every dynamic dimension
65
+ :return: args and kwargs
66
+
67
+ .. runpython::
68
+ :showcode:
69
+
70
+ import pprint
71
+ import torch
72
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
73
+ from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
74
+
75
+ bsize, nheads, slen, dim = 2, 1, 30, 96
76
+ inputs1 = dict(
77
+ input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
78
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
79
+ position_ids=torch.arange(3, dtype=torch.int64),
80
+ past_key_values=make_dynamic_cache(
81
+ [
82
+ (
83
+ torch.randn(bsize, nheads, slen, dim),
84
+ torch.randn(bsize, nheads, slen, dim),
85
+ ),
86
+ ]
87
+ ),
88
+ )
89
+ bsize, nheads, slen, dim = 3, 1, 33, 96
90
+ inputs2 = dict(
91
+ input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
92
+ attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
93
+ position_ids=torch.arange(4, dtype=torch.int64),
94
+ past_key_values=make_dynamic_cache(
95
+ [
96
+ (
97
+ torch.randn(bsize, nheads, slen, dim),
98
+ torch.randn(bsize, nheads, slen, dim),
99
+ ),
100
+ ]
101
+ ),
102
+ )
103
+ ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
104
+ pprint.pprint(ds)
105
+
106
+ This function returns something equivalent to function
107
+ :class:`torch.export.dynamic_shapes.AdditionalInputs` but this
108
+ one needs a model.
109
+
110
+ .. runpython::
111
+ :showcode:
112
+
113
+ import pprint
114
+ import torch
115
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
116
+ from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
117
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
118
+
119
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
120
+ ds = torch.export.dynamic_shapes.AdditionalInputs()
121
+ ds.add((), data["inputs"])
122
+ ds.add((), data["inputs2"])
123
+ pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
124
+ """
125
+ mi = ModelInputs(None, inputs)
126
+ return mi.guess_dynamic_shapes(auto=auto)
@@ -1014,7 +1014,7 @@ class ExtTestCase(unittest.TestCase):
1014
1014
  msg_ = "\n".join(excs)
1015
1015
  msg = f"{msg}\n{msg_}" if msg else msg_
1016
1016
  raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
1017
- elif expected.__class__.__name__ == "DynamicCache":
1017
+ elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
1018
1018
  atts = {"key_cache", "value_cache"}
1019
1019
  self.assertEqualArrayAny(
1020
1020
  {k: expected.__dict__.get(k, None) for k in atts},
@@ -1,11 +1,15 @@
1
- from typing import Any, List, Tuple
1
+ from typing import Any, Callable, List, Optional, Tuple
2
2
  import packaging.version as pv
3
3
  import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
7
 
8
- def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
8
+ def flatten_unflatten_for_dynamic_shapes(
9
+ obj: Any,
10
+ use_dict: bool = False,
11
+ change_function: Optional[Callable[[torch.Tensor], Any]] = None,
12
+ ) -> Any:
9
13
  """
10
14
  Returns the object in a different structure similar to what
11
15
  the definition of the dynamic shapes should use.
@@ -15,11 +19,13 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
15
19
  :func:`torch.export.export` only considers the values,
16
20
  the context gives the dictionary keys but it is not expressed
17
21
  in the dynamic shapes, these specifications seems to be different
18
- for the strict and non strict mode.
22
+ for the strict and non strict mode. It also preserves tuple.
23
+ :param change_function: to modifies the tensor in the structure itself,
24
+ like replace them by a shape
19
25
  :return: the serialized object
20
26
  """
21
27
  if isinstance(obj, torch.Tensor):
22
- return obj
28
+ return change_function(obj) if change_function else obj
23
29
  flat, spec = torch.utils._pytree.tree_flatten(obj)
24
30
  start = 0
25
31
  end = 0
@@ -27,12 +33,17 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
27
33
  for subspec in spec.children_specs:
28
34
  end += subspec.num_leaves
29
35
  value = subspec.unflatten(flat[start:end])
30
- value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
36
+ value = flatten_unflatten_for_dynamic_shapes(
37
+ value, use_dict=use_dict, change_function=change_function
38
+ )
31
39
  subtrees.append(value)
32
40
  start = end
33
- if use_dict and (spec.type is dict or spec.context):
34
- # This a dictionary.
35
- return dict(zip(spec.context, subtrees))
41
+ if use_dict:
42
+ if spec.type is dict or spec.context:
43
+ # This a dictionary.
44
+ return dict(zip(spec.context, subtrees))
45
+ if spec.type is tuple:
46
+ return tuple(subtrees)
36
47
  # This is a list.
37
48
  return subtrees
38
49
 
@@ -141,6 +152,65 @@ else:
141
152
  return cache
142
153
 
143
154
 
155
+ def make_static_cache(
156
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
157
+ ) -> transformers.cache_utils.DynamicCache:
158
+ """
159
+ Creates an instance of :class:`transformers.cache_utils.StaticCache`.
160
+ :param key_value_pairs: list of pairs of (key, values)
161
+ :return: :class:`transformers.cache_utils.StaticCache`
162
+
163
+ Example:
164
+
165
+ .. runpython::
166
+ :showcode:
167
+
168
+ import torch
169
+ from onnx_diagnostic.helpers import string_type
170
+ from onnx_diagnostic.helpers.cache_helper import make_static_cache
171
+
172
+ n_layers = 2
173
+ bsize, nheads, slen, dim = 2, 4, 3, 7
174
+
175
+ past_key_values = make_static_cache(
176
+ [
177
+ (
178
+ torch.randn(bsize, nheads, slen, dim),
179
+ torch.randn(bsize, nheads, slen, dim),
180
+ )
181
+ for i in range(n_layers)
182
+ ]
183
+ )
184
+ print(string_type(past_key_values, with_shape=True))
185
+ """
186
+
187
+ class _config:
188
+ def __init__(self):
189
+ self.head_dim = key_value_pairs[0][0].shape[-1]
190
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
191
+ self.num_hidden_layers = len(key_value_pairs)
192
+
193
+ cache = transformers.cache_utils.StaticCache(
194
+ _config(),
195
+ max_batch_size=key_value_pairs[0][0].shape[0],
196
+ device=key_value_pairs[0][0].device,
197
+ dtype=key_value_pairs[0][0].dtype,
198
+ max_cache_len=key_value_pairs[0][0].shape[2],
199
+ )
200
+ for i in range(len(key_value_pairs)):
201
+ assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
202
+ f"Shape mismatch, expected {cache.key_cache[i].shape}, "
203
+ f"got {key_value_pairs[i][0].shape}"
204
+ )
205
+ cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
206
+ assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
207
+ f"Shape mismatch, expected {cache.value_cache[i].shape}, "
208
+ f"got {key_value_pairs[i][1].shape}"
209
+ )
210
+ cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
211
+ return cache
212
+
213
+
144
214
  def make_encoder_decoder_cache(
145
215
  self_attention_cache: transformers.cache_utils.DynamicCache,
146
216
  cross_attention_cache: transformers.cache_utils.DynamicCache,
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
34
34
  config._attn_implementation_autoset = False
35
35
  continue
36
36
  if isinstance(v, dict):
37
- assert hasattr(
38
- config, k
39
- ), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
40
- update_config(getattr(config, k), v)
37
+ if not hasattr(config, k) or getattr(config, k) is None:
38
+ setattr(config, k, v)
39
+ continue
40
+ existing = getattr(config, k)
41
+ if type(existing) is dict:
42
+ existing.update(v)
43
+ else:
44
+ update_config(getattr(config, k), v)
41
45
  continue
42
46
  setattr(config, k, v)
43
47
 
@@ -558,7 +558,7 @@ def string_type(
558
558
  print(f"[string_type] CACHE1:{type(obj)}")
559
559
  return f"MambaCache(conv_states={c}, ssm_states={d})"
560
560
 
561
- if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
561
+ if obj.__class__.__name__ in {"DynamicCache", "SlidingWindowCache", "StaticCache"}:
562
562
  kc = string_type(
563
563
  obj.key_cache,
564
564
  with_shape=with_shape,
@@ -857,7 +857,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
857
857
  return flatten_object(list(x.values()), drop_keys=drop_keys)
858
858
  return flatten_object(list(x.items()), drop_keys=drop_keys)
859
859
 
860
- if x.__class__.__name__ == "DynamicCache":
860
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
861
861
  res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
862
862
  return tuple(res)
863
863
  if x.__class__.__name__ == "EncoderDecoderCache":
@@ -1424,10 +1424,37 @@ def max_diff(
1424
1424
  f"level={level}"
1425
1425
  )
1426
1426
 
1427
+ if expected.__class__.__name__ == "StaticCache":
1428
+ if got.__class__.__name__ == "StaticCache":
1429
+ if verbose >= 6:
1430
+ print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
1431
+ return max_diff(
1432
+ [expected.key_cache, expected.value_cache],
1433
+ [got.key_cache, got.value_cache],
1434
+ verbose=verbose,
1435
+ hist=hist,
1436
+ )
1437
+ if isinstance(got, tuple) and len(got) == 2:
1438
+ return max_diff(
1439
+ [expected.key_cache, expected.value_cache],
1440
+ [got[0], got[1]],
1441
+ debug_info=_debug(expected.__class__.__name__),
1442
+ **_dkws,
1443
+ )
1444
+ raise AssertionError(
1445
+ f"StaticCache not fully implemented with classes "
1446
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1447
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1448
+ f"level={level}"
1449
+ )
1450
+
1427
1451
  if expected.__class__.__name__ == "SlidingWindowCache":
1428
1452
  if got.__class__.__name__ == "SlidingWindowCache":
1429
1453
  if verbose >= 6:
1430
- print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1454
+ print(
1455
+ f"[max_diff] SlidingWindowCache: "
1456
+ f"{string_type(expected)} ? {string_type(got)}"
1457
+ )
1431
1458
  return max_diff(
1432
1459
  [expected.key_cache, expected.value_cache],
1433
1460
  [got.key_cache, got.value_cache],