onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +87 -77
  3. onnx_diagnostic/doc.py +22 -0
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +59 -0
  6. onnx_diagnostic/helpers/config_helper.py +8 -4
  7. onnx_diagnostic/helpers/helper.py +30 -3
  8. onnx_diagnostic/helpers/log_helper.py +585 -0
  9. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  10. onnx_diagnostic/helpers/model_builder_helper.py +54 -73
  11. onnx_diagnostic/helpers/torch_helper.py +18 -2
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  14. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  15. onnx_diagnostic/reference/torch_evaluator.py +21 -0
  16. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  17. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  18. onnx_diagnostic/tasks/fill_mask.py +3 -0
  19. onnx_diagnostic/tasks/image_classification.py +7 -1
  20. onnx_diagnostic/tasks/image_text_to_text.py +3 -0
  21. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  22. onnx_diagnostic/tasks/object_detection.py +3 -0
  23. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  24. onnx_diagnostic/tasks/summarization.py +3 -0
  25. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  26. onnx_diagnostic/tasks/text_classification.py +3 -0
  27. onnx_diagnostic/tasks/text_generation.py +90 -43
  28. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  29. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  30. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  32. onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
  33. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  34. onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
  35. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  36. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +158 -103
  37. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
  38. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +41 -39
  39. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
  40. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.6.3"
6
+ __version__ = "0.7.0"
7
7
  __author__ = "Xavier Dupré"
@@ -5,19 +5,18 @@ import re
5
5
  import sys
6
6
  import textwrap
7
7
  import onnx
8
- from typing import Any, List, Optional
8
+ from typing import Any, Dict, List, Optional, Union
9
9
  from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
10
- from textwrap import dedent
11
10
 
12
11
 
13
12
  def get_parser_lighten() -> ArgumentParser:
