onnx-diagnostic 0.8.10__py3-none-any.whl → 0.8.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +136 -140
  3. onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
  4. onnx_diagnostic/export/api.py +2 -4
  5. onnx_diagnostic/export/validate.py +2 -0
  6. onnx_diagnostic/ext_test_case.py +32 -15
  7. onnx_diagnostic/helpers/args_helper.py +1 -0
  8. onnx_diagnostic/helpers/bench_run.py +0 -1
  9. onnx_diagnostic/helpers/cache_helper.py +6 -6
  10. onnx_diagnostic/helpers/doc_helper.py +7 -4
  11. onnx_diagnostic/helpers/graph_helper.py +6 -6
  12. onnx_diagnostic/helpers/log_helper.py +37 -14
  13. onnx_diagnostic/helpers/memory_peak.py +5 -1
  14. onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
  15. onnx_diagnostic/helpers/model_builder_helper.py +1 -1
  16. onnx_diagnostic/helpers/onnx_helper.py +283 -110
  17. onnx_diagnostic/helpers/ort_session.py +0 -1
  18. onnx_diagnostic/helpers/torch_helper.py +8 -9
  19. onnx_diagnostic/investigate/__init__.py +0 -0
  20. onnx_diagnostic/investigate/input_observer.py +329 -0
  21. onnx_diagnostic/reference/evaluator.py +0 -1
  22. onnx_diagnostic/reference/ort_evaluator.py +0 -1
  23. onnx_diagnostic/reference/report_results_comparison.py +9 -3
  24. onnx_diagnostic/reference/torch_evaluator.py +5 -1
  25. onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
  26. onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
  27. onnx_diagnostic/tasks/feature_extraction.py +0 -1
  28. onnx_diagnostic/torch_export_patches/__init__.py +0 -1
  29. onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
  32. onnx_diagnostic/torch_models/code_sample.py +5 -10
  33. onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
  34. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
  35. onnx_diagnostic/torch_models/validate.py +1 -1
  36. onnx_diagnostic/torch_onnx/compare.py +0 -1
  37. onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
  38. onnx_diagnostic/torch_onnx/sbs.py +1 -1
  39. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
  40. onnx_diagnostic/typing.py +15 -0
  41. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
  42. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +45 -43
  43. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
  44. onnx_diagnostic/api.py +0 -15
  45. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.10"
6
+ __version__ = "0.8.11"
7
7
  __author__ = "Xavier Dupré"
@@ -14,11 +14,9 @@ from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
14
14
  def get_parser_dot() -> ArgumentParser:
15
15
  parser = ArgumentParser(
16
16
  prog="dot",
17
- description=textwrap.dedent(
18
- """
17
+ description=textwrap.dedent("""
19
18
  Converts a model into a dot file dot can draw into a graph.
20
- """
21
- ),
19
+ """),
22
20
  )
23
21
  parser.add_argument("input", type=str, help="onnx model to lighten")
24
22
  parser.add_argument(
@@ -85,12 +83,10 @@ def _cmd_dot(argv: List[Any]):
85
83
  def get_parser_lighten() -> ArgumentParser:
86
84
  parser = ArgumentParser(
87
85
  prog="lighten",
88
- description=textwrap.dedent(
89
- """
86
+ description=textwrap.dedent("""
90
87
  Removes the weights from a heavy model, stores statistics to restore
91
88
  random weights.
92
- """
93
- ),
89
+ """),
94
90
  epilog="This is mostly used to write unit tests without adding "
95
91
  "a big onnx file to the repository.",
96
92
  )
@@ -142,12 +138,10 @@ def _cmd_lighten(argv: List[Any]):
142
138
  def get_parser_unlighten() -> ArgumentParser:
143
139
  parser = ArgumentParser(
144
140
  prog="unlighten",
145
- description=textwrap.dedent(
146
- """
141
+ description=textwrap.dedent("""
147
142
  Restores random weights for a model reduces with command lighten,
148
143
  the command expects to find a file nearby with extension '.stats'.
149
- """
150
- ),
144
+ """),
151
145
  epilog="This is mostly used to write unit tests without adding "
152
146
  "a big onnx file to the repository.",
153
147
  )
@@ -200,8 +194,7 @@ def get_parser_print() -> ArgumentParser:
200
194
  "fmt",
201
195
  choices=["dot", "pretty", "printer", "raw", "shape", "text"],
202
196
  default="pretty",
203
- help=textwrap.dedent(
204
- """
197
+ help=textwrap.dedent("""
205
198
  Prints out a model on the standard output.
206
199
 
207
200
  dot - converts the graph into dot
@@ -211,10 +204,7 @@ def get_parser_print() -> ArgumentParser:
211
204
  shape - prints every node node with input and output shapes
212
205
  text - uses GraphRendering
213
206
 
214
- """.strip(
215
- "\n"
216
- )
217
- ),
207
+ """.strip("\n")),
218
208
  )
