onnx-diagnostic 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/ci_models/export_phi4_mm.py +1 -1
- onnx_diagnostic/doc.py +258 -8
- onnx_diagnostic/export/api.py +755 -5
- onnx_diagnostic/export/dynamic_shapes.py +61 -4
- onnx_diagnostic/export/shape_helper.py +1 -8
- onnx_diagnostic/helpers/cache_helper.py +98 -21
- onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
- onnx_diagnostic/helpers/helper.py +36 -6
- onnx_diagnostic/helpers/onnx_helper.py +7 -0
- onnx_diagnostic/helpers/ort_session.py +5 -0
- onnx_diagnostic/helpers/rt_helper.py +14 -1
- onnx_diagnostic/helpers/torch_helper.py +22 -9
- onnx_diagnostic/tasks/image_text_to_text.py +8 -5
- onnx_diagnostic/tasks/text_generation.py +17 -17
- onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
- onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
- onnx_diagnostic/torch_models/validate.py +48 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +3 -1
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +28 -28
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
3
3
|
import torch
|
|
4
|
-
from ..helpers.cache_helper import make_dynamic_cache,
|
|
4
|
+
from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache
|
|
5
5
|
from ..helpers.config_helper import (
|
|
6
6
|
update_config,
|
|
7
7
|
check_hasattr,
|
|
@@ -172,10 +172,10 @@ def _get_inputs_gemma3(
|
|
|
172
172
|
assert expected & set(
|
|
173
173
|
dummies
|
|
174
174
|
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
|
|
175
|
-
assert sequence_length == dummies["input_ids"].shape[-1], (
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
)
|
|
175
|
+
# assert sequence_length == dummies["input_ids"].shape[-1], (
|
|
176
|
+
# f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
|
|
177
|
+
# f"model class {model.__class__.__name__}"
|
|
178
|
+
# )
|
|
179
179
|
assert batch_size == dummies["input_ids"].shape[0], (
|
|
180
180
|
f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
|
|
181
181
|
f"model class {model.__class__.__name__}"
|
|
@@ -200,6 +200,9 @@ def _get_inputs_gemma3(
|
|
|
200
200
|
|
|
201
201
|
_check_()
|
|
202
202
|
|
|
203
|
+
make_hybrid_cache = get_make_hybrid_cache()
|
|
204
|
+
assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing"
|
|
205
|
+
|
|
203
206
|
inputs = dict(
|
|
204
207
|
input_ids=dummies["input_ids"],
|
|
205
208
|
token_type_ids=dummies["token_type_ids"],
|
|
@@ -1,11 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.cache_helper import
|
|
4
|
-
make_dynamic_cache,
|
|
5
|
-
make_mamba_cache,
|
|
6
|
-
make_sliding_window_cache,
|
|
7
|
-
make_static_cache,
|
|
8
|
-
)
|
|
3
|
+
from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache, make_static_cache
|
|
9
4
|
from ..helpers.config_helper import (
|
|
10
5
|
update_config,
|
|
11
6
|
check_hasattr,
|
|
@@ -187,17 +182,22 @@ def get_inputs(
|
|
|
187
182
|
if cls_cache is None or isinstance(cls_cache, str)
|
|
188
183
|
else cls_cache.__name__
|
|
189
184
|
)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
185
|
+
if cache_name == "DynamicSlidingWindowCache":
|
|
186
|
+
from ..helpers.cache_helper import make_sliding_window_cache
|
|
187
|
+
|
|
188
|
+
make_cache = make_sliding_window_cache
|
|
189
|
+
is_static = False
|
|
190
|
+
else:
|
|
191
|
+
make_caches = {
|
|
192
|
+
"DynamicCache": make_dynamic_cache,
|
|
193
|
+
"StaticCache": make_static_cache,
|
|
194
|
+
}
|
|
195
|
+
assert cache_name is None or cache_name in make_caches, (
|
|
196
|
+
f"Unable to handle cls_cache={cache_name!r}, it should be in "
|
|
197
|
+
f"{sorted(make_caches)}"
|
|
198
|
+
)
|
|
199
|
+
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] # type: ignore[assignment]
|
|
200
|
+
is_static = cache_name == "StaticCache"
|
|
201
201
|
|
|
202
202
|
if is_static:
|
|
203
203
|
# static
|
|
@@ -521,7 +521,7 @@ def run_exporter(
|
|
|
521
521
|
:param exporter: exporter
|
|
522
522
|
:param cls_model: model class to create
|
|
523
523
|
:param inputs: list of inputs to try
|
|
524
|
-
:param dynamic: use dynamic
|
|
524
|
+
:param dynamic: use dynamic shapes or not
|
|
525
525
|
:param quiet: raise exception or not
|
|
526
526
|
:param verbose: verbosity
|
|
527
527
|
:return: results
|
|
@@ -7,15 +7,9 @@ import transformers
|
|
|
7
7
|
from transformers.cache_utils import DynamicCache, StaticCache
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
|
-
from transformers.cache_utils import
|
|
11
|
-
EncoderDecoderCache,
|
|
12
|
-
HybridCache,
|
|
13
|
-
SlidingWindowCache,
|
|
14
|
-
)
|
|
10
|
+
from transformers.cache_utils import EncoderDecoderCache
|
|
15
11
|
except ImportError:
|
|
16
12
|
EncoderDecoderCache = None
|
|
17
|
-
HybridCache = None
|
|
18
|
-
SlidingWindowCache = None
|
|
19
13
|
from ..helpers import string_type
|
|
20
14
|
from .serialization import _lower_name_with_
|
|
21
15
|
|
|
@@ -36,6 +30,24 @@ def get_mamba_cache_cls() -> type:
|
|
|
36
30
|
return None
|
|
37
31
|
|
|
38
32
|
|
|
33
|
+
def get_hybrid_cache_cls() -> type:
|
|
34
|
+
try:
|
|
35
|
+
from transformers.cache_utils import HybridCache
|
|
36
|
+
|
|
37
|
+
return HybridCache
|
|
38
|
+
except ImportError:
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_sliding_window_cache_cls() -> type:
|
|
43
|
+
try:
|
|
44
|
+
from transformers.cache_utils import SlidingWindowCache
|
|
45
|
+
|
|
46
|
+
return SlidingWindowCache
|
|
47
|
+
except ImportError:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
39
51
|
def register_class_serialization(
|
|
40
52
|
cls,
|
|
41
53
|
f_flatten: Callable,
|
|
@@ -179,18 +191,9 @@ def serialization_functions(
|
|
|
179
191
|
flatten_dynamic_cache,
|
|
180
192
|
unflatten_dynamic_cache,
|
|
181
193
|
flatten_with_keys_dynamic_cache,
|
|
182
|
-
flatten_hybrid_cache,
|
|
183
|
-
unflatten_hybrid_cache,
|
|
184
|
-
flatten_with_keys_hybrid_cache,
|
|
185
|
-
flatten_mamba_cache,
|
|
186
|
-
unflatten_mamba_cache,
|
|
187
|
-
flatten_with_keys_mamba_cache,
|
|
188
194
|
flatten_encoder_decoder_cache,
|
|
189
195
|
unflatten_encoder_decoder_cache,
|
|
190
196
|
flatten_with_keys_encoder_decoder_cache,
|
|
191
|
-
flatten_sliding_window_cache,
|
|
192
|
-
unflatten_sliding_window_cache,
|
|
193
|
-
flatten_with_keys_sliding_window_cache,
|
|
194
197
|
flatten_static_cache,
|
|
195
198
|
unflatten_static_cache,
|
|
196
199
|
flatten_with_keys_static_cache,
|
|
@@ -208,14 +211,6 @@ def serialization_functions(
|
|
|
208
211
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
209
212
|
verbose=verbose,
|
|
210
213
|
),
|
|
211
|
-
HybridCache: lambda verbose=verbose: register_class_serialization(
|
|
212
|
-
HybridCache,
|
|
213
|
-
flatten_hybrid_cache,
|
|
214
|
-
unflatten_hybrid_cache,
|
|
215
|
-
flatten_with_keys_hybrid_cache,
|
|
216
|
-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
217
|
-
verbose=verbose,
|
|
218
|
-
),
|
|
219
214
|
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
220
215
|
EncoderDecoderCache,
|
|
221
216
|
flatten_encoder_decoder_cache,
|
|
@@ -223,13 +218,6 @@ def serialization_functions(
|
|
|
223
218
|
flatten_with_keys_encoder_decoder_cache,
|
|
224
219
|
verbose=verbose,
|
|
225
220
|
),
|
|
226
|
-
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
|
|
227
|
-
SlidingWindowCache,
|
|
228
|
-
flatten_sliding_window_cache,
|
|
229
|
-
unflatten_sliding_window_cache,
|
|
230
|
-
flatten_with_keys_sliding_window_cache,
|
|
231
|
-
verbose=verbose,
|
|
232
|
-
),
|
|
233
221
|
StaticCache: lambda verbose=verbose: register_class_serialization(
|
|
234
222
|
StaticCache,
|
|
235
223
|
flatten_static_cache,
|
|
@@ -240,6 +228,12 @@ def serialization_functions(
|
|
|
240
228
|
}
|
|
241
229
|
MambaCache = get_mamba_cache_cls()
|
|
242
230
|
if MambaCache:
|
|
231
|
+
from .serialization.transformers_impl import (
|
|
232
|
+
flatten_mamba_cache,
|
|
233
|
+
unflatten_mamba_cache,
|
|
234
|
+
flatten_with_keys_mamba_cache,
|
|
235
|
+
)
|
|
236
|
+
|
|
243
237
|
transformers_classes[MambaCache] = (
|
|
244
238
|
lambda verbose=verbose: register_class_serialization(
|
|
245
239
|
MambaCache,
|
|
@@ -249,6 +243,42 @@ def serialization_functions(
|
|
|
249
243
|
verbose=verbose,
|
|
250
244
|
)
|
|
251
245
|
)
|
|
246
|
+
HybridCache = get_hybrid_cache_cls()
|
|
247
|
+
if HybridCache:
|
|
248
|
+
from .serialization.transformers_impl import (
|
|
249
|
+
flatten_hybrid_cache,
|
|
250
|
+
unflatten_hybrid_cache,
|
|
251
|
+
flatten_with_keys_hybrid_cache,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
transformers_classes[HybridCache] = (
|
|
255
|
+
lambda verbose=verbose: register_class_serialization(
|
|
256
|
+
HybridCache,
|
|
257
|
+
flatten_hybrid_cache,
|
|
258
|
+
unflatten_hybrid_cache,
|
|
259
|
+
flatten_with_keys_hybrid_cache,
|
|
260
|
+
verbose=verbose,
|
|
261
|
+
)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
SlidingWindowCache = get_sliding_window_cache_cls()
|
|
265
|
+
if SlidingWindowCache:
|
|
266
|
+
from .serialization.transformers_impl import (
|
|
267
|
+
flatten_sliding_window_cache,
|
|
268
|
+
unflatten_sliding_window_cache,
|
|
269
|
+
flatten_with_keys_sliding_window_cache,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
transformers_classes[SlidingWindowCache] = (
|
|
273
|
+
lambda verbose=verbose: register_class_serialization(
|
|
274
|
+
SlidingWindowCache,
|
|
275
|
+
flatten_sliding_window_cache,
|
|
276
|
+
unflatten_sliding_window_cache,
|
|
277
|
+
flatten_with_keys_sliding_window_cache,
|
|
278
|
+
verbose=verbose,
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
|
|
252
282
|
classes.update(transformers_classes)
|
|
253
283
|
|
|
254
284
|
if patch_diffusers:
|
|
@@ -303,13 +333,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
|
303
333
|
|
|
304
334
|
|
|
305
335
|
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
306
|
-
|
|
307
|
-
MambaCache = get_mamba_cache_cls()
|
|
308
|
-
cls_ensemble = (
|
|
309
|
-
{DynamicCache, EncoderDecoderCache}
|
|
310
|
-
| set(undo)
|
|
311
|
-
| ({MambaCache} if MambaCache else set())
|
|
312
|
-
)
|
|
336
|
+
cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
|
|
313
337
|
for cls in cls_ensemble:
|
|
314
338
|
if undo.get(cls.__name__, False):
|
|
315
339
|
unregister_class_serialization(cls, verbose)
|
|
@@ -191,7 +191,7 @@ class PatchDetails:
|
|
|
191
191
|
ep = torch.export.export(
|
|
192
192
|
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
|
|
193
193
|
)
|
|
194
|
-
patches = details.
|
|
194
|
+
patches = details.patches_involved_in_graph(ep.graph)
|
|
195
195
|
report = details.make_report(patches, format="rst")
|
|
196
196
|
print(report)
|
|
197
197
|
"""
|
|
@@ -235,7 +235,7 @@ class PatchDetails:
|
|
|
235
235
|
"""Returns the data for a dataframe."""
|
|
236
236
|
return [p.to_dict() for p in self.patched]
|
|
237
237
|
|
|
238
|
-
def
|
|
238
|
+
def patches_involved_in_graph(
|
|
239
239
|
self, graph: "torch.fx.Graph" # noqa: F821
|
|
240
240
|
) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
|
|
241
241
|
"""
|
|
@@ -322,7 +322,7 @@ class PatchDetails:
|
|
|
322
322
|
"""
|
|
323
323
|
Creates a report based on the involved patches.
|
|
324
324
|
|
|
325
|
-
:param patches: from method :meth:`
|
|
325
|
+
:param patches: from method :meth:`patches_involved_in_graph`
|
|
326
326
|
:param format: format of the report
|
|
327
327
|
:return: report
|
|
328
328
|
"""
|
|
@@ -22,13 +22,22 @@ if patch_DynamicLayer:
|
|
|
22
22
|
_PATCHES_ = ["lazy_initialization"]
|
|
23
23
|
_PATCHED_CLASS_ = DynamicLayer
|
|
24
24
|
|
|
25
|
-
def lazy_initialization(
|
|
25
|
+
def lazy_initialization(
|
|
26
|
+
self, key_states: torch.Tensor, value_states: torch.Tensor = None
|
|
27
|
+
):
|
|
26
28
|
self.dtype, self.device = key_states.dtype, key_states.device
|
|
27
|
-
|
|
28
|
-
|
|
29
|
+
assert (
|
|
30
|
+
hasattr(key_states, "shape") and key_states is not None
|
|
31
|
+
), f"Attribute 'shape' is wrong for type {type(key_states)}"
|
|
32
|
+
like = torch.narrow(key_states, dim=-2, start=0, length=0)
|
|
29
33
|
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
30
|
-
|
|
31
|
-
|
|
34
|
+
if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
|
|
35
|
+
with key_states.fake_mode:
|
|
36
|
+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
37
|
+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
38
|
+
else:
|
|
39
|
+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
40
|
+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
32
41
|
if patch_is_initialized:
|
|
33
42
|
self.is_initialized = True
|
|
34
43
|
|
|
@@ -214,7 +214,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
214
214
|
cond,
|
|
215
215
|
(lambda x, y: x.clone()),
|
|
216
216
|
(lambda x, y: y.clone()),
|
|
217
|
-
[long_inv_freq, original_inv_freq],
|
|
217
|
+
[long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
|
|
218
218
|
)
|
|
219
219
|
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
220
220
|
# if seq_len > original_max_position_embeddings:
|
|
@@ -293,7 +293,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
293
293
|
cond,
|
|
294
294
|
(lambda x, y: x.clone()),
|
|
295
295
|
(lambda x, y: y.clone()),
|
|
296
|
-
[long_inv_freq, original_inv_freq],
|
|
296
|
+
[long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
|
|
297
297
|
)
|
|
298
298
|
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
299
299
|
|
|
@@ -541,14 +541,17 @@ class patched_ShapeEnv:
|
|
|
541
541
|
# oblivious_var_to_val will be defined iff we have sizes
|
|
542
542
|
# with DimDynamic.OBLIVIOUS_SIZE type.
|
|
543
543
|
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
|
544
|
+
var_to_val = getattr(
|
|
545
|
+
self,
|
|
546
|
+
"unbacked_var_to_val",
|
|
547
|
+
getattr(self, "oblivious_var_to_val", False),
|
|
548
|
+
)
|
|
544
549
|
if (
|
|
545
|
-
|
|
546
|
-
and not (
|
|
547
|
-
correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
|
|
548
|
-
).free_symbols
|
|
550
|
+
var_to_val
|
|
551
|
+
and not (correct_hint := orig_expr.xreplace(var_to_val)).free_symbols
|
|
549
552
|
and not (
|
|
550
553
|
counterfactual_hint := orig_expr.xreplace(
|
|
551
|
-
{k: max(2, v) for k, v in
|
|
554
|
+
{k: max(2, v) for k, v in var_to_val.items()}
|
|
552
555
|
)
|
|
553
556
|
).free_symbols
|
|
554
557
|
and correct_hint == counterfactual_hint
|
|
@@ -571,11 +574,11 @@ class patched_ShapeEnv:
|
|
|
571
574
|
# and if they pass we add a runtime assertions and continue.
|
|
572
575
|
if (
|
|
573
576
|
not ok
|
|
574
|
-
and
|
|
577
|
+
and var_to_val
|
|
575
578
|
and not (
|
|
576
|
-
unsound_result := orig_expr.xreplace(
|
|
577
|
-
|
|
578
|
-
)
|
|
579
|
+
unsound_result := orig_expr.xreplace(var_to_val).xreplace(
|
|
580
|
+
var_to_val
|
|
581
|
+
)
|
|
579
582
|
).free_symbols
|
|
580
583
|
):
|
|
581
584
|
# pyrefly: ignore # unbound-name
|
|
@@ -1,13 +1,7 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from typing import Any, Callable, List, Set, Tuple
|
|
3
3
|
import torch
|
|
4
|
-
from transformers.cache_utils import
|
|
5
|
-
Cache,
|
|
6
|
-
DynamicCache,
|
|
7
|
-
EncoderDecoderCache,
|
|
8
|
-
HybridCache,
|
|
9
|
-
StaticCache,
|
|
10
|
-
)
|
|
4
|
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
|
11
5
|
|
|
12
6
|
try:
|
|
13
7
|
from transformers.cache_utils import SlidingWindowCache
|
|
@@ -15,18 +9,17 @@ except ImportError:
|
|
|
15
9
|
SlidingWindowCache = None
|
|
16
10
|
|
|
17
11
|
|
|
12
|
+
try:
|
|
13
|
+
from transformers.cache_utils import HybridCache
|
|
14
|
+
except ImportError:
|
|
15
|
+
HybridCache = None
|
|
16
|
+
|
|
18
17
|
try:
|
|
19
18
|
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
20
19
|
except ImportError:
|
|
21
20
|
from transformers.cache_utils import MambaCache
|
|
22
21
|
from transformers.modeling_outputs import BaseModelOutput
|
|
23
|
-
from ...helpers.cache_helper import
|
|
24
|
-
make_dynamic_cache,
|
|
25
|
-
make_hybrid_cache,
|
|
26
|
-
make_sliding_window_cache,
|
|
27
|
-
make_static_cache,
|
|
28
|
-
CacheKeyValue,
|
|
29
|
-
)
|
|
22
|
+
from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue
|
|
30
23
|
from . import make_serialization_function_for_dataclass
|
|
31
24
|
|
|
32
25
|
|
|
@@ -78,6 +71,14 @@ def flatten_dynamic_cache(
|
|
|
78
71
|
dynamic_cache: DynamicCache,
|
|
79
72
|
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
80
73
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
74
|
+
assert (
|
|
75
|
+
not hasattr(dynamic_cache, "layers")
|
|
76
|
+
or not dynamic_cache.layers
|
|
77
|
+
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
|
|
78
|
+
), (
|
|
79
|
+
f"The serialization does not work yet on other layers "
|
|
80
|
+
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
|
|
81
|
+
)
|
|
81
82
|
return _flatten_key_value_cache(dynamic_cache)
|
|
82
83
|
|
|
83
84
|
|
|
@@ -85,6 +86,14 @@ def flatten_with_keys_dynamic_cache(
|
|
|
85
86
|
dynamic_cache: DynamicCache,
|
|
86
87
|
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
87
88
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
89
|
+
assert (
|
|
90
|
+
not hasattr(dynamic_cache, "layers")
|
|
91
|
+
or not dynamic_cache.layers
|
|
92
|
+
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
|
|
93
|
+
), (
|
|
94
|
+
f"The serialization does not work yet on other layers "
|
|
95
|
+
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
|
|
96
|
+
)
|
|
88
97
|
return _flatten_with_keys_cache(dynamic_cache)
|
|
89
98
|
|
|
90
99
|
|
|
@@ -99,26 +108,27 @@ def unflatten_dynamic_cache(
|
|
|
99
108
|
# HybridCache
|
|
100
109
|
#############
|
|
101
110
|
|
|
111
|
+
if HybridCache:
|
|
102
112
|
|
|
103
|
-
def flatten_hybrid_cache(
|
|
104
|
-
|
|
105
|
-
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
113
|
+
def flatten_hybrid_cache(
|
|
114
|
+
cache: HybridCache,
|
|
115
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
116
|
+
"""Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
|
|
117
|
+
return _flatten_key_value_cache(cache)
|
|
109
118
|
|
|
110
|
-
def flatten_with_keys_hybrid_cache(
|
|
111
|
-
|
|
112
|
-
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
113
|
-
|
|
114
|
-
|
|
119
|
+
def flatten_with_keys_hybrid_cache(
|
|
120
|
+
cache: HybridCache,
|
|
121
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
122
|
+
"""Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
|
|
123
|
+
return _flatten_with_keys_cache(cache)
|
|
115
124
|
|
|
125
|
+
def unflatten_hybrid_cache(
|
|
126
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
127
|
+
) -> HybridCache:
|
|
128
|
+
"""Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
|
|
129
|
+
from ...helpers.cache_helper import make_hybrid_cache
|
|
116
130
|
|
|
117
|
-
|
|
118
|
-
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
119
|
-
) -> HybridCache:
|
|
120
|
-
"""Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
|
|
121
|
-
return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
|
|
131
|
+
return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
|
|
122
132
|
|
|
123
133
|
|
|
124
134
|
#############
|
|
@@ -190,6 +200,8 @@ if SlidingWindowCache:
|
|
|
190
200
|
Restores a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
191
201
|
from python objects.
|
|
192
202
|
"""
|
|
203
|
+
from ...helpers.cache_helper import make_sliding_window_cache
|
|
204
|
+
|
|
193
205
|
return _unflatten_cache(
|
|
194
206
|
make_sliding_window_cache, values, context, output_type=output_type
|
|
195
207
|
)
|
|
@@ -1771,6 +1771,10 @@ def validate_onnx_model(
|
|
|
1771
1771
|
if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
|
|
1772
1772
|
opts = onnxruntime.SessionOptions()
|
|
1773
1773
|
opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
|
|
1774
|
+
opts.add_session_config_entry(
|
|
1775
|
+
"session.optimized_model_external_initializers_file_name",
|
|
1776
|
+
f"{os.path.split(data['onnx_filename'])[0]}.rtopt.data",
|
|
1777
|
+
)
|
|
1774
1778
|
if verbose:
|
|
1775
1779
|
print(
|
|
1776
1780
|
f"[validate_onnx_model] saved optimized onnxruntime "
|
|
@@ -2326,6 +2330,7 @@ def call_torch_export_custom(
|
|
|
2326
2330
|
"custom-dec",
|
|
2327
2331
|
"custom-decall",
|
|
2328
2332
|
"custom-fake",
|
|
2333
|
+
"custom-tracing",
|
|
2329
2334
|
}
|
|
2330
2335
|
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
|
|
2331
2336
|
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
@@ -2338,11 +2343,16 @@ def call_torch_export_custom(
|
|
|
2338
2343
|
f"Options strict cannot be specified in the exporter name {exporter!r} "
|
|
2339
2344
|
f"and in the options {exporter_options}"
|
|
2340
2345
|
)
|
|
2346
|
+
assert ("-tracing" not in exporter) or ("tracing" not in exporter_options), (
|
|
2347
|
+
f"Options tracing cannot be specified in the exporter name {exporter!r} "
|
|
2348
|
+
f"and in the options {exporter_options}"
|
|
2349
|
+
)
|
|
2341
2350
|
summary: Dict[str, Union[str, int, float]] = {}
|
|
2342
2351
|
strict = "-strict" in exporter or exporter_options.pop("strict", False)
|
|
2343
2352
|
args, kwargs = split_args_kwargs(data["inputs_export"])
|
|
2344
2353
|
ds = data.get("dynamic_shapes", None)
|
|
2345
2354
|
fake = "-fake" in exporter or exporter_options.pop("fake", False)
|
|
2355
|
+
tracing = "-tracing" in exporter or exporter_options.pop("tracing", False)
|
|
2346
2356
|
if fake:
|
|
2347
2357
|
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
2348
2358
|
|
|
@@ -2366,6 +2376,7 @@ def call_torch_export_custom(
|
|
|
2366
2376
|
summary["export_exporter"] = exporter
|
|
2367
2377
|
summary["export_optimization"] = optimization or ""
|
|
2368
2378
|
summary["export_strict"] = strict
|
|
2379
|
+
summary["export_tracing"] = tracing
|
|
2369
2380
|
summary["export_fake"] = fake
|
|
2370
2381
|
summary["export_args"] = string_type(args, with_shape=True)
|
|
2371
2382
|
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
|
|
@@ -2388,6 +2399,7 @@ def call_torch_export_custom(
|
|
|
2388
2399
|
)
|
|
2389
2400
|
)
|
|
2390
2401
|
large_model = bool(exporter_options.pop("large_model", True))
|
|
2402
|
+
exporter_options.pop("tracing", False)
|
|
2391
2403
|
return_optimize_report = bool(exporter_options.pop("return_optimize_report", True))
|
|
2392
2404
|
export_modules_as_functions = bool(
|
|
2393
2405
|
exporter_options.pop("export_modules_as_functions", False)
|
|
@@ -2401,6 +2413,7 @@ def call_torch_export_custom(
|
|
|
2401
2413
|
summary["export_external_threshold"] = str(external_threshold)
|
|
2402
2414
|
|
|
2403
2415
|
export_options = ExportOptions(
|
|
2416
|
+
tracing=tracing,
|
|
2404
2417
|
strict=strict,
|
|
2405
2418
|
decomposition_table=decomposition_table,
|
|
2406
2419
|
save_ep=(
|
|
@@ -2445,6 +2458,41 @@ def call_torch_export_custom(
|
|
|
2445
2458
|
)
|
|
2446
2459
|
),
|
|
2447
2460
|
)
|
|
2461
|
+
if "optimization" in opt_stats and dump_folder:
|
|
2462
|
+
import pandas
|
|
2463
|
+
|
|
2464
|
+
pattern_stats = []
|
|
2465
|
+
for k, v in opt_stats.items():
|
|
2466
|
+
if "time" in k:
|
|
2467
|
+
pattern_stats.append(dict(level="main", pattern=k, time_in=v))
|
|
2468
|
+
pattern_stats.extend(
|
|
2469
|
+
[{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
|
|
2470
|
+
)
|
|
2471
|
+
stat_filename = os.path.join(dump_folder, "optimization_stats.xlsx")
|
|
2472
|
+
df = pandas.DataFrame(pattern_stats)
|
|
2473
|
+
df.to_excel(stat_filename, index=False)
|
|
2474
|
+
cols = [
|
|
2475
|
+
c
|
|
2476
|
+
for c in [
|
|
2477
|
+
"level",
|
|
2478
|
+
"pattern",
|
|
2479
|
+
"time_in",
|
|
2480
|
+
"iteration",
|
|
2481
|
+
"inlined",
|
|
2482
|
+
"removed",
|
|
2483
|
+
"added",
|
|
2484
|
+
"instances",
|
|
2485
|
+
"changed",
|
|
2486
|
+
"scale",
|
|
2487
|
+
]
|
|
2488
|
+
if c in df.columns
|
|
2489
|
+
]
|
|
2490
|
+
agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
|
|
2491
|
+
agg.update(dict(iteration="max", instances="mean"))
|
|
2492
|
+
agg = {k: v for k, v in agg.items() if k in df.columns}
|
|
2493
|
+
stat_filename = os.path.join(dump_folder, "optimization_stats.agg.xlsx")
|
|
2494
|
+
df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
|
|
2495
|
+
|
|
2448
2496
|
if "ERR_export_onnx_c" in summary:
|
|
2449
2497
|
return summary, data
|
|
2450
2498
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-diagnostic
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.9
|
|
4
4
|
Summary: Tools to help converting pytorch models into ONNX.
|
|
5
5
|
Home-page: https://github.com/sdpython/onnx-diagnostic
|
|
6
6
|
Author: Xavier Dupré
|
|
@@ -90,6 +90,8 @@ Enlightening Examples
|
|
|
90
90
|
|
|
91
91
|
* `Export microsoft/phi-2
|
|
92
92
|
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
|
|
93
|
+
* `Export a LLM through method generate (with Tiny-LLM)
|
|
94
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_method_generate.html>`_
|
|
93
95
|
|
|
94
96
|
**Torch Export**
|
|
95
97
|
|