onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
  4. onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
  5. onnx_diagnostic/ci_models/data/__init__.py +0 -0
  6. onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
  7. onnx_diagnostic/export/api.py +13 -4
  8. onnx_diagnostic/export/dynamic_shapes.py +1 -1
  9. onnx_diagnostic/export/validate.py +2 -0
  10. onnx_diagnostic/ext_test_case.py +32 -15
  11. onnx_diagnostic/helpers/args_helper.py +1 -0
  12. onnx_diagnostic/helpers/bench_run.py +0 -1
  13. onnx_diagnostic/helpers/cache_helper.py +102 -36
  14. onnx_diagnostic/helpers/doc_helper.py +7 -4
  15. onnx_diagnostic/helpers/graph_helper.py +6 -6
  16. onnx_diagnostic/helpers/helper.py +39 -0
  17. onnx_diagnostic/helpers/log_helper.py +37 -14
  18. onnx_diagnostic/helpers/memory_peak.py +5 -1
  19. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  20. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  21. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  22. onnx_diagnostic/helpers/ort_session.py +5 -2
  23. onnx_diagnostic/helpers/rt_helper.py +53 -9
  24. onnx_diagnostic/helpers/torch_helper.py +15 -11
  25. onnx_diagnostic/investigate/__init__.py +0 -0
  26. onnx_diagnostic/investigate/input_observer.py +970 -0
  27. onnx_diagnostic/reference/evaluator.py +0 -1
  28. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  29. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  30. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  31. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  32. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  33. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  34. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  35. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
  36. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  39. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
  40. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
  41. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
  42. onnx_diagnostic/torch_models/code_sample.py +5 -10
  43. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  44. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  45. onnx_diagnostic/torch_models/validate.py +1 -1
  46. onnx_diagnostic/torch_onnx/compare.py +0 -1
  47. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  48. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  49. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  50. onnx_diagnostic/typing.py +15 -0
  51. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
  52. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
  53. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
  54. onnx_diagnostic/api.py +0 -15
  55. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  56. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
@@ -428,6 +428,16 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
428
428
  new_kwargs[k] = v
429
429
  return new_kwargs
430
430
 
431
+ def is_empty_cache(self, cache):
432
+ if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"):
433
+ if len(cache.layers) == 1 and cache.layers[0].keys is None:
434
+ return True
435
+ if len(cache.layers) == 0:
436
+ return True
437
+ if cache is None:
438
+ return True
439
+ return False
440
+
431
441
  def forward(self, *args, **kwargs):
432
442
  if not self._export_done:
433
443
  inp_args = args
@@ -443,6 +453,7 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
443
453
  if v is not None
444
454
  and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
445
455
  and not isinstance(v, (bool, int, float))
456
+ and not self.is_empty_cache(v)
446
457
  }
447
458
  )
448
459
  inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
@@ -509,12 +520,10 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
509
520
  simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
510
521
  args = str(simple_sig)[1:-1]
511
522
  calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
512
- src = textwrap.dedent(
513
- f"""
523
+ src = textwrap.dedent(f"""
514
524
  def f(self, {args}):
515
525
  return self._method_call({calls_args})
516
- """
517
- )
526
+ """)
518
527
  self._method_src = src
519
528
  ns = {}
520
529
  try:
@@ -834,7 +834,7 @@ class ModelInputs:
834
834
  """Guesses the dynamic shapes for one argument."""
835
835
  if len(objs) == 0:
836
836
  return None
837
- set_types = set(type(o) for o in objs)
837
+ set_types = set(type(o) for o in objs if o is not None)
838
838
  assert (
839
839
  len(set_types) == 1
840
840
  ), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
@@ -80,6 +80,7 @@ def compare_modules(
80
80
  )
81
81
  got = modep(*_get(args), **_get(kwargs))
82
82
  if verbose:
83
+ # pyrefly: ignore[unbound-name]
83
84
  d = time.perf_counter() - begin
84
85
  print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
85
86
  if mod:
