onnx-diagnostic 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +108 -77
- onnx_diagnostic/doc.py +68 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +59 -0
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/doc_helper.py +27 -7
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +585 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +57 -73
- onnx_diagnostic/helpers/onnx_helper.py +291 -7
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +23 -2
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +3 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +174 -114
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +44 -42
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -5,19 +5,18 @@ import re
|
|
|
5
5
|
import sys
|
|
6
6
|
import textwrap
|
|
7
7
|
import onnx
|
|
8
|
-
from typing import Any, List, Optional
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
9
|
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
|
|
10
|
-
from textwrap import dedent
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def get_parser_lighten() -> ArgumentParser:
|
|
14
13
|
parser = ArgumentParser(
|
|
15
14
|
prog="lighten",
|
|
16
|
-
description=dedent(
|
|
15
|
+
description=textwrap.dedent(
|
|
16
|
+
"""
|
|
17
|
+
Removes the weights from a heavy model, stores statistics to restore
|
|
18
|
+
random weights.
|
|
17
19
|
"""
|
|
18
|
-
Removes the weights from a heavy model, stores statistics to restore
|
|
19
|
-
random weights.
|
|
20
|
-
"""
|
|
21
20
|
),
|
|
22
21
|
epilog="This is mostly used to write unit tests without adding "
|
|
23
22
|
"a big onnx file to the repository.",
|
|
@@ -70,11 +69,11 @@ def _cmd_lighten(argv: List[Any]):
|
|
|
70
69
|
def get_parser_unlighten() -> ArgumentParser:
|
|
71
70
|
parser = ArgumentParser(
|
|
72
71
|
prog="unlighten",
|
|
73
|
-
description=dedent(
|
|
72
|
+
description=textwrap.dedent(
|
|
73
|
+
"""
|
|
74
|
+
Restores random weights for a model reduces with command lighten,
|
|
75
|
+
the command expects to find a file nearby with extension '.stats'.
|
|
74
76
|
"""
|
|
75
|
-
Restores random weights for a model reduces with command lighten,
|
|
76
|
-
the command expects to find a file nearby with extension '.stats'.
|
|
77
|
-
"""
|
|
78
77
|
),
|
|
79
78
|
epilog="This is mostly used to write unit tests without adding "
|
|
80
79
|
"a big onnx file to the repository.",
|
|
@@ -120,11 +119,7 @@ def _cmd_unlighten(argv: List[Any]):
|
|
|
120
119
|
def get_parser_print() -> ArgumentParser:
|
|
121
120
|
parser = ArgumentParser(
|
|
122
121
|
prog="print",
|
|
123
|
-
description=
|
|
124
|
-
"""
|
|
125
|
-
Prints the model on the standard output.
|
|
126
|
-
"""
|
|
127
|
-
),
|
|
122
|
+
description="Prints the model on the standard output.",
|
|
128
123
|
epilog="To show a model.",
|
|
129
124
|
formatter_class=RawTextHelpFormatter,
|
|
130
125
|
)
|
|
@@ -171,11 +166,11 @@ def _cmd_print(argv: List[Any]):
|
|
|
171
166
|
def get_parser_find() -> ArgumentParser:
|
|
172
167
|
parser = ArgumentParser(
|
|
173
168
|
prog="find",
|
|
174
|
-
description=dedent(
|
|
169
|
+
description=textwrap.dedent(
|
|
170
|
+
"""
|
|
171
|
+
Look into a model and search for a set of names,
|
|
172
|
+
tells which node is consuming or producing it.
|
|
175
173
|
"""
|
|
176
|
-
Look into a model and search for a set of names,
|
|
177
|
-
tells which node is consuming or producing it.
|
|
178
|
-
"""
|
|
179
174
|
),
|
|
180
175
|
epilog="Enables Some quick validation.",
|
|
181
176
|
)
|
|
@@ -191,35 +186,57 @@ def get_parser_find() -> ArgumentParser:
|
|
|
191
186
|
"--names",
|
|
192
187
|
type=str,
|
|
193
188
|
required=False,
|
|
194
|
-
help="
|
|
189
|
+
help="Names to look at comma separated values, if 'SHADOW', "
|
|
190
|
+
"search for shadowing names.",
|
|
195
191
|
)
|
|
196
192
|
parser.add_argument(
|
|
197
193
|
"-v",
|
|
198
194
|
"--verbose",
|
|
199
195
|
default=0,
|
|
196
|
+
type=int,
|
|
200
197
|
required=False,
|
|
201
198
|
help="verbosity",
|
|
202
199
|
)
|
|
200
|
+
parser.add_argument(
|
|
201
|
+
"--v2",
|
|
202
|
+
default=False,
|
|
203
|
+
action=BooleanOptionalAction,
|
|
204
|
+
help="Uses enumerate_results instead of onnx_find.",
|
|
205
|
+
)
|
|
203
206
|
return parser
|
|
204
207
|
|
|
205
208
|
|
|
206
209
|
def _cmd_find(argv: List[Any]):
|
|
207
|
-
from .helpers.onnx_helper import onnx_find
|
|
210
|
+
from .helpers.onnx_helper import onnx_find, enumerate_results, shadowing_names
|
|
208
211
|
|
|
209
212
|
parser = get_parser_find()
|
|
210
213
|
args = parser.parse_args(argv[1:])
|
|
211
|
-
|
|
214
|
+
if args.names == "SHADOW":
|
|
215
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
216
|
+
s, ps = shadowing_names(onx)[:2]
|
|
217
|
+
print(f"shadowing names: {s}")
|
|
218
|
+
print(f"post-shadowing names: {ps}")
|
|
219
|
+
elif args.v2:
|
|
220
|
+
onx = onnx.load(args.input, load_external_data=False)
|
|
221
|
+
res = list(
|
|
222
|
+
enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose)
|
|
223
|
+
)
|
|
224
|
+
if not args.verbose:
|
|
225
|
+
print("\n".join(map(str, res)))
|
|
226
|
+
else:
|
|
227
|
+
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
|
|
212
228
|
|
|
213
229
|
|
|
214
230
|
def get_parser_config() -> ArgumentParser:
|
|
215
231
|
parser = ArgumentParser(
|
|
216
232
|
prog="config",
|
|
217
|
-
description=dedent(
|
|
233
|
+
description=textwrap.dedent(
|
|
234
|
+
"""
|
|
235
|
+
Prints out a configuration for a model id,
|
|
236
|
+
prints the associated task as well.
|
|
218
237
|
"""
|
|
219
|
-
Prints out a configuration for a model id,
|
|
220
|
-
prints the associated task as well.
|
|
221
|
-
"""
|
|
222
238
|
),
|
|
239
|
+
formatter_class=RawTextHelpFormatter,
|
|
223
240
|
epilog="",
|
|
224
241
|
)
|
|
225
242
|
parser.add_argument(
|
|
@@ -227,29 +244,29 @@ def get_parser_config() -> ArgumentParser:
|
|
|
227
244
|
"--mid",
|
|
228
245
|
type=str,
|
|
229
246
|
required=True,
|
|
230
|
-
help="model id, usually
|
|
247
|
+
help="model id, usually `<author>/<name>`",
|
|
231
248
|
)
|
|
232
249
|
parser.add_argument(
|
|
233
250
|
"-t",
|
|
234
251
|
"--task",
|
|
235
252
|
default=False,
|
|
236
253
|
action=BooleanOptionalAction,
|
|
237
|
-
help="
|
|
254
|
+
help="Displays the task as well.",
|
|
238
255
|
)
|
|
239
256
|
parser.add_argument(
|
|
240
257
|
"-c",
|
|
241
258
|
"--cached",
|
|
242
259
|
default=True,
|
|
243
260
|
action=BooleanOptionalAction,
|
|
244
|
-
help="
|
|
245
|
-
"mostly for unit test purposes",
|
|
261
|
+
help="Uses cached configuration, only available for some of them,\n"
|
|
262
|
+
"mostly for unit test purposes.",
|
|
246
263
|
)
|
|
247
264
|
parser.add_argument(
|
|
248
265
|
"--mop",
|
|
249
266
|
metavar="KEY=VALUE",
|
|
250
267
|
nargs="*",
|
|
251
268
|
help="Additional model options, use to change some parameters of the model, "
|
|
252
|
-
"example
|
|
269
|
+
"example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
|
|
253
270
|
action=_ParseDict,
|
|
254
271
|
)
|
|
255
272
|
return parser
|
|
@@ -270,6 +287,14 @@ def _cmd_config(argv: List[Any]):
|
|
|
270
287
|
print(f"task: {task_from_id(args.mid)}")
|
|
271
288
|
|
|
272
289
|
|
|
290
|
+
def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
|
|
291
|
+
assert isinstance(value, str), f"value should be string but value={value!r}"
|
|
292
|
+
if value and value[0] == "{" and value[-1] == "}":
|
|
293
|
+
# a dictionary
|
|
294
|
+
return json.loads(value.replace("'", '"'))
|
|
295
|
+
return value
|
|
296
|
+
|
|
297
|
+
|
|
273
298
|
class _ParseDict(argparse.Action):
|
|
274
299
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
275
300
|
d = getattr(namespace, self.dest) or {}
|
|
@@ -293,22 +318,23 @@ class _ParseDict(argparse.Action):
|
|
|
293
318
|
continue
|
|
294
319
|
except (TypeError, ValueError):
|
|
295
320
|
pass
|
|
296
|
-
d[key] = value
|
|
321
|
+
d[key] = _parse_json(value)
|
|
297
322
|
|
|
298
323
|
setattr(namespace, self.dest, d)
|
|
299
324
|
|
|
300
325
|
|
|
301
326
|
def get_parser_validate() -> ArgumentParser:
|
|
302
327
|
parser = ArgumentParser(
|
|
303
|
-
prog="
|
|
304
|
-
description=dedent(
|
|
328
|
+
prog="validate",
|
|
329
|
+
description=textwrap.dedent(
|
|
330
|
+
"""
|
|
331
|
+
Prints out dummy inputs for a particular task or a model id.
|
|
332
|
+
If both mid and task are empty, the command line displays the list
|
|
333
|
+
of supported tasks.
|
|
305
334
|
"""
|
|
306
|
-
Prints out dummy inputs for a particular task or a model id.
|
|
307
|
-
If both mid and task are empty, the command line displays the list
|
|
308
|
-
of supported tasks.
|
|
309
|
-
"""
|
|
310
335
|
),
|
|
311
336
|
epilog="If the model id is specified, one untrained version of it is instantiated.",
|
|
337
|
+
formatter_class=RawTextHelpFormatter,
|
|
312
338
|
)
|
|
313
339
|
parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
|
|
314
340
|
parser.add_argument("-t", "--task", default=None, help="force the task to use")
|
|
@@ -319,55 +345,61 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
319
345
|
"--run",
|
|
320
346
|
default=False,
|
|
321
347
|
action=BooleanOptionalAction,
|
|
322
|
-
help="
|
|
348
|
+
help="Runs the model to check it runs.",
|
|
323
349
|
)
|
|
324
350
|
parser.add_argument(
|
|
325
351
|
"-q",
|
|
326
352
|
"--quiet",
|
|
327
353
|
default=False,
|
|
328
354
|
action=BooleanOptionalAction,
|
|
329
|
-
help="
|
|
355
|
+
help="Catches exception, reports them in the summary.",
|
|
330
356
|
)
|
|
331
357
|
parser.add_argument(
|
|
332
358
|
"--patch",
|
|
333
359
|
default=True,
|
|
334
360
|
action=BooleanOptionalAction,
|
|
335
|
-
help="
|
|
361
|
+
help="Applies patches before exporting.",
|
|
336
362
|
)
|
|
337
363
|
parser.add_argument(
|
|
338
364
|
"--rewrite",
|
|
339
365
|
default=True,
|
|
340
366
|
action=BooleanOptionalAction,
|
|
341
|
-
help="
|
|
367
|
+
help="Applies rewrite before exporting.",
|
|
342
368
|
)
|
|
343
369
|
parser.add_argument(
|
|
344
370
|
"--stop-if-static",
|
|
345
371
|
default=0,
|
|
346
372
|
type=int,
|
|
347
|
-
help="
|
|
373
|
+
help="Raises an exception if a dynamic dimension becomes static.",
|
|
348
374
|
)
|
|
349
375
|
parser.add_argument(
|
|
350
376
|
"--trained",
|
|
351
377
|
default=False,
|
|
352
378
|
action=BooleanOptionalAction,
|
|
353
|
-
help="
|
|
379
|
+
help="Validates the trained model (requires downloading).",
|
|
380
|
+
)
|
|
381
|
+
parser.add_argument(
|
|
382
|
+
"--inputs2",
|
|
383
|
+
default=True,
|
|
384
|
+
action=BooleanOptionalAction,
|
|
385
|
+
help="Validates the model on a second set of inputs\n"
|
|
386
|
+
"to check the exported model supports dynamism.",
|
|
354
387
|
)
|
|
355
388
|
parser.add_argument(
|
|
356
389
|
"--runtime",
|
|
357
390
|
choices=["onnxruntime", "torch", "ref"],
|
|
358
391
|
default="onnxruntime",
|
|
359
|
-
help="onnx runtime to use, onnxruntime by default",
|
|
392
|
+
help="onnx runtime to use, `onnxruntime` by default",
|
|
360
393
|
)
|
|
361
394
|
parser.add_argument(
|
|
362
395
|
"-o",
|
|
363
396
|
"--dump-folder",
|
|
364
|
-
help="
|
|
365
|
-
"exported program, onnx...",
|
|
397
|
+
help="A folder is created to dumps statistics,\nexported program, onnx...",
|
|
366
398
|
)
|
|
367
399
|
parser.add_argument(
|
|
368
400
|
"--drop",
|
|
369
|
-
help="
|
|
370
|
-
"with comma separated values",
|
|
401
|
+
help="Drops the following inputs names, it should be a list\n"
|
|
402
|
+
"with comma separated values.",
|
|
371
403
|
)
|
|
372
404
|
parser.add_argument(
|
|
373
405
|
"--opset",
|
|
@@ -377,24 +409,25 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
377
409
|
)
|
|
378
410
|
parser.add_argument(
|
|
379
411
|
"--subfolder",
|
|
380
|
-
help="
|
|
412
|
+
help="Subfolder where to find the model and the configuration.",
|
|
381
413
|
)
|
|
382
414
|
parser.add_argument(
|
|
383
415
|
"--ortfusiontype",
|
|
384
416
|
required=False,
|
|
385
|
-
help="
|
|
386
|
-
"model type or multiple values separated by `|`. `ALL` can be used
|
|
387
|
-
"to run them all",
|
|
417
|
+
help="Applies onnxruntime fusion, this parameter should contain the\n"
|
|
418
|
+
"model type or multiple values separated by `|`. `ALL` can be used\n"
|
|
419
|
+
"to run them all.",
|
|
388
420
|
)
|
|
389
421
|
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
390
|
-
parser.add_argument("--dtype", help="
|
|
391
|
-
parser.add_argument("--device", help="
|
|
422
|
+
parser.add_argument("--dtype", help="Changes dtype if necessary.")
|
|
423
|
+
parser.add_argument("--device", help="Changes the device if necessary.")
|
|
392
424
|
parser.add_argument(
|
|
393
425
|
"--iop",
|
|
394
426
|
metavar="KEY=VALUE",
|
|
395
427
|
nargs="*",
|
|
396
|
-
help="Additional input options, use to change the default
|
|
397
|
-
"inputs use to export, example
|
|
428
|
+
help="Additional input options, use to change the default"
|
|
429
|
+
"inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
|
|
430
|
+
"\n --iop cls_cache=StaticCache",
|
|
398
431
|
action=_ParseDict,
|
|
399
432
|
)
|
|
400
433
|
parser.add_argument(
|
|
@@ -402,7 +435,8 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
402
435
|
metavar="KEY=VALUE",
|
|
403
436
|
nargs="*",
|
|
404
437
|
help="Additional model options, use to change some parameters of the model, "
|
|
405
|
-
"example
|
|
438
|
+
"example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
|
|
439
|
+
"--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
|
|
406
440
|
action=_ParseDict,
|
|
407
441
|
)
|
|
408
442
|
parser.add_argument(
|
|
@@ -419,7 +453,7 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
419
453
|
|
|
420
454
|
def _cmd_validate(argv: List[Any]):
|
|
421
455
|
from .helpers import string_type
|
|
422
|
-
from .torch_models.
|
|
456
|
+
from .torch_models.validate import get_inputs_for_task, validate_model
|
|
423
457
|
from .tasks import supported_tasks
|
|
424
458
|
|
|
425
459
|
parser = get_parser_validate()
|
|
@@ -471,6 +505,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
471
505
|
runtime=args.runtime,
|
|
472
506
|
repeat=args.repeat,
|
|
473
507
|
warmup=args.warmup,
|
|
508
|
+
inputs2=args.inputs2,
|
|
474
509
|
)
|
|
475
510
|
print("")
|
|
476
511
|
print("-- summary --")
|
|
@@ -481,11 +516,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
481
516
|
def get_parser_stats() -> ArgumentParser:
|
|
482
517
|
parser = ArgumentParser(
|
|
483
518
|
prog="stats",
|
|
484
|
-
description=
|
|
485
|
-
"""
|
|
486
|
-
Prints out statistics on an ONNX model.
|
|
487
|
-
"""
|
|
488
|
-
),
|
|
519
|
+
description="Prints out statistics on an ONNX model.",
|
|
489
520
|
epilog="",
|
|
490
521
|
)
|
|
491
522
|
parser.add_argument(
|
|
@@ -532,8 +563,8 @@ def get_parser_stats() -> ArgumentParser:
|
|
|
532
563
|
required=False,
|
|
533
564
|
default="",
|
|
534
565
|
type=str,
|
|
535
|
-
help="
|
|
536
|
-
"this regular expression, empty = no filter",
|
|
566
|
+
help="Keeps only tensors whose name verifies "
|
|
567
|
+
"this regular expression, empty = no filter.",
|
|
537
568
|
)
|
|
538
569
|
return parser
|
|
539
570
|
|
|
@@ -585,17 +616,17 @@ def get_main_parser() -> ArgumentParser:
|
|
|
585
616
|
formatter_class=RawTextHelpFormatter,
|
|
586
617
|
epilog=textwrap.dedent(
|
|
587
618
|
"""
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
619
|
+
Type 'python -m onnx_diagnostic <cmd> --help'
|
|
620
|
+
to get help for a specific command.
|
|
621
|
+
|
|
622
|
+
config - prints a configuration for a model id
|
|
623
|
+
find - find node consuming or producing a result
|
|
624
|
+
lighten - makes an onnx model lighter by removing the weights,
|
|
625
|
+
unlighten - restores an onnx model produces by the previous experiment
|
|
626
|
+
print - prints the model on standard output
|
|
627
|
+
validate - validate a model
|
|
628
|
+
stats - produces statistics on a model
|
|
629
|
+
"""
|
|
599
630
|
),
|
|
600
631
|
)
|
|
601
632
|
parser.add_argument(
|
onnx_diagnostic/doc.py
CHANGED
|
@@ -1,3 +1,29 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
|
|
6
|
+
"""Returns the latest published version."""
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
url = f"https://pypi.org/pypi/{package_name}/json"
|
|
11
|
+
response = requests.get(url)
|
|
12
|
+
|
|
13
|
+
assert response.status_code == 200, f"Unable to retrieve the version response={response}"
|
|
14
|
+
data = response.json()
|
|
15
|
+
version = data["info"]["version"]
|
|
16
|
+
return version
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
|
|
20
|
+
"Adds dev if the major version is different from the latest published one."
|
|
21
|
+
released = get_latest_pypi_version(package_name)
|
|
22
|
+
shorten_r = ".".join(released.split(".")[:2])
|
|
23
|
+
shorten_v = ".".join(version.split(".")[:2])
|
|
24
|
+
return version if shorten_r == shorten_v else f"{shorten_v}.dev"
|
|
25
|
+
|
|
26
|
+
|
|
1
27
|
def reset_torch_transformers(gallery_conf, fname):
|
|
2
28
|
"Resets torch dynamo for :epkg:`sphinx-gallery`."
|
|
3
29
|
import matplotlib.pyplot as plt
|
|
@@ -30,3 +56,45 @@ def plot_legend(
|
|
|
30
56
|
ax.grid(False)
|
|
31
57
|
ax.set_axis_off()
|
|
32
58
|
return ax
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def rotate_align(ax, angle=15, align="right"):
|
|
62
|
+
"""Rotates x-label and align them to thr right. Returns ax."""
|
|
63
|
+
for label in ax.get_xticklabels():
|
|
64
|
+
label.set_rotation(angle)
|
|
65
|
+
label.set_horizontalalignment(align)
|
|
66
|
+
return ax
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def save_fig(ax, name: str):
|
|
70
|
+
"""Applies ``tight_layout`` and saves the figures. Returns ax."""
|
|
71
|
+
import matplotlib.pyplot as plt
|
|
72
|
+
|
|
73
|
+
plt.tight_layout()
|
|
74
|
+
fig = ax.get_figure()
|
|
75
|
+
fig.savefig(name)
|
|
76
|
+
return ax
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
|
|
80
|
+
"Adds a title to axes and returns them."
|
|
81
|
+
ax.set_title(title)
|
|
82
|
+
return ax
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def plot_histogram(
|
|
86
|
+
tensor: np.ndarray,
|
|
87
|
+
ax: Optional["plt.axes"] = None, # noqa: F821
|
|
88
|
+
bins: int = 30,
|
|
89
|
+
color: str = "orange",
|
|
90
|
+
alpha: float = 0.7,
|
|
91
|
+
) -> "plt.axes": # noqa: F821
|
|
92
|
+
"Computes the distribution for a tensor."
|
|
93
|
+
if ax is None:
|
|
94
|
+
import matplotlib.pyplot as plt
|
|
95
|
+
|
|
96
|
+
ax = plt.gca()
|
|
97
|
+
ax.cla()
|
|
98
|
+
ax.hist(tensor, bins=30, color="orange", alpha=0.7)
|
|
99
|
+
ax.set_yscale("log")
|
|
100
|
+
return ax
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -1014,7 +1014,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1014
1014
|
msg_ = "\n".join(excs)
|
|
1015
1015
|
msg = f"{msg}\n{msg_}" if msg else msg_
|
|
1016
1016
|
raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
|
|
1017
|
-
elif expected.__class__.__name__
|
|
1017
|
+
elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
|
|
1018
1018
|
atts = {"key_cache", "value_cache"}
|
|
1019
1019
|
self.assertEqualArrayAny(
|
|
1020
1020
|
{k: expected.__dict__.get(k, None) for k in atts},
|
|
@@ -141,6 +141,65 @@ else:
|
|
|
141
141
|
return cache
|
|
142
142
|
|
|
143
143
|
|
|
144
|
+
def make_static_cache(
|
|
145
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
146
|
+
) -> transformers.cache_utils.DynamicCache:
|
|
147
|
+
"""
|
|
148
|
+
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
149
|
+
:param key_value_pairs: list of pairs of (key, values)
|
|
150
|
+
:return: :class:`transformers.cache_utils.StaticCache`
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
|
|
154
|
+
.. runpython::
|
|
155
|
+
:showcode:
|
|
156
|
+
|
|
157
|
+
import torch
|
|
158
|
+
from onnx_diagnostic.helpers import string_type
|
|
159
|
+
from onnx_diagnostic.helpers.cache_helper import make_static_cache
|
|
160
|
+
|
|
161
|
+
n_layers = 2
|
|
162
|
+
bsize, nheads, slen, dim = 2, 4, 3, 7
|
|
163
|
+
|
|
164
|
+
past_key_values = make_static_cache(
|
|
165
|
+
[
|
|
166
|
+
(
|
|
167
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
168
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
169
|
+
)
|
|
170
|
+
for i in range(n_layers)
|
|
171
|
+
]
|
|
172
|
+
)
|
|
173
|
+
print(string_type(past_key_values, with_shape=True))
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
class _config:
|
|
177
|
+
def __init__(self):
|
|
178
|
+
self.head_dim = key_value_pairs[0][0].shape[-1]
|
|
179
|
+
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
180
|
+
self.num_hidden_layers = len(key_value_pairs)
|
|
181
|
+
|
|
182
|
+
cache = transformers.cache_utils.StaticCache(
|
|
183
|
+
_config(),
|
|
184
|
+
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
185
|
+
device=key_value_pairs[0][0].device,
|
|
186
|
+
dtype=key_value_pairs[0][0].dtype,
|
|
187
|
+
max_cache_len=key_value_pairs[0][0].shape[2],
|
|
188
|
+
)
|
|
189
|
+
for i in range(len(key_value_pairs)):
|
|
190
|
+
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
191
|
+
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
192
|
+
f"got {key_value_pairs[i][0].shape}"
|
|
193
|
+
)
|
|
194
|
+
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
|
|
195
|
+
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
|
|
196
|
+
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
197
|
+
f"got {key_value_pairs[i][1].shape}"
|
|
198
|
+
)
|
|
199
|
+
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
200
|
+
return cache
|
|
201
|
+
|
|
202
|
+
|
|
144
203
|
def make_encoder_decoder_cache(
|
|
145
204
|
self_attention_cache: transformers.cache_utils.DynamicCache,
|
|
146
205
|
cross_attention_cache: transformers.cache_utils.DynamicCache,
|
|
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
|
|
|
34
34
|
config._attn_implementation_autoset = False
|
|
35
35
|
continue
|
|
36
36
|
if isinstance(v, dict):
|
|
37
|
-
|
|
38
|
-
config, k
|
|
39
|
-
|
|
40
|
-
|
|
37
|
+
if not hasattr(config, k) or getattr(config, k) is None:
|
|
38
|
+
setattr(config, k, v)
|
|
39
|
+
continue
|
|
40
|
+
existing = getattr(config, k)
|
|
41
|
+
if type(existing) is dict:
|
|
42
|
+
existing.update(v)
|
|
43
|
+
else:
|
|
44
|
+
update_config(getattr(config, k), v)
|
|
41
45
|
continue
|
|
42
46
|
setattr(config, k, v)
|
|
43
47
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional, Tuple
|
|
2
3
|
import onnx
|
|
3
4
|
import onnx.helper as oh
|
|
4
5
|
import torch
|
|
@@ -6,6 +7,17 @@ from ..reference.torch_ops import OpRunKernel, OpRunTensor
|
|
|
6
7
|
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
7
8
|
from .ort_session import InferenceSessionForTorch
|
|
8
9
|
|
|
10
|
+
_SAVED: List[str] = []
|
|
11
|
+
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_model_name(op_name: str, provider: str) -> Optional[str]:
|
|
15
|
+
if _SAVE_OPTIMIZED_MODEL_:
|
|
16
|
+
name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
|
|
17
|
+
_SAVED.append(name)
|
|
18
|
+
return name
|
|
19
|
+
return None
|
|
20
|
+
|
|
9
21
|
|
|
10
22
|
class LayerNormalizationOrt(OpRunKernel):
|
|
11
23
|
"LayerNormalization with onnxruntime"
|
|
@@ -13,14 +25,14 @@ class LayerNormalizationOrt(OpRunKernel):
|
|
|
13
25
|
@classmethod
|
|
14
26
|
def device_dependent(cls) -> bool:
|
|
15
27
|
"Needs device."
|
|
16
|
-
return
|
|
28
|
+
return True
|
|
17
29
|
|
|
18
30
|
def __init__(
|
|
19
31
|
self,
|
|
20
32
|
node: onnx.NodeProto,
|
|
21
33
|
version=None,
|
|
22
34
|
device: Optional[torch.device] = None,
|
|
23
|
-
verbose=0,
|
|
35
|
+
verbose: int = 0,
|
|
24
36
|
):
|
|
25
37
|
super().__init__(node, version, verbose=verbose)
|
|
26
38
|
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
@@ -70,7 +82,11 @@ class LayerNormalizationOrt(OpRunKernel):
|
|
|
70
82
|
)
|
|
71
83
|
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
|
|
72
84
|
self._provider = provider
|
|
73
|
-
return InferenceSessionForTorch(
|
|
85
|
+
return InferenceSessionForTorch(
|
|
86
|
+
layer_model,
|
|
87
|
+
optimized_model_filepath=_get_model_name("layer_norm", provider),
|
|
88
|
+
providers=[provider],
|
|
89
|
+
)
|
|
74
90
|
|
|
75
91
|
def run(self, x, scale, bias=None):
|
|
76
92
|
itype = torch_dtype_to_onnx_dtype(x.dtype)
|
|
@@ -94,14 +110,14 @@ class MatMulOrt(OpRunKernel):
|
|
|
94
110
|
@classmethod
|
|
95
111
|
def device_dependent(cls) -> bool:
|
|
96
112
|
"Needs device."
|
|
97
|
-
return
|
|
113
|
+
return True
|
|
98
114
|
|
|
99
115
|
def __init__(
|
|
100
116
|
self,
|
|
101
117
|
node: onnx.NodeProto,
|
|
102
118
|
version=None,
|
|
103
119
|
device: Optional[torch.device] = None,
|
|
104
|
-
verbose=0,
|
|
120
|
+
verbose: int = 0,
|
|
105
121
|
):
|
|
106
122
|
super().__init__(node, version, verbose=verbose)
|
|
107
123
|
self.device = device
|
|
@@ -127,7 +143,11 @@ class MatMulOrt(OpRunKernel):
|
|
|
127
143
|
)
|
|
128
144
|
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
|
|
129
145
|
self._provider = provider
|
|
130
|
-
return InferenceSessionForTorch(
|
|
146
|
+
return InferenceSessionForTorch(
|
|
147
|
+
model,
|
|
148
|
+
optimized_model_filepath=_get_model_name("matmul", provider),
|
|
149
|
+
providers=[provider],
|
|
150
|
+
)
|
|
131
151
|
|
|
132
152
|
def run(self, a, b):
|
|
133
153
|
itype = torch_dtype_to_onnx_dtype(a.dtype)
|