onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +91 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +3 -3
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +92 -23
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +90 -26
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +103 -1
  15. onnx_diagnostic/helpers/ort_session.py +37 -11
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +103 -6
  18. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/validate.py +50 -1
  37. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  38. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  39. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  40. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
  41. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,18 @@
1
1
  import contextlib
2
- import inspect
3
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4
- import onnx
5
- import onnx.helper as oh
2
+ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
6
3
  import torch
7
- from torch._higher_order_ops.utils import materialize_as_graph
8
- from torch._higher_order_ops.utils import check_input_alias_and_mutation_return_outputs
9
- from .api import to_onnx
4
+ from torch._higher_order_ops.utils import (
5
+ materialize_as_graph,
6
+ check_input_alias_and_mutation_return_outputs,
7
+ # _maybe_reenter_make_fx,
8
+ )
10
9
 
11
10
  _TEST_EXPORT = False
12
- _REGISTERED_SCHEMA = {} # type: ignore[var-annotated]
13
- _DISPATCHER = None
14
-
15
-
16
- def create_global_dispatcher():
17
- global _DISPATCHER
18
-
19
- if not _DISPATCHER:
20
- from experimental_experiment.torch_interpreter import Dispatcher
21
-
22
- class ControlFlowDispatcher(Dispatcher):
23
- def __init__(self):
24
- super().__init__({})
25
-
26
- def register(self, aten_name: str, converter: Callable):
27
- assert aten_name not in self.registered_functions, (
28
- f"Name {aten_name!r} is already registered in "
29
- f"{sorted(self.registered_functions)}"
30
- )
31
- self.registered_functions[aten_name] = converter
32
-
33
- _DISPATCHER = ControlFlowDispatcher()
34
- return _DISPATCHER
35
11
 
36
12
 
37
13
  @contextlib.contextmanager
38
14
  def enable_code_export_control_flow():
39
- """Enables the code means to be exported."""
15
+ """Enables the code meant to be exported."""
40
16
  global _TEST_EXPORT
41
17
  old = _TEST_EXPORT
42
18
  _TEST_EXPORT = True
