onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
import packaging.version as pv
|
|
3
|
+
import torch
|
|
4
|
+
import transformers
|
|
5
|
+
from .patch_helper import _has_transformers
|
|
6
|
+
|
|
7
|
+
patch_is_initialized = _has_transformers("4.56.99")
|
|
8
|
+
patch_DynamicCache = pv.Version(transformers.__version__) < pv.Version("4.51")
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
# transformers>= 4.55.1
|
|
12
|
+
from transformers.cache_utils import DynamicLayer
|
|
13
|
+
|
|
14
|
+
patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_DynamicLayer = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if patch_DynamicLayer:
|
|
20
|
+
|
|
21
|
+
class patched_DynamicLayer:
|
|
22
|
+
_PATCHES_ = ["lazy_initialization"]
|
|
23
|
+
_PATCHED_CLASS_ = DynamicLayer
|
|
24
|
+
|
|
25
|
+
def lazy_initialization(self, key_states: torch.Tensor):
|
|
26
|
+
self.dtype, self.device = key_states.dtype, key_states.device
|
|
27
|
+
new_shape = list(key_states.shape)
|
|
28
|
+
new_shape[-2] = 0
|
|
29
|
+
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
30
|
+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
31
|
+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
32
|
+
if patch_is_initialized:
|
|
33
|
+
self.is_initialized = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if patch_DynamicCache:
|
|
37
|
+
from typing import Any, Dict
|
|
38
|
+
from transformers.cache_utils import DynamicCache
|
|
39
|
+
|
|
40
|
+
class patched_DynamicCache:
|
|
41
|
+
"""
|
|
42
|
+
Applies modifications implemented in PR
|
|
43
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
47
|
+
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
48
|
+
|
|
49
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
50
|
+
"""Returns the sequence length of the cached states.
|
|
51
|
+
A layer index can be optionally passed."""
|
|
52
|
+
# TODO: deprecate this function in favor of `cache_position`
|
|
53
|
+
is_empty_layer = (
|
|
54
|
+
len(self.key_cache) == 0 # no cache in any layer
|
|
55
|
+
or len(self.key_cache)
|
|
56
|
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
57
|
+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
58
|
+
)
|
|
59
|
+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
60
|
+
return layer_seq_length
|
|
61
|
+
|
|
62
|
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
63
|
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
64
|
+
for layer_idx in range(len(self.key_cache)):
|
|
65
|
+
if self.key_cache[layer_idx].numel():
|
|
66
|
+
device = self.key_cache[layer_idx].device
|
|
67
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
68
|
+
0, beam_idx.to(device)
|
|
69
|
+
)
|
|
70
|
+
if self.value_cache[layer_idx].numel():
|
|
71
|
+
device = self.value_cache[layer_idx].device
|
|
72
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
73
|
+
0, beam_idx.to(device)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def update(
|
|
77
|
+
self,
|
|
78
|
+
key_states: torch.Tensor,
|
|
79
|
+
value_states: torch.Tensor,
|
|
80
|
+
layer_idx: int,
|
|
81
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
82
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
+
"""
|
|
84
|
+
Updates the cache with the new `key_states`
|
|
85
|
+
and `value_states` for the layer `layer_idx`.
|
|
86
|
+
Parameters:
|
|
87
|
+
key_states (`torch.Tensor`):
|
|
88
|
+
The new key states to cache.
|
|
89
|
+
value_states (`torch.Tensor`):
|
|
90
|
+
The new value states to cache.
|
|
91
|
+
layer_idx (`int`):
|
|
92
|
+
The index of the layer to cache the states for.
|
|
93
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
94
|
+
Additional arguments for the cache subclass.
|
|
95
|
+
No additional arguments are used in `DynamicCache`.
|
|
96
|
+
Return:
|
|
97
|
+
A tuple containing the updated key and value states.
|
|
98
|
+
"""
|
|
99
|
+
# Update the number of seen tokens
|
|
100
|
+
if layer_idx == 0:
|
|
101
|
+
if hasattr(self, "_seen_tokens"):
|
|
102
|
+
self._seen_tokens += key_states.shape[-2]
|
|
103
|
+
|
|
104
|
+
# Update the cache
|
|
105
|
+
if key_states is not None:
|
|
106
|
+
if len(self.key_cache) <= layer_idx:
|
|
107
|
+
# There may be skipped layers, fill them with empty lists
|
|
108
|
+
for _ in range(len(self.key_cache), layer_idx):
|
|
109
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
110
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
111
|
+
self.key_cache.append(key_states)
|
|
112
|
+
self.value_cache.append(value_states)
|
|
113
|
+
elif not self.key_cache[
|
|
114
|
+
layer_idx
|
|
115
|
+
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
116
|
+
# fills previously skipped layers; checking for tensor causes errors
|
|
117
|
+
self.key_cache[layer_idx] = key_states
|
|
118
|
+
self.value_cache[layer_idx] = value_states
|
|
119
|
+
else:
|
|
120
|
+
torch._check(
|
|
121
|
+
len(self.key_cache[layer_idx].shape) == len(key_states.shape),
|
|
122
|
+
lambda: (
|
|
123
|
+
f"Rank mismatch len(self.key_cache[layer_idx].shape)="
|
|
124
|
+
f"{len(self.key_cache[layer_idx].shape)}, "
|
|
125
|
+
f"len(key_states.shape)={len(key_states.shape)}"
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
self.key_cache[layer_idx] = torch.cat(
|
|
129
|
+
[self.key_cache[layer_idx], key_states], dim=-2
|
|
130
|
+
)
|
|
131
|
+
self.value_cache[layer_idx] = torch.cat(
|
|
132
|
+
[self.value_cache[layer_idx], value_states], dim=-2
|
|
133
|
+
)
|
|
134
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
135
|
+
|
|
136
|
+
def crop(self, max_length: int):
|
|
137
|
+
"""Crop the past key values up to a new `max_length`
|
|
138
|
+
in terms of tokens. `max_length` can also be
|
|
139
|
+
negative to remove `max_length` tokens.
|
|
140
|
+
This is used in assisted decoding and contrastive search.
|
|
141
|
+
"""
|
|
142
|
+
# In case it is negative
|
|
143
|
+
if max_length < 0:
|
|
144
|
+
max_length = self.get_seq_length() - abs(max_length)
|
|
145
|
+
|
|
146
|
+
if self.get_seq_length() <= max_length:
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
if hasattr(self, "_seen_tokens"):
|
|
150
|
+
self._seen_tokens = max_length
|
|
151
|
+
for idx in range(len(self.key_cache)):
|
|
152
|
+
if self.key_cache[idx].numel():
|
|
153
|
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
154
|
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
|
|
158
|
+
"""This is the opposite of the above `batch_split()` method.
|
|
159
|
+
This will be used by `stack_model_outputs` in
|
|
160
|
+
`generation.utils`"""
|
|
161
|
+
cache = cls()
|
|
162
|
+
for idx in range(len(splits[0])):
|
|
163
|
+
key_cache = [
|
|
164
|
+
current.key_cache[idx]
|
|
165
|
+
for current in splits
|
|
166
|
+
if current.key_cache[idx].numel()
|
|
167
|
+
]
|
|
168
|
+
value_cache = [
|
|
169
|
+
current.value_cache[idx]
|
|
170
|
+
for current in splits
|
|
171
|
+
if current.value_cache[idx].numel()
|
|
172
|
+
]
|
|
173
|
+
if key_cache != []:
|
|
174
|
+
layer_keys = torch.cat(key_cache, dim=0)
|
|
175
|
+
layer_values = torch.cat(value_cache, dim=0)
|
|
176
|
+
cache.update(layer_keys, layer_values, idx)
|
|
177
|
+
return cache
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import transformers
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
|
|
6
|
+
|
|
7
|
+
patch_gemma3 = True
|
|
8
|
+
except ImportError:
|
|
9
|
+
patch_gemma3 = False
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if patch_gemma3:
|
|
13
|
+
|
|
14
|
+
class patched_Gemma3Model(torch.nn.Module):
|
|
15
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
16
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
|
|
17
|
+
_PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
|
|
18
|
+
|
|
19
|
+
def get_placeholder_mask(
|
|
20
|
+
self,
|
|
21
|
+
input_ids: torch.LongTensor,
|
|
22
|
+
inputs_embeds: torch.FloatTensor,
|
|
23
|
+
image_features: torch.FloatTensor,
|
|
24
|
+
):
|
|
25
|
+
if input_ids is None:
|
|
26
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
27
|
+
torch.tensor(
|
|
28
|
+
self.config.image_token_id,
|
|
29
|
+
dtype=torch.long,
|
|
30
|
+
device=inputs_embeds.device,
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
special_image_mask = special_image_mask.all(-1)
|
|
34
|
+
else:
|
|
35
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
36
|
+
|
|
37
|
+
n_image_tokens = special_image_mask.sum()
|
|
38
|
+
special_image_mask = (
|
|
39
|
+
special_image_mask.unsqueeze(-1)
|
|
40
|
+
.expand_as(inputs_embeds)
|
|
41
|
+
.to(inputs_embeds.device)
|
|
42
|
+
)
|
|
43
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
44
|
+
# PATCHED: torch._check
|
|
45
|
+
# if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
46
|
+
# raise ValueError( ... )
|
|
47
|
+
torch._check(
|
|
48
|
+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
49
|
+
lambda: (
|
|
50
|
+
f"Image features and image tokens do not match: tokens: "
|
|
51
|
+
f"{n_image_tokens}, features {n_image_features}"
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
return special_image_mask
|