onnx-diagnostic 0.7.9__py3-none-any.whl → 0.7.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +8 -1
- onnx_diagnostic/helpers/cache_helper.py +12 -10
- onnx_diagnostic/helpers/helper.py +8 -0
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/torch_helper.py +14 -4
- onnx_diagnostic/reference/ops/op_scan.py +5 -5
- onnx_diagnostic/reference/ort_evaluator.py +2 -2
- onnx_diagnostic/tasks/__init__.py +4 -2
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +3 -3
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +98 -4
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +42 -2
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +0 -1
- onnx_diagnostic/torch_models/hghub/hub_api.py +69 -22
- onnx_diagnostic/torch_models/hghub/hub_data.py +5 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +173 -128
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +11 -3
- onnx_diagnostic/torch_models/validate.py +146 -17
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/RECORD +27 -25
- {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.9.dist-info → onnx_diagnostic-0.7.11.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
474
474
|
)
|
|
475
475
|
parser.add_argument(
|
|
476
476
|
"--runtime",
|
|
477
|
-
choices=["onnxruntime", "torch", "ref"],
|
|
477
|
+
choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
|
|
478
478
|
default="onnxruntime",
|
|
479
479
|
help="onnx runtime to use, `onnxruntime` by default",
|
|
480
480
|
)
|
|
@@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
542
542
|
"the onnx exporter should use.",
|
|
543
543
|
default="",
|
|
544
544
|
)
|
|
545
|
+
parser.add_argument(
|
|
546
|
+
"--ort-logs",
|
|
547
|
+
default=False,
|
|
548
|
+
action=BooleanOptionalAction,
|
|
549
|
+
help="Enables onnxruntime logging when the session is created",
|
|
550
|
+
)
|
|
545
551
|
return parser
|
|
546
552
|
|
|
547
553
|
|
|
@@ -601,6 +607,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
601
607
|
repeat=args.repeat,
|
|
602
608
|
warmup=args.warmup,
|
|
603
609
|
inputs2=args.inputs2,
|
|
610
|
+
ort_logs=args.ort_logs,
|
|
604
611
|
output_names=(
|
|
605
612
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
606
613
|
),
|
|
@@ -4,11 +4,6 @@ import torch
|
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
|
-
try:
|
|
8
|
-
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
9
|
-
except ImportError:
|
|
10
|
-
from transformers.cache_utils import MambaCache
|
|
11
|
-
|
|
12
7
|
|
|
13
8
|
class CacheKeyValue:
|
|
14
9
|
"""
|
|
@@ -270,7 +265,7 @@ def make_static_cache(
|
|
|
270
265
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
271
266
|
self.num_hidden_layers = len(key_value_pairs)
|
|
272
267
|
|
|
273
|
-
def get_text_config(self):
|
|
268
|
+
def get_text_config(self, *args, **kwargs):
|
|
274
269
|
return self
|
|
275
270
|
|
|
276
271
|
assert max_cache_len is not None, (
|
|
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
|
|
|
354
349
|
)
|
|
355
350
|
|
|
356
351
|
|
|
357
|
-
def make_mamba_cache(
|
|
352
|
+
def make_mamba_cache(
|
|
353
|
+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
354
|
+
) -> "MambaCache": # noqa: F821
|
|
358
355
|
"Creates a ``MambaCache``."
|
|
356
|
+
# import is moved here because this part is slow.
|
|
357
|
+
try:
|
|
358
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
359
|
+
except ImportError:
|
|
360
|
+
from transformers.cache_utils import MambaCache
|
|
359
361
|
dtype = key_value_pairs[0][0].dtype
|
|
360
362
|
|
|
361
363
|
class _config:
|
|
@@ -366,7 +368,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
|
|
|
366
368
|
self.num_hidden_layers = len(key_value_pairs)
|
|
367
369
|
self.dtype = dtype
|
|
368
370
|
|
|
369
|
-
def get_text_config(self):
|
|
371
|
+
def get_text_config(self, *args, **kwargs):
|
|
370
372
|
return self
|
|
371
373
|
|
|
372
374
|
cache = MambaCache(
|
|
@@ -409,7 +411,7 @@ def make_sliding_window_cache(
|
|
|
409
411
|
self.num_hidden_layers = len(key_value_pairs)
|
|
410
412
|
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
411
413
|
|
|
412
|
-
def get_text_config(self):
|
|
414
|
+
def get_text_config(self, *args, **kwargs):
|
|
413
415
|
return self
|
|
414
416
|
|
|
415
417
|
cache = transformers.cache_utils.SlidingWindowCache(
|
|
@@ -577,7 +579,7 @@ def make_hybrid_cache(
|
|
|
577
579
|
sliding_window = _sliding_window
|
|
578
580
|
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
579
581
|
|
|
580
|
-
def get_text_config(self):
|
|
582
|
+
def get_text_config(self, *args, **kwargs):
|
|
581
583
|
return self
|
|
582
584
|
|
|
583
585
|
if layer_types:
|
|
@@ -774,6 +774,14 @@ def string_type(
|
|
|
774
774
|
return f"{obj.__class__.__name__}(**{s})"
|
|
775
775
|
if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
|
|
776
776
|
return f"{obj.__class__.__name__}(...)"
|
|
777
|
+
if obj.__class__.__name__ == "Results":
|
|
778
|
+
import ultralytics
|
|
779
|
+
|
|
780
|
+
assert isinstance(
|
|
781
|
+
obj, ultralytics.engine.results.Results
|
|
782
|
+
), f"Unexpected type={type(obj)}"
|
|
783
|
+
return f"ultralytics.{obj.__class__.__name__}(...)"
|
|
784
|
+
|
|
777
785
|
if verbose:
|
|
778
786
|
print(f"[string_type] END:{type(obj)}")
|
|
779
787
|
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
|
|
@@ -1186,7 +1186,7 @@ def shadowing_names(
|
|
|
1186
1186
|
shadow |= set(i.name for i in g.input) & shadow_context
|
|
1187
1187
|
shadow |= set(i.name for i in g.initializer) & shadow_context
|
|
1188
1188
|
shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
|
|
1189
|
-
s,
|
|
1189
|
+
s, _ps, c = shadowing_names(
|
|
1190
1190
|
g.node, verbose=verbose, existing=existing, shadow_context=existing
|
|
1191
1191
|
)
|
|
1192
1192
|
shadow |= s
|
|
@@ -543,7 +543,7 @@ def dummy_llm(
|
|
|
543
543
|
)
|
|
544
544
|
|
|
545
545
|
def forward(self, x):
|
|
546
|
-
|
|
546
|
+
_B, T, C = x.shape
|
|
547
547
|
|
|
548
548
|
query = self.query(x)
|
|
549
549
|
key = self.key(x)
|
|
@@ -721,9 +721,10 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
721
721
|
return {to_any(t, to_value) for t in value}
|
|
722
722
|
if type(value) is dict:
|
|
723
723
|
return {k: to_any(t, to_value) for k, t in value.items()}
|
|
724
|
-
if value.__class__.__name__
|
|
724
|
+
if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
|
|
725
|
+
make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
|
|
725
726
|
cc = CacheKeyValue(value)
|
|
726
|
-
return
|
|
727
|
+
return make[value.__class__.__name__]( # type: ignore[operator]
|
|
727
728
|
list(
|
|
728
729
|
zip(
|
|
729
730
|
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
@@ -822,6 +823,15 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
822
823
|
new_args = torch_deepcopy(args)
|
|
823
824
|
return torch.utils._pytree.tree_unflatten(new_args, spec)
|
|
824
825
|
|
|
826
|
+
if value.__class__.__name__ == "Results":
|
|
827
|
+
import copy
|
|
828
|
+
import ultralytics
|
|
829
|
+
|
|
830
|
+
assert isinstance(
|
|
831
|
+
value, ultralytics.engine.results.Results
|
|
832
|
+
), f"Unexpected type={type(value)}"
|
|
833
|
+
return copy.deepcopy(value)
|
|
834
|
+
|
|
825
835
|
# We should have a code using serialization, deserialization assuming a model
|
|
826
836
|
# cannot be exported without them.
|
|
827
837
|
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
|
|
@@ -856,7 +866,7 @@ def torch_tensor_size(value: Any) -> Any:
|
|
|
856
866
|
if value.__class__.__name__ == "MambaCache":
|
|
857
867
|
return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
|
|
858
868
|
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
859
|
-
args,
|
|
869
|
+
args, _spec = torch.utils._pytree.tree_flatten(value)
|
|
860
870
|
return sum(torch_tensor_size(a) for a in args)
|
|
861
871
|
|
|
862
872
|
# We should have a code using serialization, deserialization assuming a model
|
|
@@ -26,11 +26,11 @@ class Scan(_Scan):
|
|
|
26
26
|
):
|
|
27
27
|
(
|
|
28
28
|
num_loop_state_vars,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
29
|
+
_num_scan_outputs,
|
|
30
|
+
_output_directions,
|
|
31
|
+
_max_dir_out,
|
|
32
|
+
_output_axes,
|
|
33
|
+
_max_axe_out,
|
|
34
34
|
state_names_in,
|
|
35
35
|
state_names_out,
|
|
36
36
|
scan_names_in,
|
|
@@ -562,7 +562,7 @@ class OnnxruntimeEvaluator:
|
|
|
562
562
|
if key in self._cache:
|
|
563
563
|
sess = self._cache[key][1]
|
|
564
564
|
else:
|
|
565
|
-
self._cache[key] =
|
|
565
|
+
self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results)
|
|
566
566
|
|
|
567
567
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
568
568
|
feeds = {name: results[name] for name in sess.input_names}
|
|
@@ -616,7 +616,7 @@ class OnnxruntimeEvaluator:
|
|
|
616
616
|
if key in self._cache:
|
|
617
617
|
sess = self._cache[key][1]
|
|
618
618
|
else:
|
|
619
|
-
self._cache[key] =
|
|
619
|
+
self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
|
|
620
620
|
|
|
621
621
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
622
622
|
feeds = {name: results[name] for name in sess.input_names}
|
|
@@ -5,6 +5,8 @@ from . import (
|
|
|
5
5
|
fill_mask,
|
|
6
6
|
image_classification,
|
|
7
7
|
image_text_to_text,
|
|
8
|
+
image_to_video,
|
|
9
|
+
mask_generation,
|
|
8
10
|
mixture_of_expert,
|
|
9
11
|
object_detection,
|
|
10
12
|
sentence_similarity,
|
|
@@ -14,7 +16,6 @@ from . import (
|
|
|
14
16
|
text_to_image,
|
|
15
17
|
text2text_generation,
|
|
16
18
|
zero_shot_image_classification,
|
|
17
|
-
mask_generation,
|
|
18
19
|
)
|
|
19
20
|
|
|
20
21
|
__TASKS__ = [
|
|
@@ -23,6 +24,8 @@ __TASKS__ = [
|
|
|
23
24
|
fill_mask,
|
|
24
25
|
image_classification,
|
|
25
26
|
image_text_to_text,
|
|
27
|
+
image_to_video,
|
|
28
|
+
mask_generation,
|
|
26
29
|
mixture_of_expert,
|
|
27
30
|
object_detection,
|
|
28
31
|
sentence_similarity,
|
|
@@ -32,7 +35,6 @@ __TASKS__ = [
|
|
|
32
35
|
text_to_image,
|
|
33
36
|
text2text_generation,
|
|
34
37
|
zero_shot_image_classification,
|
|
35
|
-
mask_generation,
|
|
36
38
|
]
|
|
37
39
|
|
|
38
40
|
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__TASK__ = "image-to-video"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
13
|
+
"""Reduces a model size."""
|
|
14
|
+
if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
|
|
15
|
+
# We cannot reduce.
|
|
16
|
+
return {}
|
|
17
|
+
check_hasattr(config, ("num_hidden_layers", "num_layers"))
|
|
18
|
+
kwargs = {}
|
|
19
|
+
if hasattr(config, "num_layers"):
|
|
20
|
+
kwargs["num_layers"] = min(config.num_layers, nhl())
|
|
21
|
+
if hasattr(config, "num_hidden_layers"):
|
|
22
|
+
kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
|
|
23
|
+
|
|
24
|
+
update_config(config, kwargs)
|
|
25
|
+
return kwargs
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_inputs(
|
|
29
|
+
model: torch.nn.Module,
|
|
30
|
+
config: Optional[Any],
|
|
31
|
+
text_embed_dim: int,
|
|
32
|
+
latent_channels: int,
|
|
33
|
+
batch_size: int = 2,
|
|
34
|
+
image_height: int = 704,
|
|
35
|
+
image_width: int = 1280,
|
|
36
|
+
latent_frames: int = 1,
|
|
37
|
+
text_maxlen: int = 512,
|
|
38
|
+
add_second_input: int = 1,
|
|
39
|
+
**kwargs, # unused
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Generates inputs for task ``image-to-video``.
|
|
43
|
+
"""
|
|
44
|
+
assert (
|
|
45
|
+
"cls_cache" not in kwargs
|
|
46
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
47
|
+
latent_height = image_height // 8
|
|
48
|
+
latent_width = image_width // 8
|
|
49
|
+
dtype = torch.float32
|
|
50
|
+
|
|
51
|
+
inputs = dict(
|
|
52
|
+
hidden_states=torch.randn(
|
|
53
|
+
batch_size,
|
|
54
|
+
latent_channels,
|
|
55
|
+
latent_frames,
|
|
56
|
+
latent_height,
|
|
57
|
+
latent_width,
|
|
58
|
+
dtype=dtype,
|
|
59
|
+
),
|
|
60
|
+
timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
|
|
61
|
+
encoder_hidden_states=torch.randn(
|
|
62
|
+
batch_size, text_maxlen, text_embed_dim, dtype=dtype
|
|
63
|
+
),
|
|
64
|
+
padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
|
|
65
|
+
fps=torch.tensor([16] * batch_size, dtype=dtype),
|
|
66
|
+
condition_mask=torch.randn(
|
|
67
|
+
batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
shapes = dict(
|
|
71
|
+
hidden_states={
|
|
72
|
+
0: "batch_size",
|
|
73
|
+
2: "latent_frames",
|
|
74
|
+
3: "latent_height",
|
|
75
|
+
4: "latent_width",
|
|
76
|
+
},
|
|
77
|
+
timestep={0: "batch_size"},
|
|
78
|
+
encoder_hidden_states={0: "batch_size"},
|
|
79
|
+
padding_mask={0: "batch_size", 2: "height", 3: "width"},
|
|
80
|
+
fps={0: "batch_size"},
|
|
81
|
+
condition_mask={
|
|
82
|
+
0: "batch_size",
|
|
83
|
+
2: "latent_frames",
|
|
84
|
+
3: "latent_height",
|
|
85
|
+
4: "latent_width",
|
|
86
|
+
},
|
|
87
|
+
)
|
|
88
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
89
|
+
|
|
90
|
+
if add_second_input:
|
|
91
|
+
assert (
|
|
92
|
+
add_second_input > 0
|
|
93
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
94
|
+
res["inputs2"] = get_inputs(
|
|
95
|
+
model=model,
|
|
96
|
+
config=config,
|
|
97
|
+
text_embed_dim=text_embed_dim,
|
|
98
|
+
latent_channels=latent_channels,
|
|
99
|
+
batch_size=batch_size,
|
|
100
|
+
image_height=image_height,
|
|
101
|
+
image_width=image_width,
|
|
102
|
+
latent_frames=latent_frames,
|
|
103
|
+
text_maxlen=text_maxlen,
|
|
104
|
+
add_second_input=0,
|
|
105
|
+
**kwargs,
|
|
106
|
+
)["inputs"]
|
|
107
|
+
return res
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
111
|
+
"""
|
|
112
|
+
Inputs kwargs.
|
|
113
|
+
|
|
114
|
+
If the configuration is None, the function selects typical dimensions.
|
|
115
|
+
"""
|
|
116
|
+
if config is not None:
|
|
117
|
+
check_hasattr(config, "in_channels", "text_embed_dim"),
|
|
118
|
+
kwargs = dict(
|
|
119
|
+
text_embed_dim=1024 if config is None else config.text_embed_dim,
|
|
120
|
+
latent_channels=16 if config is None else config.in_channels - 1,
|
|
121
|
+
batch_size=1,
|
|
122
|
+
image_height=8 * 50,
|
|
123
|
+
image_width=8 * 80,
|
|
124
|
+
latent_frames=1,
|
|
125
|
+
text_maxlen=512,
|
|
126
|
+
)
|
|
127
|
+
return kwargs, get_inputs
|
|
@@ -384,7 +384,7 @@ class ControlFlowScan(torch.nn.Module):
|
|
|
384
384
|
|
|
385
385
|
def forward(self, x):
|
|
386
386
|
init = torch.zeros_like(x[0])
|
|
387
|
-
carry,
|
|
387
|
+
carry, _out = torch.ops.higher_order.scan(
|
|
388
388
|
ControlFlowScan.add, [init], [x], additional_inputs=[]
|
|
389
389
|
)
|
|
390
390
|
return carry
|
|
@@ -429,7 +429,7 @@ class ControlFlowScanCDist(torch.nn.Module):
|
|
|
429
429
|
return [carry.clone(), rd]
|
|
430
430
|
|
|
431
431
|
def forward(self, x):
|
|
432
|
-
|
|
432
|
+
_carry, out = torch.ops.higher_order.scan(
|
|
433
433
|
ControlFlowScanCDist.dist,
|
|
434
434
|
[x],
|
|
435
435
|
[x],
|
|
@@ -483,7 +483,7 @@ class ControlFlowScanCDistXY(torch.nn.Module):
|
|
|
483
483
|
return [y.clone(), rd]
|
|
484
484
|
|
|
485
485
|
def forward(self, x, y):
|
|
486
|
-
|
|
486
|
+
_carry, out = torch.ops.higher_order.scan(
|
|
487
487
|
ControlFlowScanCDistXY.dist,
|
|
488
488
|
[y],
|
|
489
489
|
[x],
|
|
@@ -439,6 +439,28 @@ def torch_export_patches(
|
|
|
439
439
|
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
440
440
|
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
441
441
|
|
|
442
|
+
if verbose:
|
|
443
|
+
print(
|
|
444
|
+
"[torch_export_patches] patches "
|
|
445
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
446
|
+
)
|
|
447
|
+
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
|
|
448
|
+
masking_utils.sdpa_mask_recent_torch = (
|
|
449
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
450
|
+
)
|
|
451
|
+
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
|
|
452
|
+
if verbose:
|
|
453
|
+
print(
|
|
454
|
+
"[torch_export_patches] patches "
|
|
455
|
+
"transformers.masking_utils.sdpa_mask"
|
|
456
|
+
)
|
|
457
|
+
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
458
|
+
masking_utils.sdpa_mask = (
|
|
459
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
460
|
+
)
|
|
461
|
+
else:
|
|
462
|
+
f_transformers_sdpa_mask = None
|
|
463
|
+
|
|
442
464
|
if (
|
|
443
465
|
masking_utils
|
|
444
466
|
and patch_transformers_list.patch_masking_utils
|
|
@@ -456,10 +478,37 @@ def torch_export_patches(
|
|
|
456
478
|
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
457
479
|
== f_transformers_eager_mask
|
|
458
480
|
):
|
|
481
|
+
if verbose:
|
|
482
|
+
print(
|
|
483
|
+
"[torch_export_patches] patches "
|
|
484
|
+
"transformers.masking_utils.eager_mask "
|
|
485
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
486
|
+
)
|
|
459
487
|
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
460
488
|
patch_transformers_list.patched_eager_mask
|
|
461
489
|
)
|
|
462
490
|
|
|
491
|
+
if (
|
|
492
|
+
masking_utils
|
|
493
|
+
and patch_transformers_list.patch_masking_utils
|
|
494
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
495
|
+
and f_transformers_sdpa_mask is not None
|
|
496
|
+
):
|
|
497
|
+
if verbose:
|
|
498
|
+
print(
|
|
499
|
+
"[torch_export_patches] patches "
|
|
500
|
+
"transformers.masking_utils.sdpa_mask "
|
|
501
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
502
|
+
)
|
|
503
|
+
if (
|
|
504
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
505
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
506
|
+
== f_transformers_sdpa_mask
|
|
507
|
+
):
|
|
508
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
509
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
510
|
+
)
|
|
511
|
+
|
|
463
512
|
if custom_patches:
|
|
464
513
|
if verbose:
|
|
465
514
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -568,12 +617,31 @@ def torch_export_patches(
|
|
|
568
617
|
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
569
618
|
):
|
|
570
619
|
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
620
|
+
|
|
571
621
|
if verbose:
|
|
572
622
|
print(
|
|
573
623
|
"[torch_export_patches] restored "
|
|
574
624
|
"transformers.masking_utils._vmap_for_bhqkv"
|
|
575
625
|
)
|
|
576
626
|
|
|
627
|
+
masking_utils.sdpa_mask_recent_torch = (
|
|
628
|
+
f_transformers_sdpa_mask_recent_torch
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
if verbose:
|
|
632
|
+
print(
|
|
633
|
+
"[torch_export_patches] restored "
|
|
634
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if f_transformers_sdpa_mask is not None:
|
|
638
|
+
masking_utils.sdpa_mask = f_transformers_sdpa_mask
|
|
639
|
+
if verbose:
|
|
640
|
+
print(
|
|
641
|
+
"[torch_export_patches] restored "
|
|
642
|
+
"transformers.masking_utils.sdpa_mask"
|
|
643
|
+
)
|
|
644
|
+
|
|
577
645
|
if (
|
|
578
646
|
masking_utils
|
|
579
647
|
and patch_transformers_list.patch_masking_utils
|
|
@@ -581,6 +649,11 @@ def torch_export_patches(
|
|
|
581
649
|
):
|
|
582
650
|
f_transformers_eager_mask = masking_utils.eager_mask
|
|
583
651
|
masking_utils.eager_mask = f_transformers_eager_mask
|
|
652
|
+
if verbose:
|
|
653
|
+
print(
|
|
654
|
+
"[torch_export_patches] restored "
|
|
655
|
+
"transformers.masking_utils.eager_mask"
|
|
656
|
+
)
|
|
584
657
|
if (
|
|
585
658
|
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
586
659
|
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
@@ -589,11 +662,32 @@ def torch_export_patches(
|
|
|
589
662
|
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
590
663
|
f_transformers_eager_mask
|
|
591
664
|
)
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
665
|
+
if verbose:
|
|
666
|
+
print(
|
|
667
|
+
"[torch_export_patches] restored "
|
|
668
|
+
"transformers.masking_utils.eager_mask "
|
|
669
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
if (
|
|
673
|
+
masking_utils
|
|
674
|
+
and patch_transformers_list.patch_masking_utils
|
|
675
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
676
|
+
):
|
|
677
|
+
if (
|
|
678
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
679
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
680
|
+
== patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
681
|
+
):
|
|
682
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
683
|
+
f_transformers_sdpa_mask
|
|
596
684
|
)
|
|
685
|
+
if verbose:
|
|
686
|
+
print(
|
|
687
|
+
"[torch_export_patches] restored "
|
|
688
|
+
"transformers.masking_utils.sdpa_mask "
|
|
689
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
690
|
+
)
|
|
597
691
|
|
|
598
692
|
########
|
|
599
693
|
# caches
|
|
@@ -35,9 +35,18 @@ except ImportError:
|
|
|
35
35
|
from ...ext_test_case import has_transformers
|
|
36
36
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
37
37
|
|
|
38
|
+
patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
|
|
39
|
+
|
|
40
|
+
|
|
38
41
|
if patch_masking_utils:
|
|
39
42
|
# Introduced in 4.52
|
|
40
|
-
from transformers.masking_utils import
|
|
43
|
+
from transformers.masking_utils import (
|
|
44
|
+
causal_mask_function,
|
|
45
|
+
padding_mask_function,
|
|
46
|
+
and_masks,
|
|
47
|
+
_ignore_causal_mask_sdpa,
|
|
48
|
+
prepare_padding_mask,
|
|
49
|
+
)
|
|
41
50
|
|
|
42
51
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
43
52
|
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
@@ -105,7 +114,7 @@ if patch_masking_utils:
|
|
|
105
114
|
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
106
115
|
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
107
116
|
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
108
|
-
mask =
|
|
117
|
+
mask = patched_sdpa_mask_recent_torch(
|
|
109
118
|
batch_size=batch_size,
|
|
110
119
|
cache_position=cache_position,
|
|
111
120
|
kv_length=kv_length,
|
|
@@ -125,6 +134,35 @@ if patch_masking_utils:
|
|
|
125
134
|
mask = (~mask).to(dtype) * min_dtype
|
|
126
135
|
return mask
|
|
127
136
|
|
|
137
|
+
def patched_sdpa_mask_recent_torch(
|
|
138
|
+
batch_size: int,
|
|
139
|
+
cache_position: torch.Tensor,
|
|
140
|
+
kv_length: int,
|
|
141
|
+
kv_offset: int = 0,
|
|
142
|
+
mask_function: Callable = causal_mask_function,
|
|
143
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
144
|
+
local_size: Optional[int] = None,
|
|
145
|
+
allow_is_causal_skip: bool = True,
|
|
146
|
+
**kwargs,
|
|
147
|
+
) -> Optional[torch.Tensor]:
|
|
148
|
+
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
149
|
+
q_length = cache_position.shape[0]
|
|
150
|
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
|
151
|
+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
152
|
+
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
153
|
+
):
|
|
154
|
+
return None
|
|
155
|
+
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
|
156
|
+
kv_arange += kv_offset
|
|
157
|
+
if padding_mask is not None:
|
|
158
|
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
159
|
+
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
160
|
+
head_arange = torch.arange(1, device=cache_position.device)
|
|
161
|
+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
|
|
162
|
+
batch_arange, head_arange, cache_position, kv_arange
|
|
163
|
+
)
|
|
164
|
+
return causal_mask
|
|
165
|
+
|
|
128
166
|
|
|
129
167
|
if patch_parse_processor_args:
|
|
130
168
|
|
|
@@ -178,6 +216,8 @@ if patch_DynamicLayer:
|
|
|
178
216
|
new_shape[-2] = 0
|
|
179
217
|
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
180
218
|
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
219
|
+
if patch_is_initialized:
|
|
220
|
+
self.is_initialized = True
|
|
181
221
|
|
|
182
222
|
|
|
183
223
|
def _patch_make_causal_mask(
|
|
@@ -218,7 +218,6 @@ def unflatten_sliding_window_cache(
|
|
|
218
218
|
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
219
219
|
) -> SlidingWindowCache:
|
|
220
220
|
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
|
|
221
|
-
key_cache, value_cache = values
|
|
222
221
|
return make_sliding_window_cache(list(zip(values[0], values[1])))
|
|
223
222
|
|
|
224
223
|
|