@@ -89,6 +90,7 @@ def compare_modules(
89
90
  expected = mod(*_get(args), **_get(kwargs))
90
91
  diff = max_diff(expected, got)
91
92
  if verbose:
93
+ # pyrefly: ignore[unbound-name]
92
94
  d = time.perf_counter() - begin
93
95
  print(
94
96
  f"[compare_modules] done in {d} with "
@@ -780,7 +780,7 @@ class ExtTestCase(unittest.TestCase):
780
780
 
781
781
  @property
782
782
  def verbose(self) -> int:
783
- "Returns the the value of environment variable ``VERBOSE``."
783
+ "Returns the value of environment variable ``VERBOSE``."
784
784
  return int(os.environ.get("VERBOSE", "0"))
785
785
 
786
786
  @classmethod
@@ -1028,6 +1028,19 @@ class ExtTestCase(unittest.TestCase):
1028
1028
  rtol=rtol,
1029
1029
  msg=msg,
1030
1030
  )
1031
+ elif expected.__class__.__name__ == "BaseModelOutputWithPooling":
1032
+ if expected.__class__.__name__ == value.__class__.__name__:
1033
+ self.assertEqual(len(expected), len(value), msg=msg)
1034
+ self.assertEqual(list(expected), list(value), msg=msg) # checks the order
1035
+ self.assertEqualAny(
1036
+ {k: v for k, v in expected.items()}, # noqa: C416
1037
+ {k: v for k, v in value.items()}, # noqa: C416
1038
+ atol=atol,
1039
+ rtol=rtol,
1040
+ msg=msg,
1041
+ )
1042
+ else:
1043
+ self.assertEqualArray(expected.last_hidden_state, value)
1031
1044
  elif isinstance(expected, (tuple, list, dict)):
1032
1045
  self.assertIsInstance(value, type(expected), msg=msg)
1033
1046
  self.assertEqual(len(expected), len(value), msg=msg)
@@ -1043,24 +1056,28 @@ class ExtTestCase(unittest.TestCase):
1043
1056
  "SlidingWindowCache",
1044
1057
  "HybridCache",
1045
1058
  ):
1059
+ from .helpers.cache_helper import CacheKeyValue
1060
+
1046
1061
  self.assertEqual(type(expected), type(value), msg=msg)
1047
- atts = ["key_cache", "value_cache"]
1048
- self.assertEqualAny(
1049
- {k: expected.__dict__.get(k, None) for k in atts},
1050
- {k: value.__dict__.get(k, None) for k in atts},
1051
- atol=atol,
1052
- rtol=rtol,
1053
- )
1062
+ self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
1054
1063
  elif expected.__class__.__name__ == "StaticCache":
1064
+ from .helpers.cache_helper import CacheKeyValue
1065
+
1055
1066
  self.assertEqual(type(expected), type(value), msg=msg)
1056
1067
  self.assertEqual(expected.max_cache_len, value.max_cache_len)
1057
- atts = ["key_cache", "value_cache"]
1058
- self.assertEqualAny(
1059
- {k: expected.__dict__.get(k, None) for k in atts},
1060
- {k: value.__dict__.get(k, None) for k in atts},
1061
- atol=atol,
1062
- rtol=rtol,
1063
- )
1068
+ self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
1069
+ elif expected.__class__.__name__ == "CacheKeyValue":
1070
+ self.assertEqual(type(expected), type(value), msg=msg)
1071
+ if expected.cls_layers is None:
1072
+ self.assertEqual(expected.cls_layers, value.cls_layers)
1073
+ else:
1074
+ self.assertEqualAny(
1075
+ [cls.__name__ for cls in expected.cls_layers],
1076
+ [cls.__name__ for cls in value.cls_layers],
1077
+ msg=msg,
1078
+ )
1079
+ self.assertEqualAny(expected.key_cache, value.key_cache, msg=msg)
1080
+ self.assertEqualAny(expected.value_cache, value.value_cache, msg=msg)
1064
1081
  elif expected.__class__.__name__ == "EncoderDecoderCache":
1065
1082
  self.assertEqual(type(expected), type(value), msg=msg)
1066
1083
  atts = ["self_attention_cache", "cross_attention_cache"]
@@ -105,6 +105,7 @@ def get_parsed_args(
105
105
  default=tries,
106
106
  )
107
107
  for k, v in kwargs.items():