219
209
  parser.add_argument("input", type=str, help="onnx model to load")
220
210
  return parser
@@ -251,12 +241,10 @@ def _cmd_print(argv: List[Any]):
251
241
  def get_parser_find() -> ArgumentParser:
252
242
  parser = ArgumentParser(
253
243
  prog="find",
254
- description=textwrap.dedent(
255
- """
244
+ description=textwrap.dedent("""
256
245
  Look into a model and search for a set of names,
257
246
  tells which node is consuming or producing it.
258
- """
259
- ),
247
+ """),
260
248
  epilog="Enables Some quick validation.",
261
249
  )
262
250
  parser.add_argument(
@@ -315,12 +303,10 @@ def _cmd_find(argv: List[Any]):
315
303
  def get_parser_config() -> ArgumentParser:
316
304
  parser = ArgumentParser(
317
305
  prog="config",
318
- description=textwrap.dedent(
319
- """
306
+ description=textwrap.dedent("""
320
307
  Prints out a configuration for a model id,
321
308
  prints the associated task as well.
322
- """
323
- ),
309
+ """),
324
310
  formatter_class=RawTextHelpFormatter,
325
311
  epilog="",
326
312
  )
@@ -470,8 +456,7 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
470
456
  Creates a script to export a model for a particular task given the model id.
471
457
  """
472
458
  ),
473
- epilog=textwrap.dedent(
474
- f"""
459
+ epilog=textwrap.dedent(f"""
475
460
  If the model id is specified, one untrained version of it is instantiated.
476
461
  Examples:
477
462
 
@@ -500,8 +485,7 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
500
485
 
501
486
  pyinstrument -m onnx_diagnostic {name} ...
502
487
  pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ...
503
- """
504
- ),
488
+ """),
505
489
  formatter_class=RawTextHelpFormatter,
506
490
  )
507
491
  parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
@@ -527,17 +511,13 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
527
511
  default=True,
528
512
  action=_BoolOrParseDictPatch,
529
513
  nargs="*",
530
- help=textwrap.dedent(
531
- """
514
+ help=textwrap.dedent("""
532
515
  Applies patches before exporting, it can be a boolean
533
516
  to enable to disable the patches or be more finetuned
534
517
  (default is True). It is possible to disable patch for torch
535
518
  by adding:
536
519
  --patch "patch_sympy=False" --patch "patch_torch=False"
537
- """.strip(
538
- "\n"
539
- )
540
- ),
520
+ """.strip("\n")),
541
521
  )
542
522
  parser.add_argument(
543
523
  "--rewrite",
@@ -567,16 +547,12 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
567
547
  "--inputs2",
568
548
  default=1,
569
549
  type=int,
570
- help=textwrap.dedent(
571
- """
550
+ help=textwrap.dedent("""
572
551
  Validates or exports the model on a second set of inputs
573
552
  to check the exported model supports dynamism. The values is used
574
553
  as an increment to the first set of inputs. A high value may trick
575
554
  a different behavior in the model and missed by the exporter.
576
- """.strip(
577
- "\n"
578
- )
579
- ),
555
+ """.strip("\n")),
580
556
  )
581
557
  parser.add_argument(
582
558
  "--runtime",
@@ -609,15 +585,11 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
609
585
  parser.add_argument(
610
586
  "--ortfusiontype",
611
587
  required=False,
612
- help=textwrap.dedent(
613
- """
588
+ help=textwrap.dedent("""
614
589
  Applies onnxruntime fusion, this parameter should contain the
615
590
  model type or multiple values separated by `|`. `ALL` can be used
616
591
  to run them all.
617
- """.strip(
618
- "\n"
619
- )
620
- ),
592
+ """.strip("\n")),
621
593
  )
622
594
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
623
595
  parser.add_argument("--dtype", help="Changes dtype if necessary.")
@@ -626,32 +598,24 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
626
598
  "--iop",
627
599
  metavar="KEY=VALUE",
628
600
  nargs="*",
629
- help=textwrap.dedent(
630
- """
601
+ help=textwrap.dedent("""
631
602
  Additional input options, used to change the default
632
603
  inputs use to export. Examples:
633
604
  --iop cls_cache=SlidingWindowCache
634
605
  --iop cls_cache=StaticCache
635
- """.strip(
636
- "\n"
637
- )
638
- ),
606
+ """.strip("\n")),
639
607
  action=_ParseDict,
640
608
  )
641
609
  parser.add_argument(
642
610
  "--mop",
643
611
  metavar="KEY=VALUE",
644
612
  nargs="*",
645
- help=textwrap.dedent(
646
- """
613
+ help=textwrap.dedent("""
647
614
  Additional model options, used to change some parameters
648
615
  of the model. Example:
649
616
  --mop attn_implementation=sdpa --mop attn_implementation=eager"
650
617
  --mop "rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}"
651
- """.strip(
652
- "\n"
653
- )
654
- ),
618
+ """.strip("\n")),
655
619
  action=_ParseDict,
656
620
  )
657
621
  if name == "validate":
@@ -683,42 +647,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
683
647
  parser.add_argument(
684
648
  "--quiet-input-sets",
685
649
  default="",
686
- help=textwrap.dedent(
687
- """
650
+ help=textwrap.dedent("""
688
651
  Avoids raising an exception when an input sets does not work with
689
652
  the exported model. Example:
690
653
  --quiet-input-sets=inputs,inputs22
691
- """.strip(
692
- "\n"
693
- )
694
- ),
654
+ """.strip("\n")),
695
655
  )
696
656
  parser.add_argument(
697
657
  "--expop",
698
658
  metavar="KEY=VALUE",
699
659
  nargs="*",
700
- help=textwrap.dedent(
701
- """
660
+ help=textwrap.dedent("""
702
661
  Additional exporter options, use to change some parameters
703
662
  of the model. Examples:
704
663
  --expop report=True
705
664
  --expop report=True --expop verify=True
706
- """.strip(
707
- "\n"
708
- )
709
- ),
665
+ """.strip("\n")),
710
666
  action=_ParseDict,
711
667
  )
712
668
  parser.add_argument(
713
669
  "--save-ep",
714
670
  default="",
715
- help=textwrap.dedent(
716
- """
671
+ help=textwrap.dedent("""
717
672
  saves the exported program with torch.export.save
718
673
  and the inputs sets with torch.save,
719
674
  then command line sbs can be used to look for discrepancies.
720
- """
721
- ),
675
+ """),
722
676
  )
723
677
 
724
678
  return parser
@@ -1003,18 +957,15 @@ class _ParseNamedDict(argparse.Action):
1003
957
  def get_parser_agg() -> ArgumentParser:
1004
958
  parser = ArgumentParser(
1005
959
  prog="agg",
1006
- description=textwrap.dedent(
1007
- """
960
+ description=textwrap.dedent("""
1008
961
  Aggregates statistics coming from benchmarks.
1009
962
  Every run is a row. Every row is indexed by some keys,
1010
963
  and produces values. Every row has a date.
1011
964
  The data can come any csv files produces by benchmarks,
1012
965
  it can concatenates many csv files, or csv files inside zip files.
1013
966
  It produces an excel file with many tabs, one per view.
1014
- """
1015
- ),
1016
- epilog=textwrap.dedent(
1017
- """
967
+ """),
968
+ epilog=textwrap.dedent("""
1018
969
  examples:
1019
970
 
1020
971
  python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
@@ -1025,8 +976,7 @@ def get_parser_agg() -> ArgumentParser:
1025
976
 
1026
977
  python -m onnx_diagnostic agg history.xlsx raw/*.csv -v 1 --no-raw \\
1027
978
  --no-recent
1028
- """
1029
- ),
979
+ """),
1030
980
  formatter_class=RawTextHelpFormatter,
1031
981
  )
1032
982
  parser.add_argument("output", help="output excel file")
@@ -1104,15 +1054,13 @@ def get_parser_agg() -> ArgumentParser:
1104
1054
  "--views",
1105
1055
  default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
1106
1056
  "bucket-speedup,raw-short,counts,peak-gpu,onnx",
1107
- help=textwrap.dedent(
1108
- """
1057
+ help=textwrap.dedent("""
1109
1058
  Views to add to the output files. Each view becomes a tab.
1110
1059
  A view is defined by its name, among
1111
1060
  agg-suite, agg-all, disc, speedup, time, time_export, err,
1112
1061
  cmd, bucket-speedup, raw-short, counts, peak-gpu, onnx.
1113
1062
  Their definition is part of class CubeLogsPerformance.
1114
- """
1115
- ),
1063
+ """),
1116
1064
  )
1117
1065
  parser.add_argument(
1118
1066
  "--csv",
@@ -1134,14 +1082,12 @@ def get_parser_agg() -> ArgumentParser:
1134
1082
  )
1135
1083
  parser.add_argument(
1136
1084
  "--sbs",
1137
- help=textwrap.dedent(
1138
- """
1085
+ help=textwrap.dedent("""
1139
1086
  Defines an exporter to compare to another, there must be at least
1140
1087
  two arguments defined with --sbs. Example:
1141
1088
  --sbs dynamo:exporter=onnx-dynamo,opt=ir,attn_impl=eager
1142
1089
  --sbs custom:exporter=custom,opt=default,attn_impl=eager
1143
- """
1144
- ),
1090
+ """),
1145
1091
  action=_ParseNamedDict,
1146
1092
  )
1147
1093
  return parser
@@ -1205,18 +1151,15 @@ def _cmd_agg(argv: List[Any]):
1205
1151
  def get_parser_sbs() -> ArgumentParser:
1206
1152
  parser = ArgumentParser(
1207
1153
  prog="side-by-side (sbs)",
1208
- description=textwrap.dedent(
1209
- """
1154
+ description=textwrap.dedent("""
1210
1155
  Compares the intermediate outputs between the exported program and
1211
1156
  the exported onnx model. It assumes some names are common.
1212
1157
  The execution of the exported program and the onnx model
1213
1158
  are done in parallel. The device is the one used to store the
1214
1159
  model and the inputs.
1215
1160
  Where do discrepancies start? This function tries to answer that question.
1216
- """
1217
- ),
1218
- epilog=textwrap.dedent(
1219
- """
1161
+ """),
1162
+ epilog=textwrap.dedent("""
1220
1163
  The command line expects the following files to be saved with
1221
1164
  the following function. inputs is a dictionary of the input of the model.
1222
1165
 
@@ -1231,8 +1174,7 @@ def get_parser_sbs() -> ArgumentParser:
1231
1174
  model (through the exported program) and its onnx conversion.
1232
1175
  This functionality dumps everything it can to disk
1233
1176
  so that it be replayed in a separate process.
1234
- """
1235
- ),
1177
+ """),
1236
1178
  )
1237
1179
  parser.add_argument(
1238
1180
  "-i",
@@ -1246,12 +1188,10 @@ def get_parser_sbs() -> ArgumentParser:
1246
1188
  "--ep",
1247
1189
  type=str,
1248
1190
  required=True,
1249
- help=textwrap.dedent(
1250
- """
1191
+ help=textwrap.dedent("""
1251
1192
  exported program saved with torch.export.save,
1252
1193
  input sets saved with torch.save,
1253
- """
1254
- ),
1194
+ """),
1255
1195
  )
1256
1196
  parser.add_argument(
1257
1197
  "-m",
@@ -1311,25 +1251,21 @@ def get_parser_sbs() -> ArgumentParser:
1311
1251
  "--second-run",
1312
1252
  action=BooleanOptionalAction,
1313
1253
  default=False,
1314
- help=textwrap.dedent(
1315
- """
1254
+ help=textwrap.dedent("""
1316
1255
  Tries to run all onnx nodes with torch results produced by the exported
1317
1256
  program. It then measures the discrepancies again. It can be used
1318
1257
  to identify kernel introduces discrepancies from other just propagating them.
1319
- """
1320
- ),
1258
+ """),
1321
1259
  )
1322
1260
  parser.add_argument(
1323
1261
  "--reset",
1324
1262
  required=False,
1325
1263
  default="",
1326
- help=textwrap.dedent(
1327
- """
1264
+ help=textwrap.dedent("""
1328
1265
  List of result names separated by a comma. For those results,
1329
1266
  the side-by-side will take torch results instead of onnx results
1330
1267
  to compute the rest of the onnx model.
1331
- """
1332
- ),
1268
+ """),
1333
1269
  )
1334
1270
  parser.add_argument(
1335
1271
  "-s",
@@ -1365,14 +1301,12 @@ def get_parser_sbs() -> ArgumentParser:
1365
1301
  "--replay-prefix-model",
1366
1302
  action=BooleanOptionalAction,
1367
1303
  default=False,
1368
- help=textwrap.dedent(
1369
- """
1304
+ help=textwrap.dedent("""
1370
1305
  There are two ways to recompute an intermediate output, the first one is to "
1371
1306
  produce the minimal model between torch and onnx.
1372
1307
  The second one is to dump onnx models from the inputs
1373
1308
  to the considered intermediate results. This enables the second one.
1374
- """
1375
- ),
1309
+ """),
1376
1310
  )
1377
1311
 
1378
1312
  return parser
@@ -1511,20 +1445,16 @@ def _cmd_sbs(argv: List[Any]):
1511
1445
  def get_parser_compare() -> ArgumentParser:
1512
1446
  parser = ArgumentParser(
1513
1447
  prog="compare",
1514
- description=textwrap.dedent(
1515
- """
1448
+ description=textwrap.dedent("""
1516
1449
  Compares two onnx models by aligning the nodes between both models.
1517
1450
  This is done through an edit distance.
1518
- """
1519
- ),
1520
- epilog=textwrap.dedent(
1521
- """
1451
+ """),
1452
+ epilog=textwrap.dedent("""
1522
1453
  Each element (initializer, input, node, output) of the model
1523
1454
  is converted into an observation. Then it defines a distance between
1524
1455
  two elements. And finally, it finds the best alignment with
1525
1456
  an edit distance.
1526
- """
1527
- ),
1457
+ """),
1528
1458
  )
1529
1459
  parser.add_argument("model1", type=str, help="first model to compare")
1530
1460
  parser.add_argument("model2", type=str, help="second model to compare")
@@ -1551,15 +1481,12 @@ def get_parser_optimize() -> ArgumentParser:
1551
1481
  parser = ArgumentParser(
1552
1482
  prog="optimize",
1553
1483
  formatter_class=RawTextHelpFormatter,
1554
- description=textwrap.dedent(
1555
- """
1484
+ description=textwrap.dedent("""
1556
1485
  Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
1557
1486
  and replaces them by the corresponding nodes. It also does basic optimization
1558
1487
  such as removing identity nodes or unused nodes.
1559
- """
1560
- ),
1561
- epilog=textwrap.dedent(
1562
- """
1488
+ """),
1489
+ epilog=textwrap.dedent("""
1563
1490
  The goal is to make the model faster.
1564
1491
  Argument patterns defines the patterns to apply or the set of patterns.
1565
1492
  It is possible to show statistics or to remove a particular pattern.
@@ -1575,8 +1502,7 @@ def get_parser_optimize() -> ArgumentParser:
1575
1502
  - PATTERN=<pattern1,pattern2,...>: increase verbosity for specific
1576
1503
  patterns to understand why one pattern was not applied,
1577
1504
  this shows which line is rejecting a pattern if it seems one pattern was missed
1578
- """
1579
- ),
1505
+ """),
1580
1506
  )
1581
1507
  parser.add_argument(
1582
1508
  "algorithm",
@@ -1608,13 +1534,11 @@ def get_parser_optimize() -> ArgumentParser:
1608
1534
  parser.add_argument(
1609
1535
  "--processor",
1610
1536
  default="",
1611
- help=textwrap.dedent(
1612
- """
1537
+ help=textwrap.dedent("""
1613
1538
  optimization for a specific processor, CPU, CUDA or both CPU,CUDA,
1614
1539
  some operators are only available in one processor, it might be not used
1615
1540
  with all
1616
- """
1617
- ).strip("\n"),
1541
+ """).strip("\n"),
1618
1542
  )
1619
1543
  parser.add_argument(
1620
1544
  "--remove-shape-info",
@@ -1648,6 +1572,76 @@ def _cmd_optimize(argv: List[Any]):
1648
1572
  )
1649
1573
 
1650
1574
 
1575
+ def get_parser_partition() -> ArgumentParser:
1576
+ parser = ArgumentParser(
1577
+ prog="partition",
1578
+ formatter_class=RawTextHelpFormatter,
1579
+ description=textwrap.dedent("""
1580
+ Partitions an onnx model by moving nodes into local functions.
1581
+ Exporters may add metadata to the onnx nodes telling which part
1582
+ of the model it comes from (namespace, source, ...).
1583
+ This nodes are moved into local functions.
1584
+ """),
1585
+ epilog=textwrap.dedent("""
1586
+ The regular may match the following values,
1587
+ 'model.layers.0.forward', 'model.layers.1.forward', ...
1588
+ A local function will be created for each distinct layer.
1589
+ """),
1590
+ )
1591
+ parser.add_argument("input", help="input model")
1592
+ parser.add_argument("output", help="output model")
1593
+ parser.add_argument(
1594
+ "-r",
1595
+ "--regex",
1596
+ default=".*[.]layers[.][0-9]+[.]forward$",
1597
+ help=textwrap.dedent("""
1598
+ merges all nodes sharing the same value in node metadata,
1599
+ these values must match the regular expression specified by
1600
+ this parameter, the default value matches what transformers
1601
+ usually to define a layer
1602
+ """).strip("\n"),
1603
+ )
1604
+ parser.add_argument(
1605
+ "-p",
1606
+ "--meta-prefix",
1607
+ default="namespace,source[",
1608
+ help="allowed prefixes for keys in the metadata",
1609
+ )
1610
+ parser.add_argument(
1611
+ "-v",
1612
+ "--verbose",
1613
+ default=0,
1614
+ required=False,
1615
+ type=int,
1616
+ help="verbosity",
1617
+ )
1618
+ return parser
1619
+
1620
+
1621
+ def _cmd_partition(argv: List[Any]):
1622
+ from .helpers.onnx_helper import make_model_with_local_functions
1623
+
1624
+ parser = get_parser_partition()
1625
+ args = parser.parse_args(argv[1:])
1626
+
1627
+ if args.verbose:
1628
+ print(f"-- load {args.input!r}")
1629
+ onx = onnx.load(args.input, load_external_data=False)
1630
+ if args.verbose:
1631
+ print("-- partition")
1632
+ onx2 = make_model_with_local_functions(
1633
+ onx,
1634
+ regex=args.regex,
1635
+ metadata_key_prefix=tuple(args.meta_prefix.split(",")),
1636
+ verbose=args.verbose,
1637
+ )
1638
+ if args.verbose:
1639
+ print(f"-- save into {args.output!r}")
1640
+ onnx.save(onx2, args.output)
1641
+ if args.verbose:
1642
+ print("-- done")
1643
+
1644
+
1651
1645
  #############
1652
1646
  # main parser
1653
1647
  #############
@@ -1658,8 +1652,7 @@ def get_main_parser() -> ArgumentParser:
1658
1652
  prog="onnx_diagnostic",
1659
1653
  description="onnx_diagnostic main command line.\n",
1660
1654
  formatter_class=RawTextHelpFormatter,
1661
- epilog=textwrap.dedent(
1662
- """
1655
+ epilog=textwrap.dedent("""
1663
1656
  Type 'python -m onnx_diagnostic <cmd> --help'
1664
1657
  to get help for a specific command.
1665
1658
 
@@ -1670,13 +1663,13 @@ def get_main_parser() -> ArgumentParser:
1670
1663
  find - find node consuming or producing a result
1671
1664
  lighten - makes an onnx model lighter by removing the weights
1672
1665
  optimize - optimizes an onnx model
1666
+ partition - partition a model, each partition appears as local function
1673
1667
  print - prints the model on standard output
1674
1668
  sbs - compares an exported program and a onnx model
1675
1669
  stats - produces statistics on a model
1676
1670
  unlighten - restores an onnx model produces by the previous experiment
1677
1671
  validate - validate a model (knowing its model id on HuggginFace Hub)
1678
- """
1679
- ),
1672
+ """),
1680
1673
  )
1681
1674
  parser.add_argument(
1682
1675
  "cmd",
@@ -1688,6 +1681,7 @@ def get_main_parser() -> ArgumentParser:
1688
1681
  "find",
1689
1682
  "lighten",
1690
1683
  "optimize",
1684
+ "partition",
1691
1685
  "print",
1692
1686
  "sbs",
1693
1687
  "stats",
@@ -1709,6 +1703,7 @@ def main(argv: Optional[List[Any]] = None):
1709
1703
  find=_cmd_find,
1710
1704
  lighten=_cmd_lighten,
1711
1705
  optimize=_cmd_optimize,
1706
+ partition=_cmd_partition,
1712
1707
  print=_cmd_print,
1713
1708
  sbs=_cmd_sbs,
1714
1709
  stats=_cmd_stats,
@@ -1736,6 +1731,7 @@ def main(argv: Optional[List[Any]] = None):
1736
1731
  find=get_parser_find,
1737
1732
  lighten=get_parser_lighten,
1738
1733
  optimize=get_parser_optimize,
1734
+ partition=get_parser_partition,
1739
1735
  print=get_parser_print,
1740
1736
  sbs=get_parser_sbs,
1741
1737
  stats=get_parser_stats,
@@ -1033,16 +1033,14 @@ def main(
1033
1033
  if __name__ == "__main__":
1034
1034
  parser = get_parser(
1035
1035
  "qwen25",
1036
- epilog=textwrap.dedent(
1037
- r"""
1036
+ epilog=textwrap.dedent(r"""
1038
1037
  Tested command lines::
1039
1038
 
1040
1039
  python -m onnx_diagnostic.ci_models.export_phi4_mm \
1041
1040
  -m microsoft/Phi-4-multimodal-instruct \
1042
1041
  --device cuda --dtype float16 --exporter custom \
1043
1042
  --pretrained --second-input --part vision
1044
- """
1045
- ),
1043
+ """),
1046
1044
  )
1047
1045
  args = parser.parse_args(sys.argv[1:])
1048
1046
  main(
@@ -509,12 +509,10 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
509
509
  simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
510
510
  args = str(simple_sig)[1:-1]
511
511
  calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
512
- src = textwrap.dedent(
513
- f"""
512
+ src = textwrap.dedent(f"""
514
513
  def f(self, {args}):
515
514
  return self._method_call({calls_args})
516
- """
517
- )
515
+ """)
518
516
  self._method_src = src
519
517
  ns = {}
520
518
  try:
@@ -80,6 +80,7 @@ def compare_modules(
80
80
  )
81
81
  got = modep(*_get(args), **_get(kwargs))
82
82
  if verbose:
83
+ # pyrefly: ignore[unbound-name]
83
84
  d = time.perf_counter() - begin
84
85
  print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
85
86
  if mod:
@@ -89,6 +90,7 @@ def compare_modules(
89
90
  expected = mod(*_get(args), **_get(kwargs))
90
91
  diff = max_diff(expected, got)
91
92
  if verbose:
93
+ # pyrefly: ignore[unbound-name]
92
94
  d = time.perf_counter() - begin
93
95
  print(
94
96
  f"[compare_modules] done in {d} with "