onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +2 -2
- onnx_diagnostic/_command_lines_parser.py +21 -1
- onnx_diagnostic/export/dynamic_shapes.py +14 -5
- onnx_diagnostic/ext_test_case.py +12 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +24 -0
- onnx_diagnostic/helpers/model_builder_helper.py +333 -0
- onnx_diagnostic/helpers/rt_helper.py +65 -1
- onnx_diagnostic/torch_export_patches/eval/__init__.py +621 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +896 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +25 -19
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +29 -1
- onnx_diagnostic/torch_models/test_helper.py +110 -7
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/RECORD +21 -17
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -107,7 +107,7 @@ def torch_export_patches(
|
|
|
107
107
|
) -> Callable:
|
|
108
108
|
"""
|
|
109
109
|
Tries to bypass some situations :func:`torch.export.export` does not support.
|
|
110
|
-
See also :ref:`l-patches-explained`.
|
|
110
|
+
See also :ref:`l-patches-explained` and :ref:`l-patch-coverage`.
|
|
111
111
|
|
|
112
112
|
:param patch_sympy: fix missing method ``name`` for IntegerConstant
|
|
113
113
|
:param patch_torch: patches :epkg:`torch` with supported implementation
|
|
@@ -140,6 +140,7 @@ def torch_export_patches(
|
|
|
140
140
|
* ``torch.jit.isinstance``
|
|
141
141
|
* ``torch._dynamo.mark_static_address``
|
|
142
142
|
* ``torch._subclasses.fake_impls.infer_size``
|
|
143
|
+
* ``torch.vmap``
|
|
143
144
|
* fix missing method ``name`` for ``sympy.S.IntegerConstant``
|
|
144
145
|
* ``AttentionMaskConverter._make_causal_mask``
|
|
145
146
|
* Serialization of ``MambaCache`` (in :epkg:`transformers`)
|
|
@@ -251,6 +252,7 @@ def torch_export_patches(
|
|
|
251
252
|
if patch_torch:
|
|
252
253
|
from .patches.patch_torch import (
|
|
253
254
|
patched_infer_size,
|
|
255
|
+
patched_vmap,
|
|
254
256
|
patched__broadcast_shapes,
|
|
255
257
|
_catch_produce_guards_and_solve_constraints,
|
|
256
258
|
patch__check_input_constraints_for_graph,
|
|
@@ -261,6 +263,10 @@ def torch_export_patches(
|
|
|
261
263
|
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
|
|
262
264
|
print("[torch_export_patches] patch pytorch")
|
|
263
265
|
|
|
266
|
+
# torch.vmap
|
|
267
|
+
f_vmap = torch.vmap
|
|
268
|
+
torch.vmap = patched_vmap
|
|
269
|
+
|
|
264
270
|
# torch.jit.isinstance
|
|
265
271
|
f_jit_isinstance = torch.jit.isinstance
|
|
266
272
|
torch.jit.isinstance = isinstance
|
|
@@ -328,6 +334,11 @@ def torch_export_patches(
|
|
|
328
334
|
####################
|
|
329
335
|
|
|
330
336
|
if patch_transformers:
|
|
337
|
+
try:
|
|
338
|
+
import transformers.masking_utils as masking_utils
|
|
339
|
+
except ImportError:
|
|
340
|
+
masking_utils = None
|
|
341
|
+
|
|
331
342
|
if verbose:
|
|
332
343
|
import transformers
|
|
333
344
|
|
|
@@ -339,6 +350,15 @@ def torch_export_patches(
|
|
|
339
350
|
patch_transformers_list, verbose=verbose
|
|
340
351
|
)
|
|
341
352
|
|
|
353
|
+
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
|
|
354
|
+
if verbose:
|
|
355
|
+
print(
|
|
356
|
+
"[torch_export_patches] patches "
|
|
357
|
+
"transformers.masking_utils._vmap_for_bhqkv"
|
|
358
|
+
)
|
|
359
|
+
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
360
|
+
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
361
|
+
|
|
342
362
|
if custom_patches:
|
|
343
363
|
if verbose:
|
|
344
364
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -381,6 +401,7 @@ def torch_export_patches(
|
|
|
381
401
|
|
|
382
402
|
if patch_torch:
|
|
383
403
|
# this should disappear when torch.jit is removed
|
|
404
|
+
torch.vmap = f_vmap
|
|
384
405
|
torch.jit.isinstance = f_jit_isinstance
|
|
385
406
|
torch._dynamo.mark_static_address = f_mark_static_address
|
|
386
407
|
# tracked by https://github.com/pytorch/pytorch/issues/143495
|
|
@@ -430,12 +451,24 @@ def torch_export_patches(
|
|
|
430
451
|
##############
|
|
431
452
|
|
|
432
453
|
if patch_transformers:
|
|
454
|
+
try:
|
|
455
|
+
import transformers.masking_utils as masking_utils
|
|
456
|
+
except ImportError:
|
|
457
|
+
masking_utils = None
|
|
433
458
|
if verbose:
|
|
434
459
|
print("[torch_export_patches] unpatch transformers")
|
|
435
460
|
unpatch_module_or_classes(
|
|
436
461
|
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
437
462
|
)
|
|
438
463
|
|
|
464
|
+
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
|
|
465
|
+
if verbose:
|
|
466
|
+
print(
|
|
467
|
+
"[torch_export_patches] unpatch "
|
|
468
|
+
"transformers.masking_utils._vmap_for_bhqkv"
|
|
469
|
+
)
|
|
470
|
+
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
471
|
+
|
|
439
472
|
########
|
|
440
473
|
# caches
|
|
441
474
|
########
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import pprint
|
|
2
|
-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
3
3
|
import packaging.version as pv
|
|
4
4
|
import optree
|
|
5
5
|
import torch
|
|
@@ -133,6 +133,11 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
|
|
|
133
133
|
# To avoid doing it multiple times.
|
|
134
134
|
PATCH_OF_PATCHES.add(BaseModelOutput)
|
|
135
135
|
|
|
136
|
+
return serialization_functions(verbose=verbose)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]:
|
|
140
|
+
"""Returns the list of serialization functions."""
|
|
136
141
|
return dict(
|
|
137
142
|
DynamicCache=register_class_serialization(
|
|
138
143
|
DynamicCache,
|
|
@@ -19,6 +19,28 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node":
|
|
|
19
19
|
return new_node
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
def _rewrite_bart_encoder_layer():
|
|
23
|
+
"BartEncoderLayer, PLBartEncoderLayer"
|
|
24
|
+
import transformers
|
|
25
|
+
|
|
26
|
+
bd = dict(
|
|
27
|
+
filter_node=(
|
|
28
|
+
lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
|
|
29
|
+
),
|
|
30
|
+
pre_rewriter=ast_or_into_bitor,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def _add(f):
|
|
34
|
+
g = bd.copy()
|
|
35
|
+
g["function"] = f
|
|
36
|
+
return g
|
|
37
|
+
|
|
38
|
+
return [
|
|
39
|
+
_add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward),
|
|
40
|
+
_add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward),
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
22
44
|
def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
|
|
23
45
|
"""
|
|
24
46
|
Returns a known list of methods or functions to rewrite because of control flow
|
|
@@ -30,11 +52,12 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
|
|
|
30
52
|
.. runpython::
|
|
31
53
|
:showcode:
|
|
32
54
|
|
|
55
|
+
import pprint
|
|
33
56
|
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
|
|
34
57
|
code_needing_rewriting,
|
|
35
58
|
)
|
|
36
59
|
|
|
37
|
-
|
|
60
|
+
pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
|
|
38
61
|
"""
|
|
39
62
|
if cls_name in {
|
|
40
63
|
"BartEncoderLayer",
|
|
@@ -42,22 +65,5 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]:
|
|
|
42
65
|
"PLBartEncoderLayer",
|
|
43
66
|
"PLBartForConditionalGeneration",
|
|
44
67
|
}:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
bd = dict(
|
|
48
|
-
filter_node=(
|
|
49
|
-
lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name)
|
|
50
|
-
),
|
|
51
|
-
pre_rewriter=ast_or_into_bitor,
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
def _add(f):
|
|
55
|
-
g = bd.copy()
|
|
56
|
-
g["function"] = f
|
|
57
|
-
return g
|
|
58
|
-
|
|
59
|
-
return [
|
|
60
|
-
_add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward),
|
|
61
|
-
_add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward),
|
|
62
|
-
]
|
|
68
|
+
return _rewrite_bart_encoder_layer()
|
|
63
69
|
return None
|
|
@@ -370,3 +370,94 @@ class patched_ShapeEnv:
|
|
|
370
370
|
# RuntimeWarning,
|
|
371
371
|
# stacklevel=0,
|
|
372
372
|
# )
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def patched_vmap(func, in_dims=0, out_dims=0):
|
|
376
|
+
"""
|
|
377
|
+
Python implementation of :func:`torch.vmap`.
|
|
378
|
+
The implementation raises an issue when it is being exported with
|
|
379
|
+
:func:`torch.export.export` when the function is called with
|
|
380
|
+
non tensors arguments and the batch size is dynamic.
|
|
381
|
+
"""
|
|
382
|
+
from ...helpers import string_type
|
|
383
|
+
|
|
384
|
+
def wrapped(*args):
|
|
385
|
+
assert all(not isinstance(a, dict) for a in args), (
|
|
386
|
+
f"dictionaries are not implemented in "
|
|
387
|
+
f"args={string_type(args, with_shape=True)}"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
in_dims_ = (
|
|
391
|
+
([in_dims] * len(args))
|
|
392
|
+
if not isinstance(in_dims, (list, tuple))
|
|
393
|
+
else list(in_dims)
|
|
394
|
+
)
|
|
395
|
+
assert len(in_dims_) == len(args), (
|
|
396
|
+
f"Mismtch between in_dims={in_dims_} and "
|
|
397
|
+
f"args={string_type(args, with_shape=True)}"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
batch_size = None
|
|
401
|
+
batched_args = []
|
|
402
|
+
for arg, in_dim in zip(args, in_dims_):
|
|
403
|
+
if in_dim is None:
|
|
404
|
+
batched_args.append(arg)
|
|
405
|
+
continue
|
|
406
|
+
|
|
407
|
+
assert batch_size is None or batch_size == arg.size(in_dim), (
|
|
408
|
+
f"Unable to continue, batch_size={batch_size}, in_dim={in_dim}, "
|
|
409
|
+
f"arg.size(in_dim)={arg.size(in_dim)}"
|
|
410
|
+
)
|
|
411
|
+
if batch_size is None:
|
|
412
|
+
batch_size = arg.size(in_dim)
|
|
413
|
+
arg = arg.movedim(in_dim, 0)
|
|
414
|
+
batched_args.append(arg)
|
|
415
|
+
|
|
416
|
+
if all(isinstance(a, torch.Tensor) for a in args) and isinstance(
|
|
417
|
+
batch_size, torch.SymInt
|
|
418
|
+
):
|
|
419
|
+
batched_tensors = [
|
|
420
|
+
(
|
|
421
|
+
arg
|
|
422
|
+
if (isinstance(arg, torch.Tensor) and in_dim is not None)
|
|
423
|
+
else arg.unsqueeze(0).expand((batch_size, *arg.shape))
|
|
424
|
+
)
|
|
425
|
+
for arg, in_dim in zip(batched_args, in_dims_)
|
|
426
|
+
]
|
|
427
|
+
results = torch.ops.higher_order.scan(func, [], batched_tensors, [])
|
|
428
|
+
stacked = results[0]
|
|
429
|
+
if out_dims != 0:
|
|
430
|
+
return stacked.movedim(0, out_dims)
|
|
431
|
+
return stacked
|
|
432
|
+
|
|
433
|
+
else:
|
|
434
|
+
torch._check(
|
|
435
|
+
not isinstance(batch_size, torch.SymInt),
|
|
436
|
+
lambda: (
|
|
437
|
+
f"patched_vmap supports dynamic batch_size only if all argument "
|
|
438
|
+
f"are tensors but types are {[type(a) for a in args]}"
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
batched_tensors = [
|
|
442
|
+
(
|
|
443
|
+
(None, arg)
|
|
444
|
+
if (isinstance(arg, torch.Tensor) and in_dim is not None)
|
|
445
|
+
else (arg, arg)
|
|
446
|
+
)
|
|
447
|
+
for arg, in_dim in zip(batched_args, in_dims_)
|
|
448
|
+
]
|
|
449
|
+
|
|
450
|
+
results = []
|
|
451
|
+
for i in range(batch_size):
|
|
452
|
+
input_slice = [v if v is not None else arg[i] for v, arg in batched_tensors]
|
|
453
|
+
result = func(*input_slice)
|
|
454
|
+
results.append(result)
|
|
455
|
+
|
|
456
|
+
if isinstance(results[0], torch.Tensor):
|
|
457
|
+
stacked = torch.stack(results)
|
|
458
|
+
if out_dims != 0:
|
|
459
|
+
return stacked.movedim(0, out_dims)
|
|
460
|
+
return stacked
|
|
461
|
+
return results
|
|
462
|
+
|
|
463
|
+
return wrapped
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
4
4
|
import torch
|
|
5
5
|
import transformers
|
|
6
6
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
@@ -9,6 +9,34 @@ from ...ext_test_case import has_transformers
|
|
|
9
9
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
13
|
+
"""Patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
14
|
+
from ...helpers import string_type
|
|
15
|
+
|
|
16
|
+
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
17
|
+
(None, None, None, 0),
|
|
18
|
+
(None, None, 0, None),
|
|
19
|
+
]
|
|
20
|
+
if bh_indices:
|
|
21
|
+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
22
|
+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
23
|
+
dimensions = tuple(reversed(dimensions))
|
|
24
|
+
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
25
|
+
|
|
26
|
+
def vector_mask_function(
|
|
27
|
+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
28
|
+
):
|
|
29
|
+
assert len(args) == len(
|
|
30
|
+
dimensions
|
|
31
|
+
), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
|
|
32
|
+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
33
|
+
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
34
|
+
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
35
|
+
return mask_function(*expanded_args)
|
|
36
|
+
|
|
37
|
+
return vector_mask_function
|
|
38
|
+
|
|
39
|
+
|
|
12
40
|
def _patch_make_causal_mask(
|
|
13
41
|
input_ids_shape: torch.Size,
|
|
14
42
|
dtype: torch.dtype,
|
|
@@ -345,8 +345,32 @@ def validate_model(
|
|
|
345
345
|
)
|
|
346
346
|
),
|
|
347
347
|
)
|
|
348
|
+
|
|
349
|
+
if exporter == "modelbuilder":
|
|
350
|
+
# Models used with ModelBuilder do not like batch size > 1.
|
|
351
|
+
# Let's change that.
|
|
352
|
+
for k in ["inputs", "inputs2"]:
|
|
353
|
+
if k not in data:
|
|
354
|
+
continue
|
|
355
|
+
if verbose:
|
|
356
|
+
print(f"[validate_model] set batch=1 for data[{k!r}]")
|
|
357
|
+
print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
|
|
358
|
+
cpl = CoupleInputsDynamicShapes(
|
|
359
|
+
tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
|
|
360
|
+
)
|
|
361
|
+
data[k] = cpl.change_dynamic_dimensions(
|
|
362
|
+
desired_values=dict(batch=1), only_desired=True
|
|
363
|
+
)
|
|
364
|
+
if verbose:
|
|
365
|
+
print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
|
|
366
|
+
|
|
348
367
|
data["input_options"] = iop
|
|
349
368
|
data["model_options"] = mop
|
|
369
|
+
data["model_dump_folder"] = dump_folder
|
|
370
|
+
if dtype:
|
|
371
|
+
data["model_dtype"] = dtype if isinstance(dtype, str) else str(dtype)
|
|
372
|
+
if device:
|
|
373
|
+
data["model_device"] = str(device)
|
|
350
374
|
if opset:
|
|
351
375
|
data["model_opset"] = opset
|
|
352
376
|
if "rewrite" in data:
|
|
@@ -555,6 +579,16 @@ def validate_model(
|
|
|
555
579
|
begin = time.perf_counter()
|
|
556
580
|
if isinstance(epo, onnx.model_container.ModelContainer):
|
|
557
581
|
epo.save(onnx_filename, all_tensors_to_one_file=True)
|
|
582
|
+
elif isinstance(epo, onnx.ModelProto):
|
|
583
|
+
if os.path.exists(f"{onnx_filename}.data"):
|
|
584
|
+
os.remove(f"{onnx_filename}.data")
|
|
585
|
+
onnx.save(
|
|
586
|
+
epo,
|
|
587
|
+
onnx_filename,
|
|
588
|
+
save_as_external_data=True,
|
|
589
|
+
all_tensors_to_one_file=True,
|
|
590
|
+
location=f"{os.path.split(onnx_filename)[-1]}.data",
|
|
591
|
+
)
|
|
558
592
|
else:
|
|
559
593
|
epo.save(onnx_filename, external_data=True)
|
|
560
594
|
duration = time.perf_counter() - begin
|
|
@@ -572,7 +606,8 @@ def validate_model(
|
|
|
572
606
|
print("[validate_model] done (dump)")
|
|
573
607
|
|
|
574
608
|
if not exporter or (
|
|
575
|
-
not exporter.startswith(("onnx-", "custom-"))
|
|
609
|
+
not exporter.startswith(("onnx-", "custom-"))
|
|
610
|
+
and exporter not in ("custom", "modelbuilder")
|
|
576
611
|
):
|
|
577
612
|
if verbose:
|
|
578
613
|
print("[validate_model] -- done (final)")
|
|
@@ -704,6 +739,16 @@ def call_exporter(
|
|
|
704
739
|
dump_folder=dump_folder,
|
|
705
740
|
)
|
|
706
741
|
return summary, data
|
|
742
|
+
if exporter == "modelbuilder":
|
|
743
|
+
# torch export
|
|
744
|
+
summary, data = call_torch_export_model_builder(
|
|
745
|
+
exporter=exporter,
|
|
746
|
+
data=data,
|
|
747
|
+
quiet=quiet,
|
|
748
|
+
verbose=verbose,
|
|
749
|
+
optimization=optimization,
|
|
750
|
+
)
|
|
751
|
+
return summary, data
|
|
707
752
|
raise NotImplementedError(
|
|
708
753
|
f"export with {exporter!r} and optimization={optimization!r} not implemented yet"
|
|
709
754
|
)
|
|
@@ -871,6 +916,8 @@ def validate_onnx_model(
|
|
|
871
916
|
if input_data_key in data:
|
|
872
917
|
source = data[input_data_key]
|
|
873
918
|
if not os.path.exists(source):
|
|
919
|
+
if verbose:
|
|
920
|
+
print(f"[validate_onnx_model] missing {source!r}")
|
|
874
921
|
summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})"
|
|
875
922
|
return summary, data
|
|
876
923
|
summary[input_data_key] = source
|
|
@@ -911,12 +958,7 @@ def validate_onnx_model(
|
|
|
911
958
|
if verbose:
|
|
912
959
|
print("[validate_onnx_model] -- make_feeds...")
|
|
913
960
|
print(f"[validate_onnx_model] inputs={string_type(data['inputs'], with_shape=True)}")
|
|
914
|
-
feeds = make_feeds(
|
|
915
|
-
[i.name for i in sess.get_inputs()],
|
|
916
|
-
data["inputs"],
|
|
917
|
-
use_numpy=True,
|
|
918
|
-
check_flatten=False,
|
|
919
|
-
)
|
|
961
|
+
feeds = make_feeds(sess, data["inputs"], use_numpy=True, check_flatten=False)
|
|
920
962
|
if verbose:
|
|
921
963
|
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
|
|
922
964
|
summary[_mk("onnx_ort_inputs")] = string_type(feeds, with_shape=True)
|
|
@@ -1085,6 +1127,67 @@ def call_torch_export_onnx(
|
|
|
1085
1127
|
return summary, data
|
|
1086
1128
|
|
|
1087
1129
|
|
|
1130
|
+
def call_torch_export_model_builder(
|
|
1131
|
+
data: Dict[str, Any],
|
|
1132
|
+
exporter: str,
|
|
1133
|
+
quiet: bool = False,
|
|
1134
|
+
verbose: int = 0,
|
|
1135
|
+
optimization: Optional[str] = None,
|
|
1136
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1137
|
+
"""
|
|
1138
|
+
Exports a model into onnx with :epkg:`ModelBuilder`.
|
|
1139
|
+
|
|
1140
|
+
:param data: dictionary with all the necessary inputs, the dictionary must
|
|
1141
|
+
contains keys ``model`` and ``inputs_export``
|
|
1142
|
+
:param exporter: exporter to call
|
|
1143
|
+
:param quiet: catch exception or not
|
|
1144
|
+
:param verbose: verbosity
|
|
1145
|
+
:param optimization: optimization to do
|
|
1146
|
+
:return: two dictionaries, one with some metrics,
|
|
1147
|
+
another one with whatever the function produces
|
|
1148
|
+
"""
|
|
1149
|
+
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
|
|
1150
|
+
|
|
1151
|
+
assert optimization in (
|
|
1152
|
+
None,
|
|
1153
|
+
"",
|
|
1154
|
+
), f"unexpected value for optimization={optimization}, none is available"
|
|
1155
|
+
precision = data.get("model_dtype", "fp32")
|
|
1156
|
+
provider = data.get("model_device", "cpu")
|
|
1157
|
+
dump_folder = data.get("model_dump_folder", "")
|
|
1158
|
+
assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
|
|
1159
|
+
cache_dir = os.path.join(dump_folder, "cache_mb")
|
|
1160
|
+
if not os.path.exists(cache_dir):
|
|
1161
|
+
os.makedirs(cache_dir)
|
|
1162
|
+
summary: Dict[str, Any] = {}
|
|
1163
|
+
|
|
1164
|
+
epo = _quiet_or_not_quiet(
|
|
1165
|
+
quiet,
|
|
1166
|
+
"export_model_builder",
|
|
1167
|
+
summary,
|
|
1168
|
+
data,
|
|
1169
|
+
(
|
|
1170
|
+
lambda m=data["model"], c=data[
|
|
1171
|
+
"configuration"
|
|
1172
|
+
], p=precision, pr=provider, cd=cache_dir: (
|
|
1173
|
+
save_model_builder(
|
|
1174
|
+
create_model_builder(
|
|
1175
|
+
c, m, precision=p, execution_provider=pr, cache_dir=cd
|
|
1176
|
+
)
|
|
1177
|
+
)
|
|
1178
|
+
)
|
|
1179
|
+
),
|
|
1180
|
+
)
|
|
1181
|
+
if "ERR_export_model_builder" in summary:
|
|
1182
|
+
return summary, data
|
|
1183
|
+
|
|
1184
|
+
assert epo is not None, "no onnx export was found"
|
|
1185
|
+
if verbose:
|
|
1186
|
+
print("[call_torch_export_model_builder] done (export)")
|
|
1187
|
+
data["onnx_program"] = epo
|
|
1188
|
+
return summary, data
|
|
1189
|
+
|
|
1190
|
+
|
|
1088
1191
|
def call_torch_export_custom(
|
|
1089
1192
|
data: Dict[str, Any],
|
|
1090
1193
|
exporter: str,
|
|
@@ -1,22 +1,24 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=5S8PigU8f0RN8fU9ddmGimcrAr1kUtpUORdZRi96mHw,173
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
|
-
onnx_diagnostic/_command_lines_parser.py,sha256=
|
|
3
|
+
onnx_diagnostic/_command_lines_parser.py,sha256=yy4upYkizwu-8M6ErGKhzwiX5fW8dWT_34EBOT7CPPQ,18632
|
|
4
4
|
onnx_diagnostic/doc.py,sha256=MTuT7Kxyvn7KEy84liQeFeqhugJrUQhjjpx21F72Uxw,926
|
|
5
|
-
onnx_diagnostic/ext_test_case.py,sha256=
|
|
5
|
+
onnx_diagnostic/ext_test_case.py,sha256=PVkneWhs-jt_nkfH06hv12WjtaBX0Rim1r_dxGPXjq0,42256
|
|
6
6
|
onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
|
|
7
|
-
onnx_diagnostic/export/dynamic_shapes.py,sha256=
|
|
7
|
+
onnx_diagnostic/export/dynamic_shapes.py,sha256=EHB7VoWNx8sVetvOgE1vgC7wHtIjWDLjanhbEJNpK88,39892
|
|
8
8
|
onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
|
|
9
9
|
onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
|
|
10
10
|
onnx_diagnostic/helpers/args_helper.py,sha256=7pTrw1A1wuNvLdXJdpda5spPI140FylwSmxxZTGu_4E,4389
|
|
11
11
|
onnx_diagnostic/helpers/bench_run.py,sha256=CGA6VMJZMH2gDhVueT9ypNm4PMcjGrrGFYp08nhWj9k,16539
|
|
12
12
|
onnx_diagnostic/helpers/cache_helper.py,sha256=soKjyIXa7EQgALd9PAUGIKYzXlJGoLevYiQDsxoqkQ4,8349
|
|
13
13
|
onnx_diagnostic/helpers/config_helper.py,sha256=aZATKVbZuw8L56KQpwMNcqJ3Qi5OplzS_N3ETR3hmj0,3351
|
|
14
|
-
onnx_diagnostic/helpers/
|
|
14
|
+
onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
|
|
15
|
+
onnx_diagnostic/helpers/helper.py,sha256=h7nuAUWBvLgOq95AY9xIer4rcXhLxkAM1-u-QliAx5A,56929
|
|
15
16
|
onnx_diagnostic/helpers/memory_peak.py,sha256=OT6mz0muBbBZY0pjgW2_eCk_lOtFRo-5w4jFo2Z6Kok,6380
|
|
16
17
|
onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=R1Vu4zHzN7GIUnbMVQzpkaXj8cCyyOweWOI9-TSgAHM,20966
|
|
18
|
+
onnx_diagnostic/helpers/model_builder_helper.py,sha256=wHd2qqfGNvdgjGBmXyyZhfikjyM_-ijPkLbMRkqW0Pg,12963
|
|
17
19
|
onnx_diagnostic/helpers/onnx_helper.py,sha256=chw-HB4iqGCD_16d0_BaCnreEgWYW4KeH78nh-3t2Uw,29213
|
|
18
20
|
onnx_diagnostic/helpers/ort_session.py,sha256=UgUUeUslDxEFBc6w6f3HMq_a7bn4TBlItmojqWquSj4,29281
|
|
19
|
-
onnx_diagnostic/helpers/rt_helper.py,sha256=
|
|
21
|
+
onnx_diagnostic/helpers/rt_helper.py,sha256=PbMRp0AQGKxj9B8_-oIMYPG5o86jTXJalkoBTmg4VZs,4147
|
|
20
22
|
onnx_diagnostic/helpers/torch_helper.py,sha256=83HnFGcOX8YmPDAikhxFpBydxfI3gyWPDiRHYidrH6A,31531
|
|
21
23
|
onnx_diagnostic/reference/__init__.py,sha256=0Al5kins8LlBICAsszEZ59thMwmaARBO6fMwtYpKOOQ,98
|
|
22
24
|
onnx_diagnostic/reference/evaluator.py,sha256=RzNzjFDeMe-4X51Tb22N6aagazY5ktNq-mRmPcfY5EU,8848
|
|
@@ -66,18 +68,20 @@ onnx_diagnostic/tasks/text_classification.py,sha256=OgC_G9iumzTjTNUEvMoFFNTHCD8_
|
|
|
66
68
|
onnx_diagnostic/tasks/text_generation.py,sha256=Wv8DamBHte355wXe_tAeVxG4EL20y86fu7JEmUM75to,10385
|
|
67
69
|
onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=N3cEG1Lq95wS1N_CWUUUCU5j-4Tp5eR8Ce68U8THYAk,4380
|
|
68
70
|
onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
|
|
69
|
-
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=
|
|
70
|
-
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=
|
|
71
|
+
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=MgOQwgLf6-uCGQaiUrhVNfZQ43dCp1iWGbzLbKEVyc8,18810
|
|
72
|
+
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=l5HvE_FfdCtgLJBmJczH6nA8jZY1725AggiHwoAa-o0,15763
|
|
71
73
|
onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
|
|
72
74
|
onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=9b4pmyT00BwLqi7WG-gliep1RUy3gXEgW6BDnlSSA-M,7689
|
|
73
75
|
onnx_diagnostic/torch_export_patches/patch_module.py,sha256=R2d9IHM-RwsBKDsxuBIJnEqMoxbS9gd4YWFGG2wwV5A,39881
|
|
74
|
-
onnx_diagnostic/torch_export_patches/patch_module_helper.py,sha256
|
|
76
|
+
onnx_diagnostic/torch_export_patches/patch_module_helper.py,sha256=-sFpuBnwPl61Y0KKENniMfQkL-0-3SaLn5mzgF-fP6g,1946
|
|
77
|
+
onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=V8gbjsbYJHyjLo4WyzhsDx-noBQn6bsK1lhgm8IQlzQ,20890
|
|
78
|
+
onnx_diagnostic/torch_export_patches/eval/model_cases.py,sha256=KtK4AgrFE2Mm8wo15O7x4deTGubOU9MFp3QCuac6WkM,26471
|
|
75
79
|
onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
76
|
-
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=
|
|
77
|
-
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=
|
|
80
|
+
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=KaZ8TjDa9ATgT4HllYzzoNf_51q_yOj_GuF5NYjPCrU,18913
|
|
81
|
+
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=QGuZ_j8juG9dlSUBSpc2T1nYwkoeFz3iP2kONx2Hmrc,23025
|
|
78
82
|
onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
79
83
|
onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
|
|
80
|
-
onnx_diagnostic/torch_models/test_helper.py,sha256=
|
|
84
|
+
onnx_diagnostic/torch_models/test_helper.py,sha256=SU8HwQtfUEicCpukp_prqX0ol5fu2rA06lc1GROxW38,55309
|
|
81
85
|
onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
|
|
82
86
|
onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=BgM_p57Q0gT9GOhdrmOYcnbuTTzCWp80jS4OQqWwFhs,9990
|
|
83
87
|
onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=885wKyZkdM-Qp5Sg6C9Ol1dxigmA8FYAko-Ys08sppo,8096
|
|
@@ -88,8 +92,8 @@ onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiF
|
|
|
88
92
|
onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=7N3fGvT_4Mn4NbIo0Qk57c6DMc3OXGWyvj_P41rjwSY,3513
|
|
89
93
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
90
94
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=1EL25DeYFzlBSiFG_XjePBLvsiItRXbdDrr5-QZW2mA,16878
|
|
91
|
-
onnx_diagnostic-0.
|
|
92
|
-
onnx_diagnostic-0.
|
|
93
|
-
onnx_diagnostic-0.
|
|
94
|
-
onnx_diagnostic-0.
|
|
95
|
-
onnx_diagnostic-0.
|
|
95
|
+
onnx_diagnostic-0.6.0.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
96
|
+
onnx_diagnostic-0.6.0.dist-info/METADATA,sha256=jqrQEqFmrlycVGKW_zlSrxhzMoeV-X-H4eSeGBZyK1o,6643
|
|
97
|
+
onnx_diagnostic-0.6.0.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
|
98
|
+
onnx_diagnostic-0.6.0.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
99
|
+
onnx_diagnostic-0.6.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|