onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +281 -80
- onnx_diagnostic/doc.py +22 -0
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +78 -8
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +1744 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +54 -73
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +21 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +72 -18
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
- onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -52,6 +52,9 @@ def get_inputs(
|
|
|
52
52
|
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
53
53
|
:return: dictionary
|
|
54
54
|
"""
|
|
55
|
+
assert (
|
|
56
|
+
"cls_cache" not in kwargs
|
|
57
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
55
58
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
56
59
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
57
60
|
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
@@ -93,10 +96,10 @@ def get_inputs(
|
|
|
93
96
|
for i in range(num_hidden_layers)
|
|
94
97
|
]
|
|
95
98
|
),
|
|
96
|
-
|
|
99
|
+
pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
|
|
97
100
|
torch.int64
|
|
98
101
|
),
|
|
99
|
-
|
|
102
|
+
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
100
103
|
torch.int64
|
|
101
104
|
),
|
|
102
105
|
)
|
|
@@ -129,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
129
132
|
If the configuration is None, the function selects typical dimensions.
|
|
130
133
|
"""
|
|
131
134
|
if config is not None:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
135
|
+
if hasattr(config, "text_config"):
|
|
136
|
+
check_hasattr(
|
|
137
|
+
config.text_config,
|
|
138
|
+
"vocab_size",
|
|
139
|
+
"hidden_size",
|
|
140
|
+
"num_attention_heads",
|
|
141
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
142
|
+
"intermediate_size",
|
|
143
|
+
"hidden_size",
|
|
144
|
+
)
|
|
145
|
+
check_hasattr(config, "vision_config")
|
|
146
|
+
text_config = True
|
|
147
|
+
else:
|
|
148
|
+
check_hasattr(
|
|
149
|
+
config,
|
|
150
|
+
"vocab_size",
|
|
151
|
+
"hidden_size",
|
|
152
|
+
"num_attention_heads",
|
|
153
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
154
|
+
"intermediate_size",
|
|
155
|
+
"hidden_size",
|
|
156
|
+
"vision_config",
|
|
157
|
+
)
|
|
158
|
+
text_config = False
|
|
142
159
|
check_hasattr(config.vision_config, "image_size", "num_channels")
|
|
143
160
|
kwargs = dict(
|
|
144
161
|
batch_size=2,
|
|
@@ -147,17 +164,54 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
147
164
|
head_dim=(
|
|
148
165
|
16
|
|
149
166
|
if config is None
|
|
150
|
-
else getattr(
|
|
167
|
+
else getattr(
|
|
168
|
+
config,
|
|
169
|
+
"head_dim",
|
|
170
|
+
(config.text_config.hidden_size if text_config else config.hidden_size)
|
|
171
|
+
// (
|
|
172
|
+
config.text_config.num_attention_heads
|
|
173
|
+
if text_config
|
|
174
|
+
else config.num_attention_heads
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
),
|
|
178
|
+
dummy_max_token_id=(
|
|
179
|
+
31999
|
|
180
|
+
if config is None
|
|
181
|
+
else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
|
|
182
|
+
),
|
|
183
|
+
num_hidden_layers=(
|
|
184
|
+
4
|
|
185
|
+
if config is None
|
|
186
|
+
else (
|
|
187
|
+
config.text_config.num_hidden_layers
|
|
188
|
+
if text_config
|
|
189
|
+
else config.num_hidden_layers
|
|
190
|
+
)
|
|
151
191
|
),
|
|
152
|
-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
153
|
-
num_hidden_layers=4 if config is None else config.num_hidden_layers,
|
|
154
192
|
num_key_value_heads=(
|
|
155
193
|
8
|
|
156
194
|
if config is None
|
|
157
|
-
else
|
|
195
|
+
else (
|
|
196
|
+
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
|
|
197
|
+
if text_config
|
|
198
|
+
else _pick(config, "num_key_value_heads", "num_attention_heads")
|
|
199
|
+
)
|
|
200
|
+
),
|
|
201
|
+
intermediate_size=(
|
|
202
|
+
1024
|
|
203
|
+
if config is None
|
|
204
|
+
else (
|
|
205
|
+
config.text_config.intermediate_size
|
|
206
|
+
if text_config
|
|
207
|
+
else config.intermediate_size
|
|
208
|
+
)
|
|
209
|
+
),
|
|
210
|
+
hidden_size=(
|
|
211
|
+
512
|
|
212
|
+
if config is None
|
|
213
|
+
else (config.text_config.hidden_size if text_config else config.hidden_size)
|
|
158
214
|
),
|
|
159
|
-
intermediate_size=1024 if config is None else config.intermediate_size,
|
|
160
|
-
hidden_size=512 if config is None else config.hidden_size,
|
|
161
215
|
width=224 if config is None else config.vision_config.image_size,
|
|
162
216
|
height=224 if config is None else config.vision_config.image_size,
|
|
163
217
|
num_channels=3 if config is None else config.vision_config.num_channels,
|
|
@@ -61,6 +61,9 @@ def get_inputs(
|
|
|
61
61
|
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
62
62
|
:return: dictionary
|
|
63
63
|
"""
|
|
64
|
+
assert (
|
|
65
|
+
"cls_cache" not in kwargs
|
|
66
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
64
67
|
assert not add_second_input, "add_second_input=True not yet implemented"
|
|
65
68
|
raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
|
|
66
69
|
|
|
@@ -41,6 +41,9 @@ def get_inputs(
|
|
|
41
41
|
:param input_height: input height
|
|
42
42
|
:return: dictionary
|
|
43
43
|
"""
|
|
44
|
+
assert (
|
|
45
|
+
"cls_cache" not in kwargs
|
|
46
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
44
47
|
assert isinstance(
|
|
45
48
|
input_width, int
|
|
46
49
|
), f"Unexpected type for input_width {type(input_width)}{config}"
|
|
@@ -35,6 +35,9 @@ def get_inputs(
|
|
|
35
35
|
token_type_ids:T7s1x13[0,0:A0.0],
|
|
36
36
|
attention_mask:T7s1x13[1,1:A1.0])
|
|
37
37
|
"""
|
|
38
|
+
assert (
|
|
39
|
+
"cls_cache" not in kwargs
|
|
40
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
38
41
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
39
42
|
seq_length = "seq_length"
|
|
40
43
|
shapes = {
|
|
@@ -62,6 +62,9 @@ def get_inputs(
|
|
|
62
62
|
decoder_input_ids:T7s1x1,
|
|
63
63
|
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
|
|
64
64
|
"""
|
|
65
|
+
assert (
|
|
66
|
+
"cls_cache" not in kwargs
|
|
67
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
65
68
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
66
69
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
67
70
|
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
@@ -64,6 +64,9 @@ def get_inputs(
|
|
|
64
64
|
decoder_input_ids:T7s1x1,
|
|
65
65
|
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
|
|
66
66
|
"""
|
|
67
|
+
assert (
|
|
68
|
+
"cls_cache" not in kwargs
|
|
69
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
67
70
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
68
71
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
69
72
|
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
@@ -35,6 +35,9 @@ def get_inputs(
|
|
|
35
35
|
token_type_ids:T7s1x13[0,0:A0.0],
|
|
36
36
|
attention_mask:T7s1x13[1,1:A1.0])
|
|
37
37
|
"""
|
|
38
|
+
assert (
|
|
39
|
+
"cls_cache" not in kwargs
|
|
40
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
38
41
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
39
42
|
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
|
|
40
43
|
shapes = {
|
|
@@ -5,6 +5,7 @@ from ..helpers.cache_helper import (
|
|
|
5
5
|
make_dynamic_cache,
|
|
6
6
|
make_mamba_cache,
|
|
7
7
|
make_sliding_window_cache,
|
|
8
|
+
make_static_cache,
|
|
8
9
|
)
|
|
9
10
|
from ..helpers.config_helper import update_config, check_hasattr, _pick
|
|
10
11
|
|
|
@@ -151,52 +152,98 @@ def get_inputs(
|
|
|
151
152
|
assert config, "head_dim is None, the value cannot be set without a configuration"
|
|
152
153
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
153
154
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
"
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
},
|
|
164
|
-
"past_key_values": [
|
|
165
|
-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
166
|
-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
167
|
-
],
|
|
155
|
+
cache_name = (
|
|
156
|
+
cls_cache
|
|
157
|
+
if cls_cache is None or isinstance(cls_cache, str)
|
|
158
|
+
else cls_cache.__name__
|
|
159
|
+
)
|
|
160
|
+
make_caches = {
|
|
161
|
+
"DynamicCache": make_dynamic_cache,
|
|
162
|
+
"SlidingWindowCache": make_sliding_window_cache,
|
|
163
|
+
"StaticCache": make_static_cache,
|
|
168
164
|
}
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
|
|
173
|
-
else make_dynamic_cache
|
|
165
|
+
assert cache_name is None or cache_name in make_caches, (
|
|
166
|
+
f"Unable to handle cls_cache={cache_name!r}, it should be in "
|
|
167
|
+
f"{sorted(make_caches)}"
|
|
174
168
|
)
|
|
169
|
+
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
|
|
170
|
+
is_static = cache_name == "StaticCache"
|
|
175
171
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
172
|
+
if is_static:
|
|
173
|
+
# static
|
|
174
|
+
shapes = {
|
|
175
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
176
|
+
"attention_mask": {0: batch, 2: "seq"},
|
|
177
|
+
"cache_position": {0: "seq"},
|
|
178
|
+
"past_key_values": [
|
|
179
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
180
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
181
|
+
],
|
|
182
|
+
}
|
|
183
|
+
inputs = dict(
|
|
184
|
+
input_ids=torch.randint(
|
|
185
|
+
0, dummy_max_token_id, (batch_size, sequence_length2)
|
|
186
|
+
).to(torch.int64),
|
|
187
|
+
attention_mask=torch.ones(
|
|
188
|
+
(batch_size, num_key_value_heads, sequence_length2, head_dim)
|
|
189
|
+
).to(torch.bool),
|
|
190
|
+
cache_position=torch.arange(sequence_length2).to(torch.int64),
|
|
191
|
+
past_key_values=make_cache(
|
|
192
|
+
[
|
|
193
|
+
(
|
|
194
|
+
torch.randn(
|
|
195
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
196
|
+
),
|
|
197
|
+
torch.randn(
|
|
198
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
199
|
+
),
|
|
200
|
+
)
|
|
201
|
+
for i in range(num_hidden_layers)
|
|
202
|
+
]
|
|
203
|
+
),
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
# dynamic
|
|
207
|
+
shapes = {
|
|
208
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
209
|
+
"attention_mask": {
|
|
210
|
+
0: batch,
|
|
211
|
+
1: "cache+seq", # cache_length + seq_length
|
|
212
|
+
},
|
|
213
|
+
"position_ids": {
|
|
214
|
+
0: batch,
|
|
215
|
+
1: "cache+seq", # cache_length + seq_length
|
|
216
|
+
},
|
|
217
|
+
"past_key_values": [
|
|
218
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
219
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
220
|
+
],
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
inputs = dict(
|
|
224
|
+
input_ids=torch.randint(
|
|
225
|
+
0, dummy_max_token_id, (batch_size, sequence_length2)
|
|
226
|
+
).to(torch.int64),
|
|
227
|
+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
|
|
228
|
+
torch.int64
|
|
229
|
+
),
|
|
230
|
+
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
|
|
231
|
+
.to(torch.int64)
|
|
232
|
+
.expand((batch_size, -1)),
|
|
233
|
+
past_key_values=make_cache(
|
|
234
|
+
[
|
|
235
|
+
(
|
|
236
|
+
torch.randn(
|
|
237
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
238
|
+
),
|
|
239
|
+
torch.randn(
|
|
240
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
241
|
+
),
|
|
242
|
+
)
|
|
243
|
+
for i in range(num_hidden_layers)
|
|
244
|
+
]
|
|
245
|
+
),
|
|
246
|
+
)
|
|
200
247
|
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
201
248
|
if add_second_input:
|
|
202
249
|
res["inputs2"] = get_inputs(
|
|
@@ -55,6 +55,9 @@ def get_inputs(
|
|
|
55
55
|
# attention_mask:T7s2x7
|
|
56
56
|
# pixel_values:T1s2x3x224x224
|
|
57
57
|
"""
|
|
58
|
+
assert (
|
|
59
|
+
"cls_cache" not in kwargs
|
|
60
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
58
61
|
assert isinstance(
|
|
59
62
|
input_width, int
|
|
60
63
|
), f"Unexpected type for input_width {type(input_width)}{config}"
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import importlib
|
|
1
3
|
import contextlib
|
|
2
|
-
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
3
6
|
from .onnx_export_serialization import (
|
|
4
7
|
register_cache_serialization,
|
|
5
8
|
unregister_cache_serialization,
|
|
@@ -7,6 +10,41 @@ from .onnx_export_serialization import (
|
|
|
7
10
|
from .patches import patch_transformers as patch_transformers_list
|
|
8
11
|
|
|
9
12
|
|
|
13
|
+
def get_function(name: str) -> Tuple[type, Callable]:
|
|
14
|
+
"""Returns the module and the function based on its name."""
|
|
15
|
+
spl = name.split(".")
|
|
16
|
+
module_name = ".".join(spl[:-1])
|
|
17
|
+
fname = spl[-1]
|
|
18
|
+
mod = importlib.import_module(module_name)
|
|
19
|
+
return mod, getattr(mod, fname)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@functools.lru_cache
|
|
23
|
+
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
24
|
+
"""Returns the list of patches to make for a specific module."""
|
|
25
|
+
to_patch = []
|
|
26
|
+
for k in dir(mod):
|
|
27
|
+
if k.startswith("patched_"):
|
|
28
|
+
v = getattr(mod, k)
|
|
29
|
+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
30
|
+
to_patch.append(v)
|
|
31
|
+
else:
|
|
32
|
+
# a function
|
|
33
|
+
doc = v.__doc__.lstrip()
|
|
34
|
+
if doc.startswith("manual patch"):
|
|
35
|
+
continue
|
|
36
|
+
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
|
|
37
|
+
fall = reg.findall(doc)
|
|
38
|
+
assert (
|
|
39
|
+
len(fall) == 1
|
|
40
|
+
), f"Unable to find patching information for {v} in \n{doc}"
|
|
41
|
+
fmod, f = get_function(fall[0])
|
|
42
|
+
to_patch.append({"module": fmod, "function": f, "patch": v})
|
|
43
|
+
|
|
44
|
+
name = mod.__name__
|
|
45
|
+
return name, to_patch
|
|
46
|
+
|
|
47
|
+
|
|
10
48
|
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
|
|
11
49
|
"""
|
|
12
50
|
Applies all patches defined in classes prefixed by ``patched_``
|
|
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
23
61
|
to_patch = mod
|
|
24
62
|
name = "list"
|
|
25
63
|
else:
|
|
26
|
-
to_patch =
|
|
27
|
-
for k in dir(mod):
|
|
28
|
-
if k.startswith("patched_"):
|
|
29
|
-
v = getattr(mod, k)
|
|
30
|
-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
31
|
-
to_patch.append(v)
|
|
32
|
-
name = mod.__name__
|
|
64
|
+
name, to_patch = get_patches(mod, verbose)
|
|
33
65
|
|
|
34
66
|
res = {}
|
|
35
67
|
for cls in to_patch:
|
|
68
|
+
if isinstance(cls, dict):
|
|
69
|
+
# a function
|
|
70
|
+
keep = {}
|
|
71
|
+
original = cls["module"]
|
|
72
|
+
f = cls["function"]
|
|
73
|
+
res[f] = f
|
|
74
|
+
if verbose:
|
|
75
|
+
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
|
|
76
|
+
setattr(original, f.__name__, cls["patch"])
|
|
77
|
+
continue
|
|
78
|
+
|
|
36
79
|
original = cls._PATCHED_CLASS_
|
|
37
80
|
methods = cls._PATCHES_
|
|
38
81
|
if verbose:
|
|
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
|
|
|
57
100
|
to_patch = mod
|
|
58
101
|
name = "list"
|
|
59
102
|
else:
|
|
60
|
-
to_patch =
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
65
|
-
to_patch.append(v)
|
|
66
|
-
name = mod.__name__
|
|
67
|
-
set_patch = set(to_patch)
|
|
103
|
+
name, to_patch = get_patches(mod, verbose)
|
|
104
|
+
|
|
105
|
+
set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
|
|
106
|
+
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
|
|
68
107
|
|
|
69
108
|
for cls, methods in info.items():
|
|
70
|
-
|
|
109
|
+
if cls in set_patch_cls:
|
|
110
|
+
if verbose:
|
|
111
|
+
print(
|
|
112
|
+
f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
|
|
113
|
+
)
|
|
114
|
+
original = cls._PATCHED_CLASS_
|
|
115
|
+
for n, v in methods.items():
|
|
116
|
+
if v is None:
|
|
117
|
+
# The method did not exist. We remove it.
|
|
118
|
+
delattr(original, n)
|
|
119
|
+
else:
|
|
120
|
+
setattr(original, n, v)
|
|
121
|
+
continue
|
|
122
|
+
assert cls in dict_patch_fct, (
|
|
123
|
+
f"No patch registered for {cls} in {mod} "
|
|
124
|
+
f"(found {set_patch_cls} and {set(dict_patch_fct)})"
|
|
125
|
+
)
|
|
126
|
+
patch = dict_patch_fct[cls]
|
|
71
127
|
if verbose:
|
|
72
|
-
print(
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
delattr(original, n)
|
|
78
|
-
else:
|
|
79
|
-
setattr(original, n, v)
|
|
128
|
+
print(
|
|
129
|
+
f"[unpatch_module_or_classes] function "
|
|
130
|
+
f"{patch['module'].__name__}.{cls.__name__}"
|
|
131
|
+
)
|
|
132
|
+
setattr(patch["module"], cls.__name__, patch["function"])
|
|
80
133
|
|
|
81
134
|
|
|
82
135
|
@contextlib.contextmanager
|
|
@@ -9,9 +9,11 @@ from transformers.cache_utils import (
|
|
|
9
9
|
MambaCache,
|
|
10
10
|
EncoderDecoderCache,
|
|
11
11
|
SlidingWindowCache,
|
|
12
|
+
StaticCache,
|
|
12
13
|
)
|
|
13
14
|
from transformers.modeling_outputs import BaseModelOutput
|
|
14
15
|
from ..helpers import string_type
|
|
16
|
+
from ..helpers.cache_helper import make_static_cache
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
PATCH_OF_PATCHES: Set[Any] = set()
|
|
@@ -175,6 +177,13 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]
|
|
|
175
177
|
flatten_with_keys_sliding_window_cache,
|
|
176
178
|
verbose=verbose,
|
|
177
179
|
),
|
|
180
|
+
StaticCache=register_class_serialization(
|
|
181
|
+
StaticCache,
|
|
182
|
+
flatten_static_cache,
|
|
183
|
+
unflatten_static_cache,
|
|
184
|
+
flatten_with_keys_static_cache,
|
|
185
|
+
verbose=verbose,
|
|
186
|
+
),
|
|
178
187
|
)
|
|
179
188
|
|
|
180
189
|
|
|
@@ -309,6 +318,34 @@ def unflatten_dynamic_cache(
|
|
|
309
318
|
return cache
|
|
310
319
|
|
|
311
320
|
|
|
321
|
+
#############
|
|
322
|
+
# StaticCache
|
|
323
|
+
#############
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def flatten_static_cache(
|
|
327
|
+
cache: StaticCache,
|
|
328
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
329
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
330
|
+
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
331
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def flatten_with_keys_static_cache(
|
|
335
|
+
cache: StaticCache,
|
|
336
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
337
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
338
|
+
values, context = flatten_static_cache(cache)
|
|
339
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def unflatten_static_cache(
|
|
343
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
344
|
+
) -> StaticCache:
|
|
345
|
+
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
346
|
+
return make_static_cache(list(zip(values[0], values[1])))
|
|
347
|
+
|
|
348
|
+
|
|
312
349
|
####################
|
|
313
350
|
# SlidingWindowCache
|
|
314
351
|
####################
|