108
+ assert isinstance(v, tuple) # type
108
109
  parser.add_argument(
109
110
  f"--{k}",
110
111
  help=f"{v[1]}, default is {v[0]}",
@@ -11,7 +11,6 @@ from argparse import Namespace
11
11
  from datetime import datetime
12
12
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13
13
 
14
-
15
14
  _DEFAULT_STRING_LIMIT = 2000
16
15
 
17
16
 
@@ -4,6 +4,19 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
+ KWARGS_LAYER = {}
8
+ if hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"):
9
+ KWARGS_LAYER.update(
10
+ {
11
+ transformers.cache_utils.DynamicSlidingWindowLayer: lambda tensor: {
12
+ "sliding_window": tensor.shape[2]
13
+ },
14
+ transformers.cache_utils.StaticSlidingWindowLayer: lambda tensor: {
15
+ "sliding_window": tensor.shape[2]
16
+ },
17
+ }
18
+ )
19
+
7
20
 
8
21
  class CacheKeyValue:
9
22
  """
@@ -90,7 +103,7 @@ def flatten_unflatten_for_dynamic_shapes(
90
103
  the context gives the dictionary keys but it is not expressed
91
104
  in the dynamic shapes, these specifications seems to be different
92
105
  for the strict and non strict mode. It also preserves tuple.
93
- :param change_function: to modifies the tensor in the structure itself,
106
+ :param change_function: to modify the tensor in the structure itself,
94
107
  like replace them by a shape
95
108
  :return: the serialized object
96
109
  """
@@ -110,7 +123,7 @@ def flatten_unflatten_for_dynamic_shapes(
110
123
  start = end
111
124
  if use_dict:
112
125
  if spec.type is dict:
113
- # This a dictionary.
126
+ # This is a dictionary.
114
127
  return dict(zip(spec.context, subtrees))
115
128
  if spec.type is tuple:
116
129
  return tuple(subtrees)
@@ -185,6 +198,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
185
198
  def make_dynamic_cache(
186
199
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
187
200
  cls_layers: Optional[Union[str, List[type]]] = None,
201
+ cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
188
202
  ) -> transformers.cache_utils.DynamicCache:
189
203
  """
190
204
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -194,6 +208,8 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
194
208
  :param cls_layers: to select the appropriate class to use on each layer,
195
209
  if specified, sliding_window is ignored, it can be a string
196
210
  if all layers are expected to follow the same class
211
+ :param cls_kwargs: arguments used to build a specific layer,
212
+ such as ``sliding_window`` for ``DynamicSlidingWindowLayer``
197
213
  :return: :class:`transformers.cache_utils.DynamicCache`
198
214
 
199
215
  Example:
@@ -224,49 +240,70 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
224
240
  are supported.
225
241
  """
226
242
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
227
- cls_kwargs = {}
228
243
  if isinstance(cls_layers, str):
229
244
  assert hasattr(
230
245
  transformers.cache_utils, cls_layers
231
- ), f"Unable to find class {cls_layers!r} in transformers.cache_utils"
232
- cls_layer = getattr(transformers.cache_utils, cls_layers)
233
- if cls_layers == "DynamicSlidingWindowLayer":
234
- cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
235
- assert isinstance(
236
- cls_kwargs["sliding_window"], int
237
- ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
238
- elif cls_layers is not None:
239
- unique = set(cls_layers)
240
- assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}"
241
- cls_layer = unique.pop()
242
- if (
243
- hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer")
244
- and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer
245
- ):
246
- from .helper import string_type
247
-
248
- assert key_value_pairs and key_value_pairs[0], (
249
- f"not implemented for key_value_pairs="
250
- f"{string_type(key_value_pairs, with_shape=True)}"
251
- )
252
- cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
253
- assert isinstance(
254
- cls_kwargs["sliding_window"], int
255
- ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
246
+ ), f"Missing layer class {cls_layers!r}"
247
+ cls_layers = getattr(transformers.cache_utils, cls_layers)
248
+ if cls_layers and not isinstance(cls_layers, list):
249
+ cls_layers = [cls_layers for _ in key_value_pairs] # type: ignore[misc]
250
+ if cls_layers is not None and isinstance(cls_layers, list):
251
+ assert len(cls_layers) == len(key_value_pairs), (
252
+ f"Length mismatch {len(key_value_pairs)} expected but "
253
+ f"{len(cls_layers)} layer types are given."
254
+ )
255
+ if cls_kwargs is None:
256
+ cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment]
257
+ assert len(cls_layers) == len(cls_kwargs), (
258
+ f"Length mismatch {len(cls_kwargs)} expected but "
259
+ f"{len(cls_layers)} layer types are given, "
260
+ f"cls_layers={cls_layers}, cls_kwargs={cls_kwargs}"
261
+ )
262
+ cls_layer = None
263
+ assert (
264
+ key_value_pairs and key_value_pairs[0]
265
+ ), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}"
266
+ for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs):
267
+ default_values = KWARGS_LAYER.get(clsy, lambda tensor: {})(kv[0])
268
+ for k, v in default_values.items():
269
+ if k not in kws:
270
+ kws[k] = v # type: ignore[index]
256
271
  else:
