onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,16 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4
- from ..helpers.config_helper import (
5
- update_config,
6
- check_hasattr,
7
- _pick,
8
- default_num_hidden_layers as nhl,
9
- )
4
+ from ..helpers.config_helper import update_config, check_hasattr
10
5
 
11
6
  __TASK__ = "summarization"
12
7
 
13
8
 
14
9
  def reduce_model_config(config: Any) -> Dict[str, Any]:
15
10
  """Reduces a model size."""
16
- kwargs: Dict[str, Any] = {}
17
- if hasattr(config, "num_decoder_layers"):
18
- config.num_decoder_layers = min(config.num_decoder_layers, 2)
19
- if hasattr(config, "num_hidden_layers"):
20
- config.num_hidden_layers = min(config.num_hidden_layers, nhl())
11
+ check_hasattr(config, "vocab_size")
12
+ # Bart architecture does not like too much that the number of layers is changed.
13
+ kwargs = dict(vocab_size=2056)
21
14
  update_config(config, kwargs)
22
15
  return kwargs
23
16
 
@@ -25,96 +18,66 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
25
18
  def get_inputs(
26
19
  model: torch.nn.Module,
27
20
  config: Optional[Any],
21
+ batch_size: int,
22
+ sequence_length: int,
28
23
  dummy_max_token_id: int,
29
- num_key_value_heads_encoder: int,
30
- num_key_value_heads_decoder: int,
31
- num_hidden_layers: int,
32
- head_dim_encoder: int,
33
- head_dim_decoder: int,
34
- batch_size: int = 2,
35
- sequence_length: int = 30,
36
- sequence_length2: int = 3,
24
+ past_length: int = 30,
25
+ past_length2: int = 4,
26
+ decoder_attention_heads: Optional[int] = None,
27
+ encoder_attention_heads: Optional[int] = None,
28
+ encoder_ffn_dim: Optional[int] = None,
29
+ decoder_ffn_dim: Optional[int] = None,
30
+ num_hidden_layers: Optional[int] = None,
37
31
  add_second_input: int = 1,
38
32
  **kwargs, # unused
39
33
  ):
40
34
  """
41
- Generates input for task ``summarization``.
42
-
43
- :param model: model to get the missing information
44
- :param config: configuration used to generate the model
45
- :param head_dim_encoder: last dimension of the cache for the encoder
46
- :param head_dim_decoder: last dimension of the cache for the decoder
47
- :param num_key_value_heads_encoder: number of heads for the encoder
48
- :param num_key_value_heads_decoder: number of heads for the decoder
49
- :param dummy_max_token_id: dummy max token id
50
- :param batch_size: batch size
51
- :param sequence_length: sequence length
52
- :param sequence_length2: new sequence length
53
- :return: dictionary
54
-
55
- Stolen inputs for one model.
35
+ Generates inputs for task ``feature-extraction``.
36
+ Example:
56
37
 
57
38
  ::
58
39
 
59
- cache_position:T7s1
60
- past_key_values:EncoderDecoderCache(
61
- self_attention_cache=DynamicCache(
62
- key_cache=#6[T1s1x8x1x64,...],
63
- value_cache=#6[T1s1x8x1x64,...]),
64
- cross_attention_cache=DynamicCache(
65
- key_cache=#6[T1s1x8x16x64,...],
66
- value_cache=#6[T1s1x8x16x64,...])),
67
- decoder_input_ids:T7s1x1,
68
- encoder_outputs:dict(last_hidden_state:T1s1x16x512)
40
+ input_ids:T7s1x13[101,72654:A16789.23076923077],
41
+ token_type_ids:T7s1x13[0,0:A0.0],
42
+ attention_mask:T7s1x13[1,1:A1.0])
69
43
  """
