onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,89 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ import torch
4
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
5
+ from .patch_helper import _has_transformers
6
+
7
+
8
+ def _patch_make_causal_mask(
9
+ input_ids_shape: torch.Size,
10
+ dtype: torch.dtype,
11
+ device: torch.device,
12
+ past_key_values_length: int = 0,
13
+ sliding_window: Optional[int] = None,
14
+ ):
15
+ """Patched method."""
16
+ bsz, tgt_len = input_ids_shape
17
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
18
+ mask_cond = torch.arange(mask.size(-1), device=device)
19
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
20
+
21
+ mask = mask.to(dtype)
22
+
23
+ if past_key_values_length > 0:
24
+ mask = torch.cat(
25
+ [
26
+ torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
27
+ mask,
28
+ ],
29
+ dim=-1,
30
+ )
31
+
32
+ if sliding_window is not None:
33
+ diagonal = past_key_values_length - sliding_window - 1
34
+
35
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
36
+ # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
37
+ # and used masked_fill instead of masked_fill_
38
+ # In this case, the current implementation of torch fails (17/12/2024).
39
+ # Try model Phi-3.5-Mini-Instruct.
40
+ mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
41
+
42
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
43
+
44
+
45
+ @dataclass
46
+ class patched_AttentionMaskConverter:
47
+ """
48
+ Patches
49
+ ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
50
+ """
51
+
52
+ # This method was fixed in 4.51 at least.
53
+ _PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
54
+ _PATCHED_CLASS_ = AttentionMaskConverter
55
+
56
+ @staticmethod
57
+ def _make_causal_mask(
58
+ *args,
59
+ **kwargs,
60
+ # input_ids_shape: torch.Size,
61
+ # dtype: torch.dtype,
62
+ # device: torch.device,
63
+ # past_key_values_length: int = 0,
64
+ # sliding_window: Optional[int] = None,
65
+ ):
66
+ """
67
+ Patched method.
68
+
69
+ This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70
+ or ``self._make_causal_mask``. That changes this argument is receives.
71
+ That should not matter but...
72
+ The patch should be implemented in another way. static methods do not play well
73
+ with a simple replacement.
74
+ Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
75
+ """
76
+ if args:
77
+ index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
78
+ names = [
79
+ "input_ids_shape",
80
+ "dtype",
81
+ "device",
82
+ "past_key_values_length",
83
+ "sliding_window",
84
+ ]
85
+ for i, a in enumerate(args):
86
+ if i < index:
87
+ continue
88
+ kwargs[names[i - index]] = a
89
+ return _patch_make_causal_mask(**kwargs)
@@ -0,0 +1,177 @@
1
+ from typing import List, Optional, Tuple
2
+ import packaging.version as pv
3
+ import torch
4
+ import transformers
5
+ from .patch_helper import _has_transformers
6
+
7
+ patch_is_initialized = _has_transformers("4.56.99")
8
+ patch_DynamicCache = pv.Version(transformers.__version__) < pv.Version("4.51")
9
+
10
+ try:
11
+ # transformers>= 4.55.1
12
+ from transformers.cache_utils import DynamicLayer
13
+
14
+ patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
15
+ except ImportError:
16
+ patch_DynamicLayer = False
17
+
18
+
19
+ if patch_DynamicLayer:
20
+
21
+ class patched_DynamicLayer:
22
+ _PATCHES_ = ["lazy_initialization"]
23
+ _PATCHED_CLASS_ = DynamicLayer
24
+
25
+ def lazy_initialization(self, key_states: torch.Tensor):
26
+ self.dtype, self.device = key_states.dtype, key_states.device
27
+ new_shape = list(key_states.shape)
28
+ new_shape[-2] = 0
29
+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
30
+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
31
+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
32
+ if patch_is_initialized:
33
+ self.is_initialized = True
34
+
35
+
36
+ if patch_DynamicCache:
37
+ from typing import Any, Dict
38
+ from transformers.cache_utils import DynamicCache
39
+
40
+ class patched_DynamicCache:
41
+ """
42
+ Applies modifications implemented in PR
43
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
44
+ """
45
+
46
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
47
+ _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
48
+
49
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
50
+ """Returns the sequence length of the cached states.
51
+ A layer index can be optionally passed."""
52
+ # TODO: deprecate this function in favor of `cache_position`
53
+ is_empty_layer = (
54
+ len(self.key_cache) == 0 # no cache in any layer
55
+ or len(self.key_cache)
56
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
57
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
58
+ )
59
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
60
+ return layer_seq_length
61
+
62
+ def reorder_cache(self, beam_idx: torch.LongTensor):
63
+ """Reorders the cache for beam search, given the selected beam indices."""
64
+ for layer_idx in range(len(self.key_cache)):
65
+ if self.key_cache[layer_idx].numel():
66
+ device = self.key_cache[layer_idx].device
67
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
68
+ 0, beam_idx.to(device)
69
+ )
70
+ if self.value_cache[layer_idx].numel():
71
+ device = self.value_cache[layer_idx].device
72
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
73
+ 0, beam_idx.to(device)
74
+ )
75
+
76
+ def update(
77
+ self,
78
+ key_states: torch.Tensor,
79
+ value_states: torch.Tensor,
80
+ layer_idx: int,
81
+ cache_kwargs: Optional[Dict[str, Any]] = None,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """
84
+ Updates the cache with the new `key_states`
85
+ and `value_states` for the layer `layer_idx`.
86
+ Parameters:
87
+ key_states (`torch.Tensor`):
88
+ The new key states to cache.
89
+ value_states (`torch.Tensor`):
90
+ The new value states to cache.
91
+ layer_idx (`int`):
92
+ The index of the layer to cache the states for.
93
+ cache_kwargs (`Dict[str, Any]`, `optional`):
94
+ Additional arguments for the cache subclass.
95
+ No additional arguments are used in `DynamicCache`.
96
+ Return:
97
+ A tuple containing the updated key and value states.
98
+ """
99
+ # Update the number of seen tokens
100
+ if layer_idx == 0:
101
+ if hasattr(self, "_seen_tokens"):
102
+ self._seen_tokens += key_states.shape[-2]
103
+
104
+ # Update the cache
105
+ if key_states is not None:
106
+ if len(self.key_cache) <= layer_idx:
107
+ # There may be skipped layers, fill them with empty lists
108
+ for _ in range(len(self.key_cache), layer_idx):
109
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
110
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
111
+ self.key_cache.append(key_states)
112
+ self.value_cache.append(value_states)
113
+ elif not self.key_cache[
114
+ layer_idx
115
+ ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
116
+ # fills previously skipped layers; checking for tensor causes errors
117
+ self.key_cache[layer_idx] = key_states
118
+ self.value_cache[layer_idx] = value_states
119
+ else:
120
+ torch._check(
121
+ len(self.key_cache[layer_idx].shape) == len(key_states.shape),
122
+ lambda: (
123
+ f"Rank mismatch len(self.key_cache[layer_idx].shape)="
124
+ f"{len(self.key_cache[layer_idx].shape)}, "
125
+ f"len(key_states.shape)={len(key_states.shape)}"
126
+ ),
127
+ )
128
+ self.key_cache[layer_idx] = torch.cat(
129
+ [self.key_cache[layer_idx], key_states], dim=-2
130
+ )
131
+ self.value_cache[layer_idx] = torch.cat(
132
+ [self.value_cache[layer_idx], value_states], dim=-2
133
+ )
134
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
135
+
136
+ def crop(self, max_length: int):
137
+ """Crop the past key values up to a new `max_length`
138
+ in terms of tokens. `max_length` can also be
139
+ negative to remove `max_length` tokens.
140
+ This is used in assisted decoding and contrastive search.
141
+ """
142
+ # In case it is negative
143
+ if max_length < 0:
144
+ max_length = self.get_seq_length() - abs(max_length)
145
+
146
+ if self.get_seq_length() <= max_length:
147
+ return
148
+
149
+ if hasattr(self, "_seen_tokens"):
150
+ self._seen_tokens = max_length
151
+ for idx in range(len(self.key_cache)):
152
+ if self.key_cache[idx].numel():
153
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
154
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
155
+
156
+ @classmethod
157
+ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
158
+ """This is the opposite of the above `batch_split()` method.
159
+ This will be used by `stack_model_outputs` in
160
+ `generation.utils`"""
161
+ cache = cls()
162
+ for idx in range(len(splits[0])):
163
+ key_cache = [
164
+ current.key_cache[idx]
165
+ for current in splits
166
+ if current.key_cache[idx].numel()
167
+ ]
168
+ value_cache = [
169
+ current.value_cache[idx]
170
+ for current in splits
171
+ if current.value_cache[idx].numel()
172
+ ]
173
+ if key_cache != []:
174
+ layer_keys = torch.cat(key_cache, dim=0)
175
+ layer_values = torch.cat(value_cache, dim=0)
176
+ cache.update(layer_keys, layer_values, idx)
177
+ return cache
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import transformers
3
+
4
+ try:
5
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
6
+
7
+ patch_gemma3 = True
8
+ except ImportError:
9
+ patch_gemma3 = False
10
+
11
+
12
+ if patch_gemma3:
13
+
14
+ class patched_Gemma3Model(torch.nn.Module):
15
+ _PATCHES_ = ["get_placeholder_mask"]
16
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
17
+ _PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
18
+
19
+ def get_placeholder_mask(
20
+ self,
21
+ input_ids: torch.LongTensor,
22
+ inputs_embeds: torch.FloatTensor,
23
+ image_features: torch.FloatTensor,
24
+ ):
25
+ if input_ids is None:
26
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
27
+ torch.tensor(
28
+ self.config.image_token_id,
29
+ dtype=torch.long,
30
+ device=inputs_embeds.device,
31
+ )
32
+ )
33
+ special_image_mask = special_image_mask.all(-1)
34
+ else:
35
+ special_image_mask = input_ids == self.config.image_token_id
36
+
37
+ n_image_tokens = special_image_mask.sum()
38
+ special_image_mask = (
39
+ special_image_mask.unsqueeze(-1)
40
+ .expand_as(inputs_embeds)
41
+ .to(inputs_embeds.device)
42
+ )
43
+ n_image_features = image_features.shape[0] * image_features.shape[1]
44
+ # PATCHED: torch._check
45
+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
46
+ # raise ValueError( ... )
47
+ torch._check(
48
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
49
+ lambda: (
50
+ f"Image features and image tokens do not match: tokens: "
51
+ f"{n_image_tokens}, features {n_image_features}"
52
+ ),
53
+ )
54
+ return special_image_mask