272
+ assert cls_kwargs is None, "cls_layers must be a list if cls_kwargs is specified"
273
+ assert (
274
+ cls_layers is None
275
+ ), f"cls_layers must be list or a string but it is {cls_layers}"
276
+ cls_kwargs = {}
257
277
  cls_layer = (
258
278
  transformers.cache_utils.DynamicLayer
259
279
  if hasattr(transformers.cache_utils, "DynamicLayer")
260
280
  else None
261
281
  )
262
282
 
283
+ if cls_layer is not None:
284
+ assert isinstance(cls_kwargs, dict), (
285
+ f"one layer = one set of arguments, cls_layer={cls_layer}, "
286
+ f"cls_kwargs={cls_kwargs}"
287
+ )
288
+ cls_layers = [cls_layer for _ in key_value_pairs]
289
+ cls_kwargs = (
290
+ cls_kwargs # type: ignore[assignment]
291
+ if isinstance(cls_kwargs, list)
292
+ else [cls_kwargs for _ in key_value_pairs]
293
+ )
294
+ elif cls_layers is not None:
295
+ assert isinstance(cls_layers, list), f"Unexpected type cls_layers={cls_layers}"
296
+ assert isinstance(cls_kwargs, list), f"Unexpected type cls_kwargs={cls_kwargs}"
297
+
263
298
  if (
264
299
  key_value_pairs
265
300
  and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
266
301
  and pv.Version(transformers.__version__) >= pv.Version("4.56")
267
302
  ):
268
303
  cache = transformers.cache_utils.DynamicCache()
269
- cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
304
+ cache.layers.extend(
305
+ [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
306
+ )
270
307
  for i, layer in enumerate(cache.layers):
271
308
  k, v = key_value_pairs[i][0], key_value_pairs[i][1]
272
309
  layer.dtype = k.dtype
@@ -281,8 +318,25 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
281
318
  return finalize_cache(cache)
282
319
 
283
320
  cache = transformers.cache_utils.DynamicCache()
284
- if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer:
285
- cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
321
+ if hasattr(cache, "layers") and (
322
+ cls_layer is None or cls_layer != transformers.cache_utils.DynamicLayer
323
+ ):
324
+ assert isinstance(cls_layers, list) and isinstance(cls_kwargs, list), (
325
+ f"Wrong type {type(cls_layers)} for cls_layers or "
326
+ f"{type(cls_kwargs)} for cls_kwargs"
327
+ )
328
+ assert len(cls_kwargs) == len(cls_layers) and len(cls_kwargs) == len(
329
+ key_value_pairs
330
+ ), (
331
+ f"Length mismatch between len(cls_kwargs)={len(cls_kwargs)}, "
332
+ f"len(cls_layers)={len(cls_layers)}, "
333
+ f"len(key_value_pairs)={len(key_value_pairs)}, "
334
+ f"cls_kwargs={cls_kwargs}, cls_layers={cls_layers}"
335
+ )
336
+ del cache.layers[:]
337
+ cache.layers.extend(
338
+ [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
339
+ )
286
340
  for i, layer in enumerate(cache.layers):
287
341
  layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
288
342
  layer.is_initialized = True
@@ -306,6 +360,7 @@ else:
306
360
  def make_dynamic_cache(
307
361
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
308
362
  cls_layers: Optional[Union[str, List[type]]] = None,
363
+ cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
309
364
  ) -> transformers.cache_utils.DynamicCache:
310
365
  """
311
366
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -337,7 +392,9 @@ else:
337
392
  )
338
393
  print(string_type(past_key_values, with_shape=True))
339
394
  """
340
- assert not cls_layers, "cls_layers cannot be used for transformers<5."
395
+ assert (
396
+ not cls_layers and not cls_kwargs
397
+ ), "cls_layers, cls_kwargs cannot be used for transformers<5."
341
398
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
342
399
  cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
343
400
  for i, (key, value) in enumerate(key_value_pairs):
@@ -348,6 +405,7 @@ else:
348
405
  def make_static_cache(
349
406
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
350
407
  max_cache_len: Optional[int] = None,
408
+ cls_layers: Optional[Union[str, List[type]]] = None,
351
409
  ) -> transformers.cache_utils.DynamicCache:
352
410
  """
353
411
  Creates an instance of :class:`transformers.cache_utils.StaticCache`.
@@ -379,6 +437,9 @@ def make_static_cache(
379
437
  )
380
438
  print(string_type(past_key_values, with_shape=True))
381
439
  """
440
+ assert not cls_layers or set(cls_layers) == {
441
+ transformers.cache_utils.StaticLayer
442
+ }, f"Not implemented when cls_layers={cls_layers!r}"
382
443
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
383
444
 
384
445
  class _config:
@@ -583,13 +644,9 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
583
644
  )
584
645
  return finalize_cache(cache)
585
646
 
586
- def get_make_hybrid_cache():
587
- return make_sliding_window_cache
588
-
589
647
  else:
590
648
  make_sliding_window_cache = None # type: ignore[assignment]
591
649
 
592
-
593
650
  if hasattr(transformers.cache_utils, "HybridCache"):
594
651
 
595
652
  def make_hybrid_cache(
@@ -775,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_
775
832
  # This is used to expand the cache when it does not contains enough layers.
776
833
  # This is needed since transformers>4.55.3
777
834
  cache.layer_class_to_replicate = cache.layers[0].__class__
835
+ assert (
836
+ not hasattr(cache, "layers")
837
+ or len(cache.layers) != 1
838
+ or cache.layers[0].keys is not None
839
+ ), (
840
+ f"Size mismatch between {len(cache.layers)=}, "
841
+ f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined]
842
+ f"first value={cache.layers[0].values}" # type: ignore[attr-defined]
843
+ )
778
844
  return cache
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Dict, List, Optional, Tuple
2
+ from typing import Any, Dict, List, Optional, Tuple
3
3
  import onnx
4
4
  import onnx.helper as oh
5
5
  import torch
@@ -46,10 +46,10 @@ class LayerNormalizationOrt(OpRunKernel):
46
46
  f"This kernel implementation only work when only one output "
47
47
  f"is required but {node.output} were."
48
48
  )
