onnx-diagnostic 0.6.2__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -77
  3. onnx_diagnostic/doc.py +68 -0
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +59 -0
  6. onnx_diagnostic/helpers/config_helper.py +8 -4
  7. onnx_diagnostic/helpers/doc_helper.py +27 -7
  8. onnx_diagnostic/helpers/helper.py +30 -3
  9. onnx_diagnostic/helpers/log_helper.py +585 -0
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  11. onnx_diagnostic/helpers/model_builder_helper.py +57 -73
  12. onnx_diagnostic/helpers/onnx_helper.py +291 -7
  13. onnx_diagnostic/helpers/torch_helper.py +18 -2
  14. onnx_diagnostic/reference/__init__.py +1 -0
  15. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  16. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  17. onnx_diagnostic/reference/torch_evaluator.py +23 -2
  18. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  19. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  20. onnx_diagnostic/tasks/fill_mask.py +3 -0
  21. onnx_diagnostic/tasks/image_classification.py +7 -1
  22. onnx_diagnostic/tasks/image_text_to_text.py +3 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  24. onnx_diagnostic/tasks/object_detection.py +3 -0
  25. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  26. onnx_diagnostic/tasks/summarization.py +3 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  28. onnx_diagnostic/tasks/text_classification.py +3 -0
  29. onnx_diagnostic/tasks/text_generation.py +90 -43
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  33. onnx_diagnostic/torch_export_patches/patch_module_helper.py +1 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  35. onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
  38. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  39. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +174 -114
  40. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
  41. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +44 -42
  42. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
  43. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  44. {onnx_diagnostic-0.6.2.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.6.2"
6
+ __version__ = "0.7.0"
7
7
  __author__ = "Xavier Dupré"
@@ -5,19 +5,18 @@ import re
5
5
  import sys
6
6
  import textwrap
7
7
  import onnx
8
- from typing import Any, List, Optional
8
+ from typing import Any, Dict, List, Optional, Union
9
9
  from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
10
- from textwrap import dedent
11
10
 
12
11
 
13
12
  def get_parser_lighten() -> ArgumentParser:
14
13
  parser = ArgumentParser(
15
14
  prog="lighten",
16
- description=dedent(
15
+ description=textwrap.dedent(
16
+ """
17
+ Removes the weights from a heavy model, stores statistics to restore
18
+ random weights.
17
19
  """
18
- Removes the weights from a heavy model, stores statistics to restore
19
- random weights.
20
- """
21
20
  ),
22
21
  epilog="This is mostly used to write unit tests without adding "
23
22
  "a big onnx file to the repository.",
@@ -70,11 +69,11 @@ def _cmd_lighten(argv: List[Any]):
70
69
  def get_parser_unlighten() -> ArgumentParser:
71
70
  parser = ArgumentParser(
72
71
  prog="unlighten",
73
- description=dedent(
72
+ description=textwrap.dedent(
73
+ """
74
+ Restores random weights for a model reduces with command lighten,
75
+ the command expects to find a file nearby with extension '.stats'.
74
76
  """
75
- Restores random weights for a model reduces with command lighten,
76
- the command expects to find a file nearby with extension '.stats'.
77
- """
78
77
  ),
79
78
  epilog="This is mostly used to write unit tests without adding "
80
79
  "a big onnx file to the repository.",
@@ -120,11 +119,7 @@ def _cmd_unlighten(argv: List[Any]):
120
119
  def get_parser_print() -> ArgumentParser:
121
120
  parser = ArgumentParser(
122
121
  prog="print",
123
- description=dedent(
124
- """
125
- Prints the model on the standard output.
126
- """
127
- ),
122
+ description="Prints the model on the standard output.",
128
123
  epilog="To show a model.",
129
124
  formatter_class=RawTextHelpFormatter,
130
125
  )
@@ -171,11 +166,11 @@ def _cmd_print(argv: List[Any]):
171
166
  def get_parser_find() -> ArgumentParser:
172
167
  parser = ArgumentParser(
173
168
  prog="find",
174
- description=dedent(
169
+ description=textwrap.dedent(
170
+ """
171
+ Look into a model and search for a set of names,
172
+ tells which node is consuming or producing it.
175
173
  """
176
- Look into a model and search for a set of names,
177
- tells which node is consuming or producing it.
178
- """
179
174
  ),
180
175
  epilog="Enables Some quick validation.",
181
176
  )
@@ -191,35 +186,57 @@ def get_parser_find() -> ArgumentParser:
191
186
  "--names",
192
187
  type=str,
193
188
  required=False,
194
- help="names to look at comma separated values",
189
+ help="Names to look at comma separated values, if 'SHADOW', "
190
+ "search for shadowing names.",
195
191
  )
196
192
  parser.add_argument(
197
193
  "-v",
198
194
  "--verbose",
199
195
  default=0,
196
+ type=int,
200
197
  required=False,
201
198
  help="verbosity",
202
199
  )
200
+ parser.add_argument(
201
+ "--v2",
202
+ default=False,
203
+ action=BooleanOptionalAction,
204
+ help="Uses enumerate_results instead of onnx_find.",
205
+ )
203
206
  return parser
204
207
 
205
208
 
206
209
  def _cmd_find(argv: List[Any]):
207
- from .helpers.onnx_helper import onnx_find
210
+ from .helpers.onnx_helper import onnx_find, enumerate_results, shadowing_names
208
211
 
209
212
  parser = get_parser_find()
210
213
  args = parser.parse_args(argv[1:])
211
- onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
214
+ if args.names == "SHADOW":
215
+ onx = onnx.load(args.input, load_external_data=False)
216
+ s, ps = shadowing_names(onx)[:2]
217
+ print(f"shadowing names: {s}")
218
+ print(f"post-shadowing names: {ps}")
219
+ elif args.v2:
220
+ onx = onnx.load(args.input, load_external_data=False)
221
+ res = list(
222
+ enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose)
223
+ )
224
+ if not args.verbose:
225
+ print("\n".join(map(str, res)))
226
+ else:
227
+ onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
212
228
 
213
229
 
214
230
  def get_parser_config() -> ArgumentParser:
215
231
  parser = ArgumentParser(
216
232
  prog="config",
217
- description=dedent(
233
+ description=textwrap.dedent(
234
+ """
235
+ Prints out a configuration for a model id,
236
+ prints the associated task as well.
218
237
  """
219
- Prints out a configuration for a model id,
220
- prints the associated task as well.
221
- """
222
238
  ),
239
+ formatter_class=RawTextHelpFormatter,
223
240
  epilog="",
224
241
  )
225
242
  parser.add_argument(
@@ -227,29 +244,29 @@ def get_parser_config() -> ArgumentParser:
227
244
  "--mid",
228
245
  type=str,
229
246
  required=True,
230
- help="model id, usually <author>/<name>",
247
+ help="model id, usually `<author>/<name>`",
231
248
  )
232
249
  parser.add_argument(
233
250
  "-t",
234
251
  "--task",
235
252
  default=False,
236
253
  action=BooleanOptionalAction,
237
- help="displays the task as well",
254
+ help="Displays the task as well.",
238
255
  )
239
256
  parser.add_argument(
240
257
  "-c",
241
258
  "--cached",
242
259
  default=True,
243
260
  action=BooleanOptionalAction,
244
- help="uses cached configuration, only available for some of them, "
245
- "mostly for unit test purposes",
261
+ help="Uses cached configuration, only available for some of them,\n"
262
+ "mostly for unit test purposes.",
246
263
  )
247
264
  parser.add_argument(
248
265
  "--mop",
249
266
  metavar="KEY=VALUE",
250
267
  nargs="*",
251
268
  help="Additional model options, use to change some parameters of the model, "
252
- "example: --mop attn_implementation=eager",
269
+ "example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
253
270
  action=_ParseDict,
254
271
  )
255
272
  return parser
@@ -270,6 +287,14 @@ def _cmd_config(argv: List[Any]):
270
287
  print(f"task: {task_from_id(args.mid)}")
271
288
 
272
289
 
290
+ def _parse_json(value: str) -> Union[str, Dict[str, Any]]:
291
+ assert isinstance(value, str), f"value should be string but value={value!r}"
292
+ if value and value[0] == "{" and value[-1] == "}":
293
+ # a dictionary
294
+ return json.loads(value.replace("'", '"'))
295
+ return value
296
+
297
+
273
298
  class _ParseDict(argparse.Action):
274
299
  def __call__(self, parser, namespace, values, option_string=None):
275
300
  d = getattr(namespace, self.dest) or {}
@@ -293,22 +318,23 @@ class _ParseDict(argparse.Action):
293
318
  continue
294
319
  except (TypeError, ValueError):
295
320
  pass
296
- d[key] = value
321
+ d[key] = _parse_json(value)
297
322
 
298
323
  setattr(namespace, self.dest, d)
299
324
 
300
325
 
301
326
  def get_parser_validate() -> ArgumentParser:
302
327
  parser = ArgumentParser(
303
- prog="test",
304
- description=dedent(
328
+ prog="validate",
329
+ description=textwrap.dedent(
330
+ """
331
+ Prints out dummy inputs for a particular task or a model id.
332
+ If both mid and task are empty, the command line displays the list
333
+ of supported tasks.
305
334
  """
306
- Prints out dummy inputs for a particular task or a model id.
307
- If both mid and task are empty, the command line displays the list
308
- of supported tasks.
309
- """
310
335
  ),
311
336
  epilog="If the model id is specified, one untrained version of it is instantiated.",
337
+ formatter_class=RawTextHelpFormatter,
312
338
  )
313
339
  parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
314
340
  parser.add_argument("-t", "--task", default=None, help="force the task to use")
@@ -319,55 +345,61 @@ def get_parser_validate() -> ArgumentParser:
319
345
  "--run",
320
346
  default=False,
321
347
  action=BooleanOptionalAction,
322
- help="runs the model to check it runs",
348
+ help="Runs the model to check it runs.",
323
349
  )
324
350
  parser.add_argument(
325
351
  "-q",
326
352
  "--quiet",
327
353
  default=False,
328
354
  action=BooleanOptionalAction,
329
- help="catches exception, report them in the summary",
355
+ help="Catches exception, reports them in the summary.",
330
356
  )
331
357
  parser.add_argument(
332
358
  "--patch",
333
359
  default=True,
334
360
  action=BooleanOptionalAction,
335
- help="applies patches before exporting",
361
+ help="Applies patches before exporting.",
336
362
  )
337
363
  parser.add_argument(
338
364
  "--rewrite",
339
365
  default=True,
340
366
  action=BooleanOptionalAction,
341
- help="applies rewrite before exporting",
367
+ help="Applies rewrite before exporting.",
342
368
  )
343
369
  parser.add_argument(
344
370
  "--stop-if-static",
345
371
  default=0,
346
372
  type=int,
347
- help="raises an exception if a dynamic dimension becomes static",
373
+ help="Raises an exception if a dynamic dimension becomes static.",
348
374
  )
349
375
  parser.add_argument(
350
376
  "--trained",
351
377
  default=False,
352
378
  action=BooleanOptionalAction,
353
- help="validate the trained model (requires downloading)",
379
+ help="Validates the trained model (requires downloading).",
380
+ )
381
+ parser.add_argument(
382
+ "--inputs2",
383
+ default=True,
384
+ action=BooleanOptionalAction,
385
+ help="Validates the model on a second set of inputs\n"
386
+ "to check the exported model supports dynamism.",
354
387
  )
355
388
  parser.add_argument(
356
389
  "--runtime",
357
390
  choices=["onnxruntime", "torch", "ref"],
358
391
  default="onnxruntime",
359
- help="onnx runtime to use, onnxruntime by default",
392
+ help="onnx runtime to use, `onnxruntime` by default",
360
393
  )
361
394
  parser.add_argument(
362
395
  "-o",
363
396
  "--dump-folder",
364
- help="if not empty, a folder is created to dumps statistics, "
365
- "exported program, onnx...",
397
+ help="A folder is created to dumps statistics,\nexported program, onnx...",
366
398
  )
367
399
  parser.add_argument(
368
400
  "--drop",
369
- help="drops the following inputs names, it should be a list "
370
- "with comma separated values",
401
+ help="Drops the following inputs names, it should be a list\n"
402
+ "with comma separated values.",
371
403
  )
372
404
  parser.add_argument(
373
405
  "--opset",
@@ -377,24 +409,25 @@ def get_parser_validate() -> ArgumentParser:
377
409
  )
378
410
  parser.add_argument(
379
411
  "--subfolder",
380
- help="subfolder where to find the model and the configuration",
412
+ help="Subfolder where to find the model and the configuration.",
381
413
  )
382
414
  parser.add_argument(
383
415
  "--ortfusiontype",
384
416
  required=False,
385
- help="applies onnxruntime fusion, this parameter should contain the "
386
- "model type or multiple values separated by `|`. `ALL` can be used "
387
- "to run them all",
417
+ help="Applies onnxruntime fusion, this parameter should contain the\n"
418
+ "model type or multiple values separated by `|`. `ALL` can be used\n"
419
+ "to run them all.",
388
420
  )
389
421
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
390
- parser.add_argument("--dtype", help="changes dtype if necessary")
391
- parser.add_argument("--device", help="changes the device if necessary")
422
+ parser.add_argument("--dtype", help="Changes dtype if necessary.")
423
+ parser.add_argument("--device", help="Changes the device if necessary.")
392
424
  parser.add_argument(
393
425
  "--iop",
394
426
  metavar="KEY=VALUE",
395
427
  nargs="*",
396
- help="Additional input options, use to change the default "
397
- "inputs use to export, example: --iop cls_cache=SlidingWindowCache",
428
+ help="Additional input options, use to change the default"
429
+ "inputs use to export, example:\n --iop cls_cache=SlidingWindowCache"
430
+ "\n --iop cls_cache=StaticCache",
398
431
  action=_ParseDict,
399
432
  )
400
433
  parser.add_argument(
@@ -402,7 +435,8 @@ def get_parser_validate() -> ArgumentParser:
402
435
  metavar="KEY=VALUE",
403
436
  nargs="*",
404
437
  help="Additional model options, use to change some parameters of the model, "
405
- "example: --mop attn_implementation=eager",
438
+ "example:\n --mop attn_implementation=sdpa --mop attn_implementation=eager\n "
439
+ "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
406
440
  action=_ParseDict,
407
441
  )
408
442
  parser.add_argument(
@@ -419,7 +453,7 @@ def get_parser_validate() -> ArgumentParser:
419
453
 
420
454
  def _cmd_validate(argv: List[Any]):
421
455
  from .helpers import string_type
422
- from .torch_models.test_helper import get_inputs_for_task, validate_model
456
+ from .torch_models.validate import get_inputs_for_task, validate_model
423
457
  from .tasks import supported_tasks
424
458
 
425
459
  parser = get_parser_validate()
@@ -471,6 +505,7 @@ def _cmd_validate(argv: List[Any]):
471
505
  runtime=args.runtime,
472
506
  repeat=args.repeat,
473
507
  warmup=args.warmup,
508
+ inputs2=args.inputs2,
474
509
  )
475
510
  print("")
476
511
  print("-- summary --")
@@ -481,11 +516,7 @@ def _cmd_validate(argv: List[Any]):
481
516
  def get_parser_stats() -> ArgumentParser:
482
517
  parser = ArgumentParser(
483
518
  prog="stats",
484
- description=dedent(
485
- """
486
- Prints out statistics on an ONNX model.
487
- """
488
- ),
519
+ description="Prints out statistics on an ONNX model.",
489
520
  epilog="",
490
521
  )
491
522
  parser.add_argument(
@@ -532,8 +563,8 @@ def get_parser_stats() -> ArgumentParser:
532
563
  required=False,
533
564
  default="",
534
565
  type=str,
535
- help="keeps only tensors whose name verifies "
536
- "this regular expression, empty = no filter",
566
+ help="Keeps only tensors whose name verifies "
567
+ "this regular expression, empty = no filter.",
537
568
  )
538
569
  return parser
539
570
 
@@ -585,17 +616,17 @@ def get_main_parser() -> ArgumentParser:
585
616
  formatter_class=RawTextHelpFormatter,
586
617
  epilog=textwrap.dedent(
587
618
  """
588
- Type 'python -m onnx_diagnostic <cmd> --help'
589
- to get help for a specific command.
590
-
591
- config - prints a configuration for a model id
592
- find - find node consuming or producing a result
593
- lighten - makes an onnx model lighter by removing the weights,
594
- unlighten - restores an onnx model produces by the previous experiment
595
- print - prints the model on standard output
596
- validate - validate a model
597
- stats - produces statistics on a model
598
- """
619
+ Type 'python -m onnx_diagnostic <cmd> --help'
620
+ to get help for a specific command.
621
+
622
+ config - prints a configuration for a model id
623
+ find - find node consuming or producing a result
624
+ lighten - makes an onnx model lighter by removing the weights,
625
+ unlighten - restores an onnx model produces by the previous experiment
626
+ print - prints the model on standard output
627
+ validate - validate a model
628
+ stats - produces statistics on a model
629
+ """
599
630
  ),
600
631
  )
601
632
  parser.add_argument(
onnx_diagnostic/doc.py CHANGED
@@ -1,3 +1,29 @@
1
+ from typing import Optional
2
+ import numpy as np
3
+
4
+
5
+ def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
6
+ """Returns the latest published version."""
7
+
8
+ import requests
9
+
10
+ url = f"https://pypi.org/pypi/{package_name}/json"
11
+ response = requests.get(url)
12
+
13
+ assert response.status_code == 200, f"Unable to retrieve the version response={response}"
14
+ data = response.json()
15
+ version = data["info"]["version"]
16
+ return version
17
+
18
+
19
+ def update_version_package(version: str, package_name="onnx-diagnostic") -> str:
20
+ "Adds dev if the major version is different from the latest published one."
21
+ released = get_latest_pypi_version(package_name)
22
+ shorten_r = ".".join(released.split(".")[:2])
23
+ shorten_v = ".".join(version.split(".")[:2])
24
+ return version if shorten_r == shorten_v else f"{shorten_v}.dev"
25
+
26
+
1
27
  def reset_torch_transformers(gallery_conf, fname):
2
28
  "Resets torch dynamo for :epkg:`sphinx-gallery`."
3
29
  import matplotlib.pyplot as plt
@@ -30,3 +56,45 @@ def plot_legend(
30
56
  ax.grid(False)
31
57
  ax.set_axis_off()
32
58
  return ax
59
+
60
+
61
+ def rotate_align(ax, angle=15, align="right"):
62
+ """Rotates x-label and align them to thr right. Returns ax."""
63
+ for label in ax.get_xticklabels():
64
+ label.set_rotation(angle)
65
+ label.set_horizontalalignment(align)
66
+ return ax
67
+
68
+
69
+ def save_fig(ax, name: str):
70
+ """Applies ``tight_layout`` and saves the figures. Returns ax."""
71
+ import matplotlib.pyplot as plt
72
+
73
+ plt.tight_layout()
74
+ fig = ax.get_figure()
75
+ fig.savefig(name)
76
+ return ax
77
+
78
+
79
+ def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
80
+ "Adds a title to axes and returns them."
81
+ ax.set_title(title)
82
+ return ax
83
+
84
+
85
+ def plot_histogram(
86
+ tensor: np.ndarray,
87
+ ax: Optional["plt.axes"] = None, # noqa: F821
88
+ bins: int = 30,
89
+ color: str = "orange",
90
+ alpha: float = 0.7,
91
+ ) -> "plt.axes": # noqa: F821
92
+ "Computes the distribution for a tensor."
93
+ if ax is None:
94
+ import matplotlib.pyplot as plt
95
+
96
+ ax = plt.gca()
97
+ ax.cla()
98
+ ax.hist(tensor, bins=30, color="orange", alpha=0.7)
99
+ ax.set_yscale("log")
100
+ return ax
@@ -1014,7 +1014,7 @@ class ExtTestCase(unittest.TestCase):
1014
1014
  msg_ = "\n".join(excs)
1015
1015
  msg = f"{msg}\n{msg_}" if msg else msg_
1016
1016
  raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
1017
- elif expected.__class__.__name__ == "DynamicCache":
1017
+ elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
1018
1018
  atts = {"key_cache", "value_cache"}
1019
1019
  self.assertEqualArrayAny(
1020
1020
  {k: expected.__dict__.get(k, None) for k in atts},
@@ -141,6 +141,65 @@ else:
141
141
  return cache
142
142
 
143
143
 
144
+ def make_static_cache(
145
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
146
+ ) -> transformers.cache_utils.DynamicCache:
147
+ """
148
+ Creates an instance of :class:`transformers.cache_utils.StaticCache`.
149
+ :param key_value_pairs: list of pairs of (key, values)
150
+ :return: :class:`transformers.cache_utils.StaticCache`
151
+
152
+ Example:
153
+
154
+ .. runpython::
155
+ :showcode:
156
+
157
+ import torch
158
+ from onnx_diagnostic.helpers import string_type
159
+ from onnx_diagnostic.helpers.cache_helper import make_static_cache
160
+
161
+ n_layers = 2
162
+ bsize, nheads, slen, dim = 2, 4, 3, 7
163
+
164
+ past_key_values = make_static_cache(
165
+ [
166
+ (
167
+ torch.randn(bsize, nheads, slen, dim),
168
+ torch.randn(bsize, nheads, slen, dim),
169
+ )
170
+ for i in range(n_layers)
171
+ ]
172
+ )
173
+ print(string_type(past_key_values, with_shape=True))
174
+ """
175
+
176
+ class _config:
177
+ def __init__(self):
178
+ self.head_dim = key_value_pairs[0][0].shape[-1]
179
+ self.num_attention_heads = key_value_pairs[0][0].shape[1]
180
+ self.num_hidden_layers = len(key_value_pairs)
181
+
182
+ cache = transformers.cache_utils.StaticCache(
183
+ _config(),
184
+ max_batch_size=key_value_pairs[0][0].shape[0],
185
+ device=key_value_pairs[0][0].device,
186
+ dtype=key_value_pairs[0][0].dtype,
187
+ max_cache_len=key_value_pairs[0][0].shape[2],
188
+ )
189
+ for i in range(len(key_value_pairs)):
190
+ assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
191
+ f"Shape mismatch, expected {cache.key_cache[i].shape}, "
192
+ f"got {key_value_pairs[i][0].shape}"
193
+ )
194
+ cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
195
+ assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
196
+ f"Shape mismatch, expected {cache.value_cache[i].shape}, "
197
+ f"got {key_value_pairs[i][1].shape}"
198
+ )
199
+ cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
200
+ return cache
201
+
202
+
144
203
  def make_encoder_decoder_cache(
145
204
  self_attention_cache: transformers.cache_utils.DynamicCache,
146
205
  cross_attention_cache: transformers.cache_utils.DynamicCache,
@@ -34,10 +34,14 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
34
34
  config._attn_implementation_autoset = False
35
35
  continue
36
36
  if isinstance(v, dict):
37
- assert hasattr(
38
- config, k
39
- ), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
40
- update_config(getattr(config, k), v)
37
+ if not hasattr(config, k) or getattr(config, k) is None:
38
+ setattr(config, k, v)
39
+ continue
40
+ existing = getattr(config, k)
41
+ if type(existing) is dict:
42
+ existing.update(v)
43
+ else:
44
+ update_config(getattr(config, k), v)
41
45
  continue
42
46
  setattr(config, k, v)
43
47
 
@@ -1,4 +1,5 @@
1
- from typing import Dict, Optional, Tuple
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
2
3
  import onnx
3
4
  import onnx.helper as oh
4
5
  import torch
@@ -6,6 +7,17 @@ from ..reference.torch_ops import OpRunKernel, OpRunTensor
6
7
  from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
7
8
  from .ort_session import InferenceSessionForTorch
8
9
 
10
+ _SAVED: List[str] = []
11
+ _SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
12
+
13
+
14
+ def _get_model_name(op_name: str, provider: str) -> Optional[str]:
15
+ if _SAVE_OPTIMIZED_MODEL_:
16
+ name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
17
+ _SAVED.append(name)
18
+ return name
19
+ return None
20
+
9
21
 
10
22
  class LayerNormalizationOrt(OpRunKernel):
11
23
  "LayerNormalization with onnxruntime"
@@ -13,14 +25,14 @@ class LayerNormalizationOrt(OpRunKernel):
13
25
  @classmethod
14
26
  def device_dependent(cls) -> bool:
15
27
  "Needs device."
16
- return False
28
+ return True
17
29
 
18
30
  def __init__(
19
31
  self,
20
32
  node: onnx.NodeProto,
21
33
  version=None,
22
34
  device: Optional[torch.device] = None,
23
- verbose=0,
35
+ verbose: int = 0,
24
36
  ):
25
37
  super().__init__(node, version, verbose=verbose)
26
38
  self.axis = self.get_attribute_int(node, "axis", -1)
@@ -70,7 +82,11 @@ class LayerNormalizationOrt(OpRunKernel):
70
82
  )
71
83
  provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
72
84
  self._provider = provider
73
- return InferenceSessionForTorch(layer_model, providers=[provider])
85
+ return InferenceSessionForTorch(
86
+ layer_model,
87
+ optimized_model_filepath=_get_model_name("layer_norm", provider),
88
+ providers=[provider],
89
+ )
74
90
 
75
91
  def run(self, x, scale, bias=None):
76
92
  itype = torch_dtype_to_onnx_dtype(x.dtype)
@@ -94,14 +110,14 @@ class MatMulOrt(OpRunKernel):
94
110
  @classmethod
95
111
  def device_dependent(cls) -> bool:
96
112
  "Needs device."
97
- return False
113
+ return True
98
114
 
99
115
  def __init__(
100
116
  self,
101
117
  node: onnx.NodeProto,
102
118
  version=None,
103
119
  device: Optional[torch.device] = None,
104
- verbose=0,
120
+ verbose: int = 0,
105
121
  ):
106
122
  super().__init__(node, version, verbose=verbose)
107
123
  self.device = device
@@ -127,7 +143,11 @@ class MatMulOrt(OpRunKernel):
127
143
  )
128
144
  provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
129
145
  self._provider = provider
130
- return InferenceSessionForTorch(model, providers=[provider])
146
+ return InferenceSessionForTorch(
147
+ model,
148
+ optimized_model_filepath=_get_model_name("matmul", provider),
149
+ providers=[provider],
150
+ )
131
151
 
132
152
  def run(self, a, b):
133
153
  itype = torch_dtype_to_onnx_dtype(a.dtype)