onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +118 -5
- onnx_diagnostic/export/control_flow.py +214 -0
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +118 -25
- onnx_diagnostic/helpers/cache_helper.py +218 -204
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +92 -26
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +115 -16
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +108 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import torch
|
|
4
|
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
5
|
+
from .patch_helper import _has_transformers
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _patch_make_causal_mask(
|
|
9
|
+
input_ids_shape: torch.Size,
|
|
10
|
+
dtype: torch.dtype,
|
|
11
|
+
device: torch.device,
|
|
12
|
+
past_key_values_length: int = 0,
|
|
13
|
+
sliding_window: Optional[int] = None,
|
|
14
|
+
):
|
|
15
|
+
"""Patched method."""
|
|
16
|
+
bsz, tgt_len = input_ids_shape
|
|
17
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
18
|
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
19
|
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
20
|
+
|
|
21
|
+
mask = mask.to(dtype)
|
|
22
|
+
|
|
23
|
+
if past_key_values_length > 0:
|
|
24
|
+
mask = torch.cat(
|
|
25
|
+
[
|
|
26
|
+
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
|
|
27
|
+
mask,
|
|
28
|
+
],
|
|
29
|
+
dim=-1,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if sliding_window is not None:
|
|
33
|
+
diagonal = past_key_values_length - sliding_window - 1
|
|
34
|
+
|
|
35
|
+
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
|
36
|
+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
|
|
37
|
+
# and used masked_fill instead of masked_fill_
|
|
38
|
+
# In this case, the current implementation of torch fails (17/12/2024).
|
|
39
|
+
# Try model Phi-3.5-Mini-Instruct.
|
|
40
|
+
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
|
|
41
|
+
|
|
42
|
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class patched_AttentionMaskConverter:
|
|
47
|
+
"""
|
|
48
|
+
Patches
|
|
49
|
+
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# This method was fixed in 4.51 at least.
|
|
53
|
+
_PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
|
|
54
|
+
_PATCHED_CLASS_ = AttentionMaskConverter
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _make_causal_mask(
|
|
58
|
+
*args,
|
|
59
|
+
**kwargs,
|
|
60
|
+
# input_ids_shape: torch.Size,
|
|
61
|
+
# dtype: torch.dtype,
|
|
62
|
+
# device: torch.device,
|
|
63
|
+
# past_key_values_length: int = 0,
|
|
64
|
+
# sliding_window: Optional[int] = None,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Patched method.
|
|
68
|
+
|
|
69
|
+
This static method may be called with ``AttentionMaskConverter._make_causal_mask``
|
|
70
|
+
or ``self._make_causal_mask``. That changes this argument is receives.
|
|
71
|
+
That should not matter but...
|
|
72
|
+
The patch should be implemented in another way. static methods do not play well
|
|
73
|
+
with a simple replacement.
|
|
74
|
+
Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
|
|
75
|
+
"""
|
|
76
|
+
if args:
|
|
77
|
+
index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
|
|
78
|
+
names = [
|
|
79
|
+
"input_ids_shape",
|
|
80
|
+
"dtype",
|
|
81
|
+
"device",
|
|
82
|
+
"past_key_values_length",
|
|
83
|
+
"sliding_window",
|
|
84
|
+
]
|
|
85
|
+
for i, a in enumerate(args):
|
|
86
|
+
if i < index:
|
|
87
|
+
continue
|
|
88
|
+
kwargs[names[i - index]] = a
|
|
89
|
+
return _patch_make_causal_mask(**kwargs)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
import packaging.version as pv
|
|
3
|
+
import torch
|
|
4
|
+
import transformers
|
|
5
|
+
from .patch_helper import _has_transformers
|
|
6
|
+
|
|
7
|
+
patch_is_initialized = _has_transformers("4.56.99")
|
|
8
|
+
patch_DynamicCache = pv.Version(transformers.__version__) < pv.Version("4.51")
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
# transformers>= 4.55.1
|
|
12
|
+
from transformers.cache_utils import DynamicLayer
|
|
13
|
+
|
|
14
|
+
patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_DynamicLayer = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if patch_DynamicLayer:
|
|
20
|
+
|
|
21
|
+
class patched_DynamicLayer:
|
|
22
|
+
_PATCHES_ = ["lazy_initialization"]
|
|
23
|
+
_PATCHED_CLASS_ = DynamicLayer
|
|
24
|
+
|
|
25
|
+
def lazy_initialization(self, key_states: torch.Tensor):
|
|
26
|
+
self.dtype, self.device = key_states.dtype, key_states.device
|
|
27
|
+
new_shape = list(key_states.shape)
|
|
28
|
+
new_shape[-2] = 0
|
|
29
|
+
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
30
|
+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
31
|
+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
32
|
+
if patch_is_initialized:
|
|
33
|
+
self.is_initialized = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if patch_DynamicCache:
|
|
37
|
+
from typing import Any, Dict
|
|
38
|
+
from transformers.cache_utils import DynamicCache
|
|
39
|
+
|
|
40
|
+
class patched_DynamicCache:
|
|
41
|
+
"""
|
|
42
|
+
Applies modifications implemented in PR
|
|
43
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
47
|
+
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
48
|
+
|
|
49
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
50
|
+
"""Returns the sequence length of the cached states.
|
|
51
|
+
A layer index can be optionally passed."""
|
|
52
|
+
# TODO: deprecate this function in favor of `cache_position`
|
|
53
|
+
is_empty_layer = (
|
|
54
|
+
len(self.key_cache) == 0 # no cache in any layer
|
|
55
|
+
or len(self.key_cache)
|
|
56
|
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
57
|
+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
58
|
+
)
|
|
59
|
+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
60
|
+
return layer_seq_length
|
|
61
|
+
|
|
62
|
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
63
|
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
64
|
+
for layer_idx in range(len(self.key_cache)):
|
|
65
|
+
if self.key_cache[layer_idx].numel():
|
|
66
|
+
device = self.key_cache[layer_idx].device
|
|
67
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
68
|
+
0, beam_idx.to(device)
|
|
69
|
+
)
|
|
70
|
+
if self.value_cache[layer_idx].numel():
|
|
71
|
+
device = self.value_cache[layer_idx].device
|
|
72
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
73
|
+
0, beam_idx.to(device)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def update(
|
|
77
|
+
self,
|
|
78
|
+
key_states: torch.Tensor,
|
|
79
|
+
value_states: torch.Tensor,
|
|
80
|
+
layer_idx: int,
|
|
81
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
82
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
+
"""
|
|
84
|
+
Updates the cache with the new `key_states`
|
|
85
|
+
and `value_states` for the layer `layer_idx`.
|
|
86
|
+
Parameters:
|
|
87
|
+
key_states (`torch.Tensor`):
|
|
88
|
+
The new key states to cache.
|
|
89
|
+
value_states (`torch.Tensor`):
|
|
90
|
+
The new value states to cache.
|
|
91
|
+
layer_idx (`int`):
|
|
92
|
+
The index of the layer to cache the states for.
|
|
93
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
94
|
+
Additional arguments for the cache subclass.
|
|
95
|
+
No additional arguments are used in `DynamicCache`.
|
|
96
|
+
Return:
|
|
97
|
+
A tuple containing the updated key and value states.
|
|
98
|
+
"""
|
|
99
|
+
# Update the number of seen tokens
|
|
100
|
+
if layer_idx == 0:
|
|
101
|
+
if hasattr(self, "_seen_tokens"):
|
|
102
|
+
self._seen_tokens += key_states.shape[-2]
|
|
103
|
+
|
|
104
|
+
# Update the cache
|
|
105
|
+
if key_states is not None:
|
|
106
|
+
if len(self.key_cache) <= layer_idx:
|
|
107
|
+
# There may be skipped layers, fill them with empty lists
|
|
108
|
+
for _ in range(len(self.key_cache), layer_idx):
|
|
109
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
110
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
111
|
+
self.key_cache.append(key_states)
|
|
112
|
+
self.value_cache.append(value_states)
|
|
113
|
+
elif not self.key_cache[
|
|
114
|
+
layer_idx
|
|
115
|
+
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
116
|
+
# fills previously skipped layers; checking for tensor causes errors
|
|
117
|
+
self.key_cache[layer_idx] = key_states
|
|
118
|
+
self.value_cache[layer_idx] = value_states
|
|
119
|
+
else:
|
|
120
|
+
torch._check(
|
|
121
|
+
len(self.key_cache[layer_idx].shape) == len(key_states.shape),
|
|
122
|
+
lambda: (
|
|
123
|
+
f"Rank mismatch len(self.key_cache[layer_idx].shape)="
|
|
124
|
+
f"{len(self.key_cache[layer_idx].shape)}, "
|
|
125
|
+
f"len(key_states.shape)={len(key_states.shape)}"
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
self.key_cache[layer_idx] = torch.cat(
|
|
129
|
+
[self.key_cache[layer_idx], key_states], dim=-2
|
|
130
|
+
)
|
|
131
|
+
self.value_cache[layer_idx] = torch.cat(
|
|
132
|
+
[self.value_cache[layer_idx], value_states], dim=-2
|
|
133
|
+
)
|
|
134
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
135
|
+
|
|
136
|
+
def crop(self, max_length: int):
|
|
137
|
+
"""Crop the past key values up to a new `max_length`
|
|
138
|
+
in terms of tokens. `max_length` can also be
|
|
139
|
+
negative to remove `max_length` tokens.
|
|
140
|
+
This is used in assisted decoding and contrastive search.
|
|
141
|
+
"""
|
|
142
|
+
# In case it is negative
|
|
143
|
+
if max_length < 0:
|
|
144
|
+
max_length = self.get_seq_length() - abs(max_length)
|
|
145
|
+
|
|
146
|
+
if self.get_seq_length() <= max_length:
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
if hasattr(self, "_seen_tokens"):
|
|
150
|
+
self._seen_tokens = max_length
|
|
151
|
+
for idx in range(len(self.key_cache)):
|
|
152
|
+
if self.key_cache[idx].numel():
|
|
153
|
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
154
|
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
|
|
158
|
+
"""This is the opposite of the above `batch_split()` method.
|
|
159
|
+
This will be used by `stack_model_outputs` in
|
|
160
|
+
`generation.utils`"""
|
|
161
|
+
cache = cls()
|
|
162
|
+
for idx in range(len(splits[0])):
|
|
163
|
+
key_cache = [
|
|
164
|
+
current.key_cache[idx]
|
|
165
|
+
for current in splits
|
|
166
|
+
if current.key_cache[idx].numel()
|
|
167
|
+
]
|
|
168
|
+
value_cache = [
|
|
169
|
+
current.value_cache[idx]
|
|
170
|
+
for current in splits
|
|
171
|
+
if current.value_cache[idx].numel()
|
|
172
|
+
]
|
|
173
|
+
if key_cache != []:
|
|
174
|
+
layer_keys = torch.cat(key_cache, dim=0)
|
|
175
|
+
layer_values = torch.cat(value_cache, dim=0)
|
|
176
|
+
cache.update(layer_keys, layer_values, idx)
|
|
177
|
+
return cache
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import transformers
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
|
|
6
|
+
|
|
7
|
+
patch_gemma3 = True
|
|
8
|
+
except ImportError:
|
|
9
|
+
patch_gemma3 = False
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if patch_gemma3:
|
|
13
|
+
|
|
14
|
+
class patched_Gemma3Model(torch.nn.Module):
|
|
15
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
16
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
|
|
17
|
+
_PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
|
|
18
|
+
|
|
19
|
+
def get_placeholder_mask(
|
|
20
|
+
self,
|
|
21
|
+
input_ids: torch.LongTensor,
|
|
22
|
+
inputs_embeds: torch.FloatTensor,
|
|
23
|
+
image_features: torch.FloatTensor,
|
|
24
|
+
):
|
|
25
|
+
if input_ids is None:
|
|
26
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
27
|
+
torch.tensor(
|
|
28
|
+
self.config.image_token_id,
|
|
29
|
+
dtype=torch.long,
|
|
30
|
+
device=inputs_embeds.device,
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
special_image_mask = special_image_mask.all(-1)
|
|
34
|
+
else:
|
|
35
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
36
|
+
|
|
37
|
+
n_image_tokens = special_image_mask.sum()
|
|
38
|
+
special_image_mask = (
|
|
39
|
+
special_image_mask.unsqueeze(-1)
|
|
40
|
+
.expand_as(inputs_embeds)
|
|
41
|
+
.to(inputs_embeds.device)
|
|
42
|
+
)
|
|
43
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
44
|
+
# PATCHED: torch._check
|
|
45
|
+
# if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
46
|
+
# raise ValueError( ... )
|
|
47
|
+
torch._check(
|
|
48
|
+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
49
|
+
lambda: (
|
|
50
|
+
f"Image features and image tokens do not match: tokens: "
|
|
51
|
+
f"{n_image_tokens}, features {n_image_features}"
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
return special_image_mask
|