49
- self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
49
+ self._cache: Dict[Tuple[int, int], Any] = {}
50
50
  self.is_cpu = torch.device("cpu") == self.device
51
51
 
52
- def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
52
+ def _make_model(self, itype: int, rank: int, has_bias: bool) -> Any:
53
53
  shape = [*["d{i}" for i in range(rank - 1)], "last"]
54
54
  layer_model = oh.make_model(
55
55
  oh.make_graph(
@@ -88,6 +88,7 @@ class LayerNormalizationOrt(OpRunKernel):
88
88
  providers=[provider],
89
89
  )
90
90
 
91
+ # pyrefly: ignore[bad-override]
91
92
  def run(self, x, scale, bias=None):
92
93
  itype = torch_dtype_to_onnx_dtype(x.dtype)
93
94
  rank = len(x.shape)
@@ -124,7 +125,7 @@ class MatMulOrt(OpRunKernel):
124
125
  self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
125
126
  self.is_cpu = torch.device("cpu") == self.device
126
127
 
127
- def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
128
+ def _make_model(self, itype: int, ranka: int, rankb: int) -> Any:
128
129
  shapea = ["a{i}" for i in range(ranka)]
129
130
  shapeb = ["b{i}" for i in range(rankb)]
130
131
  shapec = ["c{i}" for i in range(max(ranka, rankb))]
@@ -149,6 +150,7 @@ class MatMulOrt(OpRunKernel):
149
150
  providers=[provider],
150
151
  )
151
152
 
153
+ # pyrefly: ignore[bad-override]
152
154
  def run(self, a, b):
153
155
  itype = torch_dtype_to_onnx_dtype(a.dtype)
154
156
  ranka, rankb = len(a.shape), len(b.shape)
@@ -159,5 +161,6 @@ class MatMulOrt(OpRunKernel):
159
161
  if self.verbose:
160
162
  print(f"[MatMulOrt] running on {self._provider!r}")
161
163
  feeds = dict(A=a.tensor, B=b.tensor)
164
+ # pyrefly: ignore[missing-attribute]
162
165
  got = sess.run(None, feeds)[0]
