onnx-diagnostic 0.7.8__py3-none-any.whl → 0.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +2 -2
- onnx_diagnostic/helpers/_log_helper.py +4 -2
- onnx_diagnostic/helpers/cache_helper.py +4 -4
- onnx_diagnostic/helpers/helper.py +8 -0
- onnx_diagnostic/helpers/log_helper.py +7 -1
- onnx_diagnostic/helpers/model_builder_helper.py +5 -0
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/torch_helper.py +14 -4
- onnx_diagnostic/reference/ops/op_scan.py +5 -5
- onnx_diagnostic/reference/ort_evaluator.py +2 -2
- onnx_diagnostic/tasks/automatic_speech_recognition.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +1 -1
- onnx_diagnostic/tasks/fill_mask.py +1 -1
- onnx_diagnostic/tasks/image_text_to_text.py +2 -2
- onnx_diagnostic/tasks/sentence_similarity.py +1 -1
- onnx_diagnostic/tasks/summarization.py +1 -1
- onnx_diagnostic/tasks/text2text_generation.py +1 -1
- onnx_diagnostic/tasks/text_classification.py +1 -1
- onnx_diagnostic/tasks/text_generation.py +1 -1
- onnx_diagnostic/tasks/zero_shot_image_classification.py +1 -1
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +3 -3
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +98 -4
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +37 -2
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +0 -1
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +139 -126
- onnx_diagnostic/torch_models/hghub/model_specific.py +49 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +11 -3
- onnx_diagnostic/torch_models/validate.py +44 -4
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/RECORD +38 -37
- {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.8.dist-info → onnx_diagnostic-0.7.10.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -850,13 +850,13 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
850
850
|
"--filter-in",
|
|
851
851
|
default="",
|
|
852
852
|
help="adds a filter to filter in data, syntax is\n"
|
|
853
|
-
'``"<column1>:<value1>;<value2
|
|
853
|
+
'``"<column1>:<value1>;<value2>//<column2>:<value3>"`` ...',
|
|
854
854
|
)
|
|
855
855
|
parser.add_argument(
|
|
856
856
|
"--filter-out",
|
|
857
857
|
default="",
|
|
858
858
|
help="adds a filter to filter out data, syntax is\n"
|
|
859
|
-
'``"<column1>:<value1>;<value2
|
|
859
|
+
'``"<column1>:<value1>;<value2>//<column2>:<value3>"`` ...',
|
|
860
860
|
)
|
|
861
861
|
parser.add_argument(
|
|
862
862
|
"--sbs",
|
|
@@ -118,9 +118,11 @@ def filter_data(
|
|
|
118
118
|
if isinstance(fmt, str):
|
|
119
119
|
cols = fmt.split("//")
|
|
120
120
|
for c in cols:
|
|
121
|
-
assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}"
|
|
121
|
+
assert ":" in c, f"Unexpected value {c!r} in fmt={fmt!r}, cols={cols!r}"
|
|
122
122
|
spl = c.split(":")
|
|
123
|
-
assert
|
|
123
|
+
assert (
|
|
124
|
+
len(spl) == 2
|
|
125
|
+
), f"Unexpected value {c!r} in fmt={fmt!r}, spl={spl}, cols={cols}"
|
|
124
126
|
name, fil = spl
|
|
125
127
|
cond[name] = set(fil.split(";"))
|
|
126
128
|
return cond
|
|
@@ -270,7 +270,7 @@ def make_static_cache(
|
|
|
270
270
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
271
271
|
self.num_hidden_layers = len(key_value_pairs)
|
|
272
272
|
|
|
273
|
-
def get_text_config(self):
|
|
273
|
+
def get_text_config(self, *args, **kwargs):
|
|
274
274
|
return self
|
|
275
275
|
|
|
276
276
|
assert max_cache_len is not None, (
|
|
@@ -366,7 +366,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
|
|
|
366
366
|
self.num_hidden_layers = len(key_value_pairs)
|
|
367
367
|
self.dtype = dtype
|
|
368
368
|
|
|
369
|
-
def get_text_config(self):
|
|
369
|
+
def get_text_config(self, *args, **kwargs):
|
|
370
370
|
return self
|
|
371
371
|
|
|
372
372
|
cache = MambaCache(
|
|
@@ -409,7 +409,7 @@ def make_sliding_window_cache(
|
|
|
409
409
|
self.num_hidden_layers = len(key_value_pairs)
|
|
410
410
|
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
411
411
|
|
|
412
|
-
def get_text_config(self):
|
|
412
|
+
def get_text_config(self, *args, **kwargs):
|
|
413
413
|
return self
|
|
414
414
|
|
|
415
415
|
cache = transformers.cache_utils.SlidingWindowCache(
|
|
@@ -577,7 +577,7 @@ def make_hybrid_cache(
|
|
|
577
577
|
sliding_window = _sliding_window
|
|
578
578
|
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
579
579
|
|
|
580
|
-
def get_text_config(self):
|
|
580
|
+
def get_text_config(self, *args, **kwargs):
|
|
581
581
|
return self
|
|
582
582
|
|
|
583
583
|
if layer_types:
|
|
@@ -774,6 +774,14 @@ def string_type(
|
|
|
774
774
|
return f"{obj.__class__.__name__}(**{s})"
|
|
775
775
|
if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
|
|
776
776
|
return f"{obj.__class__.__name__}(...)"
|
|
777
|
+
if obj.__class__.__name__ == "Results":
|
|
778
|
+
import ultralytics
|
|
779
|
+
|
|
780
|
+
assert isinstance(
|
|
781
|
+
obj, ultralytics.engine.results.Results
|
|
782
|
+
), f"Unexpected type={type(obj)}"
|
|
783
|
+
return f"ultralytics.{obj.__class__.__name__}(...)"
|
|
784
|
+
|
|
777
785
|
if verbose:
|
|
778
786
|
print(f"[string_type] END:{type(obj)}")
|
|
779
787
|
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import io
|
|
3
|
+
import os
|
|
3
4
|
import pprint
|
|
4
5
|
import re
|
|
5
6
|
import warnings
|
|
@@ -270,6 +271,10 @@ class CubePlot:
|
|
|
270
271
|
def _to_images_bar(
|
|
271
272
|
self, verbose: int = 0, merge: bool = True, title_suffix: Optional[str] = None
|
|
272
273
|
) -> List[bytes]:
|
|
274
|
+
"""
|
|
275
|
+
Environment variable ``FIGSIZEH`` can be set to increase the
|
|
276
|
+
graph height. Default is 1.0.
|
|
277
|
+
"""
|
|
273
278
|
assert merge, f"merge={merge} not implemented yet"
|
|
274
279
|
import matplotlib.pyplot as plt
|
|
275
280
|
|
|
@@ -279,7 +284,8 @@ class CubePlot:
|
|
|
279
284
|
n_cols = 3
|
|
280
285
|
nn = df.shape[1] // n_cols
|
|
281
286
|
nn += int(df.shape[1] % n_cols != 0)
|
|
282
|
-
|
|
287
|
+
ratio = float(os.environ.get("FIGSIZEH", "1"))
|
|
288
|
+
fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 3 * ratio))
|
|
283
289
|
pos = 0
|
|
284
290
|
imgs = []
|
|
285
291
|
for c in self._make_loop(df.columns, verbose):
|
|
@@ -201,10 +201,12 @@ def create_model_builder(
|
|
|
201
201
|
arch_map = {
|
|
202
202
|
"ChatGLMForConditionalGeneration": builder.ChatGLMModel,
|
|
203
203
|
"ChatGLMModel": builder.ChatGLMModel,
|
|
204
|
+
"Ernie4_5_ForCausalLM": builder.ErnieModel,
|
|
204
205
|
"GemmaForCausalLM": builder.Gemma2Model,
|
|
205
206
|
"Gemma3ForCausalLM": builder.Gemma3Model,
|
|
206
207
|
"Gemma3ForConditionalGeneration": builder.Gemma3Model,
|
|
207
208
|
"GraniteForCausalLM": builder.GraniteModel,
|
|
209
|
+
"GptOssForCausalLM": builder.GPTOSSModel,
|
|
208
210
|
"LlamaForCausalLM": builder.LlamaModel,
|
|
209
211
|
"MistralForCausalLM": builder.MistralModel,
|
|
210
212
|
"NemotronForCausalLM": builder.NemotronModel,
|
|
@@ -235,6 +237,7 @@ def create_model_builder(
|
|
|
235
237
|
"Phi4MMForCausalLM": builder.Phi4MMModel,
|
|
236
238
|
"Qwen2ForCausalLM": builder.QwenModel,
|
|
237
239
|
"Qwen3ForCausalLM": builder.Qwen3Model,
|
|
240
|
+
"SmolLM3ForCausalLM": builder.SmolLM3Model,
|
|
238
241
|
}
|
|
239
242
|
|
|
240
243
|
assert config.architectures[0] in arch_map, (
|
|
@@ -276,6 +279,8 @@ def create_model_builder(
|
|
|
276
279
|
for key in text_config:
|
|
277
280
|
if not hasattr(config, key):
|
|
278
281
|
setattr(config, key, getattr(text_config, key))
|
|
282
|
+
elif config.architectures[0] == "GptOssForCausalLM":
|
|
283
|
+
delattr(config, "quantization_config")
|
|
279
284
|
elif (
|
|
280
285
|
config.architectures[0] == "PhiMoEForCausalLM"
|
|
281
286
|
and config.max_position_embeddings != config.original_max_position_embeddings
|
|
@@ -1186,7 +1186,7 @@ def shadowing_names(
|
|
|
1186
1186
|
shadow |= set(i.name for i in g.input) & shadow_context
|
|
1187
1187
|
shadow |= set(i.name for i in g.initializer) & shadow_context
|
|
1188
1188
|
shadow |= set(i.name for i in g.sparse_initializer) & shadow_context
|
|
1189
|
-
s,
|
|
1189
|
+
s, _ps, c = shadowing_names(
|
|
1190
1190
|
g.node, verbose=verbose, existing=existing, shadow_context=existing
|
|
1191
1191
|
)
|
|
1192
1192
|
shadow |= s
|
|
@@ -543,7 +543,7 @@ def dummy_llm(
|
|
|
543
543
|
)
|
|
544
544
|
|
|
545
545
|
def forward(self, x):
|
|
546
|
-
|
|
546
|
+
_B, T, C = x.shape
|
|
547
547
|
|
|
548
548
|
query = self.query(x)
|
|
549
549
|
key = self.key(x)
|
|
@@ -721,9 +721,10 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
721
721
|
return {to_any(t, to_value) for t in value}
|
|
722
722
|
if type(value) is dict:
|
|
723
723
|
return {k: to_any(t, to_value) for k, t in value.items()}
|
|
724
|
-
if value.__class__.__name__
|
|
724
|
+
if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
|
|
725
|
+
make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
|
|
725
726
|
cc = CacheKeyValue(value)
|
|
726
|
-
return
|
|
727
|
+
return make[value.__class__.__name__]( # type: ignore[operator]
|
|
727
728
|
list(
|
|
728
729
|
zip(
|
|
729
730
|
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
@@ -822,6 +823,15 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
822
823
|
new_args = torch_deepcopy(args)
|
|
823
824
|
return torch.utils._pytree.tree_unflatten(new_args, spec)
|
|
824
825
|
|
|
826
|
+
if value.__class__.__name__ == "Results":
|
|
827
|
+
import copy
|
|
828
|
+
import ultralytics
|
|
829
|
+
|
|
830
|
+
assert isinstance(
|
|
831
|
+
value, ultralytics.engine.results.Results
|
|
832
|
+
), f"Unexpected type={type(value)}"
|
|
833
|
+
return copy.deepcopy(value)
|
|
834
|
+
|
|
825
835
|
# We should have a code using serialization, deserialization assuming a model
|
|
826
836
|
# cannot be exported without them.
|
|
827
837
|
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
|
|
@@ -856,7 +866,7 @@ def torch_tensor_size(value: Any) -> Any:
|
|
|
856
866
|
if value.__class__.__name__ == "MambaCache":
|
|
857
867
|
return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
|
|
858
868
|
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
859
|
-
args,
|
|
869
|
+
args, _spec = torch.utils._pytree.tree_flatten(value)
|
|
860
870
|
return sum(torch_tensor_size(a) for a in args)
|
|
861
871
|
|
|
862
872
|
# We should have a code using serialization, deserialization assuming a model
|
|
@@ -26,11 +26,11 @@ class Scan(_Scan):
|
|
|
26
26
|
):
|
|
27
27
|
(
|
|
28
28
|
num_loop_state_vars,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
29
|
+
_num_scan_outputs,
|
|
30
|
+
_output_directions,
|
|
31
|
+
_max_dir_out,
|
|
32
|
+
_output_axes,
|
|
33
|
+
_max_axe_out,
|
|
34
34
|
state_names_in,
|
|
35
35
|
state_names_out,
|
|
36
36
|
scan_names_in,
|
|
@@ -562,7 +562,7 @@ class OnnxruntimeEvaluator:
|
|
|
562
562
|
if key in self._cache:
|
|
563
563
|
sess = self._cache[key][1]
|
|
564
564
|
else:
|
|
565
|
-
self._cache[key] =
|
|
565
|
+
self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results)
|
|
566
566
|
|
|
567
567
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
568
568
|
feeds = {name: results[name] for name in sess.input_names}
|
|
@@ -616,7 +616,7 @@ class OnnxruntimeEvaluator:
|
|
|
616
616
|
if key in self._cache:
|
|
617
617
|
sess = self._cache[key][1]
|
|
618
618
|
else:
|
|
619
|
-
self._cache[key] =
|
|
619
|
+
self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
|
|
620
620
|
|
|
621
621
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
622
622
|
feeds = {name: results[name] for name in sess.input_names}
|
|
@@ -47,7 +47,7 @@ def get_inputs(
|
|
|
47
47
|
assert (
|
|
48
48
|
"cls_cache" not in kwargs
|
|
49
49
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
50
|
-
batch =
|
|
50
|
+
batch = "batch"
|
|
51
51
|
seq_length = "sequence_length"
|
|
52
52
|
shapes = {
|
|
53
53
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -42,7 +42,7 @@ def get_inputs(
|
|
|
42
42
|
assert (
|
|
43
43
|
"cls_cache" not in kwargs
|
|
44
44
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
45
|
-
batch =
|
|
45
|
+
batch = "batch"
|
|
46
46
|
seq_length = "sequence_length"
|
|
47
47
|
shapes = {
|
|
48
48
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -107,7 +107,7 @@ def _get_inputs_gemma3(
|
|
|
107
107
|
assert (
|
|
108
108
|
"cls_cache" not in kwargs
|
|
109
109
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
110
|
-
batch =
|
|
110
|
+
batch = "batch"
|
|
111
111
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
112
112
|
# cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
113
113
|
|
|
@@ -230,7 +230,7 @@ def get_inputs(
|
|
|
230
230
|
assert (
|
|
231
231
|
"cls_cache" not in kwargs
|
|
232
232
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
233
|
-
batch =
|
|
233
|
+
batch = "batch"
|
|
234
234
|
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
|
|
235
235
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
236
236
|
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
@@ -42,7 +42,7 @@ def get_inputs(
|
|
|
42
42
|
assert (
|
|
43
43
|
"cls_cache" not in kwargs
|
|
44
44
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
45
|
-
batch =
|
|
45
|
+
batch = "batch"
|
|
46
46
|
seq_length = "seq_length"
|
|
47
47
|
shapes = {
|
|
48
48
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -70,7 +70,7 @@ def get_inputs(
|
|
|
70
70
|
assert (
|
|
71
71
|
"cls_cache" not in kwargs
|
|
72
72
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
73
|
-
batch =
|
|
73
|
+
batch = "batch"
|
|
74
74
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
75
75
|
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
76
76
|
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
|
|
@@ -72,7 +72,7 @@ def get_inputs(
|
|
|
72
72
|
assert (
|
|
73
73
|
"cls_cache" not in kwargs
|
|
74
74
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
75
|
-
batch =
|
|
75
|
+
batch = "batch"
|
|
76
76
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
77
77
|
cache_length = "cache_length_key"
|
|
78
78
|
cache_length2 = "cache_length_val"
|
|
@@ -42,7 +42,7 @@ def get_inputs(
|
|
|
42
42
|
assert (
|
|
43
43
|
"cls_cache" not in kwargs
|
|
44
44
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
45
|
-
batch =
|
|
45
|
+
batch = "batch"
|
|
46
46
|
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
|
|
47
47
|
shapes = {
|
|
48
48
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -83,7 +83,7 @@ def get_inputs(
|
|
|
83
83
|
:class:`transformers.cache_utils.DynamicCache`
|
|
84
84
|
:return: dictionary
|
|
85
85
|
"""
|
|
86
|
-
batch =
|
|
86
|
+
batch = "batch"
|
|
87
87
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
88
88
|
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
89
89
|
|
|
@@ -65,7 +65,7 @@ def get_inputs(
|
|
|
65
65
|
input_width, int
|
|
66
66
|
), f"Unexpected type for input_height {type(input_height)}{config}"
|
|
67
67
|
|
|
68
|
-
batch =
|
|
68
|
+
batch = "batch"
|
|
69
69
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
70
70
|
shapes = {
|
|
71
71
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -384,7 +384,7 @@ class ControlFlowScan(torch.nn.Module):
|
|
|
384
384
|
|
|
385
385
|
def forward(self, x):
|
|
386
386
|
init = torch.zeros_like(x[0])
|
|
387
|
-
carry,
|
|
387
|
+
carry, _out = torch.ops.higher_order.scan(
|
|
388
388
|
ControlFlowScan.add, [init], [x], additional_inputs=[]
|
|
389
389
|
)
|
|
390
390
|
return carry
|
|
@@ -429,7 +429,7 @@ class ControlFlowScanCDist(torch.nn.Module):
|
|
|
429
429
|
return [carry.clone(), rd]
|
|
430
430
|
|
|
431
431
|
def forward(self, x):
|
|
432
|
-
|
|
432
|
+
_carry, out = torch.ops.higher_order.scan(
|
|
433
433
|
ControlFlowScanCDist.dist,
|
|
434
434
|
[x],
|
|
435
435
|
[x],
|
|
@@ -483,7 +483,7 @@ class ControlFlowScanCDistXY(torch.nn.Module):
|
|
|
483
483
|
return [y.clone(), rd]
|
|
484
484
|
|
|
485
485
|
def forward(self, x, y):
|
|
486
|
-
|
|
486
|
+
_carry, out = torch.ops.higher_order.scan(
|
|
487
487
|
ControlFlowScanCDistXY.dist,
|
|
488
488
|
[y],
|
|
489
489
|
[x],
|
|
@@ -439,6 +439,28 @@ def torch_export_patches(
|
|
|
439
439
|
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
440
440
|
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
441
441
|
|
|
442
|
+
if verbose:
|
|
443
|
+
print(
|
|
444
|
+
"[torch_export_patches] patches "
|
|
445
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
446
|
+
)
|
|
447
|
+
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
|
|
448
|
+
masking_utils.sdpa_mask_recent_torch = (
|
|
449
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
450
|
+
)
|
|
451
|
+
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
|
|
452
|
+
if verbose:
|
|
453
|
+
print(
|
|
454
|
+
"[torch_export_patches] patches "
|
|
455
|
+
"transformers.masking_utils.sdpa_mask"
|
|
456
|
+
)
|
|
457
|
+
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
458
|
+
masking_utils.sdpa_mask = (
|
|
459
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
460
|
+
)
|
|
461
|
+
else:
|
|
462
|
+
f_transformers_sdpa_mask = None
|
|
463
|
+
|
|
442
464
|
if (
|
|
443
465
|
masking_utils
|
|
444
466
|
and patch_transformers_list.patch_masking_utils
|
|
@@ -456,10 +478,37 @@ def torch_export_patches(
|
|
|
456
478
|
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
457
479
|
== f_transformers_eager_mask
|
|
458
480
|
):
|
|
481
|
+
if verbose:
|
|
482
|
+
print(
|
|
483
|
+
"[torch_export_patches] patches "
|
|
484
|
+
"transformers.masking_utils.eager_mask "
|
|
485
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
486
|
+
)
|
|
459
487
|
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
460
488
|
patch_transformers_list.patched_eager_mask
|
|
461
489
|
)
|
|
462
490
|
|
|
491
|
+
if (
|
|
492
|
+
masking_utils
|
|
493
|
+
and patch_transformers_list.patch_masking_utils
|
|
494
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
495
|
+
and f_transformers_sdpa_mask is not None
|
|
496
|
+
):
|
|
497
|
+
if verbose:
|
|
498
|
+
print(
|
|
499
|
+
"[torch_export_patches] patches "
|
|
500
|
+
"transformers.masking_utils.sdpa_mask "
|
|
501
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
502
|
+
)
|
|
503
|
+
if (
|
|
504
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
505
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
506
|
+
== f_transformers_sdpa_mask
|
|
507
|
+
):
|
|
508
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
509
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
510
|
+
)
|
|
511
|
+
|
|
463
512
|
if custom_patches:
|
|
464
513
|
if verbose:
|
|
465
514
|
print("[torch_export_patches] applies custom patches")
|
|
@@ -568,12 +617,31 @@ def torch_export_patches(
|
|
|
568
617
|
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
569
618
|
):
|
|
570
619
|
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
620
|
+
|
|
571
621
|
if verbose:
|
|
572
622
|
print(
|
|
573
623
|
"[torch_export_patches] restored "
|
|
574
624
|
"transformers.masking_utils._vmap_for_bhqkv"
|
|
575
625
|
)
|
|
576
626
|
|
|
627
|
+
masking_utils.sdpa_mask_recent_torch = (
|
|
628
|
+
f_transformers_sdpa_mask_recent_torch
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
if verbose:
|
|
632
|
+
print(
|
|
633
|
+
"[torch_export_patches] restored "
|
|
634
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if f_transformers_sdpa_mask is not None:
|
|
638
|
+
masking_utils.sdpa_mask = f_transformers_sdpa_mask
|
|
639
|
+
if verbose:
|
|
640
|
+
print(
|
|
641
|
+
"[torch_export_patches] restored "
|
|
642
|
+
"transformers.masking_utils.sdpa_mask"
|
|
643
|
+
)
|
|
644
|
+
|
|
577
645
|
if (
|
|
578
646
|
masking_utils
|
|
579
647
|
and patch_transformers_list.patch_masking_utils
|
|
@@ -581,6 +649,11 @@ def torch_export_patches(
|
|
|
581
649
|
):
|
|
582
650
|
f_transformers_eager_mask = masking_utils.eager_mask
|
|
583
651
|
masking_utils.eager_mask = f_transformers_eager_mask
|
|
652
|
+
if verbose:
|
|
653
|
+
print(
|
|
654
|
+
"[torch_export_patches] restored "
|
|
655
|
+
"transformers.masking_utils.eager_mask"
|
|
656
|
+
)
|
|
584
657
|
if (
|
|
585
658
|
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
586
659
|
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
@@ -589,11 +662,32 @@ def torch_export_patches(
|
|
|
589
662
|
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
590
663
|
f_transformers_eager_mask
|
|
591
664
|
)
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
665
|
+
if verbose:
|
|
666
|
+
print(
|
|
667
|
+
"[torch_export_patches] restored "
|
|
668
|
+
"transformers.masking_utils.eager_mask "
|
|
669
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
if (
|
|
673
|
+
masking_utils
|
|
674
|
+
and patch_transformers_list.patch_masking_utils
|
|
675
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
676
|
+
):
|
|
677
|
+
if (
|
|
678
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
679
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
680
|
+
== patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
681
|
+
):
|
|
682
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
683
|
+
f_transformers_sdpa_mask
|
|
596
684
|
)
|
|
685
|
+
if verbose:
|
|
686
|
+
print(
|
|
687
|
+
"[torch_export_patches] restored "
|
|
688
|
+
"transformers.masking_utils.sdpa_mask "
|
|
689
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
690
|
+
)
|
|
597
691
|
|
|
598
692
|
########
|
|
599
693
|
# caches
|
|
@@ -205,7 +205,10 @@ class patched_ShapeEnv:
|
|
|
205
205
|
# Precondition: a == tgt
|
|
206
206
|
assert isinstance(a, sympy.Symbol)
|
|
207
207
|
|
|
208
|
-
if
|
|
208
|
+
if (
|
|
209
|
+
getattr(self, "allow_complex_guards_as_runtime_asserts", False)
|
|
210
|
+
or getattr(self, "prefer_deferred_runtime_asserts_over_guards", False)
|
|
211
|
+
) and not _is_supported_equivalence(tgt):
|
|
209
212
|
# continuing leads to placeholder shapes
|
|
210
213
|
# having complex expressions that we can't resolve
|
|
211
214
|
return
|
|
@@ -37,7 +37,13 @@ from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
|
37
37
|
|
|
38
38
|
if patch_masking_utils:
|
|
39
39
|
# Introduced in 4.52
|
|
40
|
-
from transformers.masking_utils import
|
|
40
|
+
from transformers.masking_utils import (
|
|
41
|
+
causal_mask_function,
|
|
42
|
+
padding_mask_function,
|
|
43
|
+
and_masks,
|
|
44
|
+
_ignore_causal_mask_sdpa,
|
|
45
|
+
prepare_padding_mask,
|
|
46
|
+
)
|
|
41
47
|
|
|
42
48
|
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
43
49
|
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
@@ -105,7 +111,7 @@ if patch_masking_utils:
|
|
|
105
111
|
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
106
112
|
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
107
113
|
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
108
|
-
mask =
|
|
114
|
+
mask = patched_sdpa_mask_recent_torch(
|
|
109
115
|
batch_size=batch_size,
|
|
110
116
|
cache_position=cache_position,
|
|
111
117
|
kv_length=kv_length,
|
|
@@ -125,6 +131,35 @@ if patch_masking_utils:
|
|
|
125
131
|
mask = (~mask).to(dtype) * min_dtype
|
|
126
132
|
return mask
|
|
127
133
|
|
|
134
|
+
def patched_sdpa_mask_recent_torch(
|
|
135
|
+
batch_size: int,
|
|
136
|
+
cache_position: torch.Tensor,
|
|
137
|
+
kv_length: int,
|
|
138
|
+
kv_offset: int = 0,
|
|
139
|
+
mask_function: Callable = causal_mask_function,
|
|
140
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
141
|
+
local_size: Optional[int] = None,
|
|
142
|
+
allow_is_causal_skip: bool = True,
|
|
143
|
+
**kwargs,
|
|
144
|
+
) -> Optional[torch.Tensor]:
|
|
145
|
+
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
146
|
+
q_length = cache_position.shape[0]
|
|
147
|
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
|
148
|
+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
149
|
+
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
150
|
+
):
|
|
151
|
+
return None
|
|
152
|
+
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
|
153
|
+
kv_arange += kv_offset
|
|
154
|
+
if padding_mask is not None:
|
|
155
|
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
156
|
+
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
157
|
+
head_arange = torch.arange(1, device=cache_position.device)
|
|
158
|
+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
|
|
159
|
+
batch_arange, head_arange, cache_position, kv_arange
|
|
160
|
+
)
|
|
161
|
+
return causal_mask
|
|
162
|
+
|
|
128
163
|
|
|
129
164
|
if patch_parse_processor_args:
|
|
130
165
|
|
|
@@ -218,7 +218,6 @@ def unflatten_sliding_window_cache(
|
|
|
218
218
|
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
219
219
|
) -> SlidingWindowCache:
|
|
220
220
|
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
|
|
221
|
-
key_cache, value_cache = values
|
|
222
221
|
return make_sliding_window_cache(list(zip(values[0], values[1])))
|
|
223
222
|
|
|
224
223
|
|
|
@@ -11,6 +11,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
11
11
|
"""
|
|
12
12
|
architecture,task
|
|
13
13
|
ASTModel,feature-extraction
|
|
14
|
+
AutoencoderKL,image-to-image
|
|
14
15
|
AlbertModel,feature-extraction
|
|
15
16
|
BeitForImageClassification,image-classification
|
|
16
17
|
BartForConditionalGeneration,summarization
|
|
@@ -154,6 +155,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
154
155
|
Wav2Vec2ForCTC,automatic-speech-recognition
|
|
155
156
|
YolosForObjectDetection,object-detection
|
|
156
157
|
YolosModel,image-feature-extraction
|
|
158
|
+
Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
|
|
157
159
|
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
|
|
158
160
|
)
|
|
159
161
|
|