onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +47 -10
  3. onnx_diagnostic/export/api.py +81 -50
  4. onnx_diagnostic/export/control_flow_research.py +10 -5
  5. onnx_diagnostic/export/onnx_plug.py +250 -61
  6. onnx_diagnostic/ext_test_case.py +99 -53
  7. onnx_diagnostic/helpers/dot_helper.py +37 -25
  8. onnx_diagnostic/helpers/helper.py +44 -38
  9. onnx_diagnostic/helpers/onnx_helper.py +441 -18
  10. onnx_diagnostic/helpers/ort_session.py +8 -8
  11. onnx_diagnostic/helpers/torch_helper.py +28 -2
  12. onnx_diagnostic/reference/ort_evaluator.py +6 -29
  13. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
  14. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
  15. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
  16. onnx_diagnostic/torch_models/code_sample.py +2 -1
  17. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  18. onnx_diagnostic/torch_models/validate.py +14 -1
  19. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  20. onnx_diagnostic/torch_onnx/sbs.py +11 -5
  21. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
  22. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
  23. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
  24. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
  25. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.3"
6
+ __version__ = "0.8.5"
7
7
  __author__ = "Xavier Dupré"
@@ -198,15 +198,19 @@ def get_parser_print() -> ArgumentParser:
198
198
  )