14
13
  parser = ArgumentParser(
15
14
  prog="lighten",
16
- description=dedent(
15
+ description=textwrap.dedent(
16
+ """
17
+ Removes the weights from a heavy model, stores statistics to restore
18
+ random weights.
17
19
  """
18
- Removes the weights from a heavy model, stores statistics to restore
19
- random weights.
20
- """
21
20
  ),
22
21
  epilog="This is mostly used to write unit tests without adding "
23
22
  "a big onnx file to the repository.",
@@ -70,11 +69,11 @@ def _cmd_lighten(argv: List[Any]):
70
69
  def get_parser_unlighten() -> ArgumentParser:
71
70
  parser = ArgumentParser(
72
71
  prog="unlighten",
73
- description=dedent(
72
+ description=textwrap.dedent(
73
+ """
74
+ Restores random weights for a model reduces with command lighten,
75
+ the command expects to find a file nearby with extension '.stats'.
74
76
  """
75
- Restores random weights for a model reduces with command lighten,
76
- the command expects to find a file nearby with extension '.stats'.
77
- """
78
77
  ),
79
78
  epilog="This is mostly used to write unit tests without adding "
80
79
  "a big onnx file to the repository.",
@@ -120,11 +119,7 @@ def _cmd_unlighten(argv: List[Any]):
120
119
  def get_parser_print() -> ArgumentParser:
121
120
  parser = ArgumentParser(
122
121
  prog="print",
123
- description=dedent(
124
- """
125
- Prints the model on the standard output.
126
- """
127
- ),
122
+ description="Prints the model on the standard output.",
128
123
  epilog="To show a model.",
129
124
  formatter_class=RawTextHelpFormatter,
130
125
  )
@@ -171,11 +166,11 @@ def _cmd_print(argv: List[Any]):
171
166
  def get_parser_find() -> ArgumentParser:
172
167
  parser = ArgumentParser(
173
168
  prog="find",
174
- description=dedent(
169
+ description=textwrap.dedent(
170
+ """
171
+ Look into a model and search for a set of names,
172
+ tells which node is consuming or producing it.
175
173
  """
176
- Look into a model and search for a set of names,
177
- tells which node is consuming or producing it.
178
- """
179
174
  ),
180
175
  epilog="Enables Some quick validation.",
181
176
  )
@@ -191,8 +186,8 @@ def get_parser_find() -> ArgumentParser:
191
186
  "--names",
192
187
  type=str,
193
188
  required=False,
194
- help="names to look at comma separated values, if 'SHADOW', "
195
- "search for shadowing names",
189
+ help="Names to look at comma separated values, if 'SHADOW', "
190
+ "search for shadowing names.",
196
191
  )
197
192
  parser.add_argument(
198
193
  "-v",
@@ -206,7 +201,7 @@ def get_parser_find() -> ArgumentParser:
206
201
  "--v2",
207
202
  default=False,
208
203
  action=BooleanOptionalAction,
209
- help="use enumerate_results instead of onnx_find",
204
+ help="Uses enumerate_results instead of onnx_find.",
210
205
  )
211
206
  return parser
212
207
 
@@ -235,12 +230,13 @@ def _cmd_find(argv: List[Any]):
235
230
  def get_parser_config() -> ArgumentParser:
236
231
  parser = ArgumentParser(
237
232
  prog="config",
238
- description=dedent(
233
+ description=textwrap.dedent(
234
+ """
235
+ Prints out a configuration for a model id,
236
+ prints the associated task as well.
239
237
  """
240
- Prints out a configuration for a model id,
241
- prints the associated task as well.
242
- """
243
238
  ),
239
+ formatter_class=RawTextHelpFormatter,
244
240
  epilog="",
245
241
  )
246
242
  parser.add_argument(
@@ -248,29 +244,29 @@ def get_parser_config() -> ArgumentParser:
248
244
  "--mid",
249
245
  type=str,
250
246
  required=True,
251
- help="model id, usually <author>/<name>",
247
+ help="model id, usually `<author>/<name>`",
252
248
  )
253
249
  parser.add_argument(
254
250
  "-t",
255
251
  "--task",
256
252
  default=False,
257
253
  action=BooleanOptionalAction,
258
- help="displays the task as well",
254
+ help="Displays the task as well.",
259
255
  )
260
256
  parser.add_argument(
261
257
  "-c",
262
258
  "--cached",
263
259
  default=True,
264
260
  action=BooleanOptionalAction,
265
- help="uses cached configuration, only available for some of them, "
266
- "mostly for unit test purposes",
261
+ help="Uses cached configuration, only available for some of them,\n"
262
+ "mostly for unit test purposes.",
267
263
  )
268
264
  parser.add_argument(
269
265
  "--mop",
270
266
  metavar="KEY=VALUE",
271
267
  nargs="*",
272
268
  help="Additional model options, use to change some parameters of the model, "
273
- "example: --mop attn_implementation=eager",
269
+ "example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
274
270
  action=_ParseDict,
275
271
  )
276
272
  return parser
@@ -291,6 +287,14 @@ def _cmd_config(argv: List[Any]):
291
287
  print(f"task: {task_from_id(args.mid)}")
292
288
 
293
289
 
290
+ def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
291
+ assert isinstance(value, str), f"value should be string but value={value!r}"
292
+ if value and value[0] == "{" and value[-1] == "}":
293
+ # a dictionary
294
+ return json.loads(value.replace("'", '"'))
295
+ return value
296
+
297
+
294
298
  class _ParseDict(argparse.Action):
295
299
  def __call__(self, parser, namespace, values, option_string=None):
296
300
  d = getattr(namespace, self.dest) or {}
@@ -314,22 +318,23 @@ class _ParseDict(argparse.Action):
314
318
  continue
315
319
  except (TypeError, ValueError):
316
320
  pass
317
- d[key] = value
321
+ d[key] = _parse_json(value)
318
322
 
319
323
  setattr(namespace, self.dest, d)
320
324
 
321
325
 
322
326
  def get_parser_validate() -> ArgumentParser:
323
327
  parser = ArgumentParser(
324
- prog="test",
325
- description=dedent(
328
+ prog="validate",
329
+ description=textwrap.dedent(
330
+ """
331
+ Prints out dummy inputs for a particular task or a model id.
332
+ If both mid and task are empty, the command line displays the list
333
+ of supported tasks.
326
334
  """
327
- Prints out dummy inputs for a particular task or a model id.
328
- If both mid and task are empty, the command line displays the list
329
- of supported tasks.
330
- """
331
335
  ),
332
336
  epilog="If the model id is specified, one untrained version of it is instantiated.",
337
+ formatter_class=RawTextHelpFormatter,
333
338
  )
334
339
  parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
335
340
  parser.add_argument("-t", "--task", default=None, help="force the task to use")
@@ -340,55 +345,61 @@ def get_parser_validate() -> ArgumentParser:
340
345
  "--run",
341
346
  default=False,
342
347
  action=BooleanOptionalAction,
343
- help="runs the model to check it runs",
348
+ help="Runs the model to check it runs.",
344
349
  )
345
350
  parser.add_argument(
346
351
  "-q",
347
352
  "--quiet",
348
353
  default=False,
349
354
  action=BooleanOptionalAction,
350
- help="catches exception, report them in the summary",
355
+ help="Catches exception, reports them in the summary.",
351
356
  )
352
357
  parser.add_argument(
353
358
  "--patch",
354
359
  default=True,
355
360
  action=BooleanOptionalAction,
356
- help="applies patches before exporting",
361
+ help="Applies patches before exporting.",
357
362
  )
358
363
  parser.add_argument(
359
364
  "--rewrite",
360
365
  default=True,
361
366
  action=BooleanOptionalAction,
362
- help="applies rewrite before exporting",
367
+ help="Applies rewrite before exporting.",
363
368
  )
364
369
  parser.add_argument(
365
370
  "--stop-if-static",
366
371
  default=0,
367
372
  type=int,
368
- help="raises an exception if a dynamic dimension becomes static",
373
+ help="Raises an exception if a dynamic dimension becomes static.",
369
374
  )
370
375
  parser.add_argument(
371
376
  "--trained",
372
377
  default=False,
373
378
  action=BooleanOptionalAction,
374
- help="validate the trained model (requires downloading)",
379
+ help="Validates the trained model (requires downloading).",
380
+ )
381
+ parser.add_argument(
382
+ "--inputs2",
383
+ default=True,
384
+ action=BooleanOptionalAction,
385
+ help="Validates the model on a second set of inputs\n"
386
+ "to check the exported model supports dynamism.",
375
387
  )
376
388
  parser.add_argument(
377
389
  "--runtime",
378
390
  choices=["onnxruntime", "torch", "ref"],
379
391
  default="onnxruntime",
380
- help="onnx runtime to use, onnxruntime by default",
392
+ help="onnx runtime to use, `onnxruntime` by default",
381
393
  )
382
394
  parser.add_argument(
383
395
  "-o",
384
396
  "--dump-folder",
385
- help="if not empty, a folder is created to dumps statistics, "
386
- "exported program, onnx...",
397
+ help="A folder is created to dumps statistics,\nexported program, onnx...",
387
398
  )
388
399
  parser.add_argument(
389
400
  "--drop",
390
- help="drops the following inputs names, it should be a list "
391
- "with comma separated values",
401
+ help="Drops the following inputs names, it should be a list\n"
402
+ "with comma separated values.",
392
403
  )
393
404
  parser.add_argument(
394
405
  "--opset",
@@ -398,24 +409,25 @@ def get_parser_validate() -> ArgumentParser:
398
409
  )
399
410
  parser.add_argument(
400
411
  "--subfolder",
401
- help="subfolder where to find the model and the configuration",
412
+ help="Subfolder where to find the model and the configuration.",
402
413
  )
403
414
  parser.add_argument(
404
415
  "--ortfusiontype",
405
416
  required=False,
406
- help="applies onnxruntime fusion, this parameter should contain the "
407
- "model type or multiple values separated by `|`. `ALL` can be used "
408
- "to run them all",
417
+ help="Applies onnxruntime fusion, this parameter should contain the\n"
418
+ "model type or multiple values separated by `|`. `ALL` can be used\n"
419
+ "to run them all.",
409
420
  )
410
421
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
411
- parser.add_argument("--dtype", help="changes dtype if necessary")
412
- parser.add_argument("--device", help="changes the device if necessary")
422
+ parser.add_argument("--dtype", help="Changes dtype if necessary.")
423
+ parser.add_argument("--device", help="Changes the device if necessary.")
413
424
  parser.add_argument(
414
425
  "--iop",
415
426
  metavar="KEY=VALUE",
416
427
  nargs="*",
417
- help="Additional input options, use to change the default "
418
- "inputs use to export, example: --iop cls_cache=SlidingWindowCache",
428
+ help="Additional input options, use to change the default"
429
+ "inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
430
+ "\n --iop cls_cache=StaticCache",
419
431
  action=_ParseDict,
420
432
  )
421
433
  parser.add_argument(
@@ -423,7 +435,8 @@ def get_parser_validate() -> ArgumentParser:
423
435
  metavar="KEY=VALUE",
424
436
  nargs="*",
425
437
  help="Additional model options, use to change some parameters of the model, "
426
- "example: --mop attn_implementation=eager",
438
+ "example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
439
+ "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
427
440
  action=_ParseDict,
428
441
  )
429
442
  parser.add_argument(
@@ -440,7 +453,7 @@ def get_parser_validate() -> ArgumentParser:
440
453
 
441
454
  def _cmd_validate(argv: List[Any]):
442
455
  from .helpers import string_type
443
- from .torch_models.test_helper import get_inputs_for_task, validate_model
456
+ from .torch_models.validate import get_inputs_for_task, validate_model
444
457
  from .tasks import supported_tasks
445
458
 
446
459
  parser = get_parser_validate()
@@ -492,6 +505,7 @@ def _cmd_validate(argv: List[Any]):
492
505
  runtime=args.runtime,
493
506
  repeat=args.repeat,
494
507
  warmup=args.warmup,
508
+ inputs2=args.inputs2,
495
509
  )
496
510
  print("")
497
511
  print("-- summary --")
@@ -502,11 +516,7 @@ def _cmd_validate(argv: List[Any]):
502
516
  def get_parser_stats() -> ArgumentParser:
503
517
  parser = ArgumentParser(
504
518
  prog="stats",
505
- description=dedent(
506
- """
507
- Prints out statistics on an ONNX model.
508
- """
509
- ),
519
+ description="Prints out statistics on an ONNX model.",
510
520
  epilog="",
511
521
  )
512
522
  parser.add_argument(
@@ -553,8 +563,8 @@ def get_parser_stats() -> ArgumentParser:
553
563
  required=False,
554
564
  default="",
555
565
  type=str,
556
- help="keeps only tensors whose name verifies "
557
- "this regular expression, empty = no filter",
566
+ help="Keeps only tensors whose name verifies "
567
+ "this regular expression, empty = no filter.",
558
568
  )
559
569
  return parser
560
570
 
@@ -606,17 +616,17 @@ def get_main_parser() -> ArgumentParser:
606
616
  formatter_class=RawTextHelpFormatter,
607
617
  epilog=textwrap.dedent(
608
618
  """
609
- Type 'python -m onnx_diagnostic <cmd> --help'
610
- to get help for a specific command.
611
-
612
- config - prints a configuration for a model id
613
- find - find node consuming or producing a result
614
- lighten - makes an onnx model lighter by removing the weights,
615
- unlighten - restores an onnx model produces by the previous experiment
616
- print - prints the model on standard output
617
- validate - validate a model
618
- stats - produces statistics on a model
619
- """
619
+ Type 'python -m onnx_diagnostic <cmd> --help'
620
+ to get help for a specific command.
621
+
622
+ config - prints a configuration for a model id
623
+ find - find node consuming or producing a result
624
+ lighten - makes an onnx model lighter by removing the weights,
625
+ unlighten - restores an onnx model produces by the previous experiment
626
+ print - prints the model on standard output
627
+ validate - validate a model
628
+ stats - produces statistics on a model
629
+ """
620
630
  ),
621
631
  )
622
632
  parser.add_argument(
onnx_diagnostic/doc.py CHANGED
@@ -2,6 +2,28 @@ from typing import Optional
2
2
  import numpy as np
3
3
 
4
4
 
5
+ def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
6
+ """Returns the latest published version."""
7
+
8
+ import requests
9
+
10
+ url = f"https://pypi.org/pypi/{package_name}/json"
11
+ response = requests.get(url)
12
+
13
+ assert response.status_code == 200, f"Unable to retrieve the version response={response}"
14
+ data = response.json()
15
+ version = data["info"]["version"]
16
+ return version
17
+
18
+
19
+ def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
20
+ "Adds dev if the major version is different from the latest published one."
21
+ released = get_latest_pypi_version(package_name)
22
+ shorten_r = ".".join(released.split(".")[:2])
23
+ shorten_v = ".".join(version.split(".")[:2])
24
+ return version if shorten_r == shorten_v else f"{shorten_v}.dev"
25
+
26
+
5
27
  def reset_torch_transformers(gallery_conf, fname):
6
28
  "Resets torch dynamo for :epkg:`sphinx-gallery`."
7
29
  import matplotlib.pyplot as plt
@@ -1014,7 +1014,7 @@ class ExtTestCase(unittest.TestCase):
1014
1014
  msg_ = "\n".join(excs)
1015
1015
  msg = f"{msg}\n{msg_}" if msg else msg_
1016
1016
  raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
1017
- elif expected.__class__.__name__ == "DynamicCache":
1017
+ elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
1018
1018
  atts = {"key_cache", "value_cache"}
1019
1019
  self.assertEqualArrayAny(
1020
1020
  {k: expected.__dict__.get(k, None) for k in atts},
@@ -141,6 +141,65 @@ else:
141
141
  return cache
142
142
 
143
143
 
144
+ def make_static_cache(
145
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
146
+ ) -> transformers.cache_utils.DynamicCache:
147
+ """
148
+ Creates an instance of :class:`transformers.cache_utils.StaticCache`.
149
+ :param key_value_pairs: list of pairs of (key, values)
150
+ :return: :class:`transformers.cache_utils.StaticCache`
151
+
152
+ Example:
153
+
154
+ .. runpython::
155
+ :showcode:
156
+
157
+ import torch
158
+ from onnx_diagnostic.helpers import string_type
159
+ from onnx_diagnostic.helpers.cache_helper import make_static_cache
160
+
161
+ n_layers = 2
162
+ bsize, nheads, slen, dim = 2, 4, 3, 7
163
+
164
+ past_key_values = make_static_cache(
165
+ [
166
+ (
167
+ torch.randn(bsize, nheads, slen, dim),
168
+ torch.randn(bsize, nheads, slen, dim),
169
+ )
170
+ for i in range(n_layers)
171
+ ]
172
+ )
173
+ print(string_type(past_key_values, with_shape=True))
174
+ """
175
+
176
+ class _config:
177
+ def __init__(self):
178
+ self.head_dim = key_value_pairs[0][0].shape[-1]
179
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
180
+ self.num_hidden_layers = len(key_value_pairs)
181
+
182
+ cache = transformers.cache_utils.StaticCache(
183
+ _config(),
184
+ max_batch_size=key_value_pairs[0][0].shape[0],
185
+ device=key_value_pairs[0][0].device,
186
+ dtype=key_value_pairs[0][0].dtype,
187
+ max_cache_len=key_value_pairs[0][0].shape[2],
188
+ )
189
+ for i in range(len(key_value_pairs)):
190
+ assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
191
+ f"Shape mismatch, expected {cache.key_cache[i].shape}, "
192
+ f"got {key_value_pairs[i][0].shape}"
193
+ )
194
+ cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
195
+ assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
196
+ f"Shape mismatch, expected {cache.value_cache[i].shape}, "
197
+ f"got {key_value_pairs[i][1].shape}"
198
+ )
199
+ cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
200
+ return cache
201
+
202
+
144
203
  def make_encoder_decoder_cache(
145
204
  self_attention_cache: transformers.cache_utils.DynamicCache,
146
205
  cross_attention_cache: transformers.cache_utils.DynamicCache,
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
34
34
  config._attn_implementation_autoset = False
35
35
  continue
36
36
  if isinstance(v, dict):
37
- assert hasattr(
38
- config, k
39
- ), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
40
- update_config(getattr(config, k), v)
37
+ if not hasattr(config, k) or getattr(config, k) is None:
38
+ setattr(config, k, v)
39
+ continue
40
+ existing = getattr(config, k)
41
+ if type(existing) is dict:
42
+ existing.update(v)
43
+ else:
44
+ update_config(getattr(config, k), v)
41
45
  continue
42
46
  setattr(config, k, v)
43
47
 
@@ -558,7 +558,7 @@ def string_type(
558
558
  print(f"[string_type] CACHE1:{type(obj)}")
559
559
  return f"MambaCache(conv_states={c}, ssm_states={d})"
560
560
 
561
- if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
561
+ if obj.__class__.__name__ in {"DynamicCache", "SlidingWindowCache", "StaticCache"}:
562
562
  kc = string_type(
563
563
  obj.key_cache,
564
564
  with_shape=with_shape,
@@ -857,7 +857,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
857
857
  return flatten_object(list(x.values()), drop_keys=drop_keys)
858
858
  return flatten_object(list(x.items()), drop_keys=drop_keys)
859
859
 
860
- if x.__class__.__name__ == "DynamicCache":
860
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
861
861
  res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
862
862
  return tuple(res)
863
863
  if x.__class__.__name__ == "EncoderDecoderCache":
@@ -1424,10 +1424,37 @@ def max_diff(
1424
1424
  f"level={level}"
1425
1425
  )
1426
1426
 
1427
+ if expected.__class__.__name__ == "StaticCache":
1428
+ if got.__class__.__name__ == "StaticCache":
1429
+ if verbose >= 6:
1430
+ print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
1431
+ return max_diff(
1432
+ [expected.key_cache, expected.value_cache],
1433
+ [got.key_cache, got.value_cache],
1434
+ verbose=verbose,
1435
+ hist=hist,
1436
+ )
1437
+ if isinstance(got, tuple) and len(got) == 2:
1438
+ return max_diff(
1439
+ [expected.key_cache, expected.value_cache],
1440
+ [got[0], got[1]],
1441
+ debug_info=_debug(expected.__class__.__name__),
1442
+ **_dkws,
1443
+ )
1444
+ raise AssertionError(
1445
+ f"StaticCache not fully implemented with classes "
1446
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1447
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1448
+ f"level={level}"
1449
+ )
1450
+
1427
1451
  if expected.__class__.__name__ == "SlidingWindowCache":
1428
1452
  if got.__class__.__name__ == "SlidingWindowCache":
1429
1453
  if verbose >= 6:
1430
- print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1454
+ print(
1455
+ f"[max_diff] SlidingWindowCache: "
1456
+ f"{string_type(expected)} ? {string_type(got)}"
1457
+ )
1431
1458
  return max_diff(
1432
1459
  [expected.key_cache, expected.value_cache],
1433
1460
  [got.key_cache, got.value_cache],