70
44
  assert (
71
45
  "cls_cache" not in kwargs
72
46
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
73
47
  batch = "batch"
74
- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
75
- cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
76
- cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
77
-
48
+ seq_length = "sequence_length"
78
49
  shapes = {
79
50
  "input_ids": {0: batch, 1: seq_length},
80
- "decoder_input_ids": {0: batch, 1: "seq_ids"},
81
- "attention_mask": {0: batch, 1: "seq_mask"},
82
- # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
83
- "past_key_values": [
84
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
85
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
86
- ],
87
- # one these is selected based on the forward method signature
88
- # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
89
- # "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
51
+ "attention_mask": {0: batch, 1: seq_length},
90
52
  }
91
-
92
53
  inputs = dict(
93
54
  input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
94
55
  torch.int64
95
56
  ),
96
- decoder_input_ids=torch.randint(
97
- 0, dummy_max_token_id, (batch_size, sequence_length2)
98
- ).to(torch.int64),
99
57
  attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
100
- # cache_position=torch.arange(sequence_length, sequence_length + sequence_length2)
101
- # .to(torch.int64)
102
- # .expand((batch_size, -1)),
103
- past_key_values=make_encoder_decoder_cache(
58
+ )
59
+ if (
60
+ encoder_attention_heads
61
+ and decoder_attention_heads
62
+ and encoder_ffn_dim
63
+ and decoder_ffn_dim
64
+ and num_hidden_layers
65
+ ):
66
+ inputs["past_key_values"] = make_encoder_decoder_cache(
104
67
  make_dynamic_cache(
105
68
  [
106
69
  (
107
70
  torch.randn(
108
71
  batch_size,
109
- num_key_value_heads_encoder,
110
- sequence_length,
111
- head_dim_encoder,
72
+ encoder_attention_heads,
73
+ past_length,
74
+ encoder_ffn_dim,
112
75
  ),
113
76
  torch.randn(
114
77
  batch_size,
115
- num_key_value_heads_encoder,
116
- sequence_length,
117
- head_dim_encoder,
78
+ encoder_attention_heads,
79
+ past_length,
80
+ encoder_ffn_dim,
118
81
  ),
119
82
  )
120
83
  for i in range(num_hidden_layers)
@@ -125,22 +88,28 @@ def get_inputs(
125
88
  (
126
89
  torch.randn(
127
90
  batch_size,
128
- num_key_value_heads_decoder,
129
- sequence_length2,
130
- head_dim_decoder,
91
+ decoder_attention_heads,
92
+ past_length2,
93
+ decoder_ffn_dim,
131
94
  ),
132
95
  torch.randn(
133
96
  batch_size,
134
- num_key_value_heads_decoder,
135
- sequence_length2,
136
- head_dim_decoder,
97
+ decoder_attention_heads,
98
+ past_length2,
99
+ decoder_ffn_dim,
137
100
  ),
138
101
  )
139
102
  for i in range(num_hidden_layers)
140
103
  ]
141
104
  ),
142
- ),
143
- )
105
+ )
106
+ cache_length = "cache_length_key"
107
+ cache_length2 = "cache_length_val"
108
+ shapes["past_key_values"] = [ # type: ignore[assignment]
109
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
110
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
111
+ ]
112
+
144
113
  res = dict(inputs=inputs, dynamic_shapes=shapes)
145
114
  if add_second_input:
