onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -3
  3. onnx_diagnostic/ci_models/ci_helpers.py +12 -7
  4. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  5. onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
  6. onnx_diagnostic/export/api.py +295 -5
  7. onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
  8. onnx_diagnostic/export/dynamic_shapes.py +45 -3
  9. onnx_diagnostic/export/shape_helper.py +1 -0
  10. onnx_diagnostic/ext_test_case.py +9 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/cache_helper.py +0 -8
  13. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  14. onnx_diagnostic/helpers/helper.py +30 -1
  15. onnx_diagnostic/helpers/log_helper.py +1 -3
  16. onnx_diagnostic/helpers/optim_helper.py +116 -0
  17. onnx_diagnostic/helpers/ort_session.py +5 -0
  18. onnx_diagnostic/tasks/image_text_to_text.py +19 -9
  19. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  20. onnx_diagnostic/tasks/text_generation.py +3 -0
  21. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
  22. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  23. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  24. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
  31. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  33. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  34. onnx_diagnostic/torch_models/validate.py +48 -0
  35. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/METADATA +3 -1
  36. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/RECORD +39 -36
  37. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/WHEEL +0 -0
  38. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/top_level.txt +0 -0
@@ -221,6 +221,7 @@ def _patch_torch(
221
221
  catch_constraints: bool,
222
222
  stop_if_static: int,
223
223
  ) -> Tuple[Optional[Callable], ...]:
224
+ import packaging.version as pv
224
225
  import torch
225
226
  import torch.jit
226
227
  import torch._export.non_strict_utils # produce_guards_and_solve_constraints
@@ -238,6 +239,11 @@ def _patch_torch(
238
239
  patched_ShapeEnv,
239
240
  )
240
241
 
242
+ if pv.Version(torch.__version__) >= pv.Version("2.9.99"):
243
+ from .patches.patch_torch import patched_DynamicDimConstraintPrinter
244
+ else:
245
+ patched_DynamicDimConstraintPrinter = None
246
+
241
247
  f___constrain_user_specified_dimhint_range = None
242
248
  f__broadcast_in_dim_meta = None
243
249
  f__broadcast_shapes = None
@@ -259,6 +265,17 @@ def _patch_torch(
259
265
  print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
260
266
  print("[torch_export_patches] patch pytorch")
261
267
 
268
+ # torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
269
+ if patched_DynamicDimConstraintPrinter is not None:
270
+ f__print_symbol = (
271
+ torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
272
+ )
273
+ torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
274
+ patched_DynamicDimConstraintPrinter._print_Symbol
275
+ )
276
+ else:
277
+ f__print_symbol = None
278
+
262
279
  # torch.vmap
263
280
  f_vmap = torch.vmap
264
281
  torch.vmap = patched_vmap
@@ -392,6 +409,7 @@ def _patch_torch(
392
409
  f_shape_env__log_guard,
393
410
  f_shape_env__set_replacement,
394
411
  f_vmap,
412
+ f__print_symbol,
395
413
  )
396
414
 
397
415
 
@@ -416,6 +434,7 @@ def _unpatch_torch(
416
434
  f_shape_env__log_guard: Optional[Callable],
417
435
  f_shape_env__set_replacement: Optional[Callable],
418
436
  f_vmap: Optional[Callable],
437
+ f__print_symbol: Optional[Callable],
419
438
  ):
420
439
  import torch
421
440
  import torch.jit
@@ -423,6 +442,10 @@ def _unpatch_torch(
423
442
  from torch.fx.experimental.symbolic_shapes import ShapeEnv
424
443
 
425
444
  # this should disappear when torch.jit is removed
445
+ if f__print_symbol is not None:
446
+ torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
447
+ f__print_symbol
448
+ )
426
449
  torch.vmap = f_vmap
427
450
  torch.jit.isinstance = f_jit_isinstance
428
451
  torch._dynamo.mark_static_address = f_mark_static_address
