onnx-diagnostic 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +82 -12
- onnx_diagnostic/export/shape_helper.py +71 -0
- onnx_diagnostic/helpers/_log_helper.py +461 -0
- onnx_diagnostic/helpers/cache_helper.py +11 -1
- onnx_diagnostic/helpers/log_helper.py +404 -315
- onnx_diagnostic/reference/ops/op_cast_like.py +12 -8
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +92 -7
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +7 -3
- onnx_diagnostic/tasks/image_text_to_text.py +6 -2
- onnx_diagnostic/tasks/mixture_of_expert.py +1 -1
- onnx_diagnostic/tasks/object_detection.py +7 -3
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +6 -2
- onnx_diagnostic/tasks/text2text_generation.py +8 -4
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +5 -3
- onnx_diagnostic/tasks/text_to_image.py +6 -2
- onnx_diagnostic/tasks/zero_shot_image_classification.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +63 -7
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +188 -51
- onnx_diagnostic/torch_models/hghub/model_inputs.py +6 -1
- onnx_diagnostic/torch_models/validate.py +49 -10
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/RECORD +30 -29
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
from onnx.onnx_pb import TensorProto
|
|
2
2
|
from onnx.reference.op_run import OpRun
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from onnx.reference.ops.op_cast import (
|
|
6
|
+
bfloat16,
|
|
7
|
+
cast_to,
|
|
8
|
+
float8e4m3fn,
|
|
9
|
+
float8e4m3fnuz,
|
|
10
|
+
float8e5m2,
|
|
11
|
+
float8e5m2fnuz,
|
|
12
|
+
)
|
|
13
|
+
except ImportError:
|
|
14
|
+
from onnx.reference.ops.op_cast import cast_to
|
|
11
15
|
from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
|
|
12
16
|
|
|
13
17
|
|
|
@@ -33,7 +33,7 @@ def get_inputs(
|
|
|
33
33
|
head_dim: int,
|
|
34
34
|
batch_size: int = 2,
|
|
35
35
|
sequence_length: int = 30,
|
|
36
|
-
add_second_input:
|
|
36
|
+
add_second_input: int = 1,
|
|
37
37
|
**kwargs, # unused
|
|
38
38
|
):
|
|
39
39
|
"""
|
|
@@ -132,6 +132,9 @@ def get_inputs(
|
|
|
132
132
|
)
|
|
133
133
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
134
134
|
if add_second_input:
|
|
135
|
+
assert (
|
|
136
|
+
add_second_input > 0
|
|
137
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
135
138
|
res["inputs2"] = get_inputs(
|
|
136
139
|
model=model,
|
|
137
140
|
config=config,
|
|
@@ -144,7 +147,8 @@ def get_inputs(
|
|
|
144
147
|
decoder_layers=decoder_layers,
|
|
145
148
|
head_dim=head_dim,
|
|
146
149
|
batch_size=batch_size + 1,
|
|
147
|
-
sequence_length=sequence_length +
|
|
150
|
+
sequence_length=sequence_length + add_second_input,
|
|
151
|
+
add_second_input=0,
|
|
148
152
|
**kwargs,
|
|
149
153
|
)["inputs"]
|
|
150
154
|
return res
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
3
|
from ..helpers.config_helper import update_config, check_hasattr
|
|
4
|
+
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
4
5
|
|
|
5
6
|
__TASK__ = "feature-extraction"
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
9
10
|
"""Reduces a model size."""
|
|
10
|
-
check_hasattr(config, "
|
|
11
|
-
kwargs = dict(
|
|
12
|
-
num_hidden_layers=min(config.num_hidden_layers, 2),
|
|
13
|
-
num_attention_heads=min(config.num_attention_heads, 4),
|
|
14
|
-
)
|
|
11
|
+
check_hasattr(config, "num_hidden_layers")
|
|
12
|
+
kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, 2))
|
|
15
13
|
update_config(config, kwargs)
|
|
16
14
|
return kwargs
|
|
17
15
|
|
|
@@ -22,7 +20,13 @@ def get_inputs(
|
|
|
22
20
|
batch_size: int,
|
|
23
21
|
sequence_length: int,
|
|
24
22
|
dummy_max_token_id: int,
|
|
25
|
-
|
|
23
|
+
sequence_length2: int = 3,
|
|
24
|
+
decoder_attention_heads: Optional[int] = None,
|
|
25
|
+
encoder_attention_heads: Optional[int] = None,
|
|
26
|
+
encoder_ffn_dim: Optional[int] = None,
|
|
27
|
+
decoder_ffn_dim: Optional[int] = None,
|
|
28
|
+
num_hidden_layers: Optional[int] = None,
|
|
29
|
+
add_second_input: int = 1,
|
|
26
30
|
**kwargs, # unused
|
|
27
31
|
):
|
|
28
32
|
"""
|
|
@@ -50,14 +54,84 @@ def get_inputs(
|
|
|
50
54
|
),
|
|
51
55
|
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
|
|
52
56
|
)
|
|
57
|
+
if (
|
|
58
|
+
encoder_attention_heads
|
|
59
|
+
and decoder_attention_heads
|
|
60
|
+
and encoder_ffn_dim
|
|
61
|
+
and decoder_ffn_dim
|
|
62
|
+
and num_hidden_layers
|
|
63
|
+
):
|
|
64
|
+
inputs["past_key_values"] = make_encoder_decoder_cache(
|
|
65
|
+
make_dynamic_cache(
|
|
66
|
+
[
|
|
67
|
+
(
|
|
68
|
+
torch.randn(
|
|
69
|
+
batch_size,
|
|
70
|
+
encoder_attention_heads,
|
|
71
|
+
sequence_length,
|
|
72
|
+
encoder_ffn_dim,
|
|
73
|
+
),
|
|
74
|
+
torch.randn(
|
|
75
|
+
batch_size,
|
|
76
|
+
encoder_attention_heads,
|
|
77
|
+
sequence_length,
|
|
78
|
+
encoder_ffn_dim,
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
for i in range(num_hidden_layers)
|
|
82
|
+
]
|
|
83
|
+
),
|
|
84
|
+
make_dynamic_cache(
|
|
85
|
+
[
|
|
86
|
+
(
|
|
87
|
+
torch.randn(
|
|
88
|
+
batch_size,
|
|
89
|
+
decoder_attention_heads,
|
|
90
|
+
sequence_length2,
|
|
91
|
+
decoder_ffn_dim,
|
|
92
|
+
),
|
|
93
|
+
torch.randn(
|
|
94
|
+
batch_size,
|
|
95
|
+
decoder_attention_heads,
|
|
96
|
+
sequence_length2,
|
|
97
|
+
decoder_ffn_dim,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
for i in range(num_hidden_layers)
|
|
101
|
+
]
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
cache_length = "cache_length_key"
|
|
105
|
+
cache_length2 = "cache_length_val"
|
|
106
|
+
shapes["past_key_values"] = [ # type: ignore[assignment]
|
|
107
|
+
[
|
|
108
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
109
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
110
|
+
],
|
|
111
|
+
[
|
|
112
|
+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
|
|
113
|
+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
|
|
114
|
+
],
|
|
115
|
+
]
|
|
116
|
+
|
|
53
117
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
54
118
|
if add_second_input:
|
|
119
|
+
assert (
|
|
120
|
+
add_second_input > 0
|
|
121
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
55
122
|
res["inputs2"] = get_inputs(
|
|
56
123
|
model=model,
|
|
57
124
|
config=config,
|
|
58
125
|
batch_size=batch_size + 1,
|
|
59
|
-
sequence_length=sequence_length +
|
|
126
|
+
sequence_length=sequence_length + add_second_input,
|
|
60
127
|
dummy_max_token_id=dummy_max_token_id,
|
|
128
|
+
sequence_length2=sequence_length2,
|
|
129
|
+
decoder_attention_heads=decoder_attention_heads,
|
|
130
|
+
encoder_attention_heads=encoder_attention_heads,
|
|
131
|
+
encoder_ffn_dim=encoder_ffn_dim,
|
|
132
|
+
decoder_ffn_dim=decoder_ffn_dim,
|
|
133
|
+
num_hidden_layers=num_hidden_layers,
|
|
134
|
+
add_second_input=0,
|
|
61
135
|
**kwargs,
|
|
62
136
|
)["inputs"]
|
|
63
137
|
return res
|
|
@@ -76,4 +150,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
76
150
|
sequence_length=30,
|
|
77
151
|
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
78
152
|
)
|
|
153
|
+
for att in [
|
|
154
|
+
"decoder_attention_heads",
|
|
155
|
+
"encoder_attention_heads",
|
|
156
|
+
"encoder_ffn_dim",
|
|
157
|
+
"decoder_ffn_dim",
|
|
158
|
+
"num_hidden_layers",
|
|
159
|
+
]:
|
|
160
|
+
if hasattr(config, att):
|
|
161
|
+
kwargs[att] = getattr(config, att)
|
|
162
|
+
kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
|
|
163
|
+
print(kwargs)
|
|
79
164
|
return kwargs, get_inputs
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -54,12 +54,16 @@ def get_inputs(
|
|
|
54
54
|
)
|
|
55
55
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
56
56
|
if add_second_input:
|
|
57
|
+
assert (
|
|
58
|
+
add_second_input > 0
|
|
59
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
57
60
|
res["inputs2"] = get_inputs(
|
|
58
61
|
model=model,
|
|
59
62
|
config=config,
|
|
60
63
|
batch_size=batch_size + 1,
|
|
61
|
-
sequence_length=sequence_length +
|
|
64
|
+
sequence_length=sequence_length + add_second_input,
|
|
62
65
|
dummy_max_token_id=dummy_max_token_id,
|
|
66
|
+
add_second_input=0,
|
|
63
67
|
**kwargs,
|
|
64
68
|
)["inputs"]
|
|
65
69
|
return res
|
|
@@ -34,7 +34,7 @@ def get_inputs(
|
|
|
34
34
|
input_channels: int,
|
|
35
35
|
batch_size: int = 2,
|
|
36
36
|
dynamic_rope: bool = False,
|
|
37
|
-
add_second_input:
|
|
37
|
+
add_second_input: int = 1,
|
|
38
38
|
**kwargs, # unused
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
@@ -75,14 +75,18 @@ def get_inputs(
|
|
|
75
75
|
shapes["interpolate_pos_encoding"] = None # type: ignore[assignment]
|
|
76
76
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
77
77
|
if add_second_input:
|
|
78
|
+
assert (
|
|
79
|
+
add_second_input > 0
|
|
80
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
78
81
|
res["inputs2"] = get_inputs(
|
|
79
82
|
model=model,
|
|
80
83
|
config=config,
|
|
81
|
-
input_width=input_width +
|
|
82
|
-
input_height=input_height +
|
|
84
|
+
input_width=input_width + add_second_input,
|
|
85
|
+
input_height=input_height + add_second_input,
|
|
83
86
|
input_channels=input_channels,
|
|
84
87
|
batch_size=batch_size + 1,
|
|
85
88
|
dynamic_rope=dynamic_rope,
|
|
89
|
+
add_second_input=0,
|
|
86
90
|
**kwargs,
|
|
87
91
|
)["inputs"]
|
|
88
92
|
return res
|
|
@@ -32,7 +32,7 @@ def get_inputs(
|
|
|
32
32
|
sequence_length2: int = 3,
|
|
33
33
|
n_images: int = 2,
|
|
34
34
|
dynamic_rope: bool = False,
|
|
35
|
-
add_second_input:
|
|
35
|
+
add_second_input: int = 1,
|
|
36
36
|
**kwargs, # unused
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
@@ -105,6 +105,9 @@ def get_inputs(
|
|
|
105
105
|
)
|
|
106
106
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
107
107
|
if add_second_input:
|
|
108
|
+
assert (
|
|
109
|
+
add_second_input > 0
|
|
110
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
108
111
|
res["inputs2"] = get_inputs(
|
|
109
112
|
model=model,
|
|
110
113
|
config=config,
|
|
@@ -116,10 +119,11 @@ def get_inputs(
|
|
|
116
119
|
height=height,
|
|
117
120
|
num_channels=num_channels,
|
|
118
121
|
batch_size=batch_size + 1,
|
|
119
|
-
sequence_length=sequence_length +
|
|
122
|
+
sequence_length=sequence_length + add_second_input,
|
|
120
123
|
sequence_length2=sequence_length2 + 1,
|
|
121
124
|
n_images=n_images + 1,
|
|
122
125
|
dynamic_rope=dynamic_rope,
|
|
126
|
+
add_second_input=0,
|
|
123
127
|
**kwargs,
|
|
124
128
|
)["inputs"]
|
|
125
129
|
return res
|
|
@@ -27,7 +27,7 @@ def get_inputs(
|
|
|
27
27
|
input_channels: int,
|
|
28
28
|
batch_size: int = 2,
|
|
29
29
|
dynamic_rope: bool = False,
|
|
30
|
-
add_second_input:
|
|
30
|
+
add_second_input: int = 1,
|
|
31
31
|
**kwargs, # unused
|
|
32
32
|
):
|
|
33
33
|
"""
|
|
@@ -65,14 +65,18 @@ def get_inputs(
|
|
|
65
65
|
)
|
|
66
66
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
67
67
|
if add_second_input:
|
|
68
|
+
assert (
|
|
69
|
+
add_second_input > 0
|
|
70
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
68
71
|
res["inputs2"] = get_inputs(
|
|
69
72
|
model=model,
|
|
70
73
|
config=config,
|
|
71
|
-
input_width=input_width +
|
|
72
|
-
input_height=input_height +
|
|
74
|
+
input_width=input_width + add_second_input,
|
|
75
|
+
input_height=input_height + add_second_input,
|
|
73
76
|
input_channels=input_channels,
|
|
74
77
|
batch_size=batch_size + 1,
|
|
75
78
|
dynamic_rope=dynamic_rope,
|
|
79
|
+
add_second_input=0,
|
|
76
80
|
**kwargs,
|
|
77
81
|
)["inputs"]
|
|
78
82
|
return res
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -54,12 +54,16 @@ def get_inputs(
|
|
|
54
54
|
)
|
|
55
55
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
56
56
|
if add_second_input:
|
|
57
|
+
assert (
|
|
58
|
+
add_second_input > 0
|
|
59
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
57
60
|
res["inputs2"] = get_inputs(
|
|
58
61
|
model=model,
|
|
59
62
|
config=config,
|
|
60
63
|
batch_size=batch_size + 1,
|
|
61
|
-
sequence_length=sequence_length +
|
|
64
|
+
sequence_length=sequence_length + add_second_input,
|
|
62
65
|
dummy_max_token_id=dummy_max_token_id,
|
|
66
|
+
add_second_input=0,
|
|
63
67
|
**kwargs,
|
|
64
68
|
)["inputs"]
|
|
65
69
|
return res
|
|
@@ -29,7 +29,7 @@ def get_inputs(
|
|
|
29
29
|
batch_size: int = 2,
|
|
30
30
|
sequence_length: int = 30,
|
|
31
31
|
sequence_length2: int = 3,
|
|
32
|
-
add_second_input:
|
|
32
|
+
add_second_input: int = 1,
|
|
33
33
|
**kwargs, # unused
|
|
34
34
|
):
|
|
35
35
|
"""
|
|
@@ -144,6 +144,9 @@ def get_inputs(
|
|
|
144
144
|
)
|
|
145
145
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
146
146
|
if add_second_input:
|
|
147
|
+
assert (
|
|
148
|
+
add_second_input > 0
|
|
149
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
147
150
|
res["inputs2"] = get_inputs(
|
|
148
151
|
model=model,
|
|
149
152
|
config=config,
|
|
@@ -154,8 +157,9 @@ def get_inputs(
|
|
|
154
157
|
head_dim_encoder=head_dim_encoder,
|
|
155
158
|
head_dim_decoder=head_dim_decoder,
|
|
156
159
|
batch_size=batch_size + 1,
|
|
157
|
-
sequence_length=sequence_length +
|
|
160
|
+
sequence_length=sequence_length + add_second_input,
|
|
158
161
|
sequence_length2=sequence_length2 + 1,
|
|
162
|
+
add_second_input=0,
|
|
159
163
|
**kwargs,
|
|
160
164
|
)["inputs"]
|
|
161
165
|
return res
|
|
@@ -30,7 +30,7 @@ def get_inputs(
|
|
|
30
30
|
batch_size: int = 2,
|
|
31
31
|
sequence_length: int = 30,
|
|
32
32
|
sequence_length2: int = 3,
|
|
33
|
-
add_second_input:
|
|
33
|
+
add_second_input: int = 1,
|
|
34
34
|
**kwargs, # unused
|
|
35
35
|
):
|
|
36
36
|
"""
|
|
@@ -69,8 +69,8 @@ def get_inputs(
|
|
|
69
69
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
70
70
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
71
71
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
72
|
-
cache_length = "cache_length_key"
|
|
73
|
-
cache_length2 = "cache_length_val"
|
|
72
|
+
cache_length = "cache_length_key"
|
|
73
|
+
cache_length2 = "cache_length_val"
|
|
74
74
|
|
|
75
75
|
shapes = {
|
|
76
76
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -149,6 +149,9 @@ def get_inputs(
|
|
|
149
149
|
)
|
|
150
150
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
151
151
|
if add_second_input:
|
|
152
|
+
assert (
|
|
153
|
+
add_second_input > 0
|
|
154
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
152
155
|
res["inputs2"] = get_inputs(
|
|
153
156
|
model=model,
|
|
154
157
|
config=config,
|
|
@@ -160,8 +163,9 @@ def get_inputs(
|
|
|
160
163
|
head_dim_decoder=head_dim_decoder,
|
|
161
164
|
encoder_dim=encoder_dim,
|
|
162
165
|
batch_size=batch_size + 1,
|
|
163
|
-
sequence_length=sequence_length +
|
|
166
|
+
sequence_length=sequence_length + add_second_input,
|
|
164
167
|
sequence_length2=sequence_length2 + 1,
|
|
168
|
+
add_second_input=0,
|
|
165
169
|
**kwargs,
|
|
166
170
|
)["inputs"]
|
|
167
171
|
return res
|
|
@@ -22,7 +22,7 @@ def get_inputs(
|
|
|
22
22
|
batch_size: int,
|
|
23
23
|
sequence_length: int,
|
|
24
24
|
dummy_max_token_id: int,
|
|
25
|
-
add_second_input:
|
|
25
|
+
add_second_input: int = 1,
|
|
26
26
|
**kwargs, # unused
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
@@ -54,12 +54,16 @@ def get_inputs(
|
|
|
54
54
|
)
|
|
55
55
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
56
56
|
if add_second_input:
|
|
57
|
+
assert (
|
|
58
|
+
add_second_input > 0
|
|
59
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
57
60
|
res["inputs2"] = get_inputs(
|
|
58
61
|
model=model,
|
|
59
62
|
config=config,
|
|
60
63
|
batch_size=batch_size + 1,
|
|
61
|
-
sequence_length=sequence_length +
|
|
64
|
+
sequence_length=sequence_length + add_second_input,
|
|
62
65
|
dummy_max_token_id=dummy_max_token_id,
|
|
66
|
+
add_second_input=0,
|
|
63
67
|
**kwargs,
|
|
64
68
|
)["inputs"]
|
|
65
69
|
return res
|
|
@@ -72,7 +72,7 @@ def get_inputs(
|
|
|
72
72
|
num_key_value_heads: Optional[int] = None,
|
|
73
73
|
head_dim: Optional[int] = None,
|
|
74
74
|
cls_cache: Optional[Union[type, str]] = None,
|
|
75
|
-
add_second_input:
|
|
75
|
+
add_second_input: int = 1,
|
|
76
76
|
**kwargs, # unused
|
|
77
77
|
):
|
|
78
78
|
"""
|
|
@@ -260,13 +260,15 @@ def get_inputs(
|
|
|
260
260
|
config=config,
|
|
261
261
|
dummy_max_token_id=dummy_max_token_id,
|
|
262
262
|
num_hidden_layers=num_hidden_layers,
|
|
263
|
-
batch_size=batch_size + 1,
|
|
263
|
+
batch_size=(batch_size + 1) if add_second_input > 0 else 1,
|
|
264
264
|
sequence_length=sequence_length + 1,
|
|
265
|
-
sequence_length2=sequence_length2
|
|
265
|
+
sequence_length2=sequence_length2
|
|
266
|
+
+ (add_second_input if add_second_input > 0 else -add_second_input),
|
|
266
267
|
dynamic_rope=dynamic_rope,
|
|
267
268
|
num_key_value_heads=num_key_value_heads,
|
|
268
269
|
head_dim=head_dim,
|
|
269
270
|
cls_cache=cls_cache,
|
|
271
|
+
add_second_input=0,
|
|
270
272
|
**kwargs,
|
|
271
273
|
)["inputs"]
|
|
272
274
|
return res
|
|
@@ -25,7 +25,7 @@ def get_inputs(
|
|
|
25
25
|
in_channels: int,
|
|
26
26
|
sample_size: int,
|
|
27
27
|
cross_attention_dim: int,
|
|
28
|
-
add_second_input:
|
|
28
|
+
add_second_input: int = 1,
|
|
29
29
|
**kwargs, # unused
|
|
30
30
|
):
|
|
31
31
|
"""
|
|
@@ -58,15 +58,19 @@ def get_inputs(
|
|
|
58
58
|
)
|
|
59
59
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
60
60
|
if add_second_input:
|
|
61
|
+
assert (
|
|
62
|
+
add_second_input > 0
|
|
63
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
61
64
|
res["inputs2"] = get_inputs(
|
|
62
65
|
model=model,
|
|
63
66
|
config=config,
|
|
64
67
|
batch_size=batch_size + 1,
|
|
65
68
|
sequence_length=sequence_length,
|
|
66
|
-
cache_length=cache_length +
|
|
69
|
+
cache_length=cache_length + add_second_input,
|
|
67
70
|
in_channels=in_channels,
|
|
68
71
|
sample_size=sample_size,
|
|
69
72
|
cross_attention_dim=cross_attention_dim,
|
|
73
|
+
add_second_input=0,
|
|
70
74
|
**kwargs,
|
|
71
75
|
)["inputs"]
|
|
72
76
|
return res
|
|
@@ -34,7 +34,7 @@ def get_inputs(
|
|
|
34
34
|
input_height: int = 224,
|
|
35
35
|
input_channels: int = 3,
|
|
36
36
|
batch_size_image=3,
|
|
37
|
-
add_second_input:
|
|
37
|
+
add_second_input: int = 1,
|
|
38
38
|
**kwargs, # unused
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
@@ -87,16 +87,20 @@ def get_inputs(
|
|
|
87
87
|
)
|
|
88
88
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
89
89
|
if add_second_input:
|
|
90
|
+
assert (
|
|
91
|
+
add_second_input > 0
|
|
92
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
90
93
|
res["inputs2"] = get_inputs(
|
|
91
94
|
model=model,
|
|
92
95
|
config=config,
|
|
93
96
|
dummy_max_token_id=dummy_max_token_id,
|
|
94
97
|
batch_size=batch_size + 1,
|
|
95
|
-
sequence_length=sequence_length +
|
|
98
|
+
sequence_length=sequence_length + add_second_input,
|
|
96
99
|
input_width=input_width,
|
|
97
100
|
input_height=input_height,
|
|
98
101
|
input_channels=input_channels,
|
|
99
102
|
batch_size_image=batch_size_image + 1,
|
|
103
|
+
add_second_input=0,
|
|
100
104
|
**kwargs,
|
|
101
105
|
)["inputs"]
|
|
102
106
|
return res
|
|
@@ -16,6 +16,8 @@ def get_function(name: str) -> Tuple[type, Callable]:
|
|
|
16
16
|
module_name = ".".join(spl[:-1])
|
|
17
17
|
fname = spl[-1]
|
|
18
18
|
mod = importlib.import_module(module_name)
|
|
19
|
+
if not hasattr(mod, fname):
|
|
20
|
+
return None, None
|
|
19
21
|
return mod, getattr(mod, fname)
|
|
20
22
|
|
|
21
23
|
|
|
@@ -33,12 +35,16 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
|
33
35
|
doc = v.__doc__.lstrip()
|
|
34
36
|
if doc.startswith("manual patch"):
|
|
35
37
|
continue
|
|
36
|
-
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
|
|
38
|
+
reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
|
|
37
39
|
fall = reg.findall(doc)
|
|
38
40
|
assert (
|
|
39
41
|
len(fall) == 1
|
|
40
42
|
), f"Unable to find patching information for {v} in \n{doc}"
|
|
41
43
|
fmod, f = get_function(fall[0])
|
|
44
|
+
if fmod is None and f is None:
|
|
45
|
+
# The function does not exist in this version of transformers.
|
|
46
|
+
# No patch is needed.
|
|
47
|
+
continue
|
|
42
48
|
to_patch.append({"module": fmod, "function": f, "patch": v})
|
|
43
49
|
|
|
44
50
|
name = mod.__name__
|
|
@@ -420,7 +426,11 @@ def torch_export_patches(
|
|
|
420
426
|
patch_transformers_list, verbose=verbose
|
|
421
427
|
)
|
|
422
428
|
|
|
423
|
-
if
|
|
429
|
+
if (
|
|
430
|
+
masking_utils
|
|
431
|
+
and patch_transformers_list.patch_masking_utils
|
|
432
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
433
|
+
):
|
|
424
434
|
if verbose:
|
|
425
435
|
print(
|
|
426
436
|
"[torch_export_patches] patches "
|
|
@@ -429,6 +439,27 @@ def torch_export_patches(
|
|
|
429
439
|
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
430
440
|
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
431
441
|
|
|
442
|
+
if (
|
|
443
|
+
masking_utils
|
|
444
|
+
and patch_transformers_list.patch_masking_utils
|
|
445
|
+
and hasattr(masking_utils, "eager_mask")
|
|
446
|
+
):
|
|
447
|
+
if verbose:
|
|
448
|
+
print(
|
|
449
|
+
"[torch_export_patches] patches "
|
|
450
|
+
"transformers.masking_utils.eager_mask"
|
|
451
|
+
)
|
|
452
|
+
f_transformers_eager_mask = masking_utils.eager_mask
|
|
453
|
+
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
|
|
454
|
+
if (
|
|
455
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
456
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
457
|
+
== f_transformers_eager_mask
|
|
458
|
+
):
|
|
459
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
460
|
+
patch_transformers_list.patched_eager_mask
|
|
461
|
+
)
|
|
462
|
+
|
|
432
463
|
if custom_patches:
|
|
433
464
|
if verbose:
|
|
434
465
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -511,7 +542,7 @@ def torch_export_patches(
|
|
|
511
542
|
|
|
512
543
|
if custom_patches:
|
|
513
544
|
if verbose:
|
|
514
|
-
print("[torch_export_patches]
|
|
545
|
+
print("[torch_export_patches] unpatches custom patches")
|
|
515
546
|
unpatch_module_or_classes(
|
|
516
547
|
custom_patches, revert_custom_patches_info, verbose=verbose
|
|
517
548
|
)
|
|
@@ -526,18 +557,43 @@ def torch_export_patches(
|
|
|
526
557
|
except ImportError:
|
|
527
558
|
masking_utils = None
|
|
528
559
|
if verbose:
|
|
529
|
-
print("[torch_export_patches]
|
|
560
|
+
print("[torch_export_patches] unpatches transformers")
|
|
530
561
|
unpatch_module_or_classes(
|
|
531
562
|
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
532
563
|
)
|
|
533
564
|
|
|
534
|
-
if
|
|
565
|
+
if (
|
|
566
|
+
masking_utils
|
|
567
|
+
and patch_transformers_list.patch_masking_utils
|
|
568
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
569
|
+
):
|
|
570
|
+
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
535
571
|
if verbose:
|
|
536
572
|
print(
|
|
537
|
-
"[torch_export_patches]
|
|
573
|
+
"[torch_export_patches] restored "
|
|
538
574
|
"transformers.masking_utils._vmap_for_bhqkv"
|
|
539
575
|
)
|
|
540
|
-
|
|
576
|
+
|
|
577
|
+
if (
|
|
578
|
+
masking_utils
|
|
579
|
+
and patch_transformers_list.patch_masking_utils
|
|
580
|
+
and hasattr(masking_utils, "eager_mask")
|
|
581
|
+
):
|
|
582
|
+
f_transformers_eager_mask = masking_utils.eager_mask
|
|
583
|
+
masking_utils.eager_mask = f_transformers_eager_mask
|
|
584
|
+
if (
|
|
585
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
586
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
587
|
+
== patch_transformers_list.patched_eager_mask
|
|
588
|
+
):
|
|
589
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
590
|
+
f_transformers_eager_mask
|
|
591
|
+
)
|
|
592
|
+
if verbose:
|
|
593
|
+
print(
|
|
594
|
+
"[torch_export_patches] restored "
|
|
595
|
+
"transformers.masking_utils.eager_mask"
|
|
596
|
+
)
|
|
541
597
|
|
|
542
598
|
########
|
|
543
599
|
# caches
|