146
115
  assert (
@@ -149,15 +118,16 @@ def get_inputs(
149
118
  res["inputs2"] = get_inputs(
150
119
  model=model,
151
120
  config=config,
152
- dummy_max_token_id=dummy_max_token_id,
153
- num_key_value_heads_encoder=num_key_value_heads_encoder,
154
- num_key_value_heads_decoder=num_key_value_heads_decoder,
155
- num_hidden_layers=num_hidden_layers,
156
- head_dim_encoder=head_dim_encoder,
157
- head_dim_decoder=head_dim_decoder,
158
121
  batch_size=batch_size + 1,
159
122
  sequence_length=sequence_length + add_second_input,
160
- sequence_length2=sequence_length2 + 1,
123
+ dummy_max_token_id=dummy_max_token_id,
124
+ past_length=past_length,
125
+ past_length2=past_length2,
126
+ decoder_attention_heads=decoder_attention_heads,
127
+ encoder_attention_heads=encoder_attention_heads,
128
+ encoder_ffn_dim=encoder_ffn_dim,
129
+ decoder_ffn_dim=decoder_ffn_dim,
130
+ num_hidden_layers=num_hidden_layers,
161
131
  add_second_input=0,
162
132
  **kwargs,
163
133
  )["inputs"]
@@ -171,57 +141,22 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
171
141
  If the configuration is None, the function selects typical dimensions.
172
142
  """
173
143
  if config is not None:
174
- check_hasattr(
175
- config,
176
- "vocab_size",
177
- "hidden_size",
178
- "num_attention_heads",
179
- ("num_hidden_layers", "num_layers"),
180
- ("n_positions", "d_model"),
181
- (
182
- "num_key_value_heads",
183
- "num_heads",
184
- ("decoder_attention_heads", "encoder_attention_heads"),
185
- ),
186
- )
187
- # exceptions = {
188
- # "PLBartForConditionalGeneration": (
189
- # lambda c: c.encoder_attention_heads + c.decoder_attention_heads
190
- # )
191
- # }
144
+ check_hasattr(config, "vocab_size")
192
145
  kwargs = dict(
193
146
  batch_size=2,
194
- sequence_length=30,
195
- sequence_length2=3,
196
- head_dim_encoder=(
197
- 16 if config is None else int(_pick(config, "encoder_ffn_dim") ** 0.5)
198
- ),
199
- head_dim_decoder=(
200
- 16 if config is None else int(_pick(config, "decoder_ffn_dim") ** 0.5)
201
- ),
202
- dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
203
- num_hidden_layers=(
204
- 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
205
- ),
206
- num_key_value_heads_encoder=(
207
- 16
208
- if config is None
209
- else _pick(
210
- config,
211
- "encoder_attention_heads",
212
- "num_key_value_heads",
213
- "num_heads",
214
- )
215
- ),
216
- num_key_value_heads_decoder=(
217
- 16
218
- if config is None
219
- else _pick(
220
- config,
221
- "decoder_attention_heads",
222
- "num_key_value_heads",
223
- "num_heads",
224
- )
225
- ),
147
+ sequence_length=12,
148
+ past_length=30,
149
+ past_length2=4,
150
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
226
151
  )
152
+ for att in [
153
+ "decoder_attention_heads",
154
+ "encoder_attention_heads",
155
+ "encoder_ffn_dim",
156
+ "decoder_ffn_dim",
157
+ "num_hidden_layers",
158
+ ]:
159
+ if hasattr(config, att):
160
+ kwargs[att] = getattr(config, att)
161
+ kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
227
162
  return kwargs, get_inputs
@@ -0,0 +1,236 @@
1
+ from typing import Optional
2
+ import torch
3
+ import transformers
4
+ from .patch_helper import _has_transformers
5
+
6
+ patch_sdpa_is_causal = _has_transformers("4.99")
7
+
8
+
9
+ def common_eager_attention_forward(
10
+ module: torch.nn.Module,
11
+ query: torch.Tensor,
12
+ key: torch.Tensor,
13
+ value: torch.Tensor,
14
+ attention_mask: Optional[torch.Tensor],
15
+ scaling: Optional[float] = None,
16
+ dropout: float = 0.0,
17
+ head_mask: Optional[torch.Tensor] = None,
18
+ **kwargs,
19
+ ):
20
+ if scaling is None:
21
+ scaling = query.size(-1) ** -0.5
22
+
23
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
24
+ if attention_mask is not None:
25
+ # PATCHED
26
+ # The two following lines were added.
27
+ if attention_mask is not None and attention_mask.ndim == 4:
28
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
29
+ attn_weights = attn_weights + attention_mask
30
+
31
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
32
+
33
+ if head_mask is not None:
34
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
35
+
36
+ attn_weights = torch.nn.functional.dropout(
37
+ attn_weights, p=dropout, training=module.training
38
+ )
39
+ attn_output = torch.matmul(attn_weights, value)
40
+ attn_output = attn_output.transpose(1, 2).contiguous()
41
+
42
+ return attn_output, attn_weights
43
+
44
+
45
+ def patched_sdpa_attention_forward(
46
+ module: torch.nn.Module,
47
+ query: torch.Tensor,
48
+ key: torch.Tensor,
49
+ value: torch.Tensor,
50
+ attention_mask: Optional[torch.Tensor],
51
+ dropout: float = 0.0,
52
+ scaling: Optional[float] = None,
53
+ is_causal: Optional[bool] = None,
54
+ **kwargs,
55
+ ) -> tuple[torch.Tensor, None]:
56
+ """
57
+ manual patch for function
58
+ ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
59
+ """
60
+ assert not kwargs.get("output_attentions", False), (
61
+ "`sdpa` attention does not support `output_attentions=True`."
62
+ " Please set your attention to `eager` if you want any of these features."
63
+ )
64
+ torch._check(
65
+ query.shape[0] == key.shape[0] or query.shape[0] == 1,
66
+ lambda: (
67
+ f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
68
+ f"value: {value.shape}"
69
+ ),
70
+ )
71
+ torch._check(
72
+ key.shape[0] == value.shape[0] or key.shape[0] == 1,
73
+ lambda: (
74
+ f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
75
+ f"value: {value.shape}"
76
+ ),
77
+ )
78
+
79
+ sdpa_kwargs = {}
80
+ if hasattr(module, "num_key_value_groups"):
81
+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
82
+ key = transformers.integrations.sdpa_attention.repeat_kv(
83
+ key, module.num_key_value_groups
84
+ )
85
+ value = transformers.integrations.sdpa_attention.repeat_kv(
86
+ value, module.num_key_value_groups
87
+ )
88
+ else:
89
+ sdpa_kwargs = {"enable_gqa": True}
90
+
91
+ if attention_mask is not None and attention_mask.ndim == 4:
92
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
93
+
94
+ torch._check(
95
+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
96
+ lambda: "Attention mask shape incompatible with key shape.",
97
+ )
98
+
99
+ if patch_sdpa_is_causal:
100
+ # transformers>=4.55
101
+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
102
+
103
+ # PATCHED: remove the test query.shape[2] > 1
104
+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
105
+ # and we split the test to keep the minimum in torch.cond
106
+ is_causal = attention_mask is None and is_causal
107
+
108
+ if not is_causal:
109
+ torch._check(query.shape[0] > 0)
110
+ torch._check(query.shape[1] > 0)
111
+ torch._check(query.shape[2] > 0)
112
+ torch._check(query.shape[3] > 0)
113
+ torch._check(key.shape[0] > 0)
114
+ torch._check(key.shape[1] > 0)
115
+ torch._check(key.shape[2] > 0)
116
+ torch._check(key.shape[3] > 0)
117
+ torch._check(value.shape[0] > 0)
118
+ torch._check(value.shape[1] > 0)
119
+ torch._check(value.shape[2] > 0)
120
+ torch._check(value.shape[3] > 0)
121
+
122
+ return (
123
+ torch.nn.functional.scaled_dot_product_attention(
124
+ query,
125
+ key,
126
+ value,
127
+ attn_mask=attention_mask,
128
+ dropout_p=dropout,
129
+ scale=scaling,
130
+ is_causal=is_causal,
131
+ **sdpa_kwargs,
132
+ )
133
+ .transpose(1, 2)
134
+ .contiguous(),
135
+ None,
136
+ )
137
+ else:
138
+ # transformers<4.55
139
+ if is_causal is None and attention_mask is not None:
140
+ is_causal = False
141
+ if is_causal is not None:
142
+ return (
143
+ torch.nn.functional.scaled_dot_product_attention(
144
+ query,
145
+ key,
146
+ value,
147
+ attn_mask=attention_mask,
148
+ dropout_p=dropout,
149
+ scale=scaling,
150
+ is_causal=is_causal,
151
+ **sdpa_kwargs,
152
+ )
153
+ .transpose(1, 2)
154
+ .contiguous(),
155
+ None,
156
+ )
157
+
158
+ # To avoid the following errors:
159
+ # is_causal=query.shape[2] > 1
160
+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
161
+ # is_causal=torch.tensor(query.shape[2] > 1)
162
+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
163
+ attn_output = torch.cond(
164
+ query.shape[2] > 1, # distinction between prefill and decoding steps
165
+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
166
+ query,
167
+ key,
168
+ value,
169
+ dropout_p=dropout,
170
+ scale=scaling,
171
+ is_causal=True,
172
+ **sdpa_kwargs,
173
+ ).contiguous(),
174
+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
175
+ query,
176
+ key,
177
+ value,
178
+ dropout_p=dropout,
179
+ scale=scaling,
180
+ is_causal=False,
181
+ **sdpa_kwargs,
182
+ ).contiguous(),
183
+ [query, key, value],
184
+ )
185
+ attn_output = attn_output.transpose(1, 2).contiguous()
186
+ return attn_output, None
187
+
188
+
189
+ def patched_model_bart_eager_attention_forward(
190
+ module: torch.nn.Module,
191
+ query: torch.Tensor,
192
+ key: torch.Tensor,
193
+ value: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor],
195
+ scaling: Optional[float] = None,
196
+ dropout: float = 0.0,
197
+ head_mask: Optional[torch.Tensor] = None,
198
+ **kwargs,
199
+ ):
200
+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
201
+ return common_eager_attention_forward(
202
+ module,
203
+ query,
204
+ key,
205
+ value,
206
+ attention_mask=attention_mask,
207
+ scaling=scaling,
208
+ dropout=dropout,
209
+ head_mask=head_mask,
210
+ **kwargs,
211
+ )
212
+
213
+
214
+ def patched_modeling_marian_eager_attention_forward(
215
+ module: torch.nn.Module,
216
+ query: torch.Tensor,
217
+ key: torch.Tensor,
218
+ value: torch.Tensor,
219
+ attention_mask: Optional[torch.Tensor],
220
+ scaling: Optional[float] = None,
221
+ dropout: float = 0.0,
222
+ head_mask: Optional[torch.Tensor] = None,
223
+ **kwargs,
224
+ ):
225
+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
226
+ return common_eager_attention_forward(
227
+ module,
228
+ query,
229
+ key,
230
+ value,
231
+ attention_mask=attention_mask,
232
+ scaling=scaling,
233
+ dropout=dropout,
234
+ head_mask=head_mask,
235
+ **kwargs,
236
+ )
@@ -0,0 +1,50 @@
1
+ from typing import Optional
2
+ import inspect
3
+ import transformers
4
+
5
+ try:
6
+ from transformers.cache_utils import parse_processor_args # noqa: F401
7
+
8
+ patch_parse_processor_args = True
9
+ except ImportError:
10
+ patch_parse_processor_args = False
11
+
12
+
13
+ if patch_parse_processor_args:
14
+
15
+ def _init_cache_inspect():
16
+ res = {}
17
+ for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
18
+ try:
19
+ params = list(inspect.signature(processor_class.__init__).parameters)[2:]
20
+ res[processor_class.__init__] = params
21
+ except Exception:
22
+ res[processor_class.__init__] = None
23
+ return res
24
+
25
+ _cache_inspect = _init_cache_inspect()
26
+
27
+ def patched_parse_processor_args(
28
+ processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
29
+ ) -> tuple[dict, dict]:
30
+ """[patch:transformers.cache_utils.parse_processor_args]"""
31
+ # If not patched...
32
+ # Fails with transformers>=4.54 because function ``parse_processor_args``
33
+ # relies in inspect and the exporter is not very fond of that.
34
+ # torch._dynamo.exc.Unsupported: id() with unsupported args
35
+ # Explanation: Dynamo doesn't know how to trace id()
36
+ # call with args
37
+ # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
38
+ # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
39
+ # objects from outside the compiled region.
40
+ # Hint: It may be possible to write Dynamo tracing rules for this code.
41
+ #
42
+ # The patch is caching the signature to avoid any call to inspect.
43
+ if processor_class is None:
44
+ return {}, kwargs
45
+ params = _cache_inspect[processor_class.__init__]
46
+ if params is None:
47
+ return {}, kwargs
48
+ processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
49
+ remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
50
+ return processor_kwargs, remaining_kwargs
@@ -0,0 +1,89 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+ import torch
4
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
5
+ from .patch_helper import _has_transformers
6
+
7
+
8
+ def _patch_make_causal_mask(
9
+ input_ids_shape: torch.Size,
10
+ dtype: torch.dtype,
11
+ device: torch.device,
12
+ past_key_values_length: int = 0,
13
+ sliding_window: Optional[int] = None,
14
+ ):
15
+ """Patched method."""
16
+ bsz, tgt_len = input_ids_shape
17
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
18
+ mask_cond = torch.arange(mask.size(-1), device=device)
19
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
20
+
21
+ mask = mask.to(dtype)
22
+
23
+ if past_key_values_length > 0:
24
+ mask = torch.cat(
25
+ [
26
+ torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
27
+ mask,
28
+ ],
29
+ dim=-1,
30
+ )
31
+
32
+ if sliding_window is not None:
33
+ diagonal = past_key_values_length - sliding_window - 1
34
+
35
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
36
+ # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
37
+ # and used masked_fill instead of masked_fill_
38
+ # In this case, the current implementation of torch fails (17/12/2024).
39
+ # Try model Phi-3.5-Mini-Instruct.
40
+ mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
41
+
42
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
43
+
44
+
45
+ @dataclass
46
+ class patched_AttentionMaskConverter:
47
+ """
48
+ Patches
49
+ ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
50
+ """
51
+
52
+ # This method was fixed in 4.51 at least.
53
+ _PATCHES_ = ["_make_causal_mask"] if not _has_transformers("4.48.3") else []
54
+ _PATCHED_CLASS_ = AttentionMaskConverter
55
+
56
+ @staticmethod
57
+ def _make_causal_mask(
58
+ *args,
59
+ **kwargs,
60
+ # input_ids_shape: torch.Size,
61
+ # dtype: torch.dtype,
62
+ # device: torch.device,
63
+ # past_key_values_length: int = 0,
64
+ # sliding_window: Optional[int] = None,
65
+ ):
66
+ """
67
+ Patched method.
68
+
69
+ This static method may be called with ``AttentionMaskConverter._make_causal_mask``
70
+ or ``self._make_causal_mask``. That changes this argument is receives.
71
+ That should not matter but...
72
+ The patch should be implemented in another way. static methods do not play well
73
+ with a simple replacement.
74
+ Fortunately, this patch does not seem to be needed anymore with transformers>=4.48.3.
75
+ """
76
+ if args:
77
+ index = 0 if isinstance(args[0], (tuple, torch.Size)) else 1
78
+ names = [
79
+ "input_ids_shape",
80
+ "dtype",
81
+ "device",
82
+ "past_key_values_length",
83
+ "sliding_window",
84
+ ]
85
+ for i, a in enumerate(args):
86
+ if i < index:
87
+ continue
88
+ kwargs[names[i - index]] = a
89
+ return _patch_make_causal_mask(**kwargs)