onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import json
3
3
  import os
4
4
  import sys
5
5
  import warnings
6
- from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
6
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
  import onnx
@@ -15,6 +15,7 @@ from onnx import (
15
15
  GraphProto,
16
16
  ModelProto,
17
17
  NodeProto,
18
+ OperatorSetIdProto,
18
19
  TensorProto,
19
20
  ValueInfoProto,
20
21
  load as onnx_load,
@@ -671,21 +672,18 @@ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
671
672
  try:
672
673
  return oh.np_dtype_to_tensor_dtype(dt)
673
674
  except ValueError:
674
- try:
675
- import ml_dtypes
676
- except ImportError:
677
- ml_dtypes = None # type: ignore
678
- if ml_dtypes is not None:
679
- if dt == ml_dtypes.bfloat16:
680
- return TensorProto.BFLOAT16
681
- if dt == ml_dtypes.float8_e4m3fn:
682
- return TensorProto.FLOAT8E4M3FN
683
- if dt == ml_dtypes.float8_e4m3fnuz:
684
- return TensorProto.FLOAT8E4M3FNUZ
685
- if dt == ml_dtypes.float8_e5m2:
686
- return TensorProto.FLOAT8E5M2
687
- if dt == ml_dtypes.float8_e5m2fnuz:
688
- return TensorProto.FLOAT8E5M2FNUZ
675
+ import ml_dtypes
676
+
677
+ if dt == ml_dtypes.bfloat16:
678
+ return TensorProto.BFLOAT16
679
+ if dt == ml_dtypes.float8_e4m3fn:
680
+ return TensorProto.FLOAT8E4M3FN
681
+ if dt == ml_dtypes.float8_e4m3fnuz:
682
+ return TensorProto.FLOAT8E4M3FNUZ
683
+ if dt == ml_dtypes.float8_e5m2:
684
+ return TensorProto.FLOAT8E5M2
685
+ if dt == ml_dtypes.float8_e5m2fnuz:
686
+ return TensorProto.FLOAT8E5M2FNUZ
689
687
  if dt == np.float32:
690
688
  return TensorProto.FLOAT
691
689
  if dt == np.float16:
@@ -1198,3 +1196,104 @@ def shadowing_names(
1198
1196
  existing |= not_empty
1199
1197
  created |= not_empty
1200
1198
  return shadow, post_shadow, created
1199
+
1200
+
1201
+ def extract_subset_of_nodes(
1202
+ model: ModelProto,
1203
+ name: str,
1204
+ node_index: Optional[int] = None,
1205
+ cut_points: Optional[Set[str]] = None,
1206
+ ) -> List[NodeProto]:
1207
+ """
1208
+ Extracts the minimal subgraphs which can produce the output ``name``
1209
+ knowing ``cut_points``.
1210
+
1211
+ :param model: original model
1212
+ :param name: result name
1213
+ :param node_index: if the node index is known, otherwise searches for it
1214
+ :param cut_points: the known results or input name otherwise
1215
+ :return: minimal list of nodes
1216
+ """
1217
+ if node_index is None:
1218
+ for i, node in enumerate(model.graph.node):
1219
+ if name in node.output:
1220
+ node_index = i
1221
+ break
1222
+ assert (
1223
+ node_index is not None
1224
+ and node_index < len(model.graph.node)
1225
+ and name in model.graph.node[node_index].output
1226
+ ), f"node_index is still empty or wrong for result {name!r}"
1227
+ if cut_points is None:
1228
+ cut_points = {n.name for n in model.graph.input} | {
1229
+ n.name for n in model.graph.initializer
1230
+ }
1231
+ elif model.graph.initializer:
1232
+ cut_points = cut_points | {n.name for n in model.graph.initializer}
1233
+
1234
+ node = model.graph.node[node_index]
1235
+ selected = {node_index}
1236
+ current_node_index = node_index
1237
+ current_input_index = 0
1238
+ intermediate = {name}
1239
+ inputs = set(k for k in node.input if k)
1240
+ while not (inputs <= cut_points) and current_node_index >= 0:
1241
+ node = model.graph.node[current_node_index]
1242
+ if current_input_index == 0:
1243
+ needs = [o for o in node.output if o in intermediate and o not in cut_points]
1244
+ if needs:
1245
+ selected.add(current_node_index)
1246
+ else:
1247
+ current_node_index -= 1
1248
+ continue
1249
+ res = node.input[current_input_index]
1250
+ if res not in cut_points:
1251
+ intermediate.add(res)
1252
+ current_input_index += 1
1253
+ if current_input_index >= len(node.input):
1254
+ current_node_index -= 1
1255
+ current_input_index = 0
1256
+
1257
+ return [model.graph.node[i] for i in sorted(selected)]
1258
+
1259
+
1260
+ def make_submodel(
1261
+ nodes: List[NodeProto],
1262
+ ir_version: int,
1263
+ opset_imports: List[OperatorSetIdProto],
1264
+ output_names: List[str],
1265
+ type_rank_fn: Callable[[str], Tuple[int, int]],
1266
+ ) -> ModelProto:
1267
+ """
1268
+ Creates a model with the given list of nodes.
1269
+ It computes the minimum list of inputs needed for this model.
1270
+ The function assumes the nodes are sorted.
1271
+ It does not handle yet subgraphs.
1272
+
1273
+ :param nodes: list of nodes
1274
+ :param ir_version: ir version
1275
+ :param opset_imports: opset import
1276
+ :param output_names: desired outputs
1277
+ :param function: function returning the type and the rank of a result
1278
+ :return: model proto
1279
+ """
1280
+
1281
+ def _mkv_(name, itype, irank):
1282
+ return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])
1283
+
1284
+ not_known: Set[str] = set()
1285
+ for node in nodes[::-1]:
1286
+ not_known -= set(node.output)
1287
+ not_known |= set(node.input)
1288
+
1289
+ model = oh.make_model(
1290
+ oh.make_graph(
1291
+ nodes,
1292
+ "submodel",
1293
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
1294
+ [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
1295
+ ),
1296
+ ir_version=ir_version,
1297
+ opset_imports=opset_imports,
1298
+ )
1299
+ return model
@@ -108,7 +108,10 @@ class _InferenceSession:
108
108
  session_options,
109
109
  providers=providers,
110
110
  )