@@ -128,194 +104,31 @@ def make_custom_loop_for(
128
104
  :return: a name and the custom op definition, the name
129
105
  is used to cache the custom op
130
106
  """
131
- global _DISPATCHER
132
107
  assert body_gm is not None, "body_gm cannot be None"
133
108
  assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
134
109
  assert body_outputs is not None, "body_outputs cannot be None"
110
+
135
111
  srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136
112
  sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
137
- name = f"loop_for_{body_fn.__name__}_{id(body_fn)}_{srank}_{sred}"
138
- if name in _REGISTERED_SCHEMA:
139
- return name, _REGISTERED_SCHEMA[name][0]
140
- sig = inspect.signature(body_fn)
141
- inputs = ", ".join([f"Tensor {p}" for p in sig.parameters])
142
- schema = f"({inputs}) -> Tensor"
113
+ full_name = (
114
+ body_fn.__qualname__.replace("<locals>", "L")
115
+ .replace("<lambda>", "l")
116
+ .replace(".", "_")
117
+ )
118
+ name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
119
+
120
+ schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor"
143
121
  if len(body_outputs) > 1:
144
122
  schema += "[]"
145
- custom_def = torch.library.CustomOpDef("onnx_higher_ops", name, schema, body_fn)
123
+ custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn)
146
124
  custom_def.register_kernel("cpu")(body_fn)
147
125
 
148
- custom_def._abstract_fn = lambda *_args, _o=body_outputs: (
126
+ custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: (
149
127
  tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
150
128
  )
151
-
152
- def _make_onx(
153
- body_gm=body_gm, args=args, target_opset=None, verbose=0, exporter_kwargs=None
154
- ):
155
- return convert_into_onnx(
156
- body_gm,
157
- args,
158
- exporter_kwargs=exporter_kwargs,
159
- target_opset=target_opset,
160
- verbose=verbose,
161
- )
162
-
163
- to_register = (
164
- custom_def,
165
- _make_onx,
166
- (
167
- lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
168
- convert_custom_loop_into_onnx(
169
- g,
170
- sts,
171
- outputs,
172
- *args,
173
- body_callable=bc,
174
- reduction_dim=rd,
175
- name=name,
176
- )
177
- )
178
- ),
179
- )
180
- if _DISPATCHER is None:
181
- create_global_dispatcher()
182
- assert _DISPATCHER
183
- _DISPATCHER.register(f"onnx_higher_ops::{name}", to_register[-1])
184
- _REGISTERED_SCHEMA[name] = to_register
185
129
  return name, custom_def
186
130
 
187
131
 
188
- def convert_custom_loop_into_onnx(
189
- g: Any, # "GreaphBuilder"
190
- sts: Dict[str, Any],
191
- outputs: List[str],
192
- *args: str,
193
- body_callable: Callable[..., onnx.ModelProto],
194
- reduction_dim: Optional[Sequence[int]] = None,
195
- name: str = "loop_for",
196
- ) -> Union[str, List[str]]:
197
- """
198
- Converts a custom op ``higher_ops::loop_for...`` into e sequence of node.
199
-
200
- :param g: GreaphBuilder
201
- :param sts: if not defined, torch does not know the output shapes
202
- :param outputs: output names
203
- :param args: input argument known at export time
204
- :param body: GraphProto, the loop body
205
- :param reduction_dim: the dimension to follow when aggregating the
206
- list of tensors after the loop ran
207
- :param name: to give the onnx nodes a name
208
- :return: output names
209
- """
210
- assert body_callable is not None, "body_callable cannot be None"
211
- # This should be part of a public API.
212
- body = body_callable(
213
- target_opset=g.main_opset,
214
- verbose=g.verbose,
215
- exporter_kwargs={"options": g.optimization_options},
216
- )
217
-
218
- graph = body.graph if isinstance(body, onnx.ModelProto) else body
219
- assert isinstance(
220
- graph, onnx.GraphProto
221
- ), f"Unexpected type {type(body)} for body{g.get_debug_msg()}"
222
- assert len(outputs) == 1, f"Only one outputs is expected but outputs={outputs!r}"
223
- if len(graph.output) != 1:
224
- outputs = [f"{outputs[0]}#{i}" for i in range(len(graph.output))]
225
- input_names = [i.name for i in graph.input]
226
- inputs = [
227
- *graph.input[:1],
228
- oh.make_tensor_value_info("cond_unused", onnx.TensorProto.BOOL, []),
229
- *[
230
- oh.make_tensor_sequence_value_info(
231
- f"loop_in{i}", graph.output[i].type.tensor_type.elem_type, None
232
- )
233
- for i in range(len(graph.output))
234
- ],
235
- # hidden inputs are not added
236
- ]
237
- nodes = [
238
- oh.make_node("Identity", ["cond_unused"], ["cond_out"]),
239
- *[oh.make_node("Identity", [a], [r]) for a, r in zip(args[1:], input_names[1:])],
240
- *graph.node,
241
- *[
242
- oh.make_node(
243
- "SequenceInsert",
244
- [f"loop_in{i}", graph.output[i].name],
245
- [f"loop_out{i}"],
246
- )
247
- for i in range(len(graph.output))
248
- ],
249
- ]
250
- graph_outputs = [
251
- oh.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []),
252
- *[
253
- oh.make_tensor_sequence_value_info(
254
- f"loop_out{i}", graph.output[i].type.tensor_type.elem_type, None
255
- )
256
- for i in range(len(graph.output))
257
- ],
258
- ]
259
- graph = oh.make_graph(
260
- nodes, graph.name, inputs, graph_outputs, graph.initializer, graph.sparse_initializer
261
- )
262
-
263
- sequences = [g.op.SequenceEmpty() for _ in outputs]
264
-
265
- outloop = [g.unique_name(f"loop_for{i}") for i in range(len(sequences))]
266
-
267
- for i, s in enumerate(sequences):
268
- g.set_sequence(s, graph.output[i].type.tensor_type.elem_type)
269
- g.make_node("Loop", [args[0], "", *sequences], outloop, name=name, body=graph)
270
- for i, o in enumerate(outloop):
271
- g.set_sequence(o, graph.output[i].type.tensor_type.elem_type)
272
- _res = [
273
- g.op.ConcatFromSequence(
274
- out,
275
- outputs=[o],
276
- name=name,
277
- axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
278
- )
279
- for i, (out, o) in enumerate(zip(outloop, outputs))
280
- ]
281
- if not sts:
282
- for i, o in enumerate(outputs):
283
- g.set_type(o, graph.output[i].type.tensor_type.elem_type)
284
- g.set_rank(o, len(graph.output[i].type.tensor_type.shape.dims))
285
- return outputs if len(outputs) > 1 else outputs[0]
286
-
287
-
288
- def convert_into_onnx(
289
- body_gm: torch.fx.GraphModule,
290
- args: Sequence[torch.Tensor],
291
- target_opset: Optional[int] = None,
292
- verbose: int = 0,
293
- exporter_kwargs: Optional[Dict[str, Any]] = None,
294
- ) -> onnx.ModelProto:
295
- """
296
- Converts a torch.fx.GraphModule into ONNX.
297
- It returns a ModelProto.
298
-
299
- :param body_gm: a torch.fx.GraphModule
300
- :param args: arguments known at export time
301
- :param target_opset: targeted opset
302
- :param verbose: verbosity level
303
- :param exporter_kwargs: additional exporter arguments
304
- :return: a ModelProto
305
- """
306
- # This does not work with onnx-dynamo.
307
- # opset still needs to be defined
308
- container = to_onnx(
309
- body_gm,
310
- args,
311
- exporter="custom",
312
- exporter_kwargs=exporter_kwargs,
313
- target_opset=target_opset,
314
- verbose=verbose,
315
- )
316
- return container.model_proto
317
-
318
-
319
132
  def loop_for(
320
133
  n_iter: Union[torch.SymInt, torch.Tensor],
321
134
  body_fn: Callable[..., Tuple[torch.Tensor]],
@@ -340,144 +153,6 @@ def loop_for(
340
153
  one of each output, each of them is concatenated into one
341
154
  tensor along one dimension, by default, it is the first
342
155
  dimension, but it can be defined otherwise
343
-
344
- .. runpython::
345
- :showcode:
346
-
347
- import torch
348
- import onnxruntime
349
- from onnx_diagnostic.export.api import to_onnx
350
- from onnx_diagnostic.export.control_flow import loop_for
351
-
352
-
353
- class Model(torch.nn.Module):
354
- def forward(self, n_iter, x):
355
- def body(i, x):
356
- return x[: i.item() + 1].unsqueeze(1)
357
-
358
- return loop_for(n_iter, body, (x,))
359
-
360
-
361
- model = Model()
362
- n_iter = torch.tensor(4, dtype=torch.int64)
363
- x = torch.arange(10, dtype=torch.float32)
364
- expected = model(n_iter, x)
365
- print("expected:", expected)
366
-
367
- onx = to_onnx(
368
- model,
369
- (n_iter, x),
370
- dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
371
- exporter="custom",
372
- use_control_flow_dispatcher=True,
373
- ).model_proto
374
-
375
- sess = onnxruntime.InferenceSession(
376
- onx.SerializeToString(), providers=["CPUExecutionProvider"]
377
- )
378
- got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
379
- print("got:", got)
380
-
381
-
382
- # The loop is exported as a custom ops.
383
- ep = torch.export.export(
384
- model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
385
- )
386
- print(ep)
387
-
388
- Another example with two outputs:
389
-
390
- .. runpython::
391
- :showcode:
392
-
393
- import torch
394
- import onnxruntime
395
- from onnx_diagnostic.export.api import to_onnx
396
- from onnx_diagnostic.export.control_flow import loop_for
397
-
398
-
399
- class Model(torch.nn.Module):
400
- def forward(self, n_iter, x):
401
- def body(i, x):
402
- return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1
403
-
404
- two = loop_for(n_iter, body, (x,))
405
- return two[0] + two[1]
406
-
407
-
408
- model = Model()
409
- n_iter = torch.tensor(4, dtype=torch.int64)
410
- x = torch.arange(10, dtype=torch.float32)
411
- expected = model(n_iter, x)
412
- print("expected:", expected)
413
-
414
- onx = to_onnx(
415
- model,
416
- (n_iter, x),
417
- dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
418
- exporter="custom",
419
- use_control_flow_dispatcher=True,
420
- ).model_proto
421
-
422
- sess = onnxruntime.InferenceSession(
423
- onx.SerializeToString(), providers=["CPUExecutionProvider"]
424
- )
425
- got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
426
- print("got:", got)
427
-
428
-
429
- # The loop is exported as a custom ops.
430
- ep = torch.export.export(
431
- model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
432
- )
433
- print(ep)
434
-
435
- A last example with ``reduction_dim``:
436
-
437
- .. runpython::
438
- :showcode:
439
-
440
- import torch
441
- import onnxruntime
442
- from onnx_diagnostic.export.api import to_onnx
443
- from onnx_diagnostic.export.control_flow import loop_for
444
-
445
-
446
- class Model(torch.nn.Module):
447
- def forward(self, n_iter, x):
448
- def body(i, x):
449
- return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
450
-
451
- two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
452
- return two[0] + two[1].T
453
-
454
-
455
- model = Model()
456
- n_iter = torch.tensor(4, dtype=torch.int64)
457
- x = torch.arange(10, dtype=torch.float32)
458
- expected = model(n_iter, x)
459
- print("expected:", expected)
460
-
461
- onx = to_onnx(
462
- model,
463
- (n_iter, x),
464
- dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
465
- exporter="custom",
466
- use_control_flow_dispatcher=True,
467
- ).model_proto
468
-
469
- sess = onnxruntime.InferenceSession(
470
- onx.SerializeToString(), providers=["CPUExecutionProvider"]
471
- )
472
- got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
473
- print("got:", got)
474
-
475
-
476
- # The loop is exported as a custom ops.
477
- ep = torch.export.export(
478
- model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
479
- )
480
- print(ep)
481
156
  """
482
157
  assert args, "The function should have at least one arg."
483
158
  assert (
@@ -486,6 +161,12 @@ def loop_for(
486
161
  and len(n_iter.shape) == 0
487
162
  ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
488
163
  if is_exporting():
164
+ from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
165
+
166
+ # tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer
167
+ root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root
168
+ # graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph
169
+
489
170
  body_gm: torch.fx.GraphModule = materialize_as_graph(
490
171
  body_fn, (torch.tensor(0, dtype=torch.int64), *args)
491
172
  )
@@ -505,7 +186,29 @@ def loop_for(
505
186
  body_mutated_inputs=body_mutated_inputs,
506
187
  body_outputs=body_outputs,
507
188
  )
508
- fct = getattr(torch.ops.onnx_higher_ops, name)
509
- return fct(n_iter, *args)
189
+ root.register_module(name, body_gm)
190
+ # body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args)
191
+ return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args)
510
192
 
511
193
  return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
194
+
195
+
196
+ """
197
+ proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
198
+ proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
199
+
200
+ args = (cond_graph, body_graph, carried_inputs, additional_inputs)
201
+
202
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
203
+
204
+ out_proxy = proxy_mode.tracer.create_proxy(
205
+ "call_function", op, proxy_args, {}, name=op._name
206
+ )
207
+
208
+ out = op(
209
+ cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
210
+ )
211
+ return track_tensor_tree(
212
+ out, out_proxy, constant=None, tracer=proxy_mode.tracer
213
+ )
214
+ """