onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from typing import List, Optional, Tuple
2
+ import packaging.version as pv
3
+ import torch
4
+ import transformers
5
+ from .patch_helper import _has_transformers
6
+
7
+ patch_is_initialized = _has_transformers("4.56.99")
8
+ patch_DynamicCache = pv.Version(transformers.__version__) < pv.Version("4.51")
9
+
10
+ try:
11
+ # transformers>= 4.55.1
12
+ from transformers.cache_utils import DynamicLayer
13
+
14
+ patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
15
+ except ImportError:
16
+ patch_DynamicLayer = False
17
+
18
+
19
+ if patch_DynamicLayer:
20
+
21
+ class patched_DynamicLayer:
22
+ _PATCHES_ = ["lazy_initialization"]
23
+ _PATCHED_CLASS_ = DynamicLayer
24
+
25
+ def lazy_initialization(self, key_states: torch.Tensor):
26
+ self.dtype, self.device = key_states.dtype, key_states.device
27
+ new_shape = list(key_states.shape)
28
+ new_shape[-2] = 0
29
+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
30
+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
31
+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
32
+ if patch_is_initialized:
33
+ self.is_initialized = True
34
+
35
+
36
+ if patch_DynamicCache:
37
+ from typing import Any, Dict
38
+ from transformers.cache_utils import DynamicCache
39
+
40
+ class patched_DynamicCache:
41
+ """
42
+ Applies modifications implemented in PR
43
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
44
+ """
45
+
46
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
47
+ _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
48
+
49
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
50
+ """Returns the sequence length of the cached states.
51
+ A layer index can be optionally passed."""
52
+ # TODO: deprecate this function in favor of `cache_position`
53
+ is_empty_layer = (
54
+ len(self.key_cache) == 0 # no cache in any layer
55
+ or len(self.key_cache)
56
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
57
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
58
+ )
59
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
60
+ return layer_seq_length
61
+
62
+ def reorder_cache(self, beam_idx: torch.LongTensor):
63
+ """Reorders the cache for beam search, given the selected beam indices."""
64
+ for layer_idx in range(len(self.key_cache)):
65
+ if self.key_cache[layer_idx].numel():
66
+ device = self.key_cache[layer_idx].device
67
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
68
+ 0, beam_idx.to(device)
69
+ )
70
+ if self.value_cache[layer_idx].numel():
71
+ device = self.value_cache[layer_idx].device
72
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
73
+ 0, beam_idx.to(device)
74
+ )
75
+
76
+ def update(
77
+ self,
78
+ key_states: torch.Tensor,
79
+ value_states: torch.Tensor,
80
+ layer_idx: int,
81
+ cache_kwargs: Optional[Dict[str, Any]] = None,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """
84
+ Updates the cache with the new `key_states`
85
+ and `value_states` for the layer `layer_idx`.
86
+ Parameters:
87
+ key_states (`torch.Tensor`):
88
+ The new key states to cache.
89
+ value_states (`torch.Tensor`):
90
+ The new value states to cache.
91
+ layer_idx (`int`):
92
+ The index of the layer to cache the states for.
93
+ cache_kwargs (`Dict[str, Any]`, `optional`):
94
+ Additional arguments for the cache subclass.
95
+ No additional arguments are used in `DynamicCache`.
96
+ Return:
97
+ A tuple containing the updated key and value states.
98
+ """
99
+ # Update the number of seen tokens
100
+ if layer_idx == 0:
101
+ if hasattr(self, "_seen_tokens"):
102
+ self._seen_tokens += key_states.shape[-2]
103
+
104
+ # Update the cache
105
+ if key_states is not None:
106
+ if len(self.key_cache) <= layer_idx:
107
+ # There may be skipped layers, fill them with empty lists
108
+ for _ in range(len(self.key_cache), layer_idx):
109
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
110
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
111
+ self.key_cache.append(key_states)
112
+ self.value_cache.append(value_states)
113
+ elif not self.key_cache[
114
+ layer_idx
115
+ ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
116
+ # fills previously skipped layers; checking for tensor causes errors
117
+ self.key_cache[layer_idx] = key_states
118
+ self.value_cache[layer_idx] = value_states
119
+ else:
120
+ torch._check(
121
+ len(self.key_cache[layer_idx].shape) == len(key_states.shape),
122
+ lambda: (
123
+ f"Rank mismatch len(self.key_cache[layer_idx].shape)="
124
+ f"{len(self.key_cache[layer_idx].shape)}, "
125
+ f"len(key_states.shape)={len(key_states.shape)}"
126
+ ),
127
+ )
128
+ self.key_cache[layer_idx] = torch.cat(
129
+ [self.key_cache[layer_idx], key_states], dim=-2
130
+ )
131
+ self.value_cache[layer_idx] = torch.cat(
132
+ [self.value_cache[layer_idx], value_states], dim=-2
133
+ )
134
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
135
+
136
+ def crop(self, max_length: int):
137
+ """Crop the past key values up to a new `max_length`
138
+ in terms of tokens. `max_length` can also be
139
+ negative to remove `max_length` tokens.
140
+ This is used in assisted decoding and contrastive search.
141
+ """
142
+ # In case it is negative
143
+ if max_length < 0:
144
+ max_length = self.get_seq_length() - abs(max_length)
145
+
146
+ if self.get_seq_length() <= max_length:
147
+ return
148
+
149
+ if hasattr(self, "_seen_tokens"):
150
+ self._seen_tokens = max_length
151
+ for idx in range(len(self.key_cache)):
152
+ if self.key_cache[idx].numel():
153
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
154
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
155
+
156
+ @classmethod
157
+ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
158
+ """This is the opposite of the above `batch_split()` method.
159
+ This will be used by `stack_model_outputs` in
160
+ `generation.utils`"""
161
+ cache = cls()
162
+ for idx in range(len(splits[0])):
163
+ key_cache = [
164
+ current.key_cache[idx]
165
+ for current in splits
166
+ if current.key_cache[idx].numel()
167
+ ]
168
+ value_cache = [
169
+ current.value_cache[idx]
170
+ for current in splits
171
+ if current.value_cache[idx].numel()
172
+ ]
173
+ if key_cache != []:
174
+ layer_keys = torch.cat(key_cache, dim=0)
175
+ layer_values = torch.cat(value_cache, dim=0)
176
+ cache.update(layer_keys, layer_values, idx)
177
+ return cache
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import transformers
3
+
4
+ try:
5
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
6
+
7
+ patch_gemma3 = True
8
+ except ImportError:
9
+ patch_gemma3 = False
10
+
11
+
12
+ if patch_gemma3:
13
+
14
+ class patched_Gemma3Model(torch.nn.Module):
15
+ _PATCHES_ = ["get_placeholder_mask"]
16
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
17
+ _PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
18
+
19
+ def get_placeholder_mask(
20
+ self,
21
+ input_ids: torch.LongTensor,
22
+ inputs_embeds: torch.FloatTensor,
23
+ image_features: torch.FloatTensor,
24
+ ):
25
+ if input_ids is None:
26
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
27
+ torch.tensor(
28
+ self.config.image_token_id,
29
+ dtype=torch.long,
30
+ device=inputs_embeds.device,
31
+ )
32
+ )
33
+ special_image_mask = special_image_mask.all(-1)
34
+ else:
35
+ special_image_mask = input_ids == self.config.image_token_id
36
+
37
+ n_image_tokens = special_image_mask.sum()
38
+ special_image_mask = (
39
+ special_image_mask.unsqueeze(-1)
40
+ .expand_as(inputs_embeds)
41
+ .to(inputs_embeds.device)
42
+ )
43
+ n_image_features = image_features.shape[0] * image_features.shape[1]
44
+ # PATCHED: torch._check
45
+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
46
+ # raise ValueError( ... )
47
+ torch._check(
48
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
49
+ lambda: (
50
+ f"Image features and image tokens do not match: tokens: "
51
+ f"{n_image_tokens}, features {n_image_features}"
52
+ ),
53
+ )
54
+ return special_image_mask