onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +136 -140
- onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
- onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
- onnx_diagnostic/ci_models/data/__init__.py +0 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
- onnx_diagnostic/export/api.py +13 -4
- onnx_diagnostic/export/dynamic_shapes.py +1 -1
- onnx_diagnostic/export/validate.py +2 -0
- onnx_diagnostic/ext_test_case.py +32 -15
- onnx_diagnostic/helpers/args_helper.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +0 -1
- onnx_diagnostic/helpers/cache_helper.py +102 -36
- onnx_diagnostic/helpers/doc_helper.py +7 -4
- onnx_diagnostic/helpers/graph_helper.py +6 -6
- onnx_diagnostic/helpers/helper.py +39 -0
- onnx_diagnostic/helpers/log_helper.py +37 -14
- onnx_diagnostic/helpers/memory_peak.py +5 -1
- onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
- onnx_diagnostic/helpers/model_builder_helper.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +283 -110
- onnx_diagnostic/helpers/ort_session.py +5 -2
- onnx_diagnostic/helpers/rt_helper.py +53 -9
- onnx_diagnostic/helpers/torch_helper.py +15 -11
- onnx_diagnostic/investigate/__init__.py +0 -0
- onnx_diagnostic/investigate/input_observer.py +970 -0
- onnx_diagnostic/reference/evaluator.py +0 -1
- onnx_diagnostic/reference/ort_evaluator.py +0 -1
- onnx_diagnostic/reference/report_results_comparison.py +9 -3
- onnx_diagnostic/reference/torch_evaluator.py +5 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +0 -1
- onnx_diagnostic/torch_export_patches/__init__.py +0 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
- onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
- onnx_diagnostic/torch_models/code_sample.py +5 -10
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
- onnx_diagnostic/torch_models/validate.py +1 -1
- onnx_diagnostic/torch_onnx/compare.py +0 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
- onnx_diagnostic/typing.py +15 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
- onnx_diagnostic/api.py +0 -15
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -34,7 +34,6 @@ from ..helpers.torch_helper import to_tensor
|
|
|
34
34
|
from .report_results_comparison import ReportResultComparison
|
|
35
35
|
from .evaluator import ExtendedReferenceEvaluator
|
|
36
36
|
|
|
37
|
-
|
|
38
37
|
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
|
|
39
38
|
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
|
|
40
39
|
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Tuple, Union
|
|
2
|
-
|
|
1
|
+
from typing import Any, Dict, List, Set, Tuple, Union
|
|
3
2
|
|
|
4
3
|
ReportKeyNameType = Union[str, Tuple[str, int, str]]
|
|
5
4
|
ReportKeyValueType = Tuple[int, Tuple[int, ...]]
|
|
@@ -14,6 +13,7 @@ class ReportResultComparison:
|
|
|
14
13
|
:param tensors: tensor
|
|
15
14
|
"""
|
|
16
15
|
|
|
16
|
+
# pyrefly: ignore[unknown-name]
|
|
17
17
|
def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
|
|
18
18
|
from ..helpers.onnx_helper import dtype_to_tensor_dtype
|
|
19
19
|
from ..helpers import max_diff, string_type
|
|
@@ -25,7 +25,9 @@ class ReportResultComparison:
|
|
|
25
25
|
self.max_diff = max_diff
|
|
26
26
|
self.tensors = tensors
|
|
27
27
|
self._build_mapping()
|
|
28
|
+
self.unique_run_names: Set[str] = set()
|
|
28
29
|
|
|
30
|
+
# pyrefly: ignore[unknown-name]
|
|
29
31
|
def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
|
|
30
32
|
"Returns a key for a tensor, (onnx dtype, shape)."
|
|
31
33
|
return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
|
|
@@ -59,12 +61,15 @@ class ReportResultComparison:
|
|
|
59
61
|
for k, v in self.value.items():
|
|
60
62
|
(i_run, run_name), ref_name = k
|
|
61
63
|
d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
|
|
64
|
+
# pyrefly: ignore[no-matching-overload]
|
|
62
65
|
d.update(v)
|
|
63
66
|
rows.append(d)
|
|
64
67
|
return rows
|
|
65
68
|
|
|
66
69
|
def report(
|
|
67
|
-
self,
|
|
70
|
+
self,
|
|
71
|
+
# pyrefly: ignore[unknown-name]
|
|
72
|
+
outputs: Dict[str, "torch.Tensor"], # noqa: F821
|
|
68
73
|
) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
|
|
69
74
|
"""
|
|
70
75
|
For every tensor in outputs, compares it to every tensor held by
|
|
@@ -79,6 +84,7 @@ class ReportResultComparison:
|
|
|
79
84
|
key = self.key(tensor)
|
|
80
85
|
if key not in self.mapping:
|
|
81
86
|
continue
|
|
87
|
+
# pyrefly: ignore[unknown-name]
|
|
82
88
|
cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
|
|
83
89
|
for held_key in self.mapping[key]:
|
|
84
90
|
t2 = self.tensors[held_key]
|
|
@@ -63,7 +63,7 @@ class TorchOnnxEvaluator:
|
|
|
63
63
|
* `functions`: local functions
|
|
64
64
|
|
|
65
65
|
The class is not multithreaded. `runtime_info` gets updated
|
|
66
|
-
by the
|
|
66
|
+
by the class. The list of available kernels is returned by function
|
|
67
67
|
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
|
|
68
68
|
Example:
|
|
69
69
|
|
|
@@ -494,8 +494,10 @@ class TorchOnnxEvaluator:
|
|
|
494
494
|
r = self.runtime_info[k]
|
|
495
495
|
r.set_value(
|
|
496
496
|
torch_ops.OpRunTensor(
|
|
497
|
+
# pyrefly: ignore[missing-attribute]
|
|
497
498
|
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
|
|
498
499
|
is_constant=False,
|
|
500
|
+
# pyrefly: ignore[missing-attribute]
|
|
499
501
|
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
|
|
500
502
|
)
|
|
501
503
|
)
|
|
@@ -524,6 +526,7 @@ class TorchOnnxEvaluator:
|
|
|
524
526
|
f"for kernel {type(kernel)}."
|
|
525
527
|
)
|
|
526
528
|
for name, t in zip(kernel.output, res):
|
|
529
|
+
# pyrefly: ignore[bad-argument-type]
|
|
527
530
|
self.runtime_info[name].set_value(t)
|
|
528
531
|
if self.verbose:
|
|
529
532
|
for name in kernel.output:
|
|
@@ -644,6 +647,7 @@ class TorchOnnxEvaluator:
|
|
|
644
647
|
f"for kernel {type(kernel)}."
|
|
645
648
|
)
|
|
646
649
|
for name, t in zip(kernel.output, res):
|
|
650
|
+
# pyrefly: ignore[bad-argument-type]
|
|
647
651
|
self.runtime_info[name].set_value(t)
|
|
648
652
|
else:
|
|
649
653
|
assert isinstance(
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Optional, Union, Tuple
|
|
2
2
|
import onnx
|
|
3
3
|
import torch
|
|
4
|
-
from ...
|
|
4
|
+
from ...typing import TensorLike
|
|
5
5
|
from ...helpers import string_type
|
|
6
6
|
from ...helpers.torch_helper import to_tensor
|
|
7
7
|
|
|
@@ -149,7 +149,7 @@ class OpRunSequence(OpRunValue):
|
|
|
149
149
|
) -> "OpRunSequence":
|
|
150
150
|
"Inserts a value at a given position."
|
|
151
151
|
assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
|
|
152
|
-
new_seq = OpRunSequence()
|
|
152
|
+
new_seq = OpRunSequence() # type: ignore[abstract]
|
|
153
153
|
seq = self.sequence.copy()
|
|
154
154
|
new_seq.sequence = seq
|
|
155
155
|
if position is None:
|
|
@@ -314,9 +314,7 @@ class OpRunKernel:
|
|
|
314
314
|
|
|
315
315
|
|
|
316
316
|
class OpRunFunction(OpRunKernel):
|
|
317
|
-
"""
|
|
318
|
-
Defines a kernel based on a local functions.
|
|
319
|
-
"""
|
|
317
|
+
"""Defines a kernel based on a local functions."""
|
|
320
318
|
|
|
321
319
|
def __init__(
|
|
322
320
|
self,
|
|
@@ -562,6 +562,7 @@ def _patch_transformers(
|
|
|
562
562
|
"[torch_export_patches] patches "
|
|
563
563
|
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
564
564
|
)
|
|
565
|
+
|
|
565
566
|
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
|
|
566
567
|
masking_utils.sdpa_mask_recent_torch = (
|
|
567
568
|
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
@@ -574,7 +575,9 @@ def _patch_transformers(
|
|
|
574
575
|
)
|
|
575
576
|
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
|
|
576
577
|
if verbose:
|
|
577
|
-
print(
|
|
578
|
+
print(
|
|
579
|
+
"[torch_export_patches] patches transformers.masking_utils.sdpa_mask (1)"
|
|
580
|
+
)
|
|
578
581
|
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
579
582
|
masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
580
583
|
if patch_details:
|
|
@@ -583,8 +586,23 @@ def _patch_transformers(
|
|
|
583
586
|
f_transformers_sdpa_mask,
|
|
584
587
|
patch_transformers_list.patched_sdpa_mask_recent_torch,
|
|
585
588
|
)
|
|
586
|
-
|
|
587
|
-
|
|
589
|
+
|
|
590
|
+
if ( # vmap
|
|
591
|
+
masking_utils
|
|
592
|
+
and patch_transformers_list.patch_masking_utils
|
|
593
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
594
|
+
and f_transformers_sdpa_mask is None
|
|
595
|
+
):
|
|
596
|
+
if verbose:
|
|
597
|
+
print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask (3)")
|
|
598
|
+
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
599
|
+
masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask
|
|
600
|
+
if patch_details:
|
|
601
|
+
patch_details.append(
|
|
602
|
+
"transformers",
|
|
603
|
+
f_transformers_sdpa_mask,
|
|
604
|
+
patch_transformers_list.patched_sdpa_mask,
|
|
605
|
+
)
|
|
588
606
|
|
|
589
607
|
if ( # eager_mask
|
|
590
608
|
masking_utils
|
|
@@ -742,17 +760,17 @@ def _unpatch_transformers(
|
|
|
742
760
|
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
743
761
|
)
|
|
744
762
|
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
763
|
+
if f_transformers_sdpa_mask is not None:
|
|
764
|
+
assert f_transformers_sdpa_mask.__name__ in (
|
|
765
|
+
"sdpa_mask",
|
|
766
|
+
"sdpa_mask_recent_torch",
|
|
767
|
+
), (
|
|
768
|
+
f"corrupted function 'sdpa_mask', its name is "
|
|
769
|
+
f"{f_transformers_sdpa_mask.__name__!r}"
|
|
770
|
+
)
|
|
771
|
+
masking_utils.sdpa_mask = f_transformers_sdpa_mask
|
|
772
|
+
if verbose:
|
|
773
|
+
print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
|
|
756
774
|
|
|
757
775
|
if ( # eager_mask
|
|
758
776
|
masking_utils
|
|
@@ -36,6 +36,26 @@ if patch_masking_utils:
|
|
|
36
36
|
_ignore_bidirectional_mask_sdpa = None
|
|
37
37
|
bidirectional_mask_function = None
|
|
38
38
|
|
|
39
|
+
try:
|
|
40
|
+
from transformers.masking_utils import _non_vmap_expansion_sdpa
|
|
41
|
+
except ImportError:
|
|
42
|
+
|
|
43
|
+
def _non_vmap_expansion_sdpa(
|
|
44
|
+
batch_indices: torch.Tensor,
|
|
45
|
+
head_indices: torch.Tensor,
|
|
46
|
+
q_indices: torch.Tensor,
|
|
47
|
+
kv_indices: torch.Tensor,
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
https://github.com/huggingface/optimum-onnx/blob/
|
|
51
|
+
c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
|
|
52
|
+
"""
|
|
53
|
+
batch_indices = batch_indices[:, None, None, None]
|
|
54
|
+
head_indices = head_indices[None, :, None, None]
|
|
55
|
+
q_indices = q_indices[None, None, :, None]
|
|
56
|
+
kv_indices = kv_indices[None, None, None, :]
|
|
57
|
+
return batch_indices, head_indices, q_indices, kv_indices
|
|
58
|
+
|
|
39
59
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
40
60
|
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
41
61
|
from ...helpers import string_type
|
|
@@ -146,12 +166,13 @@ if patch_masking_utils:
|
|
|
146
166
|
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
147
167
|
):
|
|
148
168
|
return None
|
|
149
|
-
if
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
169
|
+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa:
|
|
170
|
+
# transformers<=5.0: 1 parameter, 3 for transformers>5.0
|
|
171
|
+
n_parameters = len(inspect.signature(_ignore_bidirectional_mask_sdpa).parameters)
|
|
172
|
+
if _ignore_bidirectional_mask_sdpa(
|
|
173
|
+
*[padding_mask, kv_length, kv_offset][:n_parameters]
|
|
174
|
+
):
|
|
175
|
+
return None
|
|
155
176
|
|
|
156
177
|
if mask_function is bidirectional_mask_function:
|
|
157
178
|
if padding_mask is not None:
|
|
@@ -180,3 +201,83 @@ if patch_masking_utils:
|
|
|
180
201
|
batch_arange, head_arange, cache_position, kv_arange
|
|
181
202
|
)
|
|
182
203
|
return causal_mask
|
|
204
|
+
|
|
205
|
+
def patched_sdpa_mask(
|
|
206
|
+
batch_size: int,
|
|
207
|
+
cache_position: torch.Tensor,
|
|
208
|
+
kv_length: int,
|
|
209
|
+
kv_offset: int = 0,
|
|
210
|
+
mask_function: Callable = causal_mask_function,
|
|
211
|
+
attention_mask: torch.Tensor | None = None,
|
|
212
|
+
local_size: int | None = None,
|
|
213
|
+
allow_is_causal_skip: bool = True,
|
|
214
|
+
allow_is_bidirectional_skip: bool = False,
|
|
215
|
+
allow_torch_fix: bool = True,
|
|
216
|
+
use_vmap: bool = False,
|
|
217
|
+
**kwargs,
|
|
218
|
+
) -> torch.Tensor | None:
|
|
219
|
+
"""manual patch for function ``transformers.masking_utils.sdpa_mask``."""
|
|
220
|
+
q_length = cache_position.shape[0]
|
|
221
|
+
|
|
222
|
+
# Potentially pad the 2D mask
|
|
223
|
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
|
|
224
|
+
|
|
225
|
+
# Under specific conditions, we can avoid materializing the mask
|
|
226
|
+
# 1. Causal masks can rely on the `is_causal` argument
|
|
227
|
+
# 2. Bidirectional do not need any further processing (no bias)
|
|
228
|
+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
229
|
+
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
230
|
+
):
|
|
231
|
+
return None
|
|
232
|
+
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(
|
|
233
|
+
padding_mask, kv_length, local_size
|
|
234
|
+
):
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
# Potentially add the padding 2D mask
|
|
238
|
+
if padding_mask is not None:
|
|
239
|
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
240
|
+
|
|
241
|
+
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
242
|
+
head_arange = torch.arange(1, device=cache_position.device)
|
|
243
|
+
# Similar to `kv_arange = torch.arange(start=kv_offset,
|
|
244
|
+
# end=kv_offset + kv_length, device=cache_position.device)`
|
|
245
|
+
# but without data-dependent slicing (i.e. torch.compile friendly)
|
|
246
|
+
kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
|
|
247
|
+
|
|
248
|
+
# Actual mask creation
|
|
249
|
+
# Option 1: Fast non-vmap mask creation (default)
|
|
250
|
+
# PATCHED
|
|
251
|
+
use_vmap = False
|
|
252
|
+
if not use_vmap:
|
|
253
|
+
# Apply mask function element-wise through broadcasting
|
|
254
|
+
attention_mask = mask_function(
|
|
255
|
+
*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange)
|
|
256
|
+
)
|
|
257
|
+
# Expand the mask to match batch size
|
|
258
|
+
# and query length if they weren't used in the mask function
|
|
259
|
+
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
|
|
260
|
+
|
|
261
|
+
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
|
|
262
|
+
# elif _is_torch_greater_or_equal_than_2_6:
|
|
263
|
+
# This creates the 4D mask easily.
|
|
264
|
+
# Note that we need this context manager as vmap cannot handle slicing a tensor from
|
|
265
|
+
# scalar tensor (it internally calls `.item()` which vmap does not allow,
|
|
266
|
+
# but this context works around it
|
|
267
|
+
# We don't need to add an offset to the mask_function either,
|
|
268
|
+
# as we vmap directly the correct indices for k and kv indices
|
|
269
|
+
# with TransformGetItemToIndex():
|
|
270
|
+
# attention_mask = _vmap_expansion_sdpa(mask_function)(
|
|
271
|
+
# batch_arange, head_arange, cache_position, kv_arange
|
|
272
|
+
# )
|
|
273
|
+
|
|
274
|
+
# Option 3: Error out since it indicates that the user did something custom,
|
|
275
|
+
# which they shouldn't have (torch<2.6)
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"The vmap functionality for mask creation "
|
|
279
|
+
"is only supported from torch>=2.6. "
|
|
280
|
+
"Please update your torch version or use "
|
|
281
|
+
"`use_vmap=False` with index-based masks."
|
|
282
|
+
)
|
|
283
|
+
return attention_mask
|
|
@@ -7,10 +7,10 @@ import transformers
|
|
|
7
7
|
|
|
8
8
|
def patched__compute_dynamic_ntk_parameters(
|
|
9
9
|
config: Optional[transformers.PretrainedConfig] = None,
|
|
10
|
-
device: Optional[
|
|
10
|
+
device: Optional[torch.device] = None,
|
|
11
11
|
seq_len: Optional[int] = None,
|
|
12
12
|
**rope_kwargs,
|
|
13
|
-
) -> Tuple[
|
|
13
|
+
) -> Tuple[torch.Tensor, float]:
|
|
14
14
|
"""
|
|
15
15
|
manual patch:
|
|
16
16
|
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
@@ -188,6 +188,11 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
188
188
|
return common_shape
|
|
189
189
|
|
|
190
190
|
|
|
191
|
+
def value_ranges_is_positive(value_ranges: torch.utils._sympy.value_ranges.ValueRanges):
|
|
192
|
+
"""Tells if an interval is equivalent to a positive or null integer."""
|
|
193
|
+
return value_ranges.lower == 0 and value_ranges.upper > 4623372036854775806
|
|
194
|
+
|
|
195
|
+
|
|
191
196
|
class patched_ShapeEnv:
|
|
192
197
|
|
|
193
198
|
def _check_frozen(
|
|
@@ -281,7 +286,10 @@ class patched_ShapeEnv:
|
|
|
281
286
|
)
|
|
282
287
|
self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
|
|
283
288
|
tgt_bound = self.bound_sympy(tgt)
|
|
284
|
-
assert
|
|
289
|
+
assert (
|
|
290
|
+
value_ranges_is_positive(tgt_bound)
|
|
291
|
+
and value_ranges_is_positive(src_bound)
|
|
292
|
+
) or tgt_bound.issubset(
|
|
285
293
|
src_bound
|
|
286
294
|
), f"{tgt_bound=} not a subset of {src_bound=}"
|
|
287
295
|
|
|
@@ -524,8 +532,10 @@ class patched_ShapeEnv:
|
|
|
524
532
|
|
|
525
533
|
transmute_into_runtime_assert = False
|
|
526
534
|
|
|
527
|
-
backed_var_to_val =
|
|
528
|
-
self
|
|
535
|
+
backed_var_to_val = (
|
|
536
|
+
self.backed_var_to_val
|
|
537
|
+
if hasattr(self, "backed_var_to_val")
|
|
538
|
+
else self.var_to_val
|
|
529
539
|
)
|
|
530
540
|
concrete_val = None
|
|
531
541
|
if not (expr.free_symbols <= backed_var_to_val.keys()):
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import itertools
|
|
2
2
|
from typing import Any, Callable, List, Set, Tuple
|
|
3
3
|
import torch
|
|
4
|
+
import transformers.cache_utils
|
|
4
5
|
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
|
5
6
|
|
|
6
7
|
try:
|
|
@@ -22,22 +23,63 @@ from transformers.modeling_outputs import BaseModelOutput
|
|
|
22
23
|
from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue
|
|
23
24
|
from . import make_serialization_function_for_dataclass
|
|
24
25
|
|
|
25
|
-
|
|
26
26
|
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
27
|
+
|
|
27
28
|
WRONG_REGISTRATIONS = {
|
|
28
29
|
DynamicCache: "4.50",
|
|
29
30
|
BaseModelOutput: None,
|
|
30
31
|
}
|
|
31
32
|
|
|
33
|
+
SHORTEN_LAYER_NAMES = {
|
|
34
|
+
"DynamicLayer": "D",
|
|
35
|
+
"DynamicSlidingWindowLayer": "W",
|
|
36
|
+
"StaticLayer": "S",
|
|
37
|
+
"StaticSlidingWindowLayer": "X",
|
|
38
|
+
"D": "DynamicLayer",
|
|
39
|
+
"W": "DynamicSlidingWindowLayer",
|
|
40
|
+
"S": "StaticLayer",
|
|
41
|
+
"X": "StaticSlidingWindowLayer",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
KWARGS_LAYER_NAMES = {
|
|
45
|
+
"DynamicLayer": lambda layer: "",
|
|
46
|
+
"DynamicSlidingWindowLayer": lambda layer: str(layer.sliding_window),
|
|
47
|
+
"StaticLayer": lambda layer: "",
|
|
48
|
+
"StaticSlidingWindowLayer": lambda layer: str(layer.sliding_window),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
PARSE_LAYER_NAMES = {
|
|
52
|
+
"DynamicLayer": lambda skw: {},
|
|
53
|
+
"DynamicSlidingWindowLayer": lambda skw: dict(sliding_window=int(skw[1:])),
|
|
54
|
+
"StaticLayer": lambda skw: {},
|
|
55
|
+
"StaticSlidingWindowLayer": lambda skw: dict(sliding_window=int(skw[1:])),
|
|
56
|
+
}
|
|
57
|
+
|
|
32
58
|
|
|
33
59
|
def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
34
60
|
ca = CacheKeyValue(cache)
|
|
35
61
|
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
62
|
+
unique = set(ca.cls_layers) if ca.cls_layers else None
|
|
63
|
+
if (
|
|
64
|
+
cache.__class__.__name__ != "DynamicCache"
|
|
65
|
+
or unique is None
|
|
66
|
+
or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
|
|
67
|
+
):
|
|
68
|
+
keys = list(
|
|
69
|
+
itertools.chain.from_iterable(
|
|
70
|
+
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
|
|
71
|
+
)
|
|
39
72
|
)
|
|
40
|
-
|
|
73
|
+
return flat, keys
|
|
74
|
+
|
|
75
|
+
keys = []
|
|
76
|
+
for i in range(len(ca.key_cache)):
|
|
77
|
+
letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__]
|
|
78
|
+
if hasattr(cache, "layers"):
|
|
79
|
+
kwargs = KWARGS_LAYER_NAMES[ca.cls_layers[i].__name__](cache.layers[i])
|
|
80
|
+
else:
|
|
81
|
+
kwargs = ""
|
|
82
|
+
keys.extend([f"key_{letter}{kwargs}_{i}", f"value_{letter}{kwargs}_{i}"])
|
|
41
83
|
return flat, keys
|
|
42
84
|
|
|
43
85
|
|
|
@@ -55,7 +97,26 @@ def _unflatten_cache(
|
|
|
55
97
|
output_type=None,
|
|
56
98
|
) -> DynamicCache:
|
|
57
99
|
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
58
|
-
|
|
100
|
+
expected = list(
|
|
101
|
+
itertools.chain.from_iterable(
|
|
102
|
+
(f"key_{i}", f"value_{i}") for i in range(len(values) // 2)
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
if expected == context:
|
|
106
|
+
res = make_cache(list(zip(values[::2], values[1::2])))
|
|
107
|
+
else:
|
|
108
|
+
cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
|
|
109
|
+
cls_kwargs = [
|
|
110
|
+
PARSE_LAYER_NAMES[SHORTEN_LAYER_NAMES[name.split("_")[1][0]]](name.split("_")[1])
|
|
111
|
+
for name in context
|
|
112
|
+
][::2]
|
|
113
|
+
cls_layers = [
|
|
114
|
+
getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
|
|
115
|
+
]
|
|
116
|
+
res = make_cache(
|
|
117
|
+
list(zip(values[::2], values[1::2])), cls_layers=cls_layers, cls_kwargs=cls_kwargs
|
|
118
|
+
)
|
|
119
|
+
|
|
59
120
|
assert output_type is None or isinstance(
|
|
60
121
|
res, output_type
|
|
61
122
|
), f"Type mismatch between {output_type} (expected) and {type(res)}"
|
|
@@ -71,14 +132,6 @@ def flatten_dynamic_cache(
|
|
|
71
132
|
dynamic_cache: DynamicCache,
|
|
72
133
|
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
73
134
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
74
|
-
assert (
|
|
75
|
-
not hasattr(dynamic_cache, "layers")
|
|
76
|
-
or not dynamic_cache.layers
|
|
77
|
-
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
|
|
78
|
-
), (
|
|
79
|
-
f"The serialization does not work yet on other layers "
|
|
80
|
-
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
|
|
81
|
-
)
|
|
82
135
|
return _flatten_key_value_cache(dynamic_cache)
|
|
83
136
|
|
|
84
137
|
|
|
@@ -86,14 +139,6 @@ def flatten_with_keys_dynamic_cache(
|
|
|
86
139
|
dynamic_cache: DynamicCache,
|
|
87
140
|
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
88
141
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
89
|
-
assert (
|
|
90
|
-
not hasattr(dynamic_cache, "layers")
|
|
91
|
-
or not dynamic_cache.layers
|
|
92
|
-
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
|
|
93
|
-
), (
|
|
94
|
-
f"The serialization does not work yet on other layers "
|
|
95
|
-
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
|
|
96
|
-
)
|
|
97
142
|
return _flatten_with_keys_cache(dynamic_cache)
|
|
98
143
|
|
|
99
144
|
|
|
@@ -161,7 +206,9 @@ def unflatten_static_cache(
|
|
|
161
206
|
) -> StaticCache:
|
|
162
207
|
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
163
208
|
return _unflatten_cache(
|
|
164
|
-
lambda *args: make_static_cache(
|
|
209
|
+
lambda *args, **kwargs: make_static_cache(
|
|
210
|
+
*args, max_cache_len=values[0].shape[2], **kwargs
|
|
211
|
+
),
|
|
165
212
|
values,
|
|
166
213
|
context,
|
|
167
214
|
output_type=output_type,
|
|
@@ -8,11 +8,9 @@ from .hghub.model_inputs import _preprocess_model_id
|
|
|
8
8
|
from .hghub import get_untrained_model_with_inputs
|
|
9
9
|
from .validate import filter_inputs, make_patch_kwargs
|
|
10
10
|
|
|
11
|
-
|
|
12
11
|
CODE_SAMPLES = {
|
|
13
12
|
"imports": "from typing import Any\nimport torch",
|
|
14
|
-
"get_model_with_inputs": textwrap.dedent(
|
|
15
|
-
"""
|
|
13
|
+
"get_model_with_inputs": textwrap.dedent("""
|
|
16
14
|
def get_model_with_inputs(
|
|
17
15
|
model_id:str,
|
|
18
16
|
subfolder: str | None = None,
|
|
@@ -57,8 +55,7 @@ CODE_SAMPLES = {
|
|
|
57
55
|
if device:
|
|
58
56
|
data["model"] = data["model"].to(device)
|
|
59
57
|
return data["model"]
|
|
60
|
-
"""
|
|
61
|
-
),
|
|
58
|
+
"""),
|
|
62
59
|
}
|
|
63
60
|
|
|
64
61
|
|
|
@@ -198,7 +195,7 @@ def code_sample(
|
|
|
198
195
|
this is not always possible
|
|
199
196
|
:param use_pretrained: use the trained model, not the untrained one
|
|
200
197
|
:param optimization: optimization to apply to the exported model,
|
|
201
|
-
depend on the
|
|
198
|
+
depend on the exporter
|
|
202
199
|
:param quiet: if quiet, catches exception if any issue
|
|
203
200
|
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
|
|
204
201
|
if True before exporting
|
|
@@ -326,11 +323,9 @@ def code_sample(
|
|
|
326
323
|
imports,
|
|
327
324
|
cache_import,
|
|
328
325
|
CODE_SAMPLES["get_model_with_inputs"],
|
|
329
|
-
textwrap.dedent(
|
|
330
|
-
f"""
|
|
326
|
+
textwrap.dedent(f"""
|
|
331
327
|
model = get_model_with_inputs({model_args})
|
|
332
|
-
"""
|
|
333
|
-
),
|
|
328
|
+
"""),
|
|
334
329
|
f"inputs = {input_code}",
|
|
335
330
|
exporter_code,
|
|
336
331
|
]
|
|
@@ -10,8 +10,7 @@ __data_arch_values__ = {
|
|
|
10
10
|
"ResNetForImageClassification": dict(image_size=224),
|
|
11
11
|
}
|
|
12
12
|
|
|
13
|
-
__data_arch__ = textwrap.dedent(
|
|
14
|
-
"""
|
|
13
|
+
__data_arch__ = textwrap.dedent("""
|
|
15
14
|
architecture,task
|
|
16
15
|
ASTModel,feature-extraction
|
|
17
16
|
AutoencoderKL,image-to-image
|
|
@@ -166,8 +165,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
166
165
|
YolosModel,image-feature-extraction
|
|
167
166
|
Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
|
|
168
167
|
emilyalsentzer/Bio_ClinicalBERT,fill-mask
|
|
169
|
-
nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video"""
|
|
170
|
-
)
|
|
168
|
+
nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video""")
|
|
171
169
|
|
|
172
170
|
__data_tasks__ = [
|
|
173
171
|
"audio-classification",
|