onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +118 -5
- onnx_diagnostic/export/control_flow.py +214 -0
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +118 -25
- onnx_diagnostic/helpers/cache_helper.py +218 -204
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +92 -26
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +115 -16
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +108 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
import sys
|
|
5
5
|
import warnings
|
|
6
|
-
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
9
9
|
import onnx
|
|
@@ -15,6 +15,7 @@ from onnx import (
|
|
|
15
15
|
GraphProto,
|
|
16
16
|
ModelProto,
|
|
17
17
|
NodeProto,
|
|
18
|
+
OperatorSetIdProto,
|
|
18
19
|
TensorProto,
|
|
19
20
|
ValueInfoProto,
|
|
20
21
|
load as onnx_load,
|
|
@@ -671,21 +672,18 @@ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
|
|
|
671
672
|
try:
|
|
672
673
|
return oh.np_dtype_to_tensor_dtype(dt)
|
|
673
674
|
except ValueError:
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
if
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
return TensorProto.FLOAT8E5M2
|
|
687
|
-
if dt == ml_dtypes.float8_e5m2fnuz:
|
|
688
|
-
return TensorProto.FLOAT8E5M2FNUZ
|
|
675
|
+
import ml_dtypes
|
|
676
|
+
|
|
677
|
+
if dt == ml_dtypes.bfloat16:
|
|
678
|
+
return TensorProto.BFLOAT16
|
|
679
|
+
if dt == ml_dtypes.float8_e4m3fn:
|
|
680
|
+
return TensorProto.FLOAT8E4M3FN
|
|
681
|
+
if dt == ml_dtypes.float8_e4m3fnuz:
|
|
682
|
+
return TensorProto.FLOAT8E4M3FNUZ
|
|
683
|
+
if dt == ml_dtypes.float8_e5m2:
|
|
684
|
+
return TensorProto.FLOAT8E5M2
|
|
685
|
+
if dt == ml_dtypes.float8_e5m2fnuz:
|
|
686
|
+
return TensorProto.FLOAT8E5M2FNUZ
|
|
689
687
|
if dt == np.float32:
|
|
690
688
|
return TensorProto.FLOAT
|
|
691
689
|
if dt == np.float16:
|
|
@@ -1198,3 +1196,104 @@ def shadowing_names(
|
|
|
1198
1196
|
existing |= not_empty
|
|
1199
1197
|
created |= not_empty
|
|
1200
1198
|
return shadow, post_shadow, created
|
|
1199
|
+
|
|
1200
|
+
|
|
1201
|
+
def extract_subset_of_nodes(
|
|
1202
|
+
model: ModelProto,
|
|
1203
|
+
name: str,
|
|
1204
|
+
node_index: Optional[int] = None,
|
|
1205
|
+
cut_points: Optional[Set[str]] = None,
|
|
1206
|
+
) -> List[NodeProto]:
|
|
1207
|
+
"""
|
|
1208
|
+
Extracts the minimal subgraphs which can produce the output ``name``
|
|
1209
|
+
knowing ``cut_points``.
|
|
1210
|
+
|
|
1211
|
+
:param model: original model
|
|
1212
|
+
:param name: result name
|
|
1213
|
+
:param node_index: if the node index is known, otherwise searches for it
|
|
1214
|
+
:param cut_points: the known results or input name otherwise
|
|
1215
|
+
:return: minimal list of nodes
|
|
1216
|
+
"""
|
|
1217
|
+
if node_index is None:
|
|
1218
|
+
for i, node in enumerate(model.graph.node):
|
|
1219
|
+
if name in node.output:
|
|
1220
|
+
node_index = i
|
|
1221
|
+
break
|
|
1222
|
+
assert (
|
|
1223
|
+
node_index is not None
|
|
1224
|
+
and node_index < len(model.graph.node)
|
|
1225
|
+
and name in model.graph.node[node_index].output
|
|
1226
|
+
), f"node_index is still empty or wrong for result {name!r}"
|
|
1227
|
+
if cut_points is None:
|
|
1228
|
+
cut_points = {n.name for n in model.graph.input} | {
|
|
1229
|
+
n.name for n in model.graph.initializer
|
|
1230
|
+
}
|
|
1231
|
+
elif model.graph.initializer:
|
|
1232
|
+
cut_points = cut_points | {n.name for n in model.graph.initializer}
|
|
1233
|
+
|
|
1234
|
+
node = model.graph.node[node_index]
|
|
1235
|
+
selected = {node_index}
|
|
1236
|
+
current_node_index = node_index
|
|
1237
|
+
current_input_index = 0
|
|
1238
|
+
intermediate = {name}
|
|
1239
|
+
inputs = set(k for k in node.input if k)
|
|
1240
|
+
while not (inputs <= cut_points) and current_node_index >= 0:
|
|
1241
|
+
node = model.graph.node[current_node_index]
|
|
1242
|
+
if current_input_index == 0:
|
|
1243
|
+
needs = [o for o in node.output if o in intermediate and o not in cut_points]
|
|
1244
|
+
if needs:
|
|
1245
|
+
selected.add(current_node_index)
|
|
1246
|
+
else:
|
|
1247
|
+
current_node_index -= 1
|
|
1248
|
+
continue
|
|
1249
|
+
res = node.input[current_input_index]
|
|
1250
|
+
if res not in cut_points:
|
|
1251
|
+
intermediate.add(res)
|
|
1252
|
+
current_input_index += 1
|
|
1253
|
+
if current_input_index >= len(node.input):
|
|
1254
|
+
current_node_index -= 1
|
|
1255
|
+
current_input_index = 0
|
|
1256
|
+
|
|
1257
|
+
return [model.graph.node[i] for i in sorted(selected)]
|
|
1258
|
+
|
|
1259
|
+
|
|
1260
|
+
def make_submodel(
|
|
1261
|
+
nodes: List[NodeProto],
|
|
1262
|
+
ir_version: int,
|
|
1263
|
+
opset_imports: List[OperatorSetIdProto],
|
|
1264
|
+
output_names: List[str],
|
|
1265
|
+
type_rank_fn: Callable[[str], Tuple[int, int]],
|
|
1266
|
+
) -> ModelProto:
|
|
1267
|
+
"""
|
|
1268
|
+
Creates a model with the given list of nodes.
|
|
1269
|
+
It computes the minimum list of inputs needed for this model.
|
|
1270
|
+
The function assumes the nodes are sorted.
|
|
1271
|
+
It does not handle yet subgraphs.
|
|
1272
|
+
|
|
1273
|
+
:param nodes: list of nodes
|
|
1274
|
+
:param ir_version: ir version
|
|
1275
|
+
:param opset_imports: opset import
|
|
1276
|
+
:param output_names: desired outputs
|
|
1277
|
+
:param function: function returning the type and the rank of a result
|
|
1278
|
+
:return: model proto
|
|
1279
|
+
"""
|
|
1280
|
+
|
|
1281
|
+
def _mkv_(name, itype, irank):
|
|
1282
|
+
return oh.make_tensor_value_info(name, itype, [f"{name}_d{i}" for i in range(irank)])
|
|
1283
|
+
|
|
1284
|
+
not_known: Set[str] = set()
|
|
1285
|
+
for node in nodes[::-1]:
|
|
1286
|
+
not_known -= set(node.output)
|
|
1287
|
+
not_known |= set(node.input)
|
|
1288
|
+
|
|
1289
|
+
model = oh.make_model(
|
|
1290
|
+
oh.make_graph(
|
|
1291
|
+
nodes,
|
|
1292
|
+
"submodel",
|
|
1293
|
+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known)],
|
|
1294
|
+
[_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names)],
|
|
1295
|
+
),
|
|
1296
|
+
ir_version=ir_version,
|
|
1297
|
+
opset_imports=opset_imports,
|
|
1298
|
+
)
|
|
1299
|
+
return model
|
|
@@ -108,7 +108,10 @@ class _InferenceSession:
|
|
|
108
108
|
session_options,
|
|
109
109
|
providers=providers,
|
|
110
110
|
)
|
|
111
|
-
except
|
|
111
|
+
except (
|
|
112
|
+
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
|
113
|
+
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
|
114
|
+
) as e:
|
|
112
115
|
if isinstance(sess, onnx.ModelProto):
|
|
113
116
|
debug_path = "_debug_InferenceSession_last_failure.onnx"
|
|
114
117
|
onnx.save(
|
|
@@ -134,7 +137,13 @@ class _InferenceSession:
|
|
|
134
137
|
|
|
135
138
|
self.sess = sess
|
|
136
139
|
self.input_names = [i.name for i in sess.get_inputs()]
|
|
140
|
+
assert (
|
|
141
|
+
"" not in self.input_names
|
|
142
|
+
), f"Input name cannot be empty but input_names={self.input_names}"
|
|
137
143
|
self.output_names = [i.name for i in sess.get_outputs()]
|
|
144
|
+
assert (
|
|
145
|
+
"" not in self.input_names
|
|
146
|
+
), f"Output name cannot be empty but output_names={self.output_names}"
|
|
138
147
|
self.input_shapes = [i.shape for i in sess.get_inputs()]
|
|
139
148
|
self.output_shapes = [i.shape for i in sess.get_outputs()]
|
|
140
149
|
self.input_types = [i.type for i in sess.get_inputs()]
|
|
@@ -338,6 +347,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
338
347
|
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
|
|
339
348
|
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
|
|
340
349
|
:param use_training_api: use onnxruntime-traning API
|
|
350
|
+
:param cpu_output: if True, force the outputs to be on CPU
|
|
341
351
|
"""
|
|
342
352
|
|
|
343
353
|
def __init__(
|
|
@@ -353,6 +363,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
353
363
|
optimized_model_filepath: Optional[str] = None,
|
|
354
364
|
disable_aot_function_inlining: Optional[bool] = None,
|
|
355
365
|
use_training_api: Optional[bool] = None,
|
|
366
|
+
cpu_outputs: bool = False,
|
|
356
367
|
):
|
|
357
368
|
super().__init__(
|
|
358
369
|
sess,
|
|
@@ -367,6 +378,7 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
367
378
|
disable_aot_function_inlining=disable_aot_function_inlining,
|
|
368
379
|
use_training_api=use_training_api,
|
|
369
380
|
)
|
|
381
|
+
self.cpu_outputs = cpu_outputs
|
|
370
382
|
|
|
371
383
|
def _get_ortvalues_from_torch_tensors(
|
|
372
384
|
self, tensors: Tuple[torch.Tensor, ...], n_outputs: int
|
|
@@ -490,23 +502,37 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
490
502
|
feeds is a dictionary of :class:`torch.Tensor`.
|
|
491
503
|
The output device is CPU even if the outputs are on CUDA.
|
|
492
504
|
"""
|
|
493
|
-
|
|
505
|
+
input_names = []
|
|
506
|
+
values = ORTC.OrtValueVector()
|
|
507
|
+
device = -1
|
|
494
508
|
for k, v in feeds.items():
|
|
509
|
+
assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
|
|
510
|
+
device = max(device, v.get_device())
|
|
495
511
|
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
|
|
496
512
|
if not v.is_contiguous():
|
|
497
513
|
v = v.contiguous()
|
|
498
514
|
if v.dtype == torch.bool:
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
new_feeds[k] = ORTC.OrtValue.ortvalue_from_numpy_with_onnx_type(
|
|
502
|
-
v.detach().numpy(), onnx.TensorProto.BOOL
|
|
503
|
-
)
|
|
515
|
+
v = v.to(torch.uint8)
|
|
516
|
+
v = ORTC.OrtValue.from_dlpack(v.__dlpack__(), True)
|
|
504
517
|
else:
|
|
505
|
-
|
|
518
|
+
v = ORTC.OrtValue.from_dlpack(v.detach().__dlpack__(), False)
|
|
519
|
+
input_names.append(k)
|
|
520
|
+
values.push_back(v)
|
|
506
521
|
if self.nvtx:
|
|
507
|
-
self.torch.cuda.nvtx.range_push("
|
|
508
|
-
|
|
509
|
-
|
|
522
|
+
self.torch.cuda.nvtx.range_push("run_with_ortvaluevector")
|
|
523
|
+
|
|
524
|
+
# ort_outputs = self.sess._sess.run_with_ort_values(
|
|
525
|
+
# new_feeds, output_names or self.output_names, self.run_options
|
|
526
|
+
# )
|
|
527
|
+
ort_outputs = ORTC.OrtValueVector()
|
|
528
|
+
out_names = output_names or self.output_names
|
|
529
|
+
self.sess._sess.run_with_ortvaluevector(
|
|
530
|
+
self.run_options,
|
|
531
|
+
input_names,
|
|
532
|
+
values,
|
|
533
|
+
out_names,
|
|
534
|
+
ort_outputs,
|
|
535
|
+
[DEVICES[-1 if self.cpu_outputs else device] for o in out_names],
|
|
510
536
|
)
|
|
511
537
|
if self.nvtx:
|
|
512
538
|
self.torch.cuda.nvtx.range_pop()
|