onnx-diagnostic 0.7.3__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +16 -4
- onnx_diagnostic/export/shape_helper.py +71 -0
- onnx_diagnostic/helpers/cache_helper.py +11 -1
- onnx_diagnostic/reference/ops/op_cast_like.py +12 -8
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +6 -2
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +7 -3
- onnx_diagnostic/tasks/image_text_to_text.py +6 -2
- onnx_diagnostic/tasks/mixture_of_expert.py +1 -1
- onnx_diagnostic/tasks/object_detection.py +7 -3
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +6 -2
- onnx_diagnostic/tasks/text2text_generation.py +6 -2
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +5 -3
- onnx_diagnostic/tasks/text_to_image.py +6 -2
- onnx_diagnostic/tasks/zero_shot_image_classification.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +56 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +96 -48
- onnx_diagnostic/torch_models/hghub/model_inputs.py +1 -1
- onnx_diagnostic/torch_models/validate.py +23 -7
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.4.dist-info}/RECORD +28 -28
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.4.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -349,6 +349,15 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
349
349
|
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
|
|
350
350
|
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
351
351
|
--dtype float16 --device cuda --export modelbuilder
|
|
352
|
+
|
|
353
|
+
position_ids is usually not needed, they can be removed by adding:
|
|
354
|
+
|
|
355
|
+
--drop position_ids
|
|
356
|
+
|
|
357
|
+
The behaviour may be modified compare the original configuration,
|
|
358
|
+
the following argument can be rope_scaling to dynamic:
|
|
359
|
+
|
|
360
|
+
--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
|
|
352
361
|
"""
|
|
353
362
|
),
|
|
354
363
|
formatter_class=RawTextHelpFormatter,
|
|
@@ -403,10 +412,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
403
412
|
)
|
|
404
413
|
parser.add_argument(
|
|
405
414
|
"--inputs2",
|
|
406
|
-
default=
|
|
407
|
-
|
|
415
|
+
default=1,
|
|
416
|
+
type=int,
|
|
408
417
|
help="Validates the model on a second set of inputs\n"
|
|
409
|
-
"to check the exported model supports dynamism."
|
|
418
|
+
"to check the exported model supports dynamism. The values is used "
|
|
419
|
+
"as an increment to the first set of inputs. A high value may trick "
|
|
420
|
+
"a different behavior in the model and missed by the exporter.",
|
|
410
421
|
)
|
|
411
422
|
parser.add_argument(
|
|
412
423
|
"--runtime",
|
|
@@ -422,7 +433,8 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
422
433
|
parser.add_argument(
|
|
423
434
|
"--drop",
|
|
424
435
|
help="Drops the following inputs names, it should be a list\n"
|
|
425
|
-
"with comma separated values
|
|
436
|
+
"with comma separated values, example:\n"
|
|
437
|
+
"--drop position_ids",
|
|
426
438
|
)
|
|
427
439
|
parser.add_argument(
|
|
428
440
|
"--opset",
|
|
@@ -30,6 +30,77 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
30
30
|
)
|
|
31
31
|
ds = all_dynamic_shape_from_inputs(inputs)
|
|
32
32
|
pprint.pprint(ds)
|
|
33
|
+
|
|
34
|
+
For this function to work, patches must be enabled if :epkg:`transformers`
|
|
35
|
+
does not implement the serialization functions.
|
|
36
|
+
|
|
37
|
+
.. runpython::
|
|
38
|
+
:showcode:
|
|
39
|
+
|
|
40
|
+
import pprint
|
|
41
|
+
import torch
|
|
42
|
+
from onnx_diagnostic.helpers.cache_helper import (
|
|
43
|
+
make_dynamic_cache,
|
|
44
|
+
make_encoder_decoder_cache,
|
|
45
|
+
make_mamba_cache,
|
|
46
|
+
make_sliding_window_cache,
|
|
47
|
+
make_static_cache,
|
|
48
|
+
)
|
|
49
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
50
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
51
|
+
|
|
52
|
+
caches = [
|
|
53
|
+
make_dynamic_cache(
|
|
54
|
+
[
|
|
55
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
56
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
57
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
58
|
+
]
|
|
59
|
+
),
|
|
60
|
+
make_encoder_decoder_cache(
|
|
61
|
+
make_dynamic_cache(
|
|
62
|
+
[
|
|
63
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
64
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
65
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
66
|
+
]
|
|
67
|
+
),
|
|
68
|
+
make_dynamic_cache(
|
|
69
|
+
[
|
|
70
|
+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
|
|
71
|
+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
|
|
72
|
+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
|
|
73
|
+
]
|
|
74
|
+
),
|
|
75
|
+
),
|
|
76
|
+
make_sliding_window_cache(
|
|
77
|
+
[
|
|
78
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
79
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
80
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
81
|
+
]
|
|
82
|
+
),
|
|
83
|
+
make_static_cache(
|
|
84
|
+
[
|
|
85
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
86
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
87
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
88
|
+
],
|
|
89
|
+
max_cache_len=15,
|
|
90
|
+
),
|
|
91
|
+
make_mamba_cache(
|
|
92
|
+
[
|
|
93
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
94
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
95
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
96
|
+
]
|
|
97
|
+
),
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
with torch_export_patches(patch_transformers=True):
|
|
101
|
+
for cache in caches:
|
|
102
|
+
print(f"-- {cache.__class__.__name__}")
|
|
103
|
+
pprint.pprint(all_dynamic_shape_from_inputs(cache))
|
|
33
104
|
"""
|
|
34
105
|
if isinstance(dim_prefix, str):
|
|
35
106
|
prefixes: Set[str] = set()
|
|
@@ -39,11 +39,21 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
39
39
|
subtrees.append(value)
|
|
40
40
|
start = end
|
|
41
41
|
if use_dict:
|
|
42
|
-
if spec.type is dict
|
|
42
|
+
if spec.type is dict:
|
|
43
43
|
# This a dictionary.
|
|
44
44
|
return dict(zip(spec.context, subtrees))
|
|
45
45
|
if spec.type is tuple:
|
|
46
46
|
return tuple(subtrees)
|
|
47
|
+
if spec.type is list:
|
|
48
|
+
return list(subtrees)
|
|
49
|
+
if spec.context:
|
|
50
|
+
# This is a custom class with attributes.
|
|
51
|
+
# It is returned as a list.
|
|
52
|
+
return list(subtrees)
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Unable to interpret spec type {spec.type} "
|
|
55
|
+
f"(type is {type(spec.type)}, context is {spec.context})."
|
|
56
|
+
)
|
|
47
57
|
# This is a list.
|
|
48
58
|
return subtrees
|
|
49
59
|
|
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
from onnx.onnx_pb import TensorProto
|
|
2
2
|
from onnx.reference.op_run import OpRun
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from onnx.reference.ops.op_cast import (
|
|
6
|
+
bfloat16,
|
|
7
|
+
cast_to,
|
|
8
|
+
float8e4m3fn,
|
|
9
|
+
float8e4m3fnuz,
|
|
10
|
+
float8e5m2,
|
|
11
|
+
float8e5m2fnuz,
|
|
12
|
+
)
|
|
13
|
+
except ImportError:
|
|
14
|
+
from onnx.reference.ops.op_cast import cast_to
|
|
11
15
|
from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
|
|
12
16
|
|
|
13
17
|
|
|
@@ -33,7 +33,7 @@ def get_inputs(
|
|
|
33
33
|
head_dim: int,
|
|
34
34
|
batch_size: int = 2,
|
|
35
35
|
sequence_length: int = 30,
|
|
36
|
-
add_second_input:
|
|
36
|
+
add_second_input: int = 1,
|
|
37
37
|
**kwargs, # unused
|
|
38
38
|
):
|
|
39
39
|
"""
|
|
@@ -132,6 +132,9 @@ def get_inputs(
|
|
|
132
132
|
)
|
|
133
133
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
134
134
|
if add_second_input:
|
|
135
|
+
assert (
|
|
136
|
+
add_second_input > 0
|
|
137
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
135
138
|
res["inputs2"] = get_inputs(
|
|
136
139
|
model=model,
|
|
137
140
|
config=config,
|
|
@@ -144,7 +147,8 @@ def get_inputs(
|
|
|
144
147
|
decoder_layers=decoder_layers,
|
|
145
148
|
head_dim=head_dim,
|
|
146
149
|
batch_size=batch_size + 1,
|
|
147
|
-
sequence_length=sequence_length +
|
|
150
|
+
sequence_length=sequence_length + add_second_input,
|
|
151
|
+
add_second_input=0,
|
|
148
152
|
**kwargs,
|
|
149
153
|
)["inputs"]
|
|
150
154
|
return res
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -52,12 +52,16 @@ def get_inputs(
|
|
|
52
52
|
)
|
|
53
53
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
54
54
|
if add_second_input:
|
|
55
|
+
assert (
|
|
56
|
+
add_second_input > 0
|
|
57
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
55
58
|
res["inputs2"] = get_inputs(
|
|
56
59
|
model=model,
|
|
57
60
|
config=config,
|
|
58
61
|
batch_size=batch_size + 1,
|
|
59
|
-
sequence_length=sequence_length +
|
|
62
|
+
sequence_length=sequence_length + add_second_input,
|
|
60
63
|
dummy_max_token_id=dummy_max_token_id,
|
|
64
|
+
add_second_input=0,
|
|
61
65
|
**kwargs,
|
|
62
66
|
)["inputs"]
|
|
63
67
|
return res
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -54,12 +54,16 @@ def get_inputs(
|
|
|
54
54
|
)
|
|
55
55
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
56
56
|
if add_second_input:
|
|
57
|
+
assert (
|
|
58
|
+
add_second_input > 0
|
|
59
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
57
60
|
res["inputs2"] = get_inputs(
|
|
58
61
|
model=model,
|
|
59
62
|
config=config,
|
|
60
63
|
batch_size=batch_size + 1,
|
|
61
|
-
sequence_length=sequence_length +
|
|
64
|
+
sequence_length=sequence_length + add_second_input,
|
|
62
65
|
dummy_max_token_id=dummy_max_token_id,
|
|
66
|
+
add_second_input=0,
|
|
63
67
|
**kwargs,
|
|
64
68
|
)["inputs"]
|
|
65
69
|
return res
|
|
@@ -34,7 +34,7 @@ def get_inputs(
|
|
|
34
34
|
input_channels: int,
|
|
35
35
|
batch_size: int = 2,
|
|
36
36
|
dynamic_rope: bool = False,
|
|
37
|
-
add_second_input:
|
|
37
|
+
add_second_input: int = 1,
|
|
38
38
|
**kwargs, # unused
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
@@ -75,14 +75,18 @@ def get_inputs(
|
|
|
75
75
|
shapes["interpolate_pos_encoding"] = None # type: ignore[assignment]
|
|
76
76
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
77
77
|
if add_second_input:
|
|
78
|
+
assert (
|
|
79
|
+
add_second_input > 0
|
|
80
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
78
81
|
res["inputs2"] = get_inputs(
|
|
79
82
|
model=model,
|
|
80
83
|
config=config,
|
|
81
|
-
input_width=input_width +
|
|
82
|
-
input_height=input_height +
|
|
84
|
+
input_width=input_width + add_second_input,
|
|
85
|
+
input_height=input_height + add_second_input,
|
|
83
86
|
input_channels=input_channels,
|
|
84
87
|
batch_size=batch_size + 1,
|
|
85
88
|
dynamic_rope=dynamic_rope,
|
|
89
|
+
add_second_input=0,
|
|
86
90
|
**kwargs,
|
|
87
91
|
)["inputs"]
|
|
88
92
|
return res
|
|
@@ -32,7 +32,7 @@ def get_inputs(
|
|
|
32
32
|
sequence_length2: int = 3,
|
|
33
33
|
n_images: int = 2,
|
|
34
34
|
dynamic_rope: bool = False,
|
|
35
|
-
add_second_input:
|
|
35
|
+
add_second_input: int = 1,
|
|
36
36
|
**kwargs, # unused
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
@@ -105,6 +105,9 @@ def get_inputs(
|
|
|
105
105
|
)
|
|
106
106
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
107
107
|
if add_second_input:
|
|
108
|
+
assert (
|
|
109
|
+
add_second_input > 0
|
|
110
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
108
111
|
res["inputs2"] = get_inputs(
|
|
109
112
|
model=model,
|
|
110
113
|
config=config,
|
|
@@ -116,10 +119,11 @@ def get_inputs(
|
|
|
116
119
|
height=height,
|
|
117
120
|
num_channels=num_channels,
|
|
118
121
|
batch_size=batch_size + 1,
|
|
119
|
-
sequence_length=sequence_length +
|
|
122
|
+
sequence_length=sequence_length + add_second_input,
|
|
120
123
|
sequence_length2=sequence_length2 + 1,
|
|
121
124
|
n_images=n_images + 1,
|
|
122
125
|
dynamic_rope=dynamic_rope,
|
|
126
|
+
add_second_input=0,
|
|
123
127
|
**kwargs,
|
|
124
128
|
)["inputs"]
|
|
125
129
|
return res
|
|
@@ -27,7 +27,7 @@ def get_inputs(
|
|
|
27
27
|
input_channels: int,
|
|
28
28
|
batch_size: int = 2,
|
|
29
29
|
dynamic_rope: bool = False,
|
|
30
|
-
add_second_input:
|
|
30
|
+
add_second_input: int = 1,
|
|
31
31
|
**kwargs, # unused
|
|
32
32
|
):
|
|
33
33
|
"""
|
|
@@ -65,14 +65,18 @@ def get_inputs(
|
|
|
65
65
|
)
|
|
66
66
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
67
67
|
if add_second_input:
|
|
68
|
+
assert (
|
|
69
|
+
add_second_input > 0
|
|
70
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
68
71
|
res["inputs2"] = get_inputs(
|
|
69
72
|
model=model,
|
|
70
73
|
config=config,
|
|
71
|
-
input_width=input_width +
|
|
72
|
-
input_height=input_height +
|
|
74
|
+
input_width=input_width + add_second_input,
|
|
75
|
+
input_height=input_height + add_second_input,
|
|
73
76
|
input_channels=input_channels,
|
|
74
77
|
batch_size=batch_size + 1,
|
|
75
78
|
dynamic_rope=dynamic_rope,
|
|
79
|
+
add_second_input=0,
|
|
76
80
|
**kwargs,
|
|
77
81
|
)["inputs"]
|
|
78
82
|
return res
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -54,12 +54,16 @@ def get_inputs(
|
|
|
54
54
|
)
|
|
55
55
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
56
56
|
if add_second_input:
|
|
57
|
+
assert (
|
|
58
|
+
add_second_input > 0
|
|
59
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
57
60
|
res["inputs2"] = get_inputs(
|
|
58
61
|
model=model,
|
|
59
62
|
config=config,
|
|
60
63
|
batch_size=batch_size + 1,
|
|
61
|
-
sequence_length=sequence_length +
|
|
64
|
+
sequence_length=sequence_length + add_second_input,
|
|
62
65
|
dummy_max_token_id=dummy_max_token_id,
|
|
66
|
+
add_second_input=0,
|
|
63
67
|
**kwargs,
|
|
64
68
|
)["inputs"]
|
|
65
69
|
return res
|
|
@@ -29,7 +29,7 @@ def get_inputs(
|
|
|
29
29
|
batch_size: int = 2,
|
|
30
30
|
sequence_length: int = 30,
|
|
31
31
|
sequence_length2: int = 3,
|
|
32
|
-
add_second_input:
|
|
32
|
+
add_second_input: int = 1,
|
|
33
33
|
**kwargs, # unused
|
|
34
34
|
):
|
|
35
35
|
"""
|
|
@@ -144,6 +144,9 @@ def get_inputs(
|
|
|
144
144
|
)
|
|
145
145
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
146
146
|
if add_second_input:
|
|
147
|
+
assert (
|
|
148
|
+
add_second_input > 0
|
|
149
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
147
150
|
res["inputs2"] = get_inputs(
|
|
148
151
|
model=model,
|
|
149
152
|
config=config,
|
|
@@ -154,8 +157,9 @@ def get_inputs(
|
|
|
154
157
|
head_dim_encoder=head_dim_encoder,
|
|
155
158
|
head_dim_decoder=head_dim_decoder,
|
|
156
159
|
batch_size=batch_size + 1,
|
|
157
|
-
sequence_length=sequence_length +
|
|
160
|
+
sequence_length=sequence_length + add_second_input,
|
|
158
161
|
sequence_length2=sequence_length2 + 1,
|
|
162
|
+
add_second_input=0,
|
|
159
163
|
**kwargs,
|
|
160
164
|
)["inputs"]
|
|
161
165
|
return res
|
|
@@ -30,7 +30,7 @@ def get_inputs(
|
|
|
30
30
|
batch_size: int = 2,
|
|
31
31
|
sequence_length: int = 30,
|
|
32
32
|
sequence_length2: int = 3,
|
|
33
|
-
add_second_input:
|
|
33
|
+
add_second_input: int = 1,
|
|
34
34
|
**kwargs, # unused
|
|
35
35
|
):
|
|
36
36
|
"""
|
|
@@ -149,6 +149,9 @@ def get_inputs(
|
|
|
149
149
|
)
|
|
150
150
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
151
151
|
if add_second_input:
|
|
152
|
+
assert (
|
|
153
|
+
add_second_input > 0
|
|
154
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
152
155
|
res["inputs2"] = get_inputs(
|
|
153
156
|
model=model,
|
|
154
157
|
config=config,
|
|
@@ -160,8 +163,9 @@ def get_inputs(
|
|
|
160
163
|
head_dim_decoder=head_dim_decoder,
|
|
161
164
|
encoder_dim=encoder_dim,
|
|
162
165
|
batch_size=batch_size + 1,
|
|
163
|
-
sequence_length=sequence_length +
|
|
166
|
+
sequence_length=sequence_length + add_second_input,
|
|
164
167
|
sequence_length2=sequence_length2 + 1,
|
|
168
|
+
add_second_input=0,
|
|
165
169
|
**kwargs,
|
|
166
170
|
)["inputs"]
|
|
167
171
|
return res
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -54,12 +54,16 @@ def get_inputs(
|
|
|
54
54
|
)
|
|
55
55
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
56
56
|
if add_second_input:
|
|
57
|
+
assert (
|
|
58
|
+
add_second_input > 0
|
|
59
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
57
60
|
res["inputs2"] = get_inputs(
|
|
58
61
|
model=model,
|
|
59
62
|
config=config,
|
|
60
63
|
batch_size=batch_size + 1,
|
|
61
|
-
sequence_length=sequence_length +
|
|
64
|
+
sequence_length=sequence_length + add_second_input,
|
|
62
65
|
dummy_max_token_id=dummy_max_token_id,
|
|
66
|
+
add_second_input=0,
|
|
63
67
|
**kwargs,
|
|
64
68
|
)["inputs"]
|
|
65
69
|
return res
|
|
@@ -72,7 +72,7 @@ def get_inputs(
|
|
|
72
72
|
num_key_value_heads: Optional[int] = None,
|
|
73
73
|
head_dim: Optional[int] = None,
|
|
74
74
|
cls_cache: Optional[Union[type, str]] = None,
|
|
75
|
-
add_second_input:
|
|
75
|
+
add_second_input: int = 1,
|
|
76
76
|
**kwargs, # unused
|
|
77
77
|
):
|
|
78
78
|
"""
|
|
@@ -260,13 +260,15 @@ def get_inputs(
|
|
|
260
260
|
config=config,
|
|
261
261
|
dummy_max_token_id=dummy_max_token_id,
|
|
262
262
|
num_hidden_layers=num_hidden_layers,
|
|
263
|
-
batch_size=batch_size + 1,
|
|
263
|
+
batch_size=(batch_size + 1) if add_second_input > 0 else 1,
|
|
264
264
|
sequence_length=sequence_length + 1,
|
|
265
|
-
sequence_length2=sequence_length2
|
|
265
|
+
sequence_length2=sequence_length2
|
|
266
|
+
+ (add_second_input if add_second_input > 0 else -add_second_input),
|
|
266
267
|
dynamic_rope=dynamic_rope,
|
|
267
268
|
num_key_value_heads=num_key_value_heads,
|
|
268
269
|
head_dim=head_dim,
|
|
269
270
|
cls_cache=cls_cache,
|
|
271
|
+
add_second_input=0,
|
|
270
272
|
**kwargs,
|
|
271
273
|
)["inputs"]
|
|
272
274
|
return res
|
|
@@ -25,7 +25,7 @@ def get_inputs(
|
|
|
25
25
|
in_channels: int,
|
|
26
26
|
sample_size: int,
|
|
27
27
|
cross_attention_dim: int,
|
|
28
|
-
add_second_input:
|
|
28
|
+
add_second_input: int = 1,
|
|
29
29
|
**kwargs, # unused
|
|
30
30
|
):
|
|
31
31
|
"""
|
|
@@ -58,15 +58,19 @@ def get_inputs(
|
|
|
58
58
|
)
|
|
59
59
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
60
60
|
if add_second_input:
|
|
61
|
+
assert (
|
|
62
|
+
add_second_input > 0
|
|
63
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
61
64
|
res["inputs2"] = get_inputs(
|
|
62
65
|
model=model,
|
|
63
66
|
config=config,
|
|
64
67
|
batch_size=batch_size + 1,
|
|
65
68
|
sequence_length=sequence_length,
|
|
66
|
-
cache_length=cache_length +
|
|
69
|
+
cache_length=cache_length + add_second_input,
|
|
67
70
|
in_channels=in_channels,
|
|
68
71
|
sample_size=sample_size,
|
|
69
72
|
cross_attention_dim=cross_attention_dim,
|
|
73
|
+
add_second_input=0,
|
|
70
74
|
**kwargs,
|
|
71
75
|
)["inputs"]
|
|
72
76
|
return res
|
|
@@ -34,7 +34,7 @@ def get_inputs(
|
|
|
34
34
|
input_height: int = 224,
|
|
35
35
|
input_channels: int = 3,
|
|
36
36
|
batch_size_image=3,
|
|
37
|
-
add_second_input:
|
|
37
|
+
add_second_input: int = 1,
|
|
38
38
|
**kwargs, # unused
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
@@ -87,16 +87,20 @@ def get_inputs(
|
|
|
87
87
|
)
|
|
88
88
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
89
89
|
if add_second_input:
|
|
90
|
+
assert (
|
|
91
|
+
add_second_input > 0
|
|
92
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
90
93
|
res["inputs2"] = get_inputs(
|
|
91
94
|
model=model,
|
|
92
95
|
config=config,
|
|
93
96
|
dummy_max_token_id=dummy_max_token_id,
|
|
94
97
|
batch_size=batch_size + 1,
|
|
95
|
-
sequence_length=sequence_length +
|
|
98
|
+
sequence_length=sequence_length + add_second_input,
|
|
96
99
|
input_width=input_width,
|
|
97
100
|
input_height=input_height,
|
|
98
101
|
input_channels=input_channels,
|
|
99
102
|
batch_size_image=batch_size_image + 1,
|
|
103
|
+
add_second_input=0,
|
|
100
104
|
**kwargs,
|
|
101
105
|
)["inputs"]
|
|
102
106
|
return res
|
|
@@ -420,7 +420,11 @@ def torch_export_patches(
|
|
|
420
420
|
patch_transformers_list, verbose=verbose
|
|
421
421
|
)
|
|
422
422
|
|
|
423
|
-
if
|
|
423
|
+
if (
|
|
424
|
+
masking_utils
|
|
425
|
+
and patch_transformers_list.patch_masking_utils
|
|
426
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
427
|
+
):
|
|
424
428
|
if verbose:
|
|
425
429
|
print(
|
|
426
430
|
"[torch_export_patches] patches "
|
|
@@ -429,6 +433,27 @@ def torch_export_patches(
|
|
|
429
433
|
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
430
434
|
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
431
435
|
|
|
436
|
+
if (
|
|
437
|
+
masking_utils
|
|
438
|
+
and patch_transformers_list.patch_masking_utils
|
|
439
|
+
and hasattr(masking_utils, "eager_mask")
|
|
440
|
+
):
|
|
441
|
+
if verbose:
|
|
442
|
+
print(
|
|
443
|
+
"[torch_export_patches] patches "
|
|
444
|
+
"transformers.masking_utils.eager_mask"
|
|
445
|
+
)
|
|
446
|
+
f_transformers_eager_mask = masking_utils.eager_mask
|
|
447
|
+
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
|
|
448
|
+
if (
|
|
449
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
450
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
451
|
+
== f_transformers_eager_mask
|
|
452
|
+
):
|
|
453
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
454
|
+
patch_transformers_list.patched_eager_mask
|
|
455
|
+
)
|
|
456
|
+
|
|
432
457
|
if custom_patches:
|
|
433
458
|
if verbose:
|
|
434
459
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -511,7 +536,7 @@ def torch_export_patches(
|
|
|
511
536
|
|
|
512
537
|
if custom_patches:
|
|
513
538
|
if verbose:
|
|
514
|
-
print("[torch_export_patches]
|
|
539
|
+
print("[torch_export_patches] unpatches custom patches")
|
|
515
540
|
unpatch_module_or_classes(
|
|
516
541
|
custom_patches, revert_custom_patches_info, verbose=verbose
|
|
517
542
|
)
|
|
@@ -526,18 +551,43 @@ def torch_export_patches(
|
|
|
526
551
|
except ImportError:
|
|
527
552
|
masking_utils = None
|
|
528
553
|
if verbose:
|
|
529
|
-
print("[torch_export_patches]
|
|
554
|
+
print("[torch_export_patches] unpatches transformers")
|
|
530
555
|
unpatch_module_or_classes(
|
|
531
556
|
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
532
557
|
)
|
|
533
558
|
|
|
534
|
-
if
|
|
559
|
+
if (
|
|
560
|
+
masking_utils
|
|
561
|
+
and patch_transformers_list.patch_masking_utils
|
|
562
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
563
|
+
):
|
|
564
|
+
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
535
565
|
if verbose:
|
|
536
566
|
print(
|
|
537
|
-
"[torch_export_patches]
|
|
567
|
+
"[torch_export_patches] restored "
|
|
538
568
|
"transformers.masking_utils._vmap_for_bhqkv"
|
|
539
569
|
)
|
|
540
|
-
|
|
570
|
+
|
|
571
|
+
if (
|
|
572
|
+
masking_utils
|
|
573
|
+
and patch_transformers_list.patch_masking_utils
|
|
574
|
+
and hasattr(masking_utils, "eager_mask")
|
|
575
|
+
):
|
|
576
|
+
f_transformers_eager_mask = masking_utils.eager_mask
|
|
577
|
+
masking_utils.eager_mask = f_transformers_eager_mask
|
|
578
|
+
if (
|
|
579
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
580
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
581
|
+
== patch_transformers_list.patched_eager_mask
|
|
582
|
+
):
|
|
583
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
584
|
+
f_transformers_eager_mask
|
|
585
|
+
)
|
|
586
|
+
if verbose:
|
|
587
|
+
print(
|
|
588
|
+
"[torch_export_patches] restored "
|
|
589
|
+
"transformers.masking_utils.eager_mask"
|
|
590
|
+
)
|
|
541
591
|
|
|
542
592
|
########
|
|
543
593
|
# caches
|
|
@@ -7,59 +7,107 @@ import torch
|
|
|
7
7
|
import transformers
|
|
8
8
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
9
9
|
from transformers.cache_utils import StaticCache, Cache, DynamicCache
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import transformers.masking_utils
|
|
13
|
+
|
|
14
|
+
patch_masking_utils = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_masking_utils = False
|
|
17
|
+
|
|
10
18
|
from ...ext_test_case import has_transformers
|
|
11
19
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
12
20
|
|
|
13
21
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
22
|
+
if patch_masking_utils:
|
|
23
|
+
# Introduced in 4.52
|
|
24
|
+
from transformers.masking_utils import causal_mask_function, sdpa_mask
|
|
25
|
+
|
|
26
|
+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
27
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
28
|
+
from ...helpers import string_type
|
|
29
|
+
|
|
30
|
+
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
31
|
+
(None, None, None, 0),
|
|
32
|
+
(None, None, 0, None),
|
|
33
|
+
]
|
|
34
|
+
if bh_indices:
|
|
35
|
+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
36
|
+
# reshape
|
|
37
|
+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
38
|
+
dimensions = tuple(reversed(dimensions))
|
|
39
|
+
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
40
|
+
|
|
41
|
+
# unsqueeze
|
|
42
|
+
udimensions = [
|
|
43
|
+
tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def vector_mask_function(
|
|
47
|
+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
48
|
+
):
|
|
49
|
+
assert len(args) == len(dimensions) == len(udimensions), (
|
|
50
|
+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
|
|
51
|
+
f"and udimensions={udimensions}."
|
|
52
|
+
)
|
|
53
|
+
assert len(indices) == len(args), (
|
|
54
|
+
f"Mismatch between args={string_type(args)} and indices={indices}, "
|
|
55
|
+
f"they should have the same length."
|
|
56
|
+
)
|
|
57
|
+
for a in args:
|
|
58
|
+
assert (
|
|
59
|
+
a.ndim == 1
|
|
60
|
+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
|
|
61
|
+
torch._check(a.shape[0] > 0)
|
|
62
|
+
|
|
63
|
+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
64
|
+
# new_args = [
|
|
65
|
+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
66
|
+
# for a, dims in zip(args, udimensions)
|
|
67
|
+
# ]
|
|
68
|
+
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
69
|
+
# if is_torchdynamo_exporting():
|
|
70
|
+
# for a in args:
|
|
71
|
+
# # The exporter should export with a dimension > 1
|
|
72
|
+
# # to make sure it is dynamic.
|
|
73
|
+
# torch._check(a.shape[0] > 1)
|
|
74
|
+
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
75
|
+
return mask_function(*expanded_args)
|
|
76
|
+
|
|
77
|
+
return vector_mask_function
|
|
78
|
+
|
|
79
|
+
def patched_eager_mask(
|
|
80
|
+
batch_size: int,
|
|
81
|
+
cache_position: torch.Tensor,
|
|
82
|
+
kv_length: int,
|
|
83
|
+
kv_offset: int = 0,
|
|
84
|
+
mask_function: Callable = causal_mask_function,
|
|
85
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
86
|
+
dtype: torch.dtype = torch.float32,
|
|
87
|
+
**kwargs,
|
|
88
|
+
) -> torch.Tensor:
|
|
89
|
+
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
90
|
+
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
91
|
+
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
92
|
+
mask = sdpa_mask(
|
|
93
|
+
batch_size=batch_size,
|
|
94
|
+
cache_position=cache_position,
|
|
95
|
+
kv_length=kv_length,
|
|
96
|
+
kv_offset=kv_offset,
|
|
97
|
+
mask_function=mask_function,
|
|
98
|
+
attention_mask=attention_mask,
|
|
99
|
+
allow_is_causal_skip=False,
|
|
100
|
+
allow_torch_fix=False,
|
|
101
|
+
**kwargs,
|
|
42
102
|
)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
52
|
-
# for a, dims in zip(args, udimensions)
|
|
53
|
-
# ]
|
|
54
|
-
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
55
|
-
# if is_torchdynamo_exporting():
|
|
56
|
-
# for a in args:
|
|
57
|
-
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
|
|
58
|
-
# torch._check(a.shape[0] > 1)
|
|
59
|
-
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
60
|
-
return mask_function(*expanded_args)
|
|
61
|
-
|
|
62
|
-
return vector_mask_function
|
|
103
|
+
min_dtype = torch.finfo(dtype).min
|
|
104
|
+
# The patched line.
|
|
105
|
+
# we need 0s where the tokens should be taken into account,
|
|
106
|
+
# and -inf otherwise (mask is already of boolean type)
|
|
107
|
+
# mask =
|
|
108
|
+
# torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
|
|
109
|
+
mask = (~mask).to(dtype) * min_dtype
|
|
110
|
+
return mask
|
|
63
111
|
|
|
64
112
|
|
|
65
113
|
def _patch_make_causal_mask(
|
|
@@ -26,7 +26,7 @@ def get_untrained_model_with_inputs(
|
|
|
26
26
|
use_pretrained: bool = False,
|
|
27
27
|
same_as_pretrained: bool = False,
|
|
28
28
|
use_preinstalled: bool = True,
|
|
29
|
-
add_second_input:
|
|
29
|
+
add_second_input: int = 1,
|
|
30
30
|
subfolder: Optional[str] = None,
|
|
31
31
|
use_only_preinstalled: bool = False,
|
|
32
32
|
) -> Dict[str, Any]:
|
|
@@ -18,7 +18,6 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
|
18
18
|
from ..tasks import random_input_kwargs
|
|
19
19
|
from ..torch_export_patches import torch_export_patches
|
|
20
20
|
from ..torch_export_patches.patch_inputs import use_dyn_not_str
|
|
21
|
-
from ..reference import TorchOnnxEvaluator
|
|
22
21
|
from .hghub import get_untrained_model_with_inputs
|
|
23
22
|
|
|
24
23
|
|
|
@@ -157,6 +156,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
|
|
|
157
156
|
"version_torch": torch.__version__,
|
|
158
157
|
"version_numpy": numpy.__version__,
|
|
159
158
|
}
|
|
159
|
+
try:
|
|
160
|
+
import scipy
|
|
161
|
+
|
|
162
|
+
summary["version_scipy"] = getattr(scipy, "__version__", "?")
|
|
163
|
+
except ImportError:
|
|
164
|
+
pass
|
|
160
165
|
try:
|
|
161
166
|
import transformers
|
|
162
167
|
|
|
@@ -181,6 +186,12 @@ def version_summary() -> Dict[str, Union[int, float, str]]:
|
|
|
181
186
|
summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?")
|
|
182
187
|
except ImportError:
|
|
183
188
|
pass
|
|
189
|
+
try:
|
|
190
|
+
import onnx_ir
|
|
191
|
+
|
|
192
|
+
summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?")
|
|
193
|
+
except ImportError:
|
|
194
|
+
pass
|
|
184
195
|
import onnx_diagnostic
|
|
185
196
|
|
|
186
197
|
summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
|
|
@@ -276,7 +287,7 @@ def validate_model(
|
|
|
276
287
|
runtime: str = "onnxruntime",
|
|
277
288
|
repeat: int = 1,
|
|
278
289
|
warmup: int = 0,
|
|
279
|
-
inputs2:
|
|
290
|
+
inputs2: int = 1,
|
|
280
291
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
281
292
|
"""
|
|
282
293
|
Validates a model.
|
|
@@ -325,7 +336,8 @@ def validate_model(
|
|
|
325
336
|
:param repeat: number of time to measure the model
|
|
326
337
|
:param warmup: warmup the model first
|
|
327
338
|
:param inputs2: checks that the second set of inputs is reunning as well,
|
|
328
|
-
this ensures that the model does support dynamism
|
|
339
|
+
this ensures that the model does support dynamism, the value is used
|
|
340
|
+
as an increment to the first set of values (added to dimensions)
|
|
329
341
|
:return: two dictionaries, one with some metrics,
|
|
330
342
|
another one with whatever the function produces
|
|
331
343
|
|
|
@@ -1054,7 +1066,7 @@ def validate_onnx_model(
|
|
|
1054
1066
|
runtime: str = "onnxruntime",
|
|
1055
1067
|
repeat: int = 1,
|
|
1056
1068
|
warmup: int = 0,
|
|
1057
|
-
inputs2:
|
|
1069
|
+
inputs2: int = 1,
|
|
1058
1070
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1059
1071
|
"""
|
|
1060
1072
|
Verifies that an onnx model produces the same
|
|
@@ -1070,8 +1082,9 @@ def validate_onnx_model(
|
|
|
1070
1082
|
:param runtime: onnx runtime to use, onnxruntime or torch
|
|
1071
1083
|
:param repeat: run that number of times the model
|
|
1072
1084
|
:param warmup: warmup the model
|
|
1073
|
-
:param
|
|
1074
|
-
to make sure the exported model supports dynamism
|
|
1085
|
+
:param inputs2: to validate the model on the second input set
|
|
1086
|
+
to make sure the exported model supports dynamism, the value is
|
|
1087
|
+
used as an increment added to the first set of inputs (added to dimensions)
|
|
1075
1088
|
:return: two dictionaries, one with some metrics,
|
|
1076
1089
|
another one with whatever the function produces
|
|
1077
1090
|
"""
|
|
@@ -1113,6 +1126,9 @@ def validate_onnx_model(
|
|
|
1113
1126
|
f"{providers}..., flavour={flavour!r}"
|
|
1114
1127
|
)
|
|
1115
1128
|
|
|
1129
|
+
if runtime != "onnxruntime":
|
|
1130
|
+
from ..reference import TorchOnnxEvaluator
|
|
1131
|
+
|
|
1116
1132
|
cls_runtime = (
|
|
1117
1133
|
(
|
|
1118
1134
|
lambda model, providers: onnxruntime.InferenceSession(
|
|
@@ -1122,7 +1138,7 @@ def validate_onnx_model(
|
|
|
1122
1138
|
)
|
|
1123
1139
|
if runtime == "onnxruntime"
|
|
1124
1140
|
else (
|
|
1125
|
-
lambda model, providers:
|
|
1141
|
+
lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_( # type: ignore[misc]
|
|
1126
1142
|
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1127
1143
|
)
|
|
1128
1144
|
)
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=dmZNMpFkDRd7ZCC6bC2fEFqvAhHhsqqua8ZE5LbOC9s,173
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
|
-
onnx_diagnostic/_command_lines_parser.py,sha256=
|
|
3
|
+
onnx_diagnostic/_command_lines_parser.py,sha256=65oUjJ2tgPxQgIKgOAI04jhOFRnGUSNivUNDVMZ-urU,28597
|
|
4
4
|
onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
|
|
5
5
|
onnx_diagnostic/doc.py,sha256=t3RELgfooYnVMAi0JSpggWkQEgUsREz8NmRvn0TnLI8,2829
|
|
6
6
|
onnx_diagnostic/ext_test_case.py,sha256=Bq4vdlM0P72H1orlKJTeOBqm1YGHTK-ylAlNsBe4LeA,43438
|
|
7
7
|
onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
|
|
8
8
|
onnx_diagnostic/export/dynamic_shapes.py,sha256=HYf2OEi7PmRSj8uxMD-wbdVxxejkWXTPBAkxoFeM27A,40811
|
|
9
|
-
onnx_diagnostic/export/shape_helper.py,sha256=
|
|
9
|
+
onnx_diagnostic/export/shape_helper.py,sha256=EQXHRVxwGpHRYhx8Y44Crqs640pmaIuKSwW1KJOW0IU,7501
|
|
10
10
|
onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
|
|
11
11
|
onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
|
|
12
12
|
onnx_diagnostic/helpers/args_helper.py,sha256=SRWnqC7EENg09RZlA50B_PcdiIhdbgA4C3ACfzl5nMs,4419
|
|
13
13
|
onnx_diagnostic/helpers/bench_run.py,sha256=CGA6VMJZMH2gDhVueT9ypNm4PMcjGrrGFYp08nhWj9k,16539
|
|
14
|
-
onnx_diagnostic/helpers/cache_helper.py,sha256=
|
|
14
|
+
onnx_diagnostic/helpers/cache_helper.py,sha256=TeBUuGvqIMO-dsLDy7keaVt3ImZeifldwTgx6TEjBo8,11595
|
|
15
15
|
onnx_diagnostic/helpers/config_helper.py,sha256=9h1NWC9RLmu43Yf5Cz9usjMdLiyLWXMhwgE4Lg-eOU8,3889
|
|
16
16
|
onnx_diagnostic/helpers/doc_helper.py,sha256=pl5MZd3_FaE8BqQnqoBuSBxoNCFcd2OJd3eITUSku5c,5897
|
|
17
17
|
onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
|
|
@@ -35,7 +35,7 @@ onnx_diagnostic/reference/ops/op_add_add_mul_mul.py,sha256=CXQVtgVrT066gDJFwxL4n
|
|
|
35
35
|
onnx_diagnostic/reference/ops/op_attention.py,sha256=ThALMDF53v3QeG1bohi0bvX2o90HZhGJbbAFOtwEHPE,2027
|
|
36
36
|
onnx_diagnostic/reference/ops/op_average_pool_grad.py,sha256=zMcOtjB7hWySfIIFXogcmj3xxCWwxEX_g2VKg-SOAEs,2360
|
|
37
37
|
onnx_diagnostic/reference/ops/op_bias_softmax.py,sha256=dcXsw2chxc8-puIkI0LFsBxKOJaCSovDcF1HkgboQp0,524
|
|
38
|
-
onnx_diagnostic/reference/ops/op_cast_like.py,sha256=
|
|
38
|
+
onnx_diagnostic/reference/ops/op_cast_like.py,sha256=tH2pYFPluHK9r_cdkqnk9k2q3dPzaeKgQZt_Pjw9N4o,1383
|
|
39
39
|
onnx_diagnostic/reference/ops/op_complex.py,sha256=OobDrRNYcktdCdTJzOQBesrKC8vsKxuHIi7Yev1DJrs,651
|
|
40
40
|
onnx_diagnostic/reference/ops/op_concat.py,sha256=seW71-QDKzv9QQhjhjThKip0Y3d9nkVd7Hs1A2nNQjk,519
|
|
41
41
|
onnx_diagnostic/reference/ops/op_constant_of_shape.py,sha256=3G9TRxaUoqYudnKrVHBEblo_16qMl0c9wNZqnQeNJJ4,2009
|
|
@@ -72,22 +72,22 @@ onnx_diagnostic/reference/torch_ops/sequence_ops.py,sha256=3EiVKpGfN4d1Iry4hgnr3
|
|
|
72
72
|
onnx_diagnostic/reference/torch_ops/shape_ops.py,sha256=pJrNR2UB4PlWl6cv4EDl1uGl8YTBUUMQkhJcsh5K4sA,4291
|
|
73
73
|
onnx_diagnostic/reference/torch_ops/unary_ops.py,sha256=E8Ys1eZsOTsucBKoXb1_Kl5LbBDygniDvW2BvN4IPMo,1708
|
|
74
74
|
onnx_diagnostic/tasks/__init__.py,sha256=0BYtrAnr0zKN3om71oi-OVz5wFYDp9WWIk51qWjjyCw,2450
|
|
75
|
-
onnx_diagnostic/tasks/automatic_speech_recognition.py,sha256=
|
|
76
|
-
onnx_diagnostic/tasks/feature_extraction.py,sha256=
|
|
77
|
-
onnx_diagnostic/tasks/fill_mask.py,sha256=
|
|
78
|
-
onnx_diagnostic/tasks/image_classification.py,sha256=
|
|
79
|
-
onnx_diagnostic/tasks/image_text_to_text.py,sha256
|
|
80
|
-
onnx_diagnostic/tasks/mixture_of_expert.py,sha256=
|
|
81
|
-
onnx_diagnostic/tasks/object_detection.py,sha256=
|
|
82
|
-
onnx_diagnostic/tasks/sentence_similarity.py,sha256=
|
|
83
|
-
onnx_diagnostic/tasks/summarization.py,sha256=
|
|
84
|
-
onnx_diagnostic/tasks/text2text_generation.py,sha256=
|
|
85
|
-
onnx_diagnostic/tasks/text_classification.py,sha256=
|
|
86
|
-
onnx_diagnostic/tasks/text_generation.py,sha256=
|
|
87
|
-
onnx_diagnostic/tasks/text_to_image.py,sha256=
|
|
88
|
-
onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=
|
|
75
|
+
onnx_diagnostic/tasks/automatic_speech_recognition.py,sha256=H94rxeiAjcJdECA1g95G_U9fZfpXk6dfjNKKYuvc4Qc,7130
|
|
76
|
+
onnx_diagnostic/tasks/feature_extraction.py,sha256=MptOP-1ZSIMTIJ0woSqKLR7TlB9m7kww9V8cfEgZJTY,2502
|
|
77
|
+
onnx_diagnostic/tasks/fill_mask.py,sha256=Rvrz0j_DQu-vf4CSSAZMBMXb2EuHvOCzRZwj8Cy8yfg,2620
|
|
78
|
+
onnx_diagnostic/tasks/image_classification.py,sha256=x1XfeWAOe0r_s9kU6WENoYxjfoRTp1pkwKgIveoLbUw,4627
|
|
79
|
+
onnx_diagnostic/tasks/image_text_to_text.py,sha256=-vbZMA_ruo0WR_96YMYRvoNfq1plpElBJWXC2klAf7Q,7802
|
|
80
|
+
onnx_diagnostic/tasks/mixture_of_expert.py,sha256=DgIsbwzV4smysOK83wny91k3ix1Qt2tSFXLGLoz4WOo,2796
|
|
81
|
+
onnx_diagnostic/tasks/object_detection.py,sha256=xRBH9JZxBQf0SVSTJP6d-VVCKqrw7JAeif1joHfiYng,4224
|
|
82
|
+
onnx_diagnostic/tasks/sentence_similarity.py,sha256=soL6QxLvyjtQ-3tQ3nCFxrcrk_4a8tuAjil8zYQ_pXk,2635
|
|
83
|
+
onnx_diagnostic/tasks/summarization.py,sha256=LZ8A8wl6cd8kWSc6k5vLHa_XZkm35rYkTRv8iUYtr6I,8268
|
|
84
|
+
onnx_diagnostic/tasks/text2text_generation.py,sha256=Pk-H_qX5Y-2dzk45N9jbQ73S3O_d5-D11MyUhUfUwuM,8685
|
|
85
|
+
onnx_diagnostic/tasks/text_classification.py,sha256=dO_LLbwwv0OJfIa9DqxQqAGUDuz3iIF1XkafzaYJdJw,2691
|
|
86
|
+
onnx_diagnostic/tasks/text_generation.py,sha256=tW9Gnum_eck3czNyctuUISA-Ek7pO37v5-11GC8QBW8,13124
|
|
87
|
+
onnx_diagnostic/tasks/text_to_image.py,sha256=mOS3Ruosi3hzRMxXLDN7ZkAbi7NnQb7MWwQP_okGVHs,2962
|
|
88
|
+
onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=1iqYamkq5kZNXEXsySw748ernc0O94GkwpYAIEl6Kj4,4659
|
|
89
89
|
onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
|
|
90
|
-
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=
|
|
90
|
+
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=ZsUSOnKxeq4-dP86c5dTIbHMJFy_y690vvU4yfo6tNs,23438
|
|
91
91
|
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=wFE2fNxihAA3iua79AEB97_RBVv4wvGxwS9g4RJaSIc,10715
|
|
92
92
|
onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
|
|
93
93
|
onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=9b4pmyT00BwLqi7WG-gliep1RUy3gXEgW6BDnlSSA-M,7689
|
|
@@ -97,26 +97,26 @@ onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=57x62uZNA80XiWgkG8F
|
|
|
97
97
|
onnx_diagnostic/torch_export_patches/eval/model_cases.py,sha256=DTvdHPtNQh25Akv5o3D4Jxf1L1-SJ7w14tgvj8AAns8,26577
|
|
98
98
|
onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
99
99
|
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=KaZ8TjDa9ATgT4HllYzzoNf_51q_yOj_GuF5NYjPCrU,18913
|
|
100
|
-
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=
|
|
100
|
+
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=nVOg69e_cXvpcP5WIW9pIHCgnF-P_Ne87mRC6ep0g-I,45847
|
|
101
101
|
onnx_diagnostic/torch_export_patches/serialization/__init__.py,sha256=BHLdRPtNAtNPAS-bPKEj3-foGSPvwAbZXrHzGGPDLEw,1876
|
|
102
102
|
onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py,sha256=drq3EH_yjcSuIWYsVeUWm8Cx6YCZFU6bP_1PLtPfY5I,945
|
|
103
103
|
onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py,sha256=9u2jkqnuyBkIF3R2sDEO0Jlkedl-cQhBNXxXXDLSEwE,8885
|
|
104
104
|
onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
105
105
|
onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
|
|
106
|
-
onnx_diagnostic/torch_models/validate.py,sha256=
|
|
106
|
+
onnx_diagnostic/torch_models/validate.py,sha256=dlWeRNLcQ2h3fxx07MA0NHoOPqT3-Afejgrj2Rozvck,63796
|
|
107
107
|
onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
|
|
108
108
|
onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=Bvr-sTAhS6s6UCkt-KsY_7Mdai08-AQzvHrzbYCSuvk,13186
|
|
109
109
|
onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=NTTDsCtIVvYnr5J3rlcq0GSGDOzOPzq9Tsnb3oVf4q8,8309
|
|
110
110
|
onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=zZvIxTbmL55x44kCj3-T5Kg3Qzm9KB_Xj-MCcU9-LuQ,268245
|
|
111
|
-
onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=
|
|
111
|
+
onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=qDw03KsLd_ZAMHBso--rUriCAZIewKFZG9n4-1zvGo8,10825
|
|
112
112
|
onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
113
113
|
onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiFDBdPP0rFGPteiONDxvztw,3708
|
|
114
114
|
onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tmtVl83SmVOL4-Um7Qy-f0E48QI,2507
|
|
115
115
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
116
116
|
onnx_diagnostic/torch_onnx/runtime_info.py,sha256=1g9F_Jf9AAgYQU4stbsrFXwQl-30mWlQrFbQ7val8Ps,9268
|
|
117
117
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=1EL25DeYFzlBSiFG_XjePBLvsiItRXbdDrr5-QZW2mA,16878
|
|
118
|
-
onnx_diagnostic-0.7.
|
|
119
|
-
onnx_diagnostic-0.7.
|
|
120
|
-
onnx_diagnostic-0.7.
|
|
121
|
-
onnx_diagnostic-0.7.
|
|
122
|
-
onnx_diagnostic-0.7.
|
|
118
|
+
onnx_diagnostic-0.7.4.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
119
|
+
onnx_diagnostic-0.7.4.dist-info/METADATA,sha256=6l0XAH0UYEmqkPDswFSZFGrU5hYVEDB8YByQCiBkwlI,7431
|
|
120
|
+
onnx_diagnostic-0.7.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
121
|
+
onnx_diagnostic-0.7.4.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
122
|
+
onnx_diagnostic-0.7.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|