onnx-diagnostic 0.8.3__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,10 +18,11 @@ from onnx.defs import onnx_opset_version
18
18
  import onnxruntime
19
19
  from ..helpers import string_type
20
20
  from ..helpers.onnx_helper import (
21
- pretty_onnx,
21
+ get_hidden_inputs,
22
22
  dtype_to_tensor_dtype,
23
- to_array_extended,
24
23
  np_dtype_to_tensor_dtype,
24
+ to_array_extended,
25
+ pretty_onnx,
25
26
  )
26
27
  from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
27
28
  from ..helpers.ort_session import (
@@ -472,39 +473,15 @@ class OnnxruntimeEvaluator:
472
473
  yield from self.enumerate_nodes(att.g.node)
473
474
  yield node
474
475
 
475
- @classmethod
476
- def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
477
- """
478
- Returns the hidden inputs (inputs coming from an upper context)
479
- used by a subgraph.
480
- """
481
- hidden = set()
482
- memo = (
483
- {i.name for i in graph.initializer}
484
- | {i.name for i in graph.sparse_initializer}
485
- | {i.name for i in graph.input}
486
- )
487
- for node in graph.node:
488
- for i in node.input:
489
- if i not in memo:
490
- hidden.add(i)
491
- for att in node.attribute:
492
- if att.type == AttributeProto.GRAPH and att.g:
493
- hid = cls._get_hidden_inputs(att.g)
494
- less = set(h for h in hid if h not in memo)
495
- hidden |= less
496
- memo |= set(node.output)
497
- return hidden
498
-
499
476
  @classmethod
500
477
  def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
501
- """Calls multiple _get_hidden_inputs on every attribute."""
478
+ """Calls multiple get_hidden_inputs on every attribute."""
502
479
  if node.op_type not in {"Loop", "Scan", "If"}:
503
480
  return set()
504
481
  hidden = set()
505
482
  for att in node.attribute:
506
483
  if att.type == AttributeProto.GRAPH:
507
- hidden |= cls._get_hidden_inputs(att.g)
484
+ hidden |= get_hidden_inputs(att.g)
508
485
  return hidden - (hidden & set(node.input))
509
486
 
510
487
  def _get_sess(
@@ -624,7 +601,7 @@ class OnnxruntimeEvaluator:
624
601
  value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
625
602
  vinputs.append(value)
626
603
 
627
- reduced_set = self._get_hidden_inputs(g)
604
+ reduced_set = get_hidden_inputs(g)
628
605
  for i, v in context.items():
629
606
  if i in reduced_set and i not in unique_names:
630
607
  unique_names.add(i)
@@ -118,6 +118,7 @@ def patched_sdpa_attention_forward(
118
118
  torch._check(value.shape[1] > 0)
119
119
  torch._check(value.shape[2] > 0)
120
120
  torch._check(value.shape[3] > 0)
121
+
121
122
  return (
122
123
  torch.nn.functional.scaled_dot_product_attention(
123
124
  query,
@@ -1,9 +1,11 @@
1
1
  import os
2
- from typing import Callable, Optional
2
+ from typing import Callable, Optional, Tuple
3
3
  import onnx
4
+ import onnx.helper as oh
4
5
  import torch
5
6
  import torch.nn.functional as F
6
7
  from ...export.onnx_plug import EagerDirectReplacementWithOnnx
8
+ from ...helpers.torch_helper import torch_dtype_to_onnx_dtype
7
9
  from .patch_helper import _is_torchdynamo_exporting
8
10
  from ._patch_transformers_attention import patched_sdpa_attention_forward
9
11
 
@@ -22,7 +24,43 @@ if patch_qwen2_5:
22
24
 
23
25
  onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
24
26
  op = onnxscript.opset22
27
+ op24 = onnxscript.onnx_opset.opset24
25
28
  msft_op = onnxscript.values.Opset("com.microsoft", 1)
29
+ STOPAT = (
30
+ int(os.environ.get("STOPAT", None))
31
+ if os.environ.get("STOPAT", None) is not None
32
+ else None
33
+ )
34
+
35
+ def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
36
+ opsets = {d.domain: d.version for d in function_proto.opset_import}
37
+ if "com.microsoft" not in opsets:
38
+ d = function_proto.opset_import.add()
39
+ d.domain = "com.microsoft"
40
+ d.version = 1
41
+ return function_proto
42
+
43
+ def _update_sequence_type(
44
+ itype: int, function_proto: onnx.FunctionProto
45
+ ) -> onnx.FunctionProto:
46
+ proto = oh.make_function(
47
+ function_proto.domain,
48
+ function_proto.name,
49
+ function_proto.input,
50
+ function_proto.output,
51
+ [
52
+ (
53
+ oh.make_node("SequenceEmpty", node.input, node.output, dtype=itype)
54
+ if node.op_type == "SequenceEmpty"
55
+ else node
56
+ )
57
+ for node in function_proto.node
58
+ ],
59
+ attributes=function_proto.attribute,
60
+ attribute_protos=function_proto.attribute_proto,
61
+ opset_imports=function_proto.opset_import,
62
+ )
63
+ return proto
26
64
 
27
65
  @onnxscript.script(opset=onnx_plugs_op)
28
66
  def LoopMHAAttention(
@@ -32,7 +70,6 @@ if patch_qwen2_5:
32
70
  cu_seqlens,
33
71
  scaling: float = 0.11180339887498948,
34
72
  num_heads: int = 16,
35
- itype: int = onnx.TensorProto.FLOAT,
36
73
  ):
37
74
  to_3d_shape = op.Constant(value_ints=[0, 0, -1])
38
75
  query_transposed = op.Transpose(query_states, perm=[0, 2, 1, 3])
@@ -45,7 +82,7 @@ if patch_qwen2_5:
45
82
  seq_axis = op.Constant(value_ints=[1])
46
83
  seq_axis_int32 = op.Cast(seq_axis, to=onnx.TensorProto.INT32)
47
84
  # attn_output = op.Slice(value_3d, [0], [0], seq_axis)
48
- seq_attn = op.SequenceEmpty(dtype=itype)
85
+ seq_attn = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
49
86
  for i_patch in range(num_patches):
50
87
  i_1d = op.Reshape(i_patch, [1])
51
88
  i_plus_1_1d = i_1d + 1
@@ -55,11 +92,7 @@ if patch_qwen2_5:
55
92
  key_i = op.Slice(key_3d, start, end, seq_axis_int32)
56
93
  value_i = op.Slice(value_3d, start, end, seq_axis_int32)
57
94
  mha_output = msft_op.MultiHeadAttention(
58
- query_i,
59
- key_i,
60
- value_i,
61
- num_heads=num_heads,
62
- scale=scaling,
95
+ query_i, key_i, value_i, num_heads=num_heads, scale=scaling
63
96
  )
64
97
  # attn_output = op.Concat(attn_output, mha_output, axis=1)
65
98
  seq_attn = op.SequenceInsert(seq_attn, mha_output)
@@ -67,13 +100,47 @@ if patch_qwen2_5:
67
100
  attn_output_4d = op.Reshape(attn_output, output_shape)
68
101
  return attn_output_4d
69
102
 
70
- def _add_com_microsoft_opset(function_proto):
71
- opsets = {d.domain: d.version for d in function_proto.opset_import}
72
- if "com.microsoft" not in opsets:
73
- d = function_proto.opset_import.add()
74
- d.domain = "com.microsoft"
75
- d.version = 1
76
- return function_proto
103
+ @onnxscript.script(opset=onnx_plugs_op)
104
+ def LoopAttention24(
105
+ query_states,
106
+ key_states,
107
+ value_states,
108
+ cu_seqlens,
109
+ scaling: float = 0.11180339887498948,
110
+ num_heads: int = 16,
111
+ ):
112
+ to_3d_shape = op24.Constant(value_ints=[0, 0, -1])
113
+ query_transposed = op24.Transpose(query_states, perm=[0, 2, 1, 3])
114
+ output_shape = op24.Shape(query_transposed)
115
+ query_3d = op24.Reshape(query_transposed, to_3d_shape)
116
+ value_3d = op24.Reshape(op24.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
117
+ key_3d = op24.Reshape(op24.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
118
+ cu_seqlens = op24.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
119
+ num_patches = op24.Size(cu_seqlens) - 1
120
+ seq_axis = op24.Constant(value_ints=[1])
121
+ seq_axis_int32 = op24.Cast(seq_axis, to=onnx.TensorProto.INT32)
122
+ seq_attn = op24.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
123
+ for i_patch in range(num_patches):
124
+ i_1d = op24.Reshape(i_patch, [1])
125
+ i_plus_1_1d = i_1d + 1
126
+ start = op24.Gather(cu_seqlens, i_1d, axis=0)
127
+ end = op24.Gather(cu_seqlens, i_plus_1_1d, axis=0)
128
+ query_i = op24.Slice(query_3d, start, end, seq_axis_int32)
129
+ key_i = op24.Slice(key_3d, start, end, seq_axis_int32)
130
+ value_i = op24.Slice(value_3d, start, end, seq_axis_int32)
131
+ mha_output = op24.Attention(
132
+ query_i,
133
+ key_i,
134
+ value_i,
135
+ scale=scaling,
136
+ q_num_heads=num_heads,
137
+ kv_num_heads=num_heads,
138
+ softmax_precision=onnx.TensorProto.FLOAT,
139
+ )
140
+ seq_attn = op24.SequenceInsert(seq_attn, mha_output)
141
+ attn_output = op24.ConcatFromSequence(seq_attn, axis=1)
142
+ attn_output_4d = op24.Reshape(attn_output, output_shape)
143
+ return attn_output_4d
77
144
 
78
145
  @onnxscript.script(opset=onnx_plugs_op)
79
146
  def PackedAttention(
@@ -132,8 +199,40 @@ if patch_qwen2_5:
132
199
  cu_seqlens: torch.Tensor, # F7su19
133
200
  scaling: float = 0,
134
201
  num_heads: int = 16,
135
- itype: int = onnx.TensorProto.FLOAT,
136
202
  ) -> torch.Tensor:
203
+ """
204
+ The loop can be removed with the following code
205
+ but it hits memory overflow for big inputs.
206
+
207
+ .. code-block:: python
208
+
209
+ # make square mask
210
+ indices = torch.arange(
211
+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
212
+ )
213
+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
214
+ cu_seqlens.dtype
215
+ )
216
+ dot = dot.sum(dim=0)
217
+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
218
+ bool_mask = mask == 0
219
+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
220
+
221
+ torch._check(bool_mask.shape[2] == key_states.shape[2])
222
+ torch._check(bool_mask.shape[3] == key_states.shape[2])
223
+
224
+ attn_output, _ = attention_interface(
225
+ self,
226
+ query_states,
227
+ key_states,
228
+ value_states,
229
+ attention_mask=bool_mask,
230
+ scaling=self.scaling,
231
+ dropout=0.0 if not self.training else self.attention_dropout,
232
+ is_causal=False,
233
+ **kwargs,
234
+ )
235
+ """
137
236
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
138
237
  splits = [
139
238
  torch.split(tensor, lengths.tolist(), dim=2)
@@ -156,36 +255,58 @@ if patch_qwen2_5:
156
255
  attn_output = torch.cat(attn_outputs, dim=1)
157
256
  return attn_output
158
257
 
159
- # not ideal
160
- qwen_sdpa_attention_packed_versatile = EagerDirectReplacementWithOnnx(
161
- qwen_sdpa_attention,
162
- lambda qs, *args, **kwargs: torch.empty(
163
- (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
164
- dtype=qs.dtype,
165
- device=qs.device,
166
- ),
167
- _add_com_microsoft_opset(PackedAttention.to_function_proto()),
168
- n_inputs=4,
169
- n_outputs=1,
170
- kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
171
- name="qwen_sdpa_attention_packed",
172
- )
173
- PLUGS.append(qwen_sdpa_attention_packed_versatile)
258
+ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
+ first_tensor = next(a for a in args if a is not None)
260
+ dtype = first_tensor.dtype
261
+ strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
262
+ itype = torch_dtype_to_onnx_dtype(dtype)
263
+ if strategy is not None:
264
+ return strategy, itype
265
+ if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
266
+ if opset >= 24:
267
+ return "LOOPA24", itype
268
+ return "LOOPMHA", itype
269
+ if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270
+ # first_tensor may be a SymbolicTensor (onnx).
271
+ # is_cuda is not available.
272
+ if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
273
+ return "PACKED", itype
274
+ return "LOOPMHA", itype
275
+ raise AssertionError(
276
+ f"Unable to handle type {torch.dtype} (itype={itype}) "
277
+ f"on device {torch.device} with opset={opset}"
278
+ )
174
279
 
175
- qwen_sdpa_attention_loopmha_versatile = EagerDirectReplacementWithOnnx(
280
+ qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
176
281
  qwen_sdpa_attention,
177
282
  lambda qs, *args, **kwargs: torch.empty(
178
283
  (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
179
284
  dtype=qs.dtype,
180
285
  device=qs.device,
181
286
  ),
182
- _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
287
+ {
288
+ ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
289
+ PackedAttention.to_function_proto()
290
+ ),
291
+ ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
292
+ ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
293
+ onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
294
+ ),
295
+ ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
296
+ LoopMHAAttention.to_function_proto()
297
+ ),
298
+ ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
299
+ onnx.TensorProto.FLOAT16,
300
+ _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
301
+ ),
302
+ },
183
303
  n_inputs=4,
184
304
  n_outputs=1,
185
- kwargs=dict(scaling=0.11180339887498948, num_heads=16, itype=onnx.TensorProto.FLOAT),
186
- name="qwen_sdpa_attention_loopmha",
305
+ kwargs=dict(scaling=0.11180339887498948, num_heads=16),
306
+ name="qwen_sdpa_attention_versatile",
307
+ version_selector=qwen_version_selector,
187
308
  )
188
- PLUGS.append(qwen_sdpa_attention_loopmha_versatile)
309
+ PLUGS.append(qwen_sdpa_attention_versatile)
189
310
 
190
311
  class patched_Qwen2_5_VLForConditionalGeneration:
191
312
  _PATCHES_ = ["prepare_inputs_for_generation"]
@@ -434,6 +555,8 @@ if patch_qwen2_5:
434
555
  position_embeddings=position_embeddings,
435
556
  **kwargs,
436
557
  )
558
+ if STOPAT is not None and layer_num > STOPAT:
559
+ break
437
560
 
438
561
  hidden_states = self.merger(hidden_states)
439
562
  reverse_indices = torch.argsort(window_index)
@@ -473,9 +596,7 @@ if patch_qwen2_5:
473
596
  _PATCHED_CLASS_ = (
474
597
  transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLVisionAttention
475
598
  )
476
- STRATEGY_FOR_ATTENTION = lambda: os.environ.get( # noqa: E731
477
- "QWEN25ATTENTION", "PACKED"
478
- )
599
+ STRATEGY_FOR_ATTENTION = lambda: os.environ.get("QWEN25ATTENTION", None) # noqa: E731
479
600
 
480
601
  def forward(
481
602
  self,
@@ -519,14 +640,15 @@ if patch_qwen2_5:
519
640
  self.config._attn_implementation
520
641
  ]
521
642
 
522
- is_sdpa = (
643
+ is_sdpa_or_eager = (
523
644
  attention_interface
524
645
  is transformers.integrations.sdpa_attention.sdpa_attention_forward
525
646
  or attention_interface is patched_sdpa_attention_forward
647
+ or attention_interface
648
+ is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
526
649
  )
527
- attention_strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
528
- if is_sdpa and attention_strategy in "PACKED":
529
- attn_output = qwen_sdpa_attention_packed_versatile(
650
+ if is_sdpa_or_eager:
651
+ attn_output = qwen_sdpa_attention_versatile(
530
652
  query_states,
531
653
  key_states,
532
654
  value_states,
@@ -558,78 +680,10 @@ if patch_qwen2_5:
558
680
  ),
559
681
  version=1,
560
682
  )
561
- elif is_sdpa and attention_strategy == "LOOPMHA":
562
- attn_output = qwen_sdpa_attention_loopmha_versatile(
563
- query_states,
564
- key_states,
565
- value_states,
566
- cu_seqlens,
567
- self.scaling,
568
- self.num_heads,
569
- (
570
- onnx.TensorProto.FLOAT
571
- if query_states.dtype == torch.float32
572
- else (
573
- onnx.TensorProto.FLOAT16
574
- if query_states.dtype == torch.float16
575
- else onnx.TensorProto.BFLOAT16
576
- )
577
- ),
578
- )
579
-
580
- # to rewrite later with a for loop
581
- # def _iteration(start_end, query_states, key_states, value_states):
582
- # return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
583
- # self,
584
- # start_end,
585
- # query_states,
586
- # key_states,
587
- # value_states,
588
- # scaling=self.scaling,
589
- # dropout=0.0 if not self.training else self.attention_dropout,
590
- # )
591
-
592
- # starts = cu_seqlens[:-1]
593
- # ends = cu_seqlens[1:]
594
- # torch._check(starts.shape[0] > 0)
595
- # torch._check(ends.shape[0] > 0)
596
- # starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
597
- # attn_outputs = [
598
- # _iteration(start_end, query_states, key_states, value_states)
599
- # for start_end in starts_ends
600
- # ]
601
- # attn_output = torch.cat(attn_outputs, dim=1)
602
- elif is_sdpa and attention_strategy == "BIGMASK":
603
- # make square mask
604
- indices = torch.arange(
605
- cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
606
- )
607
- dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(
608
- cu_seqlens.dtype
609
- )
610
- dot = dot.sum(dim=0)
611
- mask = dot.unsqueeze(1) - dot.unsqueeze(0)
612
- bool_mask = mask == 0
613
- bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
614
-
615
- torch._check(bool_mask.shape[2] == key_states.shape[2])
616
- torch._check(bool_mask.shape[3] == key_states.shape[2])
617
-
618
- attn_output, _ = attention_interface(
619
- self,
620
- query_states,
621
- key_states,
622
- value_states,
623
- attention_mask=bool_mask,
624
- scaling=self.scaling,
625
- dropout=0.0 if not self.training else self.attention_dropout,
626
- is_causal=False,
627
- **kwargs,
628
- )
629
683
  else:
630
684
  raise NotImplementedError(
631
- f"No corresponding export strategy for "
632
- f"{attention_strategy!r}, "
685
+ f"No corresponding export strategy for implementation "
686
+ f"{self.config._attn_implementation!r}, "
633
687
  f"(use QWEN25ATTENTION to change it), and attention_interface="
634
688
  f"{attention_interface!r} (use sdpa)"
635
689
  )
@@ -653,6 +707,7 @@ if patch_qwen2_5:
653
707
  )
654
708
  else:
655
709
  # Other implementations: Process each chunk separately
710
+ # = qwen_sdpa_attention
656
711
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
657
712
  splits = [
658
713
  torch.split(tensor, lengths.tolist(), dim=2)
@@ -236,7 +236,7 @@ def code_sample(
236
236
  )
237
237
  )
238
238
  """
239
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
239
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
240
240
  model_id,
241
241
  subfolder,
242
242
  same_as_pretrained=same_as_pretrained,
@@ -256,6 +256,7 @@ def code_sample(
256
256
  model_kwargs=mop,
257
257
  subfolder=subfolder,
258
258
  add_second_input=False,
259
+ submodule=submodule,
259
260
  )
260
261
  if drop_inputs:
261
262
  update = {}
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
26
26
 
27
27
 
28
28
  def _preprocess_model_id(
29
- model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30
- ) -> Tuple[str, Optional[str], bool, bool]:
29
+ model_id: str,
30
+ subfolder: Optional[str],
31
+ same_as_pretrained: bool,
32
+ use_pretrained: bool,
33
+ submodule: Optional[str] = None,
34
+ ) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
35
+ if "::" in model_id:
36
+ assert (
37
+ not submodule
38
+ ), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
39
+ model_id, submodule = model_id.split("::", maxsplit=1)
31
40
  if subfolder or "//" not in model_id:
32
- return model_id, subfolder, same_as_pretrained, use_pretrained
41
+ return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
33
42
  spl = model_id.split("//")
34
43
  if spl[-1] == "pretrained":
35
- return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
44
+ return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
36
45
  if spl[-1] in {"transformer", "vae"}:
37
46
  # known subfolder
38
- return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39
- return model_id, subfolder, same_as_pretrained, use_pretrained
47
+ return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
48
+ return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
40
49
 
41
50
 
42
51
  def get_untrained_model_with_inputs(
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
54
63
  subfolder: Optional[str] = None,
55
64
  use_only_preinstalled: bool = False,
56
65
  config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66
+ submodule: Optional[str] = None,
57
67
  ) -> Dict[str, Any]:
58
68
  """
59
69
  Gets a non initialized model similar to the original model
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
82
92
  <onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
83
93
  this function takes a configuration and a task (string)
84
94
  as arguments
95
+ :param submodule: use a submodule instead of the main model
85
96
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration,
86
97
  some necessary rewriting as well
87
98
 
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
108
119
  f"model_id={model_id!r}, preinstalled model is only available "
109
120
  f"if use_only_preinstalled is False."
110
121
  )
111
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
122
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
112
123
  model_id,
113
124
  subfolder,
114
125
  same_as_pretrained=same_as_pretrained,
115
126
  use_pretrained=use_pretrained,
127
+ submodule=submodule,
116
128
  )
117
129
  if verbose:
118
130
  print(
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
147
159
  if verbose:
148
160
  print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
149
161
  print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
162
+ if submodule:
163
+ print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
150
164
  if task is None:
151
165
  task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
152
166
  if verbose:
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
357
371
  if diff_config is not None:
358
372
  res["dump_info"] = dict(config_diff=diff_config)
359
373
 
374
+ if submodule:
375
+ path = submodule.split("::") if "::" in submodule else [submodule]
376
+ for p in path:
377
+ assert hasattr(model, p), (
378
+ f"Unable to find submodule {p!r} in in class {type(model)}, "
379
+ f"submodule={submodule!r}, possible candidates: "
380
+ f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
381
+ )
382
+ model = getattr(model, p)
383
+
384
+ if verbose:
385
+ print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
386
+
360
387
  sizes = compute_model_size(model)
361
388
  res["model"] = model
362
389
  res["configuration"] = config
@@ -349,13 +349,15 @@ def _prepare_validation(
349
349
  verbose,
350
350
  output_names,
351
351
  dump_folder,
352
+ submodule,
352
353
  ):
353
354
  main_validation_begin = time.perf_counter()
354
- model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
355
+ model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
355
356
  model_id,
356
357
  subfolder,
357
358
  same_as_pretrained=same_as_pretrained,
358
359
  use_pretrained=use_pretrained,
360
+ submodule=submodule,
359
361
  )
360
362
  time_preprocess_model_id = time.perf_counter() - main_validation_begin
361
363
  patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
@@ -364,6 +366,7 @@ def _prepare_validation(
364
366
  summary.update(
365
367
  dict(
366
368
  version_model_id=model_id,
369
+ version_submodule=submodule,
367
370
  version_do_run=str(do_run),
368
371
  version_dtype=str(dtype or ""),
369
372
  version_device=str(device or ""),
@@ -444,6 +447,7 @@ def _prepare_validation(
444
447
  dump_folder,
445
448
  folder_name,
446
449
  patch_kwargs,
450
+ submodule,
447
451
  )
448
452
 
449
453
 
@@ -460,6 +464,7 @@ def _get_untrained_model_with_inputs(
460
464
  inputs2,
461
465
  quiet,
462
466
  dump_folder,
467
+ submodule,
463
468
  ):
464
469
  iop = input_options or {}
465
470
  mop = model_options or {}
@@ -480,6 +485,7 @@ def _get_untrained_model_with_inputs(
480
485
  model_kwargs=mop,
481
486
  subfolder=sub,
482
487
  add_second_input=i2,
488
+ submodule=submodule,
483
489
  )
484
490
  )
485
491
  ),
@@ -842,6 +848,7 @@ def validate_model(
842
848
  ort_logs: bool = False,
843
849
  quiet_input_sets: Optional[Set[str]] = None,
844
850
  save_ep: Optional[str] = None,
851
+ submodule: Optional[str] = None,
845
852
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
846
853
  """
847
854
  Validates a model.
@@ -902,6 +909,7 @@ def validate_model(
902
909
  even if quiet is False
903
910
  :param save_ep: if not empty, this can be used to save the input sets and
904
911
  the exported program
912
+ :param submodule: to test not the model but a submodule of this model
905
913
  :return: two dictionaries, one with some metrics,
906
914
  another one with whatever the function produces
907
915
 
@@ -966,6 +974,7 @@ def validate_model(
966
974
  use_pretrained=use_pretrained,
967
975
  same_as_pretrained=same_as_pretrained,
968
976
  save_ep=save_ep,
977
+ submodule=submodule,
969
978
  )
970
979
  if dump_folder:
971
980
  with open(dump_stats, "w") as f:
@@ -1053,6 +1062,7 @@ def _validate_model_step1(
1053
1062
  use_pretrained,
1054
1063
  same_as_pretrained,
1055
1064
  save_ep,
1065
+ submodule,
1056
1066
  ):
1057
1067
  assert not do_same or do_run, (
1058
1068
  f"Discrepancies cannot be measured if the model is not run, "
@@ -1067,6 +1077,7 @@ def _validate_model_step1(
1067
1077
  dump_folder,
1068
1078
  folder_name,
1069
1079
  patch_kwargs,
1080
+ submodule,
1070
1081
  ) = _prepare_validation(
1071
1082
  model_id=model_id,
1072
1083
  subfolder=subfolder,
@@ -1093,6 +1104,7 @@ def _validate_model_step1(
1093
1104
  verbose=verbose,
1094
1105
  output_names=output_names,
1095
1106
  dump_folder=dump_folder,
1107
+ submodule=submodule,
1096
1108
  )
1097
1109
 
1098
1110
  data, iop, mop = _get_untrained_model_with_inputs(
@@ -1108,6 +1120,7 @@ def _validate_model_step1(
1108
1120
  inputs2=inputs2,
1109
1121
  quiet=quiet,
1110
1122
  dump_folder=dump_folder,
1123
+ submodule=submodule,
1111
1124
  )
1112
1125
 
1113
1126
  second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]