111
- except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
111
+ except (
112
+ onnxruntime.capi.onnxruntime_pybind11_state.Fail,
113
+ onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
114
+ ) as e:
112
115
  if isinstance(sess, onnx.ModelProto):
113
116
  debug_path = "_debug_InferenceSession_last_failure.onnx"
114
117
  onnx.save(
@@ -134,7 +137,13 @@ class _InferenceSession:
134
137
 
135
138
  self.sess = sess
136
139
  self.input_names = [i.name for i in sess.get_inputs()]
140
+ assert (
141
+ "" not in self.input_names
142
+ ), f"Input name cannot be empty but input_names={self.input_names}"
137
143
  self.output_names = [i.name for i in sess.get_outputs()]
144
+ assert (
145
+ "" not in self.input_names
146
+ ), f"Output name cannot be empty but output_names={self.output_names}"
138
147
  self.input_shapes = [i.shape for i in sess.get_inputs()]
139
148
  self.output_shapes = [i.shape for i in sess.get_outputs()]
140
149
  self.input_types = [i.type for i in sess.get_inputs()]
@@ -338,6 +347,7 @@ class InferenceSessionForTorch(_InferenceSession):
338
347
  :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
339
348
  :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
340
349
  :param use_training_api: use onnxruntime-traning API
350
+ :param cpu_output: if True, force the outputs to be on CPU
341
351
  """
342
352
 
343
353
  def __init__(
@@ -353,6 +363,7 @@ class InferenceSessionForTorch(_InferenceSession):
353
363
  optimized_model_filepath: Optional[str] = None,
354
364
  disable_aot_function_inlining: Optional[bool] = None,
355
365
  use_training_api: Optional[bool] = None,
366
+ cpu_outputs: bool = False,
356
367
  ):
357
368
  super().__init__(
358
369
  sess,
@@ -367,6 +378,7 @@ class InferenceSessionForTorch(_InferenceSession):
367
378
  disable_aot_function_inlining=disable_aot_function_inlining,
368
379
  use_training_api=use_training_api,
369
380
  )
381
+ self.cpu_outputs = cpu_outputs
370
382
 
371
383
  def _get_ortvalues_from_torch_tensors(
372
384
  self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
@@ -490,23 +502,37 @@ class InferenceSessionForTorch(_InferenceSession):
490
502
  feeds is a dictionary of :class:`torch.Tensor`.
491
503
  The output device is CPU even if the outputs are on CUDA.
492
504
  """
493
- new_feeds = {}
505
+ input_names = []
506
+ values = ORTC.OrtValueVector()
507
+ device = -1
494
508
  for k, v in feeds.items():
509
+ assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
510
+ device = max(device, v.get_device())
495
511
  assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
496
512
  if not v.is_contiguous():
497
513
  v = v.contiguous()
498
514
  if v.dtype == torch.bool:
499
- # It does not work with dlpack
500
- # unless onnxruntime updates the version it is using.
501
- new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
502
- v.detach().numpy(), onnx.TensorProto.BOOL
503
- )
515
+ v = v.to(torch.uint8)
516
+ v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
504
517
  else:
505
- new_feeds[k] = ORTC.OrtValue.from_dlpack(v.__dlpack__(), False)
518
+ v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
519
+ input_names.append(k)
520
+ values.push_back(v)
506
521
  if self.nvtx:
507
- self.torch.cuda.nvtx.range_push("run_with_ort_values")
508
- ort_outputs = self.sess._sess.run_with_ort_values(
509
- new_feeds, output_names or self.output_names, self.run_options
522
+ self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
523
+
524
+ # ort_outputs = self.sess._sess.run_with_ort_values(
525
+ # new_feeds, output_names or self.output_names, self.run_options
526
+ # )
527
+ ort_outputs = ORTC.OrtValueVector()
528
+ out_names = output_names or self.output_names
529
+ self.sess._sess.run_with_ortvaluevector(
530
+ self.run_options,
531
+ input_names,
532
+ values,
533
+ out_names,
534
+ ort_outputs,
535
+ [DEVICES[-1 if self.cpu_outputs else device] for o in out_names],
510
536
  )
511
537
  if self.nvtx:
512
538
  self.torch.cuda.nvtx.range_pop()