163
166
  return OpRunTensor(got)
@@ -36,7 +36,7 @@ class GraphRendering:
36
36
  :return: computation order
37
37
  """
38
38
  assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
39
- f"This algorithme is not yet implemented if the sequence contains "
39
+ f"This algorithm is not yet implemented if the sequence contains "
40
40
  f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
41
41
  )
42
42
  number = {e: start - 1 for e in (existing or [])} # noqa: C420
@@ -131,14 +131,14 @@ class GraphRendering:
131
131
  @property
132
132
  def nodes(self) -> List[onnx.NodeProto]:
133
133
  "Returns the list of nodes"
134
- return (
134
+ return list(
135
135
  self.proto.graph.node
136
136
  if isinstance(self.proto, onnx.ModelProto)
137
137
  else self.proto.node
138
138
  )
139
139
 
140
140
  @property
141
- def start_names(self) -> List[onnx.NodeProto]:
141
+ def start_names(self) -> List[str]:
142
142
  "Returns the list of known names, inputs and initializer"
143
143
  graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
144
144
  input_names = (
@@ -151,7 +151,7 @@ class GraphRendering:
151
151
  if isinstance(graph, onnx.FunctionProto)
152
152
  else [
153
153
  *[i.name for i in graph.initializer],
154
- *[i.name for i in graph.sparse_initializer],
154
+ *[i.values.name for i in graph.sparse_initializer],
155
155
  ]
156
156
  )
157
157
  return [*input_names, *init_names]
@@ -159,7 +159,7 @@ class GraphRendering:
159
159
  @property
160
160
  def input_names(self) -> List[str]:
161
161
  "Returns the list of input names."
162
- return (
162
+ return list(
163
163
  self.proto.input
164
164
  if isinstance(self.proto, onnx.FunctionProto)
165
165
  else [
@@ -173,7 +173,7 @@ class GraphRendering:
173
173
  @property
174
174
  def output_names(self) -> List[str]:
175
175
  "Returns the list of output names."
176
- return (
176
+ return list(
177
177
  self.proto.output
178
178
  if isinstance(self.proto, onnx.FunctionProto)
179
179
  else [
@@ -574,6 +574,32 @@ def string_type(
574
574
  print(f"[string_type] CACHE1:{type(obj)}")
575
575
  return f"MambaCache(conv_states={c}, ssm_states={d})"
576
576
 
577
+ if (
578
+ obj.__class__.__name__ in {"DynamicCache"}
579
+ and hasattr(obj, "layers")
580
+ and any(lay.__class__.__name__ != "DynamicLayer" for lay in obj.layers)
581
+ ):
582
+ slay = []
583
+ for lay in obj.layers:
584
+ skeys = string_type(
585
+ lay.keys,
586
+ with_shape=with_shape,
587
+ with_min_max=with_min_max,
588
+ with_device=with_device,
589
+ limit=limit,
590
+ verbose=verbose,
591
+ )
592
+ svalues = string_type(
593
+ lay.keys,
594
+ with_shape=with_shape,
595
+ with_min_max=with_min_max,
596
+ with_device=with_device,
597
+ limit=limit,
598
+ verbose=verbose,
599
+ )
600
+ slay.append(f"{lay.__class__.__name__}({skeys}, {svalues})")
601
+ return f"{obj.__class__.__name__}({', '.join(slay)})"
602
+
577
603
  if obj.__class__.__name__ in {
578
604
  "DynamicCache",
579
605
  "SlidingWindowCache",
@@ -829,6 +855,19 @@ def string_type(
829
855
  return f"{obj}"
830
856
  if obj.__class__.__name__ == "FakeTensorContext":
831
857
  return "FakeTensorContext(...)"
858
+ if obj.__class__.__name__ == "Chat":
859
+ import transformers.utils.chat_template_utils as ctu
860
+
861
+ assert isinstance(obj, ctu.Chat), f"unexpected type {type(obj)}"
862
+ msg = string_type(
863
+ obj.messages,
864
+ with_shape=with_shape,
865
+ with_min_max=with_min_max,
866
+ with_device=with_device,
867
+ limit=limit,
868
+ verbose=verbose,
869
+ )
870
+ return f"Chat({msg})"
832
871
 
833
872
  if verbose:
834
873
  print(f"[string_type] END:{type(obj)}")