onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +47 -10
  3. onnx_diagnostic/export/api.py +81 -50
  4. onnx_diagnostic/export/control_flow_research.py +10 -5
  5. onnx_diagnostic/export/onnx_plug.py +250 -61
  6. onnx_diagnostic/ext_test_case.py +99 -53
  7. onnx_diagnostic/helpers/dot_helper.py +37 -25
  8. onnx_diagnostic/helpers/helper.py +44 -38
  9. onnx_diagnostic/helpers/onnx_helper.py +441 -18
  10. onnx_diagnostic/helpers/ort_session.py +8 -8
  11. onnx_diagnostic/helpers/torch_helper.py +28 -2
  12. onnx_diagnostic/reference/ort_evaluator.py +6 -29
  13. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +1 -0
  14. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
  15. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +168 -113
  16. onnx_diagnostic/torch_models/code_sample.py +2 -1
  17. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  18. onnx_diagnostic/torch_models/validate.py +14 -1
  19. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  20. onnx_diagnostic/torch_onnx/sbs.py +11 -5
  21. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +48 -4
  22. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
  23. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +26 -26
  24. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
  25. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {onnx_diagnostic-0.8.3.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
@@ -18,10 +18,11 @@ from onnx.defs import onnx_opset_version
18
18
  import onnxruntime
19
19
  from ..helpers import string_type
20
20
  from ..helpers.onnx_helper import (
21
- pretty_onnx,
21
+ get_hidden_inputs,
22
22
  dtype_to_tensor_dtype,
23
- to_array_extended,
24
23
  np_dtype_to_tensor_dtype,
24
+ to_array_extended,
25
+ pretty_onnx,
25
26
  )
26
27
  from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
27
28
  from ..helpers.ort_session import (
@@ -472,39 +473,15 @@ class OnnxruntimeEvaluator:
472
473
  yield from self.enumerate_nodes(att.g.node)
473
474
  yield node
474
475
 
475
- @classmethod
476
- def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
477
- """
478
- Returns the hidden inputs (inputs coming from an upper context)
479
- used by a subgraph.
480
- """
481
- hidden = set()
482
- memo = (
483
- {i.name for i in graph.initializer}
484
- | {i.name for i in graph.sparse_initializer}
485
- | {i.name for i in graph.input}
486
- )
487
- for node in graph.node:
488
- for i in node.input:
489
- if i not in memo:
490
- hidden.add(i)
491
- for att in node.attribute:
492
- if att.type == AttributeProto.GRAPH and att.g:
493
- hid = cls._get_hidden_inputs(att.g)
494
- less = set(h for h in hid if h not in memo)
495
- hidden |= less
496
- memo |= set(node.output)
497
- return hidden
498
-
499
476
  @classmethod
500
477
  def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
501
- """Calls multiple _get_hidden_inputs on every attribute."""
478
+ """Calls multiple get_hidden_inputs on every attribute."""
502
479
  if node.op_type not in {"Loop", "Scan", "If"}:
503
480
  return set()
504
481
  hidden = set()
505
482
  for att in node.attribute:
506
483
  if att.type == AttributeProto.GRAPH:
507
- hidden |= cls._get_hidden_inputs(att.g)
484
+ hidden |= get_hidden_inputs(att.g)
508
485
  return hidden - (hidden & set(node.input))
509
486
 
510
487
  def _get_sess(
@@ -624,7 +601,7 @@ class OnnxruntimeEvaluator:
624
601
  value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
625
602
  vinputs.append(value)
626
603
 
627
- reduced_set = self._get_hidden_inputs(g)
604
+ reduced_set = get_hidden_inputs(g)
628
605
  for i, v in context.items():
629
606
  if i in reduced_set and i not in unique_names:
630
607
  unique_names.add(i)
@@ -118,6 +118,7 @@ def patched_sdpa_attention_forward(
118
118
  torch._check(value.shape[1] > 0)
119
119
  torch._check(value.shape[2] > 0)
120
120
  torch._check(value.shape[3] > 0)
121
+
121
122
  return (
122
123
  torch.nn.functional.scaled_dot_product_attention(
123
124
  query,
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  from typing import Callable, List, Optional, Tuple
2
3
  import torch
3
4
 
@@ -19,6 +20,12 @@ if patch_masking_utils:
19
20
  prepare_padding_mask,
20
21
  )
21
22
 
23
+ _prepare_padding_mask_kwargs = (
24
+ dict(_slice=False)
25
+ if "_slice" in inspect.signature(prepare_padding_mask).parameters
26
+ else {}
27
+ )
28
+
22
29
  try:
23
30
  # transformers>=5.0
24
31
  from transformers.masking_utils import (
@@ -132,7 +139,9 @@ if patch_masking_utils:
132
139
  ) -> Optional[torch.Tensor]:
133
140
  """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
134
141
  q_length = cache_position.shape[0]
135
- padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
142
+ padding_mask = prepare_padding_mask(
143
+ attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
144
+ )
136
145
  if allow_is_causal_skip and _ignore_causal_mask_sdpa(
137
146
  padding_mask, q_length, kv_length, kv_offset, local_size
138
147
  ):
@@ -1,9 +1,11 @@
1
1
  import os
2
- from typing import Callable, Optional
2
+ from typing import Callable, Optional, Tuple
3
3
  import onnx
4
+ import onnx.helper as oh
4
5
  import torch
5
6
  import torch.nn.functional as F
6
7
  from ...export.onnx_plug import EagerDirectReplacementWithOnnx
8
+ from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
7
9
  from .patch_helper import _is_torchdynamo_exporting
8
10
  from ._patch_transformers_attention import patched_sdpa_attention_forward
9
11
 
@@ -22,7 +24,43 @@ if patch_qwen2_5:
22
24
 
23
25
  onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
24
26
  op = onnxscript.opset22
27
+ op23 = onnxscript.onnx_opset.opset23
25
28
  msft_op = onnxscript.values.Opset("com.microsoft", 1)
29
+ STOPAT = (
30
+ int(os.environ.get("STOPAT", None))
31
+ if os.environ.get("STOPAT", None) is not None
32
+ else None
33
+ )
34
+
35
+ def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
36
+ opsets = {d.domain: d.version for d in function_proto.opset_import}
37
+ if "com.microsoft" not in opsets:
38
+ d = function_proto.opset_import.add()
39
+ d.domain = "com.microsoft"
40
+ d.version = 1
41
+ return function_proto
42
+
43
+ def _update_sequence_type(
44
+ itype: int, function_proto: onnx.FunctionProto
45
+ ) -> onnx.FunctionProto:
46
+ proto = oh.make_function(
47
+ function_proto.domain,
48
+ function_proto.name,
49
+ function_proto.input,
50
+ function_proto.output,
51
+ [
52
+ (
53
+ oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype)
54
+ if node.op_type == "SequenceEmpty"
55
+ else node
56
+ )
57
+ for node in function_proto.node
58
+ ],
59
+ attributes=function_proto.attribute,
60
+ attribute_protos=function_proto.attribute_proto,
61
+ opset_imports=function_proto.opset_import,
62
+ )
63
+ return proto
26
64
 
27
65
  @onnxscript.script(opset=onnx_plugs_op)
28
66
  def LoopMHAAttention(
@@ -32,7 +70,6 @@ if patch_qwen2_5:
32
70
  cu_seqlens,
33
71
  scaling: float = 0.11180339887498948,
34
72
  num_heads: int = 16,
35
- itype: int = onnx.TensorProto.FLOAT,
36
73
  ):
37
74
  to_3d_shape = op.Constant(value_ints=[0, 0, -1])
38
75
  query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
@@ -45,7 +82,7 @@ if patch_qwen2_5:
45
82
  seq_axis = op.Constant(value_ints=[1])
46
83
  seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
47
84
  # attn_output = op.Slice(value_3d, [0], [0], seq_axis)
48
- seq_attn = op.SequenceEmpty(dtype=itype)
85
+ seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
49
86
  for i_patch in range(num_patches):
50
87
  i_1d = op.Reshape(i_patch, [1])
51
88
  i_plus_1_1d = i_1d + 1
@@ -55,11 +92,7 @@ if patch_qwen2_5:
55
92
  key_i = op.Slice(key_3d, start, end, seq_axis_int32)
56
93
  value_i = op.Slice(value_3d, start, end, seq_axis_int32)
57
94
  mha_output = msft_op.MultiHeadAttention(
58
- query_i,
59
- key_i,
60
- value_i,
61
- num_heads=num_heads,
62
- scale=scaling,
95
+ query_i, key_i, value_i, num_heads=num_heads, scale=scaling
63
96
  )
64
97
  # attn_output = op.Concat(attn_output, mha_output, axis=1)
65
98
  seq_attn = op.SequenceInsert(seq_attn, mha_output)
@@ -67,13 +100,47 @@ if patch_qwen2_5:
67
100
  attn_output_4d = op.Reshape(attn_output, output_shape)
68
101
  return attn_output_4d
69
102
 
70
- def _add_com_microsoft_opset(function_proto):
71
- opsets = {d.domain: d.version for d in function_proto.opset_import}
72
- if "com.microsoft" not in opsets:
73
- d = function_proto.opset_import.add()
74
- d.domain = "com.microsoft"
75
- d.version = 1
76
- return function_proto
103
+ @onnxscript.script(opset=onnx_plugs_op)
104
+ def LoopAttention23(
105
+ query_states,
106
+ key_states,
107
+ value_states,
108
+ cu_seqlens,
109
+ scaling: float = 0.11180339887498948,
110
+ num_heads: int = 16,
111
+ ):
112
+ to_3d_shape = op23.Constant(value_ints=[0, 0, -1])
113
+ query_transposed = op23.Transpose(query_states, perm=[0, 2, 1, 3])
114
+ output_shape = op23.Shape(query_transposed)
115
+ query_3d = op23.Reshape(query_transposed, to_3d_shape)
116
+ value_3d = op23.Reshape(op23.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
117
+ key_3d = op23.Reshape(op23.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
118
+ cu_seqlens = op23.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
119
+ num_patches = op23.Size(cu_seqlens) - 1
120
+ seq_axis = op23.Constant(value_ints=[1])
121
+ seq_axis_int32 = op23.Cast(seq_axis, to=onnx.TensorProto.INT32)
122
+ seq_attn = op23.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
123
+ for i_patch in range(num_patches):
124
+ i_1d = op23.Reshape(i_patch, [1])
125
+ i_plus_1_1d = i_1d + 1
126
+ start = op23.Gather(cu_seqlens, i_1d, axis=0)
127
+ end = op23.Gather(cu_seqlens, i_plus_1_1d, axis=0)
128
+ query_i = op23.Slice(query_3d, start, end, seq_axis_int32)
129
+ key_i = op23.Slice(key_3d, start, end, seq_axis_int32)
130
+ value_i = op23.Slice(value_3d, start, end, seq_axis_int32)
131
+ mha_output = op23.Attention(
132
+ query_i,
133
+ key_i,
134
+ value_i,
135
+ scale=scaling,
136
+ q_num_heads=num_heads,
137
+ kv_num_heads=num_heads,
138
+ softmax_precision=onnx.TensorProto.FLOAT,
139
+ )
140
+ seq_attn = op23.SequenceInsert(seq_attn, mha_output)
141
+ attn_output = op23.ConcatFromSequence(seq_attn, axis=1)
142
+ attn_output_4d = op23.Reshape(attn_output, output_shape)
143
+ return attn_output_4d
77
144
 
78
145
  @onnxscript.script(opset=onnx_plugs_op)
79
146
  def PackedAttention(
@@ -132,8 +199,40 @@ if patch_qwen2_5:
132
199
  cu_seqlens: torch.Tensor, # F7su19
133
200
  scaling: float = 0,
134
201
  num_heads: int = 16,
135
- itype: int = onnx.TensorProto.FLOAT,
136
202
  ) -> torch.Tensor:
203
+ """
204
+ The loop can be removed with the following code
205
+ but it hits memory overflow for big inputs.
206
+
207
+ .. code-block:: python
208
+
209
+ # make square mask
210
+ indices = torch.arange(
211
+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
212
+ )
213
+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
214
+ cu_seqlens.dtype
215
+ )
216
+ dot = dot.sum(dim=0)
217
+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
218
+ bool_mask = mask == 0
219
+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
220
+
221
+ torch._check(bool_mask.shape[2] == key_states.shape[2])
222
+ torch._check(bool_mask.shape[3] == key_states.shape[2])
223
+
224
+ attn_output, _ = attention_interface(
225
+ self,
226
+ query_states,
227
+ key_states,
228
+ value_states,
229
+ attention_mask=bool_mask,
230
+ scaling=self.scaling,
231
+ dropout=0.0 if not self.training else self.attention_dropout,
232
+ is_causal=False,
233
+ **kwargs,
234
+ )
235
+ """
137
236
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
138
237
  splits = [
139
238
  torch.split(tensor, lengths.tolist(), dim=2)
@@ -156,36 +255,58 @@ if patch_qwen2_5:
156
255
  attn_output = torch.cat(attn_outputs, dim=1)
157
256
  return attn_output
158
257
 
159
- # not ideal
160
- qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx(
161
- qwen_sdpa_attention,
162
- lambda qs, *args, **kwargs: torch.empty(
163
- (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
164
- dtype=qs.dtype,
165
- device=qs.device,
166
- ),
167
- _add_com_microsoft_opset(PackedAttention.to_function_proto()),
168
- n_inputs=4,
169
- n_outputs=1,
170
- kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
171
- name="qwen_sdpa_attention_packed",
172
- )
173
- PLUGS.append(qwen_sdpa_attention_packed_versatile)
258
+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
+ first_tensor = next(a for a in args if a is not None)
260
+ dtype = first_tensor.dtype
261
+ strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
262
+ itype = torch_dtype_to_onnx_dtype(dtype)
263
+ if strategy is not None:
264
+ return strategy, itype
265
+ if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
266
+ if opset >= 23:
267
+ return "LOOPA23", itype
268
+ return "LOOPMHA", itype
269
+ if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270
+ # first_tensor may be a SymbolicTensor (onnx).
271
+ # is_cuda is not available.
272
+ if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
273
+ return "PACKED", itype
274
+ return "LOOPMHA", itype
275
+ raise AssertionError(
276
+ f"Unable to handle type {torch.dtype} (itype={itype}) "
277
+ f"on device {torch.device} with opset={opset}"
278
+ )
174
279
 
175
- qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx(
280
+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
176
281
  qwen_sdpa_attention,
177
282
  lambda qs, *args, **kwargs: torch.empty(
178
283
  (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
179
284
  dtype=qs.dtype,
180
285
  device=qs.device,
181
286
  ),
182
- _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
287
+ {
288
+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
289
+ PackedAttention.to_function_proto()
290
+ ),
291
+ ("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
292
+ ("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
293
+ onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
294
+ ),
295
+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
296
+ LoopMHAAttention.to_function_proto()
297
+ ),
298
+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
299
+ onnx.TensorProto.FLOAT16,
300
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
301
+ ),
302
+ },
183
303
  n_inputs=4,
184
304
  n_outputs=1,
185
- kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
186
- name="qwen_sdpa_attention_loopmha",
305
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
306
+ name="qwen_sdpa_attention_versatile",
307
+ version_selector=qwen_version_selector,
187
308
  )
188
- PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
309
+ PLUGS.append(qwen_sdpa_attention_versatile)
189
310
 
190
311
  class patched_Qwen2_5_VLForConditionalGeneration:
191
312
  _PATCHES_ = ["prepare_inputs_for_generation"]
@@ -434,6 +555,8 @@ if patch_qwen2_5:
434
555
  position_embeddings=position_embeddings,
435
556
  **kwargs,
436
557
  )
558
+ if STOPAT is not None and layer_num > STOPAT:
559
+ break
437
560
 
438
561
  hidden_states = self.merger(hidden_states)
439
562
  reverse_indices = torch.argsort(window_index)
@@ -473,9 +596,7 @@ if patch_qwen2_5:
473
596
  _PATCHED_CLASS_ = (
474
597
  transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
475
598
  )
476
- STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
477
- "QWEN25ATTENTION", "PACKED"
478
- )
599
+ STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
479
600
 
480
601
  def forward(
481
602
  self,
@@ -519,14 +640,15 @@ if patch_qwen2_5:
519
640
  self.config._attn_implementation
520
641
  ]
521
642
 
522
- is_sdpa = (
643
+ is_sdpa_or_eager = (
523
644
  attention_interface
524
645
  is transformers.integrations.sdpa_attention.sdpa_attention_forward
525
646
  or attention_interface is patched_sdpa_attention_forward
647
+ or attention_interface
648
+ is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
526
649
  )
527
- attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
528
- if is_sdpa and attention_strategy in "PACKED":
529
- attn_output = qwen_sdpa_attention_packed_versatile(
650
+ if is_sdpa_or_eager:
651
+ attn_output = qwen_sdpa_attention_versatile(
530
652
  query_states,
531
653
  key_states,
532
654
  value_states,
@@ -558,78 +680,10 @@ if patch_qwen2_5:
558
680
  ),
559
681
  version=1,
560
682
  )
561
- elif is_sdpa and attention_strategy == "LOOPMHA":
562
- attn_output = qwen_sdpa_attention_loopmha_versatile(
563
- query_states,
564
- key_states,
565
- value_states,
566
- cu_seqlens,
567
- self.scaling,
568
- self.num_heads,
569
- (
570
- onnx.TensorProto.FLOAT
571
- if query_states.dtype == torch.float32
572
- else (
573
- onnx.TensorProto.FLOAT16
574
- if query_states.dtype == torch.float16
575
- else onnx.TensorProto.BFLOAT16
576
- )
577
- ),
578
- )
579
-
580
- # to rewrite later with a for loop
581
- # def _iteration(start_end, query_states, key_states, value_states):
582
- # return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
583
- # self,
584
- # start_end,
585
- # query_states,
586
- # key_states,
587
- # value_states,
588
- # scaling=self.scaling,
589
- # dropout=0.0 if not self.training else self.attention_dropout,
590
- # )
591
-
592
- # starts = cu_seqlens[:-1]
593
- # ends = cu_seqlens[1:]
594
- # torch._check(starts.shape[0] > 0)
595
- # torch._check(ends.shape[0] > 0)
596
- # starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
597
- # attn_outputs = [
598
- # _iteration(start_end, query_states, key_states, value_states)
599
- # for start_end in starts_ends
600
- # ]
601
- # attn_output = torch.cat(attn_outputs, dim=1)
602
- elif is_sdpa and attention_strategy == "BIGMASK":
603
- # make square mask
604
- indices = torch.arange(
605
- cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
606
- )
607
- dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
608
- cu_seqlens.dtype
609
- )
610
- dot = dot.sum(dim=0)
611
- mask = dot.unsqueeze(1) - dot.unsqueeze(0)
612
- bool_mask = mask == 0
613
- bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
614
-
615
- torch._check(bool_mask.shape[2] == key_states.shape[2])
616
- torch._check(bool_mask.shape[3] == key_states.shape[2])
617
-
618
- attn_output, _ = attention_interface(
619
- self,
620
- query_states,
621
- key_states,
622
- value_states,
623
- attention_mask=bool_mask,
624
- scaling=self.scaling,
625
- dropout=0.0 if not self.training else self.attention_dropout,
626
- is_causal=False,
627
- **kwargs,
628
- )
629
683
  else:
630
684
  raise NotImplementedError(
631
- f"No corresponding export strategy for "
632
- f"{attention_strategy!r}, "
685
+ f"No corresponding export strategy for implementation "
686
+ f"{self.config._attn_implementation!r}, "
633
687
  f"(use QWEN25ATTENTION to change it), and attention_interface="
634
688
  f"{attention_interface!r} (use sdpa)"
635
689
  )
@@ -653,6 +707,7 @@ if patch_qwen2_5:
653
707
  )
654
708
  else:
655
709
  # Other implementations: Process each chunk separately
710
+ # = qwen_sdpa_attention
656
711
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
657
712
  splits = [
658
713
  torch.split(tensor, lengths.tolist(), dim=2)
@@ -236,7 +236,7 @@ def code_sample(
236
236
  )
237
237
  )
238
238
  """
239
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
239
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
240
240
  model_id,
241
241
  subfolder,
242
242
  same_as_pretrained=same_as_pretrained,
@@ -256,6 +256,7 @@ def code_sample(
256
256
  model_kwargs=mop,
257
257
  subfolder=subfolder,
258
258
  add_second_input=False,
259
+ submodule=submodule,
259
260
  )
260
261
  if drop_inputs:
261
262
  update = {}
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
26
26
 
27
27
 
28
28
  def _preprocess_model_id(
29
- model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30
- ) -> Tuple[str, Optional[str], bool, bool]:
29
+ model_id: str,
30
+ subfolder: Optional[str],
31
+ same_as_pretrained: bool,
32
+ use_pretrained: bool,
33
+ submodule: Optional[str] = None,
34
+ ) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
35
+ if "::" in model_id:
36
+ assert (
37
+ not submodule
38
+ ), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
39
+ model_id, submodule = model_id.split("::", maxsplit=1)
31
40
  if subfolder or "//" not in model_id:
32
- return model_id, subfolder, same_as_pretrained, use_pretrained
41
+ return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
33
42
  spl = model_id.split("//")
34
43
  if spl[-1] == "pretrained":
35
- return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
44
+ return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
36
45
  if spl[-1] in {"transformer", "vae"}:
37
46
  # known subfolder
38
- return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39
- return model_id, subfolder, same_as_pretrained, use_pretrained
47
+ return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
48
+ return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
40
49
 
41
50
 
42
51
  def get_untrained_model_with_inputs(
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
54
63
  subfolder: Optional[str] = None,
55
64
  use_only_preinstalled: bool = False,
56
65
  config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66
+ submodule: Optional[str] = None,
57
67
  ) -> Dict[str, Any]:
58
68
  """
59
69
  Gets a non initialized model similar to the original model
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
82
92
  <onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
83
93
  this function takes a configuration and a task (string)
84
94
  as arguments
95
+ :param submodule: use a submodule instead of the main model
85
96
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration,
86
97
  some necessary rewriting as well
87
98
 
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
108
119
  f"model_id={model_id!r}, preinstalled model is only available "
109
120
  f"if use_only_preinstalled is False."
110
121
  )
111
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
122
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
112
123
  model_id,
113
124
  subfolder,
114
125
  same_as_pretrained=same_as_pretrained,
115
126
  use_pretrained=use_pretrained,
127
+ submodule=submodule,
116
128
  )
117
129
  if verbose:
118
130
  print(
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
147
159
  if verbose:
148
160
  print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
149
161
  print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
162
+ if submodule:
163
+ print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
150
164
  if task is None:
151
165
  task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
152
166
  if verbose:
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
357
371
  if diff_config is not None:
358
372
  res["dump_info"] = dict(config_diff=diff_config)
359
373
 
374
+ if submodule:
375
+ path = submodule.split("::") if "::" in submodule else [submodule]
376
+ for p in path:
377
+ assert hasattr(model, p), (
378
+ f"Unable to find submodule {p!r} in in class {type(model)}, "
379
+ f"submodule={submodule!r}, possible candidates: "
380
+ f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
381
+ )
382
+ model = getattr(model, p)
383
+
384
+ if verbose:
385
+ print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
386
+
360
387
  sizes = compute_model_size(model)
361
388
  res["model"] = model
362
389
  res["configuration"] = config