@@ -848,8 +871,9 @@ def torch_export_patches(
848
871
  this is done by function :func:`transform_method
849
872
  <onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
850
873
  its documentation provides possible values
851
- :param dump_rewriting: dumps rewriting information in file beginning with that prefix
852
- :param patch_details: if specified, this class is used to stored every rewritten done.
874
+ :param dump_rewriting: dumps rewriting information in file beginning with that prefix,
875
+ this only applied on the automated rewritings
876
+ :param patch_details: if specified, this class is used to stored every applied rewriting.
853
877
  :param verbose: to show which patches is applied
854
878
  :param profile: starts profiling whatever is called inside the context manager,
855
879
  output the profiling into a text file
@@ -992,6 +1016,7 @@ def torch_export_patches(
992
1016
  f_shape_env__log_guard,
993
1017
  f_shape_env__set_replacement,
994
1018
  f_vmap,
1019
+ f__print_Symbol,
995
1020
  ) = _patch_torch(
996
1021
  verbose, patch_details, patch_torch, catch_constraints, stop_if_static
997
1022
  )
@@ -1067,6 +1092,7 @@ def torch_export_patches(
1067
1092
  f_shape_env__log_guard,
1068
1093
  f_shape_env__set_replacement,
1069
1094
  f_vmap,
1095
+ f__print_Symbol,
1070
1096
  )
1071
1097
 
1072
1098
  if patch_transformers:
@@ -191,7 +191,7 @@ class PatchDetails:
191
191
  ep = torch.export.export(
192
192
  model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
193
193
  )
194
- patches = details.patches_involded_in_graph(ep.graph)
194
+ patches = details.patches_involved_in_graph(ep.graph)
195
195
  report = details.make_report(patches, format="rst")
196
196
  print(report)
197
197
  """
@@ -235,7 +235,7 @@ class PatchDetails:
235
235
  """Returns the data for a dataframe."""
236
236
  return [p.to_dict() for p in self.patched]
237
237
 
238
- def patches_involded_in_graph(
238
+ def patches_involved_in_graph(
239
239
  self, graph: "torch.fx.Graph" # noqa: F821
240
240
  ) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
241
241
  """
@@ -322,7 +322,7 @@ class PatchDetails:
322
322
  """
323
323
  Creates a report based on the involved patches.
324
324
 
325
- :param patches: from method :meth:`patches_involded_in_graph`
325
+ :param patches: from method :meth:`patches_involved_in_graph`
326
326
  :param format: format of the report
327
327
  :return: report
328
328
  """
@@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
101
101
 
102
102
 
103
103
  def patched_float_arange(start, end, step):
104
- """Patched arange when start, end, step are floats."""
104
+ """
105
+ Patched arange when start, end, step are floats.
106
+ This patch should not be needed after 2.10.
107
+ """
105
108
  if is_torchdynamo_exporting():
106
109
  return torch.ops.patched.float_arange(start, end, step)
107
110
  else:
@@ -596,33 +596,41 @@ class RewriteControlFlow(ast.NodeTransformer):
596
596
  elts=[
597
597
  *[
598
598
  ast.Call(
599
- ast.Attribute(
600
- value=ast.Name(id="torch", ctx=ast.Load()),
601
- attr="arange",
602
- ctx=ast.Load(),
603
- ),
604
- args=[
605
- ast.Subscript(
606
- value=ast.Attribute(
607
- value=ast.Name(id=v, ctx=ast.Load()),
608
- attr="shape",
599
+ func=ast.Attribute(
600
+ value=ast.Call(
601
+ ast.Attribute(
602
+ value=ast.Name(id="torch", ctx=ast.Load()),
603
+ attr="arange",
609
604
  ctx=ast.Load(),
610
605
  ),
611
- slice=ast.Constant(value=0, ctx=ast.Load()),
606
+ args=[
607
+ ast.Subscript(
608
+ value=ast.Attribute(
609
+ value=ast.Name(id=v, ctx=ast.Load()),
610
+ attr="shape",
611
+ ctx=ast.Load(),
612
+ ),
613
+ slice=ast.Constant(value=0, ctx=ast.Load()),
614
+ ctx=ast.Load(),
615
+ ),
616
+ ],
617
+ keywords=[
618
+ ast.keyword(
619
+ arg="dtype",
620
+ value=ast.Attribute(
621
+ value=ast.Name(id="torch", ctx=ast.Load()),
622
+ attr="int64",
623
+ ctx=ast.Load(),
624
+ ),
625
+ )
626
+ ],
612
627
  ctx=ast.Load(),
613
628
  ),
614
- ],
615
- keywords=[
616
- ast.keyword(
617
- arg="dtype",
618
- value=ast.Attribute(
619
- value=ast.Name(id="torch", ctx=ast.Load()),
620
- attr="int64",
621
- ctx=ast.Load(),
622
- ),
623
- )
624
- ],
625
- ctx=ast.Load(),
629
+ attr="unsqueeze",
630
+ ctx=ast.Load(),
631
+ ),
632
+ args=[ast.Constant(value=1)],
633
+ keywords=[],
626
634
  )
627
635
  for v in scan_shape_vars
628
636
  ],
@@ -22,13 +22,22 @@ if patch_DynamicLayer:
22
22
  _PATCHES_ = ["lazy_initialization"]
23
23
  _PATCHED_CLASS_ = DynamicLayer
24
24
 
25
- def lazy_initialization(self, key_states: torch.Tensor):
25
+ def lazy_initialization(
26
+ self, key_states: torch.Tensor, value_states: torch.Tensor = None
27
+ ):
26
28
  self.dtype, self.device = key_states.dtype, key_states.device
27
- new_shape = list(key_states.shape)
28
- new_shape[-2] = 0
29
+ assert (
30
+ hasattr(key_states, "shape") and key_states is not None
31
+ ), f"Attribute 'shape' is wrong for type {type(key_states)}"
32
+ like = torch.narrow(key_states, dim=-2, start=0, length=0)
29
33
  # PATCHED: used a tensor with an empty shape and not en empty list to initialize
30
- self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
31
- self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
34
+ if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
35
+ with key_states.fake_mode:
36
+ self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
37
+ self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
38
+ else:
39
+ self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
40
+ self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
32
41
  if patch_is_initialized:
33
42
  self.is_initialized = True
34
43
 
@@ -0,0 +1,80 @@
1
+ import torch
2
+
3
+ try:
4
+ import transformers.models.funnel.modeling_funnel
5
+
6
+ patch_funnel = True
7
+ except ImportError:
8
+ patch_funnel = False
9
+
10
+ if patch_funnel:
11
+ from transformers.models.funnel.modeling_funnel import _relative_shift_gather
12
+
13
+ class patched_FunnelAttentionStructure(torch.nn.Module):
14
+ _PATCHES_ = ["relative_pos"]
15
+ _PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure
16
+
17
+ def relative_pos(
18
+ self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
19
+ ) -> torch.Tensor:
20
+ if pooled_pos is None:
21
+ pooled_pos = pos
22
+ ref_point = pooled_pos[0] - pos[0]
23
+ # PATCHED
24
+ num_remove = shift * pooled_pos.shape[0]
25
+ max_dist = ref_point + num_remove * stride
26
+ min_dist = pooled_pos[0] - pos[-1]
27
+ return torch.arange(
28
+ max_dist.to(torch.long),
29
+ (min_dist - 1).to(torch.long),
30
+ torch.tensor(-stride, dtype=torch.long),
31
+ dtype=torch.long,
32
+ device=pos.device,
33
+ )
34
+
35
+ class patched_FunnelRelMultiheadAttention(torch.nn.Module):
36
+ _PATCHES_ = ["relative_positional_attention"]
37
+ _PATCHED_CLASS_ = (
38
+ transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
39
+ )
40
+
41
+ def relative_positional_attention(
42
+ self, position_embeds, q_head, context_len, cls_mask=None
43
+ ):
44
+ """Relative attention score for the positional encodings"""
45
+ # q_head has shape batch_size x sea_len x n_head x d_head
46
+ if self.config.attention_type == "factorized":
47
+ phi, pi, psi, omega = position_embeds
48
+ # Shape n_head x d_head
49
+ u = self.r_r_bias * self.scale
50
+ # Shape d_model x n_head x d_head
51
+ w_r = self.r_kernel
52
+
53
+ # Shape batch_size x sea_len x n_head x d_model
54
+ q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
55
+ q_r_attention_1 = q_r_attention * phi[:, None]
56
+ q_r_attention_2 = q_r_attention * pi[:, None]
57
+
58
+ # Shape batch_size x n_head x seq_len x context_len
59
+ positional_attn = torch.einsum(
60
+ "bind,jd->bnij", q_r_attention_1, psi
61
+ ) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
62
+ else:
63
+ shift = 2 if q_head.shape[1] != context_len else 1
64
+ r = position_embeds[self.block_index][shift - 1]
65
+ # Shape n_head x d_head
66
+ v = self.r_r_bias * self.scale
67
+ # Shape d_model x n_head x d_head
68
+ w_r = self.r_kernel
69
+
70
+ # Shape max_rel_len x n_head x d_model
71
+ r_head = torch.einsum("td,dnh->tnh", r, w_r)
72
+ # Shape batch_size x n_head x seq_len x max_rel_len
73
+ positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
74
+ # Shape batch_size x n_head x seq_len x context_len
75
+ positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
76
+
77
+ if cls_mask is not None:
78
+ # PATCHED
79
+ positional_attn = positional_attn * cls_mask
80
+ return positional_attn
@@ -256,10 +256,21 @@ if patch_qwen2_5:
256
256
  return attn_output
257
257
 
258
258
  def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
+ import onnx_ir
260
+
259
261
  first_float_tensor = next(
260
262
  a
261
263
  for a in args
262
- if a is not None and a.dtype in {torch.float16, torch.float32, torch.bfloat16}
264
+ if a is not None
265
+ and a.dtype
266
+ in {
267
+ torch.float16,
268
+ torch.float32,
269
+ torch.bfloat16,
270
+ onnx_ir.DataType.BFLOAT16,
271
+ onnx_ir.DataType.FLOAT16,
272
+ onnx_ir.DataType.FLOAT,
273
+ }
263
274
  )
264
275
  dtype = first_float_tensor.dtype
265
276
  strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
@@ -214,7 +214,7 @@ def patched_dynamic_rope_update(rope_forward):
214
214
  cond,
215
215
  (lambda x, y: x.clone()),
216
216
  (lambda x, y: y.clone()),
217
- [long_inv_freq, original_inv_freq],
217
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
218
218
  )
219
219
  setattr(self, f"{prefix}inv_freq", inv_freq)
220
220
  # if seq_len > original_max_position_embeddings:
@@ -293,7 +293,7 @@ def patched_dynamic_rope_update(rope_forward):
293
293
  cond,
294
294
  (lambda x, y: x.clone()),
295
295
  (lambda x, y: y.clone()),
296
- [long_inv_freq, original_inv_freq],
296
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
297
297
  )
298
298
  setattr(self, f"{prefix}inv_freq", inv_freq)
299
299
 
@@ -5,6 +5,7 @@ import os
5
5
  import traceback
6
6
  from functools import reduce
7
7
  from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
8
+ import sympy
8
9
  import torch
9
10
  from torch._subclasses.fake_tensor import FakeTensorMode
10
11
 
@@ -1091,3 +1092,17 @@ def patched__broadcast_in_dim_meta_level_2(
1091
1092
  new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1092
1093
 
1093
1094
  return a.as_strided(shape, new_strides, a.storage_offset())
1095
+
1096
+
1097
+ class patched_DynamicDimConstraintPrinter:
1098
+ """
1099
+ Patches
1100
+ ``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``.
1101
+ Valid for ``torch>=2.10``.
1102
+ """
1103
+
1104
+ def _print_Symbol(self, expr: sympy.Symbol) -> str:
1105
+ assert isinstance(expr, sympy.Symbol), str(type(expr))
1106
+ if self.symbol_to_source.get(expr):
1107
+ return self.symbol_to_source[expr][0].name
1108
+ return str(expr)
@@ -1,29 +1,37 @@
1
1
  # transformers
2
2
  from typing import List
3
3
  from .patch_helper import _has_transformers
4
-
5
4
  from ._patch_transformers_attention import (
6
5
  patched_sdpa_attention_forward,
7
6
  patched_model_bart_eager_attention_forward,
8
7
  patched_modeling_marian_eager_attention_forward,
9
8
  )
9
+ from ._patch_transformers_generation_mixin import patched_GenerationMixin
10
+ from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
11
+ from ._patch_transformers_rotary_embedding import (
12
+ patched__compute_dynamic_ntk_parameters,
13
+ patched_dynamic_rope_update,
14
+ patched_GemmaRotaryEmbedding,
15
+ patched_LlamaRotaryEmbedding,
16
+ patched_MistralRotaryEmbedding,
17
+ patched_MixtralRotaryEmbedding,
18
+ patched_PhiRotaryEmbedding,
19
+ )
20
+ from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
21
+ from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
22
+
23
+ # transformers dependent patches
10
24
 
11
25
  from ._patch_transformers_cache_utils import patch_parse_processor_args
12
26
 
13
27
  if patch_parse_processor_args:
14
28
  from ._patch_transformers_cache_utils import patched_parse_processor_args
15
-
16
- from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
17
-
18
29
  from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache
19
30
 
20
31
  if patch_DynamicLayer:
21
32
  from ._patch_transformers_dynamic_cache import patched_DynamicLayer
22
33
  if patch_DynamicCache:
23
34
  from ._patch_transformers_dynamic_cache import patched_DynamicCache
24
-
25
- from ._patch_transformers_generation_mixin import patched_GenerationMixin
26
-
27
35
  from ._patch_transformers_masking_utils import patch_masking_utils
28
36
 
29
37
  if patch_masking_utils:
@@ -33,15 +41,7 @@ if patch_masking_utils:
33
41
  patched_sdpa_mask_recent_torch,
34
42
  )
35
43
 
36
- from ._patch_transformers_rotary_embedding import (
37
- patched__compute_dynamic_ntk_parameters,
38
- patched_dynamic_rope_update,
39
- patched_GemmaRotaryEmbedding,
40
- patched_LlamaRotaryEmbedding,
41
- patched_MistralRotaryEmbedding,
42
- patched_MixtralRotaryEmbedding,
43
- patched_PhiRotaryEmbedding,
44
- )
44
+ # transformers models dependent patches
45
45
 
46
46
  if _has_transformers("4.51"):
47
47
  from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding
@@ -54,16 +54,11 @@ if _has_transformers("4.52"):
54
54
  if _has_transformers("4.53"):
55
55
  from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding
56
56
 
57
- # Models
58
-
59
57
  from ._patch_transformers_gemma3 import patch_gemma3
60
58
 
61
59
  if patch_gemma3:
62
60
  from ._patch_transformers_gemma3 import patched_Gemma3Model
63
61
 
64
- from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
65
-
66
-
67
62
  from ._patch_transformers_qwen2 import patch_qwen2
68
63
 
69
64
  if patch_qwen2:
@@ -80,14 +75,17 @@ if patch_qwen2_5:
80
75
  patched_Qwen2_5_VLModel,
81
76
  PLUGS as PLUGS_Qwen25,
82
77
  )
83
-
84
78
  from ._patch_transformers_qwen3 import patch_qwen3
85
79
 
86
80
  if patch_qwen3:
87
81
  from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock
82
+ from ._patch_transformers_funnel import patch_funnel
88
83
 
89
-
90
- from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
84
+ if patch_funnel:
85
+ from ._patch_transformers_funnel import (
86
+ patched_FunnelAttentionStructure,
87
+ patched_FunnelRelMultiheadAttention,
88
+ )
91
89
 
92
90
 
93
91
  def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821
@@ -184,7 +184,18 @@ def _trygetattr(config, attname):
184
184
  return None
185
185
 
186
186
 
187
+ def rewrite_architecture_name(name: Optional[str]) -> Optional[str]:
188
+ if name == "ConditionalDETRForObjectDetection":
189
+ return "ConditionalDetrForObjectDetection"
190
+ return name
191
+
192
+
187
193
  def architecture_from_config(config) -> Optional[str]:
194
+ """Guesses the architecture (class) of the model described by this config."""
195
+ return rewrite_architecture_name(_architecture_from_config(config))
196
+
197
+
198
+ def _architecture_from_config(config) -> Optional[str]:
188
199
  """Guesses the architecture (class) of the model described by this config."""
189
200
  if isinstance(config, dict):
190
201
  if "_class_name" in config:
@@ -5,7 +5,10 @@ from typing import Dict, List
5
5
 
6
6
  __date__ = "2025-06-21"
7
7
 
8
- __data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
8
+ __data_arch_values__ = {
9
+ "ConditionalDETRForObjectDetection": dict(image_size=224),
10
+ "ResNetForImageClassification": dict(image_size=224),
11
+ }
9
12
 
10
13
  __data_arch__ = textwrap.dedent(
11
14
  """
@@ -32,6 +35,7 @@ __data_arch__ = textwrap.dedent(
32
35
  ConvNextV2Model,image-feature-extraction
33
36
  CosmosTransformer3DModel,image-to-video
34
37
  CvtModel,feature-extraction
38
+ ClvpModelForConditionalGeneration,audio-feature-extraction
35
39
  DPTModel,image-feature-extraction
36
40
  Data2VecAudioModel,feature-extraction
37
41
  Data2VecTextModel,feature-extraction
@@ -49,6 +53,8 @@ __data_arch__ = textwrap.dedent(
49
53
  ElectraModel,feature-extraction
50
54
  EsmModel,feature-extraction
51
55
  FalconMambaForCausalLM,text-generation
56
+ FunnelBaseModel,feature-extraction
57
+ FuyuForCausalLM,image-text-to-text
52
58
  GLPNModel,image-feature-extraction
53
59
  GPT2LMHeadModel,text-generation
54
60
  GPTBigCodeModel,feature-extraction
@@ -63,6 +69,7 @@ __data_arch__ = textwrap.dedent(
63
69
  Glm4vMoeForConditionalGeneration,image-text-to-text
64
70
  GraniteForCausalLM,text-generation
65
71
  GroupViTModel,feature-extraction
72
+ HeliumForCausalLM,text-generation
66
73
  HieraForImageClassification,image-classification
67
74
  HubertModel,feature-extraction
68
75
  IBertModel,feature-extraction
@@ -136,6 +143,7 @@ __data_arch__ = textwrap.dedent(
136
143
  SwinModel,image-feature-extraction
137
144
  Swinv2Model,image-feature-extraction
138
145
  T5ForConditionalGeneration,text2text-generation
146
+ T5GemmaForConditionalGeneration,text2text-generation
139
147
  TableTransformerModel,image-feature-extraction
140
148
  TableTransformerForObjectDetection,object-detection
141
149
  UNet2DConditionModel,text-to-image
@@ -64,6 +64,7 @@ def get_untrained_model_with_inputs(
64
64
  use_only_preinstalled: bool = False,
65
65
  config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66
66
  submodule: Optional[str] = None,
67
+ skip_inputs: bool = False,
67
68
  ) -> Dict[str, Any]:
68
69
  """
69
70
  Gets a non initialized model similar to the original model
@@ -93,6 +94,7 @@ def get_untrained_model_with_inputs(
93
94
  this function takes a configuration and a task (string)
94
95
  as arguments
95
96
  :param submodule: use a submodule instead of the main model
97
+ :param skip_inputs: do not generate the inputs
96
98
  :return: dictionary with a model, inputs, dynamic shapes, and the configuration,
97
99
  some necessary rewriting as well
98
100
 
@@ -332,13 +334,12 @@ def get_untrained_model_with_inputs(
332
334
  f"[get_untrained_model_with_inputs] "
333
335
  f"instantiate_specific_model(2) {cls_model}"
334
336
  )
335
-
336
337
  try:
337
338
  if type(config) is dict:
338
339
  model = cls_model(**config)
339
340
  else:
340
341
  model = cls_model(config)
341
- except RuntimeError as e:
342
+ except (RuntimeError, AttributeError, ValueError) as e:
342
343
  raise RuntimeError(
343
344
  f"Unable to instantiate class {cls_model.__name__} with\n{config}"
344
345
  ) from e
@@ -350,23 +351,27 @@ def get_untrained_model_with_inputs(
350
351
  )
351
352
 
352
353
  # input kwargs
353
- seed = int(os.environ.get("SEED", "17")) + 1
354
- torch.manual_seed(seed)
355
- kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
356
- if verbose:
357
- print(f"[get_untrained_model_with_inputs] use fct={fct}")
358
- if os.environ.get("PRINT_CONFIG") in (1, "1"):
359
- print(f"-- input kwargs for task {task!r}")
360
- pprint.pprint(kwargs)
361
- if inputs_kwargs:
362
- kwargs.update(inputs_kwargs)
363
-
364
- # This line is important. Some models may produce different
365
- # outputs even with the same inputs in training mode.
366
- model.eval() # type: ignore[union-attr]
367
- res = fct(model, config, add_second_input=add_second_input, **kwargs)
368
-
369
- res["input_kwargs"] = kwargs
354
+ if not skip_inputs:
355
+ seed = int(os.environ.get("SEED", "17")) + 1
356
+ torch.manual_seed(seed)
357
+ kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
358
+ if verbose:
359
+ print(f"[get_untrained_model_with_inputs] use fct={fct}")
360
+ if os.environ.get("PRINT_CONFIG") in (1, "1"):
361
+ print(f"-- input kwargs for task {task!r}")
362
+ pprint.pprint(kwargs)
363
+ if inputs_kwargs:
364
+ kwargs.update(inputs_kwargs)
365
+
366
+ # This line is important. Some models may produce different
367
+ # outputs even with the same inputs in training mode.
368
+ model.eval() # type: ignore[union-attr]
369
+ res = fct(model, config, add_second_input=add_second_input, **kwargs)
370
+
371
+ res["input_kwargs"] = kwargs
372
+ else:
373
+ res = {}
374
+
370
375
  res["model_kwargs"] = mkwargs
371
376
  if diff_config is not None:
372
377
  res["dump_info"] = dict(config_diff=diff_config)