onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +87 -77
- onnx_diagnostic/doc.py +22 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +59 -0
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +585 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +54 -73
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +21 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +3 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +158 -103
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +41 -39
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -5,19 +5,18 @@ import re
|
|
|
5
5
|
import sys
|
|
6
6
|
import textwrap
|
|
7
7
|
import onnx
|
|
8
|
-
from typing import Any, List, Optional
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
9
|
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
|
|
10
|
-
from textwrap import dedent
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def get_parser_lighten() -> ArgumentParser:
|
|
14
13
|
parser = ArgumentParser(
|
|
15
14
|
prog="lighten",
|
|
16
|
-
description=dedent(
|
|
15
|
+
description=textwrap.dedent(
|
|
16
|
+
"""
|
|
17
|
+
Removes the weights from a heavy model, stores statistics to restore
|
|
18
|
+
random weights.
|
|
17
19
|
"""
|
|
18
|
-
Removes the weights from a heavy model, stores statistics to restore
|
|
19
|
-
random weights.
|
|
20
|
-
"""
|
|
21
20
|
),
|
|
22
21
|
epilog="This is mostly used to write unit tests without adding "
|
|
23
22
|
"a big onnx file to the repository.",
|
|
@@ -70,11 +69,11 @@ def _cmd_lighten(argv: List[Any]):
|
|
|
70
69
|
def get_parser_unlighten() -> ArgumentParser:
|
|
71
70
|
parser = ArgumentParser(
|
|
72
71
|
prog="unlighten",
|
|
73
|
-
description=dedent(
|
|
72
|
+
description=textwrap.dedent(
|
|
73
|
+
"""
|
|
74
|
+
Restores random weights for a model reduces with command lighten,
|
|
75
|
+
the command expects to find a file nearby with extension '.stats'.
|
|
74
76
|
"""
|
|
75
|
-
Restores random weights for a model reduces with command lighten,
|
|
76
|
-
the command expects to find a file nearby with extension '.stats'.
|
|
77
|
-
"""
|
|
78
77
|
),
|
|
79
78
|
epilog="This is mostly used to write unit tests without adding "
|
|
80
79
|
"a big onnx file to the repository.",
|
|
@@ -120,11 +119,7 @@ def _cmd_unlighten(argv: List[Any]):
|
|
|
120
119
|
def get_parser_print() -> ArgumentParser:
|
|
121
120
|
parser = ArgumentParser(
|
|
122
121
|
prog="print",
|
|
123
|
-
description=
|
|
124
|
-
"""
|
|
125
|
-
Prints the model on the standard output.
|
|
126
|
-
"""
|
|
127
|
-
),
|
|
122
|
+
description="Prints the model on the standard output.",
|
|
128
123
|
epilog="To show a model.",
|
|
129
124
|
formatter_class=RawTextHelpFormatter,
|
|
130
125
|
)
|
|
@@ -171,11 +166,11 @@ def _cmd_print(argv: List[Any]):
|
|
|
171
166
|
def get_parser_find() -> ArgumentParser:
|
|
172
167
|
parser = ArgumentParser(
|
|
173
168
|
prog="find",
|
|
174
|
-
description=dedent(
|
|
169
|
+
description=textwrap.dedent(
|
|
170
|
+
"""
|
|
171
|
+
Look into a model and search for a set of names,
|
|
172
|
+
tells which node is consuming or producing it.
|
|
175
173
|
"""
|
|
176
|
-
Look into a model and search for a set of names,
|
|
177
|
-
tells which node is consuming or producing it.
|
|
178
|
-
"""
|
|
179
174
|
),
|
|
180
175
|
epilog="Enables Some quick validation.",
|
|
181
176
|
)
|
|
@@ -191,8 +186,8 @@ def get_parser_find() -> ArgumentParser:
|
|
|
191
186
|
"--names",
|
|
192
187
|
type=str,
|
|
193
188
|
required=False,
|
|
194
|
-
help="
|
|
195
|
-
"search for shadowing names",
|
|
189
|
+
help="Names to look at comma separated values, if 'SHADOW', "
|
|
190
|
+
"search for shadowing names.",
|
|
196
191
|
)
|
|
197
192
|
parser.add_argument(
|
|
198
193
|
"-v",
|
|
@@ -206,7 +201,7 @@ def get_parser_find() -> ArgumentParser:
|
|
|
206
201
|
"--v2",
|
|
207
202
|
default=False,
|
|
208
203
|
action=BooleanOptionalAction,
|
|
209
|
-
help="
|
|
204
|
+
help="Uses enumerate_results instead of onnx_find.",
|
|
210
205
|
)
|
|
211
206
|
return parser
|
|
212
207
|
|
|
@@ -235,12 +230,13 @@ def _cmd_find(argv: List[Any]):
|
|
|
235
230
|
def get_parser_config() -> ArgumentParser:
|
|
236
231
|
parser = ArgumentParser(
|
|
237
232
|
prog="config",
|
|
238
|
-
description=dedent(
|
|
233
|
+
description=textwrap.dedent(
|
|
234
|
+
"""
|
|
235
|
+
Prints out a configuration for a model id,
|
|
236
|
+
prints the associated task as well.
|
|
239
237
|
"""
|
|
240
|
-
Prints out a configuration for a model id,
|
|
241
|
-
prints the associated task as well.
|
|
242
|
-
"""
|
|
243
238
|
),
|
|
239
|
+
formatter_class=RawTextHelpFormatter,
|
|
244
240
|
epilog="",
|
|
245
241
|
)
|
|
246
242
|
parser.add_argument(
|
|
@@ -248,29 +244,29 @@ def get_parser_config() -> ArgumentParser:
|
|
|
248
244
|
"--mid",
|
|
249
245
|
type=str,
|
|
250
246
|
required=True,
|
|
251
|
-
help="model id, usually
|
|
247
|
+
help="model id, usually `<author>/<name>`",
|
|
252
248
|
)
|
|
253
249
|
parser.add_argument(
|
|
254
250
|
"-t",
|
|
255
251
|
"--task",
|
|
256
252
|
default=False,
|
|
257
253
|
action=BooleanOptionalAction,
|
|
258
|
-
help="
|
|
254
|
+
help="Displays the task as well.",
|
|
259
255
|
)
|
|
260
256
|
parser.add_argument(
|
|
261
257
|
"-c",
|
|
262
258
|
"--cached",
|
|
263
259
|
default=True,
|
|
264
260
|
action=BooleanOptionalAction,
|
|
265
|
-
help="
|
|
266
|
-
"mostly for unit test purposes",
|
|
261
|
+
help="Uses cached configuration, only available for some of them,\n"
|
|
262
|
+
"mostly for unit test purposes.",
|
|
267
263
|
)
|
|
268
264
|
parser.add_argument(
|
|
269
265
|
"--mop",
|
|
270
266
|
metavar="KEY=VALUE",
|
|
271
267
|
nargs="*",
|
|
272
268
|
help="Additional model options, use to change some parameters of the model, "
|
|
273
|
-
"example
|
|
269
|
+
"example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
|
|
274
270
|
action=_ParseDict,
|
|
275
271
|
)
|
|
276
272
|
return parser
|
|
@@ -291,6 +287,14 @@ def _cmd_config(argv: List[Any]):
|
|
|
291
287
|
print(f"task: {task_from_id(args.mid)}")
|
|
292
288
|
|
|
293
289
|
|
|
290
|
+
def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
|
|
291
|
+
assert isinstance(value, str), f"value should be string but value={value!r}"
|
|
292
|
+
if value and value[0] == "{" and value[-1] == "}":
|
|
293
|
+
# a dictionary
|
|
294
|
+
return json.loads(value.replace("'", '"'))
|
|
295
|
+
return value
|
|
296
|
+
|
|
297
|
+
|
|
294
298
|
class _ParseDict(argparse.Action):
|
|
295
299
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
296
300
|
d = getattr(namespace, self.dest) or {}
|
|
@@ -314,22 +318,23 @@ class _ParseDict(argparse.Action):
|
|
|
314
318
|
continue
|
|
315
319
|
except (TypeError, ValueError):
|
|
316
320
|
pass
|
|
317
|
-
d[key] = value
|
|
321
|
+
d[key] = _parse_json(value)
|
|
318
322
|
|
|
319
323
|
setattr(namespace, self.dest, d)
|
|
320
324
|
|
|
321
325
|
|
|
322
326
|
def get_parser_validate() -> ArgumentParser:
|
|
323
327
|
parser = ArgumentParser(
|
|
324
|
-
prog="
|
|
325
|
-
description=dedent(
|
|
328
|
+
prog="validate",
|
|
329
|
+
description=textwrap.dedent(
|
|
330
|
+
"""
|
|
331
|
+
Prints out dummy inputs for a particular task or a model id.
|
|
332
|
+
If both mid and task are empty, the command line displays the list
|
|
333
|
+
of supported tasks.
|
|
326
334
|
"""
|
|
327
|
-
Prints out dummy inputs for a particular task or a model id.
|
|
328
|
-
If both mid and task are empty, the command line displays the list
|
|
329
|
-
of supported tasks.
|
|
330
|
-
"""
|
|
331
335
|
),
|
|
332
336
|
epilog="If the model id is specified, one untrained version of it is instantiated.",
|
|
337
|
+
formatter_class=RawTextHelpFormatter,
|
|
333
338
|
)
|
|
334
339
|
parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
|
|
335
340
|
parser.add_argument("-t", "--task", default=None, help="force the task to use")
|
|
@@ -340,55 +345,61 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
340
345
|
"--run",
|
|
341
346
|
default=False,
|
|
342
347
|
action=BooleanOptionalAction,
|
|
343
|
-
help="
|
|
348
|
+
help="Runs the model to check it runs.",
|
|
344
349
|
)
|
|
345
350
|
parser.add_argument(
|
|
346
351
|
"-q",
|
|
347
352
|
"--quiet",
|
|
348
353
|
default=False,
|
|
349
354
|
action=BooleanOptionalAction,
|
|
350
|
-
help="
|
|
355
|
+
help="Catches exception, reports them in the summary.",
|
|
351
356
|
)
|
|
352
357
|
parser.add_argument(
|
|
353
358
|
"--patch",
|
|
354
359
|
default=True,
|
|
355
360
|
action=BooleanOptionalAction,
|
|
356
|
-
help="
|
|
361
|
+
help="Applies patches before exporting.",
|
|
357
362
|
)
|
|
358
363
|
parser.add_argument(
|
|
359
364
|
"--rewrite",
|
|
360
365
|
default=True,
|
|
361
366
|
action=BooleanOptionalAction,
|
|
362
|
-
help="
|
|
367
|
+
help="Applies rewrite before exporting.",
|
|
363
368
|
)
|
|
364
369
|
parser.add_argument(
|
|
365
370
|
"--stop-if-static",
|
|
366
371
|
default=0,
|
|
367
372
|
type=int,
|
|
368
|
-
help="
|
|
373
|
+
help="Raises an exception if a dynamic dimension becomes static.",
|
|
369
374
|
)
|
|
370
375
|
parser.add_argument(
|
|
371
376
|
"--trained",
|
|
372
377
|
default=False,
|
|
373
378
|
action=BooleanOptionalAction,
|
|
374
|
-
help="
|
|
379
|
+
help="Validates the trained model (requires downloading).",
|
|
380
|
+
)
|
|
381
|
+
parser.add_argument(
|
|
382
|
+
"--inputs2",
|
|
383
|
+
default=True,
|
|
384
|
+
action=BooleanOptionalAction,
|
|
385
|
+
help="Validates the model on a second set of inputs\n"
|
|
386
|
+
"to check the exported model supports dynamism.",
|
|
375
387
|
)
|
|
376
388
|
parser.add_argument(
|
|
377
389
|
"--runtime",
|
|
378
390
|
choices=["onnxruntime", "torch", "ref"],
|
|
379
391
|
default="onnxruntime",
|
|
380
|
-
help="onnx runtime to use, onnxruntime by default",
|
|
392
|
+
help="onnx runtime to use, `onnxruntime` by default",
|
|
381
393
|
)
|
|
382
394
|
parser.add_argument(
|
|
383
395
|
"-o",
|
|
384
396
|
"--dump-folder",
|
|
385
|
-
help="
|
|
386
|
-
"exported program, onnx...",
|
|
397
|
+
help="A folder is created to dumps statistics,\nexported program, onnx...",
|
|
387
398
|
)
|
|
388
399
|
parser.add_argument(
|
|
389
400
|
"--drop",
|
|
390
|
-
help="
|
|
391
|
-
"with comma separated values",
|
|
401
|
+
help="Drops the following inputs names, it should be a list\n"
|
|
402
|
+
"with comma separated values.",
|
|
392
403
|
)
|
|
393
404
|
parser.add_argument(
|
|
394
405
|
"--opset",
|
|
@@ -398,24 +409,25 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
398
409
|
)
|
|
399
410
|
parser.add_argument(
|
|
400
411
|
"--subfolder",
|
|
401
|
-
help="
|
|
412
|
+
help="Subfolder where to find the model and the configuration.",
|
|
402
413
|
)
|
|
403
414
|
parser.add_argument(
|
|
404
415
|
"--ortfusiontype",
|
|
405
416
|
required=False,
|
|
406
|
-
help="
|
|
407
|
-
"model type or multiple values separated by `|`. `ALL` can be used
|
|
408
|
-
"to run them all",
|
|
417
|
+
help="Applies onnxruntime fusion, this parameter should contain the\n"
|
|
418
|
+
"model type or multiple values separated by `|`. `ALL` can be used\n"
|
|
419
|
+
"to run them all.",
|
|
409
420
|
)
|
|
410
421
|
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
411
|
-
parser.add_argument("--dtype", help="
|
|
412
|
-
parser.add_argument("--device", help="
|
|
422
|
+
parser.add_argument("--dtype", help="Changes dtype if necessary.")
|
|
423
|
+
parser.add_argument("--device", help="Changes the device if necessary.")
|
|
413
424
|
parser.add_argument(
|
|
414
425
|
"--iop",
|
|
415
426
|
metavar="KEY=VALUE",
|
|
416
427
|
nargs="*",
|
|
417
|
-
help="Additional input options, use to change the default
|
|
418
|
-
"inputs use to export, example
|
|
428
|
+
help="Additional input options, use to change the default"
|
|
429
|
+
"inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
|
|
430
|
+
"\n --iop cls_cache=StaticCache",
|
|
419
431
|
action=_ParseDict,
|
|
420
432
|
)
|
|
421
433
|
parser.add_argument(
|
|
@@ -423,7 +435,8 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
423
435
|
metavar="KEY=VALUE",
|
|
424
436
|
nargs="*",
|
|
425
437
|
help="Additional model options, use to change some parameters of the model, "
|
|
426
|
-
"example
|
|
438
|
+
"example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
|
|
439
|
+
"--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
|
|
427
440
|
action=_ParseDict,
|
|
428
441
|
)
|
|
429
442
|
parser.add_argument(
|
|
@@ -440,7 +453,7 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
440
453
|
|
|
441
454
|
def _cmd_validate(argv: List[Any]):
|
|
442
455
|
from .helpers import string_type
|
|
443
|
-
from .torch_models.
|
|
456
|
+
from .torch_models.validate import get_inputs_for_task, validate_model
|
|
444
457
|
from .tasks import supported_tasks
|
|
445
458
|
|
|
446
459
|
parser = get_parser_validate()
|
|
@@ -492,6 +505,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
492
505
|
runtime=args.runtime,
|
|
493
506
|
repeat=args.repeat,
|
|
494
507
|
warmup=args.warmup,
|
|
508
|
+
inputs2=args.inputs2,
|
|
495
509
|
)
|
|
496
510
|
print("")
|
|
497
511
|
print("-- summary --")
|
|
@@ -502,11 +516,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
502
516
|
def get_parser_stats() -> ArgumentParser:
|
|
503
517
|
parser = ArgumentParser(
|
|
504
518
|
prog="stats",
|
|
505
|
-
description=
|
|
506
|
-
"""
|
|
507
|
-
Prints out statistics on an ONNX model.
|
|
508
|
-
"""
|
|
509
|
-
),
|
|
519
|
+
description="Prints out statistics on an ONNX model.",
|
|
510
520
|
epilog="",
|
|
511
521
|
)
|
|
512
522
|
parser.add_argument(
|
|
@@ -553,8 +563,8 @@ def get_parser_stats() -> ArgumentParser:
|
|
|
553
563
|
required=False,
|
|
554
564
|
default="",
|
|
555
565
|
type=str,
|
|
556
|
-
help="
|
|
557
|
-
"this regular expression, empty = no filter",
|
|
566
|
+
help="Keeps only tensors whose name verifies "
|
|
567
|
+
"this regular expression, empty = no filter.",
|
|
558
568
|
)
|
|
559
569
|
return parser
|
|
560
570
|
|
|
@@ -606,17 +616,17 @@ def get_main_parser() -> ArgumentParser:
|
|
|
606
616
|
formatter_class=RawTextHelpFormatter,
|
|
607
617
|
epilog=textwrap.dedent(
|
|
608
618
|
"""
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
619
|
+
Type 'python -m onnx_diagnostic <cmd> --help'
|
|
620
|
+
to get help for a specific command.
|
|
621
|
+
|
|
622
|
+
config - prints a configuration for a model id
|
|
623
|
+
find - find node consuming or producing a result
|
|
624
|
+
lighten - makes an onnx model lighter by removing the weights,
|
|
625
|
+
unlighten - restores an onnx model produces by the previous experiment
|
|
626
|
+
print - prints the model on standard output
|
|
627
|
+
validate - validate a model
|
|
628
|
+
stats - produces statistics on a model
|
|
629
|
+
"""
|
|
620
630
|
),
|
|
621
631
|
)
|
|
622
632
|
parser.add_argument(
|
onnx_diagnostic/doc.py
CHANGED
|
@@ -2,6 +2,28 @@ from typing import Optional
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
|
|
4
4
|
|
|
5
|
+
def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
|
|
6
|
+
"""Returns the latest published version."""
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
url = f"https://pypi.org/pypi/{package_name}/json"
|
|
11
|
+
response = requests.get(url)
|
|
12
|
+
|
|
13
|
+
assert response.status_code == 200, f"Unable to retrieve the version response={response}"
|
|
14
|
+
data = response.json()
|
|
15
|
+
version = data["info"]["version"]
|
|
16
|
+
return version
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
|
|
20
|
+
"Adds dev if the major version is different from the latest published one."
|
|
21
|
+
released = get_latest_pypi_version(package_name)
|
|
22
|
+
shorten_r = ".".join(released.split(".")[:2])
|
|
23
|
+
shorten_v = ".".join(version.split(".")[:2])
|
|
24
|
+
return version if shorten_r == shorten_v else f"{shorten_v}.dev"
|
|
25
|
+
|
|
26
|
+
|
|
5
27
|
def reset_torch_transformers(gallery_conf, fname):
|
|
6
28
|
"Resets torch dynamo for :epkg:`sphinx-gallery`."
|
|
7
29
|
import matplotlib.pyplot as plt
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1014,7 +1014,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1014
1014
|
msg_ = "\n".join(excs)
|
|
1015
1015
|
msg = f"{msg}\n{msg_}" if msg else msg_
|
|
1016
1016
|
raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
|
|
1017
|
-
elif expected.__class__.__name__
|
|
1017
|
+
elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
|
|
1018
1018
|
atts = {"key_cache", "value_cache"}
|
|
1019
1019
|
self.assertEqualArrayAny(
|
|
1020
1020
|
{k: expected.__dict__.get(k, None) for k in atts},
|
|
@@ -141,6 +141,65 @@ else:
|
|
|
141
141
|
return cache
|
|
142
142
|
|
|
143
143
|
|
|
144
|
+
def make_static_cache(
|
|
145
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
146
|
+
) -> transformers.cache_utils.DynamicCache:
|
|
147
|
+
"""
|
|
148
|
+
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
149
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
150
|
+
:return: :class:`transformers.cache_utils.StaticCache`
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
|
|
154
|
+
.. runpython::
|
|
155
|
+
:showcode:
|
|
156
|
+
|
|
157
|
+
import torch
|
|
158
|
+
from onnx_diagnostic.helpers import string_type
|
|
159
|
+
from onnx_diagnostic.helpers.cache_helper import make_static_cache
|
|
160
|
+
|
|
161
|
+
n_layers = 2
|
|
162
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
163
|
+
|
|
164
|
+
past_key_values = make_static_cache(
|
|
165
|
+
[
|
|
166
|
+
(
|
|
167
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
168
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
169
|
+
)
|
|
170
|
+
for i in range(n_layers)
|
|
171
|
+
]
|
|
172
|
+
)
|
|
173
|
+
print(string_type(past_key_values, with_shape=True))
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
class _config:
|
|
177
|
+
def __init__(self):
|
|
178
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
179
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
180
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
181
|
+
|
|
182
|
+
cache = transformers.cache_utils.StaticCache(
|
|
183
|
+
_config(),
|
|
184
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
185
|
+
device=key_value_pairs[0][0].device,
|
|
186
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
187
|
+
max_cache_len=key_value_pairs[0][0].shape[2],
|
|
188
|
+
)
|
|
189
|
+
for i in range(len(key_value_pairs)):
|
|
190
|
+
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
191
|
+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
192
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
193
|
+
)
|
|
194
|
+
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
195
|
+
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
196
|
+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
197
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
198
|
+
)
|
|
199
|
+
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
200
|
+
return cache
|
|
201
|
+
|
|
202
|
+
|
|
144
203
|
def make_encoder_decoder_cache(
|
|
145
204
|
self_attention_cache: transformers.cache_utils.DynamicCache,
|
|
146
205
|
cross_attention_cache: transformers.cache_utils.DynamicCache,
|
|
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
|
|
|
34
34
|
config._attn_implementation_autoset = False
|
|
35
35
|
continue
|
|
36
36
|
if isinstance(v, dict):
|
|
37
|
-
|
|
38
|
-
config, k
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
if not hasattr(config, k) or getattr(config, k) is None:
|
|
38
|
+
setattr(config, k, v)
|
|
39
|
+
continue
|
|
40
|
+
existing = getattr(config, k)
|
|
41
|
+
if type(existing) is dict:
|
|
42
|
+
existing.update(v)
|
|
43
|
+
else:
|
|
44
|
+
update_config(getattr(config, k), v)
|
|
41
45
|
continue
|
|
42
46
|
setattr(config, k, v)
|
|
43
47
|
|
|
@@ -558,7 +558,7 @@ def string_type(
|
|
|
558
558
|
print(f"[string_type] CACHE1:{type(obj)}")
|
|
559
559
|
return f"MambaCache(conv_states={c}, ssm_states={d})"
|
|
560
560
|
|
|
561
|
-
if obj.__class__.__name__ in
|
|
561
|
+
if obj.__class__.__name__ in {"DynamicCache", "SlidingWindowCache", "StaticCache"}:
|
|
562
562
|
kc = string_type(
|
|
563
563
|
obj.key_cache,
|
|
564
564
|
with_shape=with_shape,
|
|
@@ -857,7 +857,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
857
857
|
return flatten_object(list(x.values()), drop_keys=drop_keys)
|
|
858
858
|
return flatten_object(list(x.items()), drop_keys=drop_keys)
|
|
859
859
|
|
|
860
|
-
if x.__class__.__name__
|
|
860
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
|
|
861
861
|
res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
|
|
862
862
|
return tuple(res)
|
|
863
863
|
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
@@ -1424,10 +1424,37 @@ def max_diff(
|
|
|
1424
1424
|
f"level={level}"
|
|
1425
1425
|
)
|
|
1426
1426
|
|
|
1427
|
+
if expected.__class__.__name__ == "StaticCache":
|
|
1428
|
+
if got.__class__.__name__ == "StaticCache":
|
|
1429
|
+
if verbose >= 6:
|
|
1430
|
+
print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
|
|
1431
|
+
return max_diff(
|
|
1432
|
+
[expected.key_cache, expected.value_cache],
|
|
1433
|
+
[got.key_cache, got.value_cache],
|
|
1434
|
+
verbose=verbose,
|
|
1435
|
+
hist=hist,
|
|
1436
|
+
)
|
|
1437
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1438
|
+
return max_diff(
|
|
1439
|
+
[expected.key_cache, expected.value_cache],
|
|
1440
|
+
[got[0], got[1]],
|
|
1441
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1442
|
+
**_dkws,
|
|
1443
|
+
)
|
|
1444
|
+
raise AssertionError(
|
|
1445
|
+
f"StaticCache not fully implemented with classes "
|
|
1446
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1447
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1448
|
+
f"level={level}"
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1427
1451
|
if expected.__class__.__name__ == "SlidingWindowCache":
|
|
1428
1452
|
if got.__class__.__name__ == "SlidingWindowCache":
|
|
1429
1453
|
if verbose >= 6:
|
|
1430
|
-
print(
|
|
1454
|
+
print(
|
|
1455
|
+
f"[max_diff] SlidingWindowCache: "
|
|
1456
|
+
f"{string_type(expected)} ? {string_type(got)}"
|
|
1457
|
+
)
|
|
1431
1458
|
return max_diff(
|
|
1432
1459
|
[expected.key_cache, expected.value_cache],
|
|
1433
1460
|
[got.key_cache, got.value_cache],
|