onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ from typing import Callable, List, Optional, Tuple
2
+ import torch
3
+
4
+ try:
5
+ import transformers.masking_utils # noqa: F401
6
+
7
+ patch_masking_utils = True
8
+ except ImportError:
9
+ patch_masking_utils = False
10
+
11
+
12
+ if patch_masking_utils:
13
+ # Introduced in 4.52
14
+ from transformers.masking_utils import (
15
+ _ignore_causal_mask_sdpa,
16
+ and_masks,
17
+ causal_mask_function,
18
+ padding_mask_function,
19
+ prepare_padding_mask,
20
+ )
21
+
22
+ try:
23
+ # transformers>=5.0
24
+ from transformers.masking_utils import (
25
+ _ignore_bidirectional_mask_sdpa,
26
+ bidirectional_mask_function,
27
+ )
28
+ except ImportError:
29
+ _ignore_bidirectional_mask_sdpa = None
30
+ bidirectional_mask_function = None
31
+
32
+ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
33
+ """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
34
+ from ...helpers import string_type
35
+
36
+ dimensions: List[Tuple[Optional[int], ...]] = [
37
+ (None, None, None, 0),
38
+ (None, None, 0, None),
39
+ ]
40
+ if bh_indices:
41
+ dimensions.extend([(None, 0, None, None), (0, None, None, None)])
42
+ # reshape
43
+ dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
44
+ dimensions = tuple(reversed(dimensions))
45
+ indices = tuple(shape.index(-1) for shape in dimensions)
46
+
47
+ # unsqueeze
48
+ udimensions = [
49
+ tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions
50
+ ]
51
+
52
+ def vector_mask_function(
53
+ *args, mask_function=mask_function, dimensions=dimensions, indices=indices
54
+ ):
55
+ assert len(args) == len(dimensions) == len(udimensions), (
56
+ f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
57
+ f"and udimensions={udimensions}."
58
+ )
59
+ assert len(indices) == len(args), (
60
+ f"Mismatch between args={string_type(args)} and indices={indices}, "
61
+ f"they should have the same length."
62
+ )
63
+ for a in args:
64
+ assert (
65
+ a.ndim == 1
66
+ ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
67
+ torch._check(a.shape[0] > 0)
68
+
69
+ new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
70
+ # new_args = [
71
+ # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
72
+ # for a, dims in zip(args, udimensions)
73
+ # ]
74
+ max_shape = tuple(args[i].shape[0] for i in indices)
75
+ # if _is_torchdynamo_exporting():
76
+ # for a in args:
77
+ # # The exporter should export with a dimension > 1
78
+ # # to make sure it is dynamic.
79
+ # torch._check(a.shape[0] > 1)
80
+ expanded_args = [a.expand(max_shape) for a in new_args]
81
+ return mask_function(*expanded_args)
82
+
83
+ return vector_mask_function
84
+
85
+ def patched_eager_mask(
86
+ batch_size: int,
87
+ cache_position: torch.Tensor,
88
+ kv_length: int,
89
+ kv_offset: int = 0,
90
+ mask_function: Callable = causal_mask_function,
91
+ attention_mask: Optional[torch.Tensor] = None,
92
+ dtype: torch.dtype = torch.float32,
93
+ **kwargs,
94
+ ) -> torch.Tensor:
95
+ """manual patch for function ``transformers.masking_utils.eager_mask``."""
96
+ # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
97
+ _ = kwargs.pop("allow_is_causal_skip", None)
98
+ _ = kwargs.pop("allow_is_bidirectional_skip", None)
99
+ # PATCHED: this line called the patched version of sdpa_mask
100
+ mask = patched_sdpa_mask_recent_torch(
101
+ batch_size=batch_size,
102
+ cache_position=cache_position,
103
+ kv_length=kv_length,
104
+ kv_offset=kv_offset,
105
+ mask_function=mask_function,
106
+ attention_mask=attention_mask,
107
+ allow_is_causal_skip=False,
108
+ allow_is_bidirectional_skip=False,
109
+ allow_torch_fix=False,
110
+ **kwargs,
111
+ )
112
+ min_dtype = torch.finfo(dtype).min
113
+ # PATCHED: the following line
114
+ # we need 0s where the tokens should be taken into account,
115
+ # and -inf otherwise (mask is already of boolean type)
116
+ # mask =
117
+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
118
+ mask = (~mask).to(dtype) * min_dtype
119
+ return mask
120
+
121
+ def patched_sdpa_mask_recent_torch(
122
+ batch_size: int,
123
+ cache_position: torch.Tensor,
124
+ kv_length: int,
125
+ kv_offset: int = 0,
126
+ mask_function: Callable = causal_mask_function,
127
+ attention_mask: Optional[torch.Tensor] = None,
128
+ local_size: Optional[int] = None,
129
+ allow_is_causal_skip: bool = True,
130
+ allow_is_bidirectional_skip: bool = False,
131
+ **kwargs,
132
+ ) -> Optional[torch.Tensor]:
133
+ """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
134
+ q_length = cache_position.shape[0]
135
+ padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
136
+ if allow_is_causal_skip and _ignore_causal_mask_sdpa(
137
+ padding_mask, q_length, kv_length, kv_offset, local_size
138
+ ):
139
+ return None
140
+ if (
141
+ allow_is_bidirectional_skip
142
+ and _ignore_bidirectional_mask_sdpa
143
+ and _ignore_bidirectional_mask_sdpa(padding_mask)
144
+ ):
145
+ return None
146
+
147
+ if mask_function is bidirectional_mask_function:
148
+ if padding_mask is not None:
149
+ # used for slicing without data-dependent slicing
150
+ mask_indices = (
151
+ torch.arange(kv_length, device=cache_position.device) + kv_offset
152
+ )
153
+ return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
154
+ return torch.ones(
155
+ batch_size,
156
+ 1,
157
+ q_length,
158
+ kv_length,
159
+ dtype=torch.bool,
160
+ device=cache_position.device,
161
+ )
162
+
163
+ kv_arange = torch.arange(kv_length, device=cache_position.device)
164
+ kv_arange += kv_offset
165
+ if padding_mask is not None:
166
+ mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
167
+ batch_arange = torch.arange(batch_size, device=cache_position.device)
168
+ head_arange = torch.arange(1, device=cache_position.device)
169
+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
170
+ causal_mask = patched__vmap_for_bhqkv(mask_function)(
171
+ batch_arange, head_arange, cache_position, kv_arange
172
+ )
173
+ return causal_mask
@@ -0,0 +1,99 @@
1
+ from typing import Optional, Tuple
2
+ import torch
3
+
4
+ try:
5
+ import transformers.models.qwen2_vl
6
+
7
+ patch_qwen2 = True
8
+ except ImportError:
9
+ patch_qwen2 = False
10
+
11
+
12
+ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
13
+ """
14
+ Rewrites the loop in:
15
+
16
+ .. code-block:: python
17
+
18
+ attention_mask = torch.full(
19
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
20
+ )
21
+ for i in range(1, len(seq)):
22
+ attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
23
+ """
24
+ r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
25
+ less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
26
+ less = less0.sum(axis=-1, keepdim=True) + 1
27
+ sq = less * less.T
28
+ look = (
29
+ torch.max(seq.min() == 0, less != less.max())
30
+ * torch.max(seq.max() == mask.shape[-1], less != less.min())
31
+ * less
32
+ )
33
+ filt = (sq != look**2).to(mask.dtype)
34
+ return mask * filt
35
+
36
+
37
+ if patch_qwen2:
38
+
39
+ class patched_VisionAttention(torch.nn.Module):
40
+ _PATCHES_ = ["forward"]
41
+ _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
42
+
43
+ def forward(
44
+ self,
45
+ hidden_states: torch.Tensor,
46
+ cu_seqlens: torch.Tensor,
47
+ rotary_pos_emb: Optional[torch.Tensor] = None,
48
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
49
+ ) -> torch.Tensor:
50
+ seq_length = hidden_states.shape[0]
51
+ q, k, v = (
52
+ self.qkv(hidden_states)
53
+ .reshape(seq_length, 3, self.num_heads, -1)
54
+ .permute(1, 0, 2, 3)
55
+ .unbind(0)
56
+ )
57
+ if position_embeddings is None:
58
+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
59
+ "The attention layers in this model are transitioning from "
60
+ " computing the RoPE embeddings internally "
61
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
62
+ "to using externally computed "
63
+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
64
+ " In v4.54 `rotary_pos_emb` will be "
65
+ "removed and `position_embeddings` will be mandatory."
66
+ )
67
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
68
+ cos = emb.cos()
69
+ sin = emb.sin()
70
+ else:
71
+ cos, sin = position_embeddings
72
+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
73
+ q, k, cos, sin
74
+ )
75
+
76
+ attention_mask = torch.full(
77
+ [1, seq_length, seq_length],
78
+ torch.finfo(q.dtype).min,
79
+ device=q.device,
80
+ dtype=q.dtype,
81
+ )
82
+ # for i in range(1, len(cu_seqlens)):
83
+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
84
+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
85
+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
86
+
87
+ q = q.transpose(0, 1)
88
+ k = k.transpose(0, 1)
89
+ v = v.transpose(0, 1)
90
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / (self.head_dim**0.5)
91
+ attn_weights = attn_weights + attention_mask
92
+ attn_weights = torch.nn.functional.softmax(
93
+ attn_weights, dim=-1, dtype=torch.float32
94
+ ).to(q.dtype)
95
+ attn_output = torch.matmul(attn_weights, v)
96
+ attn_output = attn_output.transpose(0, 1)
97
+ attn_output = attn_output.reshape(seq_length, -1)
98
+ attn_output = self.proj(attn_output)
99
+ return attn_output