199
199
  parser.add_argument(
200
200
  "fmt",
201
- choices=["pretty", "raw", "text", "printer"],
201
+ choices=["dot", "pretty", "printer", "raw", "shape", "text"],
202
202
  default="pretty",
203
203
  help=textwrap.dedent(
204
204
  """
205
205
  Prints out a model on the standard output.
206
- raw - just prints the model with print(...)
207
- printer - onnx.printer.to_text(...)
206
+
207
+ dot - converts the graph into dot
208
208
  pretty - an improved rendering
209
+ printer - onnx.printer.to_text(...)
210
+ raw - just prints the model with print(...)
211
+ shape - prints every node node with input and output shapes
209
212
  text - uses GraphRendering
213
+
210
214
  """.strip(
211
215
  "\n"
212
216
  )
@@ -232,6 +236,14 @@ def _cmd_print(argv: List[Any]):
232
236
  from .helpers.graph_helper import GraphRendering
233
237
 
234
238
  print(GraphRendering(onx).text_rendering())
239
+ elif args.fmt == "shape":
240
+ from experimental_experiment.xbuilder import GraphBuilder
241
+
242
+ print(GraphBuilder(onx).pretty_text())
243
+ elif args.fmt == "dot":
244
+ from .helpers.dot_helper import to_dot
245
+
246
+ print(to_dot(onx))
235
247
  else:
236
248
  raise ValueError(f"Unexpected value fmt={args.fmt!r}")
237
249
 
@@ -517,12 +529,12 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
517
529
  nargs="*",
518
530
  help=textwrap.dedent(
519
531
  """
520
- Applies patches before exporting, it can be a boolean
521
- to enable to disable the patches or be more finetuned
522
- (default is True). It is possible to disable patch for torch
523
- by adding:
524
- --patch "patch_sympy=False" --patch "patch_torch=False"
525
- """.strip(
532
+ Applies patches before exporting, it can be a boolean
533
+ to enable to disable the patches or be more finetuned
534
+ (default is True). It is possible to disable patch for torch
535
+ by adding:
536
+ --patch "patch_sympy=False" --patch "patch_torch=False"
537
+ """.strip(
526
538
  "\n"
527
539
  )
528
540
  ),
@@ -1286,7 +1298,13 @@ def get_parser_sbs() -> ArgumentParser:
1286
1298
  "--first",
1287
1299
  action=BooleanOptionalAction,
1288
1300
  default=False,
1289
- help="First runs the whole model.",
1301
+ help="First runs the whole model (default is False).",
1302
+ )
1303
+ parser.add_argument(
1304
+ "--sbs",
1305
+ action=BooleanOptionalAction,
1306
+ default=True,
1307
+ help="Runs the side-by-side (default is True).",
1290
1308
  )
1291
1309
  parser.add_argument(
1292
1310
  "-2",
@@ -1342,6 +1360,20 @@ def get_parser_sbs() -> ArgumentParser:
1342
1360
  default="replay",
1343
1361
  help="If the replay is triggered, this defines the folder where everything is dumped.",
1344
1362
  )
1363
+ parser.add_argument(
1364
+ "-p",
1365
+ "--replay-prefix-model",
1366
+ action=BooleanOptionalAction,
1367
+ default=False,
1368
+ help=textwrap.dedent(
1369
+ """
1370
+ There are two ways to recompute an intermediate output, the first one is to "
1371
+ produce the minimal model between torch and onnx.
1372
+ The second one is to dump onnx models from the inputs
1373
+ to the considered intermediate results. This enables the second one.
1374
+ """
1375
+ ),
1376
+ )
1345
1377
 
1346
1378
  return parser
1347
1379
 
@@ -1417,6 +1449,10 @@ def _cmd_sbs(argv: List[Any]):
1417
1449
  print("-- done")
1418
1450
  del sess
1419
1451
 
1452
+ if not args.sbs:
1453
+ print("-- done")
1454
+ return
1455
+
1420
1456
  print(f"-- load onnx {args.onnx!r}")
1421
1457
  begin = time.perf_counter()
1422
1458
  onx = onnx.load(args.onnx)
@@ -1431,6 +1467,7 @@ def _cmd_sbs(argv: List[Any]):
1431
1467
  set(args.replay_op_types.split(",")) if args.replay_op_types else None
1432
1468
  ),
1433
1469
  dump_folder=args.replay_folder,
1470
+ dump_prefix_model=args.replay_prefix_model,
1434
1471
  )
1435
1472
 
1436
1473
  print("-- starts side-by-side")
@@ -3,6 +3,52 @@ import torch
3
3
  from .onnx_plug import EagerDirectReplacementWithOnnx
4
4
 
5
5
 
6
+ def get_main_dispatcher(
7
+ use_control_flow_dispatcher: bool = False,
8
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
9
+ ) -> Any: # Dispatcher
10
+ """Creates a custom dispatcher for the custom exporter."""
11
+ from experimental_experiment.torch_interpreter import Dispatcher
12
+
13
+ if use_control_flow_dispatcher:
14
+ from .control_flow_onnx import create_global_dispatcher
15
+
16
+ control_flow_dispatcher = create_global_dispatcher()
17
+ else:
18
+ control_flow_dispatcher = None
19
+
20
+ class MainDispatcher(Dispatcher):
21
+ def __init__(self, previous_dispatcher=None):
22
+ super().__init__({})
23
+ self.previous_dispatcher = previous_dispatcher
24
+
25
+ @property
26
+ def supported(self):
27
+ if self.previous_dispatcher:
28
+ return set(self.registered_functions) | self.previous_dispatcher.supported
29
+ return set(self.registered_functions)
30
+
31
+ def find_function(self, name: Any):
32
+ if self.previous_dispatcher:
33
+ find = self.previous_dispatcher.find_function(name)
34
+ if find:
35
+ return find
36
+ return Dispatcher.find_function(self, name)
37
+
38
+ def find_method(self, name: Any):
39
+ if self.previous_dispatcher:
40
+ find = self.previous_dispatcher.find_method(name)
41
+ if find:
42
+ return find
43
+ return Dispatcher.find_method(self, name)
44
+
45
+ main_dispatcher = MainDispatcher(control_flow_dispatcher)
46
+ if onnx_plugs:
47
+ for plug in onnx_plugs:
48
+ main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
49
+ return main_dispatcher
50
+
51
+
6
52
  def to_onnx(
7
53
  mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
8
54
  args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
@@ -18,6 +64,7 @@ def to_onnx(
18
64
  exporter_kwargs: Optional[Dict[str, Any]] = None,
19
65
  save_ep: Optional[str] = None,
20
66
  optimize: bool = True,
67
+ optimizer_for_ort: bool = True,
21
68
  use_control_flow_dispatcher: bool = False,
22
69
  onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
23
70
  inline: bool = True,
@@ -42,6 +89,7 @@ def to_onnx(
42
89
  :param exporter_kwargs: additional parameters sent to the exporter
43
90
  :param save_ep: saves the exported program
44
91
  :param optimize: optimizes the model
92
+ :param optimizer_for_ort: optimizes the model for onnxruntime
45
93
  :param use_control_flow_dispatcher: use the dispatcher created to supported
46
94
  custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
47
95
  :param onnx_plugs: the code was modified to replace some parts with onnx translation
@@ -80,53 +128,15 @@ def to_onnx(
80
128
  options = None
81
129
  if exporter_kwargs is not None:
82
130
  options = exporter_kwargs.pop("options", None)
83
- if options is None:
84
- options = OptimizationOptions(patterns="default+onnxruntime")
85
- if onnx_plugs or use_control_flow_dispatcher:
86
- from experimental_experiment.torch_interpreter import Dispatcher
87
-
88
- if use_control_flow_dispatcher:
89
- from .control_flow_onnx import create_global_dispatcher
90
-
91
- control_flow_dispatcher = create_global_dispatcher()
92
- else:
93
- control_flow_dispatcher = None
94
-
95
- class MainDispatcher(Dispatcher):
96
- def __init__(self, previous_dispatcher=None):
97
- super().__init__({})
98
- self.previous_dispatcher = previous_dispatcher
99
-
100
- @property
101
- def supported(self):
102
- if self.previous_dispatcher:
103
- return (
104
- set(self.registered_functions) | self.previous_dispatcher.supported
105
- )
106
- return set(self.registered_functions)
107
-
108
- def find_function(self, name: Any):
109
- if self.previous_dispatcher:
110
- find = self.previous_dispatcher.find_function(name)
111
- if find:
112
- return find
113
- return Dispatcher.find_function(self, name)
114
-
115
- def find_method(self, name: Any):
116
- if self.previous_dispatcher:
117
- find = self.previous_dispatcher.find_method(name)
118
- if find:
119
- return find
120
- return Dispatcher.find_method(self, name)
121
-
122
- main_dispatcher = MainDispatcher(control_flow_dispatcher)
123
- if onnx_plugs:
124
- for plug in onnx_plugs:
125
- main_dispatcher.registered_functions[plug.target_name] = (
126
- plug.custom_converter()
127
- )
128
- else:
129
- main_dispatcher = None
131
+ if options is None and optimize:
132
+ options = OptimizationOptions(
133
+ patterns="default+onnxruntime" if optimizer_for_ort else "default"
134
+ )
135
+ main_dispatcher = (
136
+ get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
137
+ if onnx_plugs or use_control_flow_dispatcher
138
+ else None
139
+ )
130
140
 
131
141
  return _to_onnx(
132
142
  mod,
@@ -149,11 +159,15 @@ def to_onnx(
149
159
 
150
160
  if exporter in ("dynamo", "onnx-dynamo"):
151
161
  import os
162
+ from ..helpers import flatten_object
152
163
  import onnxscript.rewriter.ort_fusions as ort_fusions
153
164
 
154
165
  assert (
155
166
  not output_dynamic_shapes
156
167
  ), f"output_dynamic_shapes not supported for exporter={exporter!r}"
168
+ assert (
169
+ optimize
170
+ ), f"torch.onnx.export always optimizes the model but optimize={optimize}"
157
171
  custom_translation_table = {}
158
172
  if onnx_plugs:
159
173
  for plug in onnx_plugs:
@@ -173,21 +187,34 @@ def to_onnx(
173
187
  custom_translation_table=custom_translation_table,
174
188
  **(exporter_kwargs or {}),
175
189
  )
176
- if not inline and optimize:
190
+ if not inline and optimize and optimizer_for_ort:
177
191
  ort_fusions.optimize_for_ort(epo.model)
178
192
 
179
193
  if onnx_plugs:
180
194
  import onnx_ir as ir
181
195
  import onnx_ir.passes.common as common_passes
182
196
 
183
- irfunctions = [ir.from_proto(plug.function_proto) for plug in onnx_plugs]
197
+ opset = (
198
+ 18
199
+ if target_opset is None
200
+ else (target_opset if isinstance(target_opset, int) else target_opset[""])
201
+ )
202
+
203
+ irfunctions = [
204
+ ir.from_proto(
205
+ plug.get_function_proto(
206
+ opset, *flatten_object((args, kwargs), drop_keys=True)
207
+ )
208
+ )
209
+ for plug in onnx_plugs
210
+ ]
184
211
  for func in irfunctions:
185
212
  epo.model.functions[func.identifier()] = func
186
213
  if inline:
187
214
  common_passes.InlinePass()(epo.model)
188
215
  common_passes.RemoveUnusedOpsetsPass()(epo.model)
189
216
 
190
- if inline and optimize:
217
+ if inline and optimize and optimizer_for_ort:
191
218
  ort_fusions.optimize_for_ort(epo.model)
192
219
  if filename:
193
220
  epo.save(filename, external_data=True)
@@ -212,6 +239,10 @@ def to_onnx(
212
239
  f"Only a specified set of inputs is supported for exporter={exporter!r}, "
213
240
  f"but it is {list(kwargs)}" # type: ignore[arg-type]
214
241
  )
242
+ assert optimizer_for_ort and optimize, (
243
+ f"ModelBuilder only produces model optimized for onnxruntime but "
244
+ f"optimizer_for_ort={optimizer_for_ort} and optimize={optimize}"
245
+ )
215
246
  flat_inputs = flatten_object(kwargs, drop_keys=True)
216
247
  first = flat_inputs[0]
217
248
  first_float = [
@@ -92,10 +92,11 @@ def simple_loop_for(
92
92
 
93
93
  from torch._higher_order_ops.utils import setup_compilation_env
94
94
 
95
- with setup_compilation_env() as backend:
96
- return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
97
- n_iter, body_fn, operands
98
- )
95
+ with setup_compilation_env() as _backend:
96
+ return _loop_for_op_wrapper(n_iter, body_fn, *operands)
97
+ # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
98
+ # n_iter, body_fn, operands
99
+ # )
99
100
 
100
101
 
101
102
  def trace_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands):
@@ -127,9 +128,13 @@ def loop_for_op_dense(n_iter, body_fn, operands):
127
128
  ), f"Dense implementation operands must be a list of tensors and ints {operands}"
128
129
  mode = _get_current_dispatch_mode()
129
130
  assert mode is None, "Mode should never be enabled for CPU/CUDA key"
130
- return _loop_for_onnx_fn(body_fn, n_iter, None, *operands)
131
+ return _loop_for_onnx_fn(body_fn, n_iter, None, operands)
131
132
 
132
133
 
133
134
  @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
134
135
  def inner(mode, n_iter, body_fn, operands):
135
136
  return trace_loop_for(mode, simple_loop_for_op, n_iter, body_fn, operands)
137
+
138
+
139
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
140
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)