onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from typing import Callable, List, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import transformers.masking_utils # noqa: F401
|
|
6
|
+
|
|
7
|
+
patch_masking_utils = True
|
|
8
|
+
except ImportError:
|
|
9
|
+
patch_masking_utils = False
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
if patch_masking_utils:
|
|
13
|
+
# Introduced in 4.52
|
|
14
|
+
from transformers.masking_utils import (
|
|
15
|
+
_ignore_causal_mask_sdpa,
|
|
16
|
+
and_masks,
|
|
17
|
+
causal_mask_function,
|
|
18
|
+
padding_mask_function,
|
|
19
|
+
prepare_padding_mask,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
# transformers>=5.0
|
|
24
|
+
from transformers.masking_utils import (
|
|
25
|
+
_ignore_bidirectional_mask_sdpa,
|
|
26
|
+
bidirectional_mask_function,
|
|
27
|
+
)
|
|
28
|
+
except ImportError:
|
|
29
|
+
_ignore_bidirectional_mask_sdpa = None
|
|
30
|
+
bidirectional_mask_function = None
|
|
31
|
+
|
|
32
|
+
def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
|
|
33
|
+
"""manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
|
|
34
|
+
from ...helpers import string_type
|
|
35
|
+
|
|
36
|
+
dimensions: List[Tuple[Optional[int], ...]] = [
|
|
37
|
+
(None, None, None, 0),
|
|
38
|
+
(None, None, 0, None),
|
|
39
|
+
]
|
|
40
|
+
if bh_indices:
|
|
41
|
+
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
42
|
+
# reshape
|
|
43
|
+
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
44
|
+
dimensions = tuple(reversed(dimensions))
|
|
45
|
+
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
46
|
+
|
|
47
|
+
# unsqueeze
|
|
48
|
+
udimensions = [
|
|
49
|
+
tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
def vector_mask_function(
|
|
53
|
+
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
54
|
+
):
|
|
55
|
+
assert len(args) == len(dimensions) == len(udimensions), (
|
|
56
|
+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
|
|
57
|
+
f"and udimensions={udimensions}."
|
|
58
|
+
)
|
|
59
|
+
assert len(indices) == len(args), (
|
|
60
|
+
f"Mismatch between args={string_type(args)} and indices={indices}, "
|
|
61
|
+
f"they should have the same length."
|
|
62
|
+
)
|
|
63
|
+
for a in args:
|
|
64
|
+
assert (
|
|
65
|
+
a.ndim == 1
|
|
66
|
+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
|
|
67
|
+
torch._check(a.shape[0] > 0)
|
|
68
|
+
|
|
69
|
+
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
70
|
+
# new_args = [
|
|
71
|
+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
72
|
+
# for a, dims in zip(args, udimensions)
|
|
73
|
+
# ]
|
|
74
|
+
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
75
|
+
# if _is_torchdynamo_exporting():
|
|
76
|
+
# for a in args:
|
|
77
|
+
# # The exporter should export with a dimension > 1
|
|
78
|
+
# # to make sure it is dynamic.
|
|
79
|
+
# torch._check(a.shape[0] > 1)
|
|
80
|
+
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
81
|
+
return mask_function(*expanded_args)
|
|
82
|
+
|
|
83
|
+
return vector_mask_function
|
|
84
|
+
|
|
85
|
+
def patched_eager_mask(
|
|
86
|
+
batch_size: int,
|
|
87
|
+
cache_position: torch.Tensor,
|
|
88
|
+
kv_length: int,
|
|
89
|
+
kv_offset: int = 0,
|
|
90
|
+
mask_function: Callable = causal_mask_function,
|
|
91
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
92
|
+
dtype: torch.dtype = torch.float32,
|
|
93
|
+
**kwargs,
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
96
|
+
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
97
|
+
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
98
|
+
_ = kwargs.pop("allow_is_bidirectional_skip", None)
|
|
99
|
+
# PATCHED: this line called the patched version of sdpa_mask
|
|
100
|
+
mask = patched_sdpa_mask_recent_torch(
|
|
101
|
+
batch_size=batch_size,
|
|
102
|
+
cache_position=cache_position,
|
|
103
|
+
kv_length=kv_length,
|
|
104
|
+
kv_offset=kv_offset,
|
|
105
|
+
mask_function=mask_function,
|
|
106
|
+
attention_mask=attention_mask,
|
|
107
|
+
allow_is_causal_skip=False,
|
|
108
|
+
allow_is_bidirectional_skip=False,
|
|
109
|
+
allow_torch_fix=False,
|
|
110
|
+
**kwargs,
|
|
111
|
+
)
|
|
112
|
+
min_dtype = torch.finfo(dtype).min
|
|
113
|
+
# PATCHED: the following line
|
|
114
|
+
# we need 0s where the tokens should be taken into account,
|
|
115
|
+
# and -inf otherwise (mask is already of boolean type)
|
|
116
|
+
# mask =
|
|
117
|
+
# torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
|
|
118
|
+
mask = (~mask).to(dtype) * min_dtype
|
|
119
|
+
return mask
|
|
120
|
+
|
|
121
|
+
def patched_sdpa_mask_recent_torch(
|
|
122
|
+
batch_size: int,
|
|
123
|
+
cache_position: torch.Tensor,
|
|
124
|
+
kv_length: int,
|
|
125
|
+
kv_offset: int = 0,
|
|
126
|
+
mask_function: Callable = causal_mask_function,
|
|
127
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
128
|
+
local_size: Optional[int] = None,
|
|
129
|
+
allow_is_causal_skip: bool = True,
|
|
130
|
+
allow_is_bidirectional_skip: bool = False,
|
|
131
|
+
**kwargs,
|
|
132
|
+
) -> Optional[torch.Tensor]:
|
|
133
|
+
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
134
|
+
q_length = cache_position.shape[0]
|
|
135
|
+
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
|
|
136
|
+
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
137
|
+
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
138
|
+
):
|
|
139
|
+
return None
|
|
140
|
+
if (
|
|
141
|
+
allow_is_bidirectional_skip
|
|
142
|
+
and _ignore_bidirectional_mask_sdpa
|
|
143
|
+
and _ignore_bidirectional_mask_sdpa(padding_mask)
|
|
144
|
+
):
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
if mask_function is bidirectional_mask_function:
|
|
148
|
+
if padding_mask is not None:
|
|
149
|
+
# used for slicing without data-dependent slicing
|
|
150
|
+
mask_indices = (
|
|
151
|
+
torch.arange(kv_length, device=cache_position.device) + kv_offset
|
|
152
|
+
)
|
|
153
|
+
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
|
|
154
|
+
return torch.ones(
|
|
155
|
+
batch_size,
|
|
156
|
+
1,
|
|
157
|
+
q_length,
|
|
158
|
+
kv_length,
|
|
159
|
+
dtype=torch.bool,
|
|
160
|
+
device=cache_position.device,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
kv_arange = torch.arange(kv_length, device=cache_position.device)
|
|
164
|
+
kv_arange += kv_offset
|
|
165
|
+
if padding_mask is not None:
|
|
166
|
+
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
167
|
+
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
168
|
+
head_arange = torch.arange(1, device=cache_position.device)
|
|
169
|
+
# PATCHED: this line calls the patched version of vmap_for_bhqkv
|
|
170
|
+
causal_mask = patched__vmap_for_bhqkv(mask_function)(
|
|
171
|
+
batch_arange, head_arange, cache_position, kv_arange
|
|
172
|
+
)
|
|
173
|
+
return causal_mask
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import transformers.models.qwen2_vl
|
|
6
|
+
|
|
7
|
+
patch_qwen2 = True
|
|
8
|
+
except ImportError:
|
|
9
|
+
patch_qwen2 = False
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
|
|
13
|
+
"""
|
|
14
|
+
Rewrites the loop in:
|
|
15
|
+
|
|
16
|
+
.. code-block:: python
|
|
17
|
+
|
|
18
|
+
attention_mask = torch.full(
|
|
19
|
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
|
|
20
|
+
)
|
|
21
|
+
for i in range(1, len(seq)):
|
|
22
|
+
attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
|
|
23
|
+
"""
|
|
24
|
+
r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
|
|
25
|
+
less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
|
|
26
|
+
less = less0.sum(axis=-1, keepdim=True) + 1
|
|
27
|
+
sq = less * less.T
|
|
28
|
+
look = (
|
|
29
|
+
torch.max(seq.min() == 0, less != less.max())
|
|
30
|
+
* torch.max(seq.max() == mask.shape[-1], less != less.min())
|
|
31
|
+
* less
|
|
32
|
+
)
|
|
33
|
+
filt = (sq != look**2).to(mask.dtype)
|
|
34
|
+
return mask * filt
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
if patch_qwen2:
|
|
38
|
+
|
|
39
|
+
class patched_VisionAttention(torch.nn.Module):
|
|
40
|
+
_PATCHES_ = ["forward"]
|
|
41
|
+
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
|
|
42
|
+
|
|
43
|
+
def forward(
|
|
44
|
+
self,
|
|
45
|
+
hidden_states: torch.Tensor,
|
|
46
|
+
cu_seqlens: torch.Tensor,
|
|
47
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
48
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
seq_length = hidden_states.shape[0]
|
|
51
|
+
q, k, v = (
|
|
52
|
+
self.qkv(hidden_states)
|
|
53
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
54
|
+
.permute(1, 0, 2, 3)
|
|
55
|
+
.unbind(0)
|
|
56
|
+
)
|
|
57
|
+
if position_embeddings is None:
|
|
58
|
+
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
|
|
59
|
+
"The attention layers in this model are transitioning from "
|
|
60
|
+
" computing the RoPE embeddings internally "
|
|
61
|
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
|
|
62
|
+
"to using externally computed "
|
|
63
|
+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
|
|
64
|
+
" In v4.54 `rotary_pos_emb` will be "
|
|
65
|
+
"removed and `position_embeddings` will be mandatory."
|
|
66
|
+
)
|
|
67
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
68
|
+
cos = emb.cos()
|
|
69
|
+
sin = emb.sin()
|
|
70
|
+
else:
|
|
71
|
+
cos, sin = position_embeddings
|
|
72
|
+
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
|
|
73
|
+
q, k, cos, sin
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
attention_mask = torch.full(
|
|
77
|
+
[1, seq_length, seq_length],
|
|
78
|
+
torch.finfo(q.dtype).min,
|
|
79
|
+
device=q.device,
|
|
80
|
+
dtype=q.dtype,
|
|
81
|
+
)
|
|
82
|
+
# for i in range(1, len(cu_seqlens)):
|
|
83
|
+
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
|
|
84
|
+
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
85
|
+
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
|
|
86
|
+
|
|
87
|
+
q = q.transpose(0, 1)
|
|
88
|
+
k = k.transpose(0, 1)
|
|
89
|
+
v = v.transpose(0, 1)
|
|
90
|
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / (self.head_dim**0.5)
|
|
91
|
+
attn_weights = attn_weights + attention_mask
|
|
92
|
+
attn_weights = torch.nn.functional.softmax(
|
|
93
|
+
attn_weights, dim=-1, dtype=torch.float32
|
|
94
|
+
).to(q.dtype)
|
|
95
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
96
|
+
attn_output = attn_output.transpose(0, 1)
|
|
97
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
98
|
+
attn_output = self.proj(attn_output)
|
|
99
|
+
return attn_output
|