onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +5 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +65 -12
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +55 -37
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +15 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
  24. onnx_diagnostic/torch_models/validate.py +337 -113
  25. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  26. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
  27. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import os
3
3
  import traceback
4
- from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5
5
  import torch
6
6
  from torch._subclasses.fake_tensor import FakeTensorMode
7
7
 
@@ -65,6 +65,8 @@ def patch__check_input_constraints_for_graph(
65
65
  verbose: int = 0,
66
66
  ) -> None:
67
67
  try:
68
+ # PATCHED: catches exception and prints out the information instead of
69
+ # stopping the conversion.
68
70
  return previous_function(input_placeholders, flat_args_with_path, range_constraints)
69
71
  except Exception as e:
70
72
  if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
@@ -122,8 +124,7 @@ def patched_infer_size(a, b):
122
124
  if b1 or b2 or b3:
123
125
  expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
124
126
  else:
125
- # In this case, the current implementation of torch fails (17/12/2024).
126
- # Try model SmolLM.
127
+ # PATCHED: generic case, the dimension is known, no need to assert
127
128
  expandedSizes[i] = torch.sym_max(sizeA, sizeB)
128
129
  return tuple(expandedSizes)
129
130
 
@@ -132,7 +133,11 @@ def patched__broadcast_shapes(*_shapes):
132
133
  """Patches ``torch._refs._broadcast_shapes``."""
133
134
  from functools import reduce
134
135
  from torch._prims_common import IntLike
135
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
136
+ from torch.fx.experimental.symbolic_shapes import (
137
+ guard_size_oblivious,
138
+ guard_or_false,
139
+ is_nested_int,
140
+ )
136
141
 
137
142
  shapes = tuple(
138
143
  (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
@@ -142,17 +147,30 @@ def patched__broadcast_shapes(*_shapes):
142
147
  if len(shapes) == 0:
143
148
  return None
144
149
 
145
- # Type checking
146
- # TODO: make common validations available as utils
147
150
  for shape in shapes:
148
- assert isinstance(shape, Sequence)
151
+ if not isinstance(shape, Sequence):
152
+ raise RuntimeError(
153
+ "Input shapes should be of type ints, a tuple of ints, "
154
+ "or a list of ints, got ",
155
+ shape,
156
+ )
149
157
 
150
158
  # Computes common shape
151
- common_shape = [ # List[Union[int, torch.SymInt]]
152
- 1,
153
- ] * reduce(max, (len(shape) for shape in shapes))
159
+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
154
160
  for _arg_idx, shape in enumerate(shapes):
155
161
  for idx in range(-1, -1 - len(shape), -1):
162
+ if is_nested_int(shape[idx]):
163
+ # Broadcasting is allowed for (j0, 1) or (j0, j0);
164
+ # not (j0, j1), (j0, 5), etc.
165
+ if is_nested_int(common_shape[idx]) and guard_or_false(
166
+ shape[idx] == common_shape[idx]
167
+ ):
168
+ continue
169
+ else:
170
+ if guard_or_false(shape[idx] == common_shape[idx]):
171
+ continue
172
+ # PATCHED: two cases, if == for sure, no broadcast,
173
+ # otherwise maybe broadcast with max(dimensions)
156
174
  if guard_size_oblivious(common_shape[idx] == 1):
157
175
  if shape[idx] < 0:
158
176
  raise ValueError(
@@ -172,6 +190,7 @@ class patched_ShapeEnv:
172
190
  ) -> None:
173
191
  if self.frozen:
174
192
  self.counter["ignored_backward_guard"] += 1
193
+ # PATCHED: raised an exception instead of logging.
175
194
  raise AssertionError(
176
195
  f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
177
196
  f"this could result in accuracy problems"
@@ -338,11 +357,13 @@ class patched_ShapeEnv:
338
357
  },
339
358
  )
340
359
 
360
+ # PATCHED: removed lines
341
361
  # if config.print_specializations:
342
362
  # self.log.warning(
343
363
  # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
344
364
  # )
345
365
  # self.log.debug("SPECIALIZATION", stack_info=True)
366
+ # PATCHED: replaces logging by raising an exception
346
367
  assert msg != "range_refined_to_singleton", (
347
368
  f"patched_ShapeEnv: A dynamic dimension becomes static! "
348
369
  f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
@@ -364,6 +385,7 @@ class patched_ShapeEnv:
364
385
  self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
365
386
  ) -> None:
366
387
  self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
388
+ # PATCHED: removed
367
389
  # It happens too often to be relevant.
368
390
  # sloc, _maybe_extra_debug = self._get_stack_summary(True)
369
391
  # warnings.warn(
@@ -464,3 +486,87 @@ def patched_vmap(func, in_dims=0, out_dims=0):
464
486
  return results
465
487
 
466
488
  return wrapped
489
+
490
+
491
+ def patched__constrain_user_specified_dimhint_range(
492
+ symint: torch.SymInt,
493
+ hint: int,
494
+ dim: "_DimHint", # noqa: F821
495
+ range_constraints,
496
+ shape_env,
497
+ keypath: "KeyPath", # noqa: F821
498
+ i: Optional[int] = None,
499
+ ) -> Optional[str]:
500
+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
501
+ from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
502
+
503
+ trace_vr = (
504
+ range_constraints[symint.node.expr]
505
+ if not is_int(symint)
506
+ else ValueRanges(int(symint), int(symint))
507
+ )
508
+ # warn on 0/1 specialization for Dim.AUTO; not an actual error
509
+ # PATCHED: remove logging
510
+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
511
+ # pathstr = f"inputs{pytree.keystr(keypath)}"
512
+ # if i is not None:
513
+ # pathstr += f".shape[{i}]"
514
+ # msg = (
515
+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
516
+ # f"with a sample input with hint = {hint}."
517
+ # )
518
+ # log.warning(msg)
519
+
520
+ try:
521
+ user_vr = ValueRanges(
522
+ lower=0 if dim.min is None else dim.min,
523
+ upper=int_oo if dim.max is None else dim.max,
524
+ )
525
+ if is_int(symint):
526
+ out_vr = trace_vr & user_vr
527
+ else:
528
+ range_constraints[symint.node.expr] &= user_vr
529
+ shape_env.var_to_range[symint.node._expr] &= user_vr
530
+ out_vr = range_constraints[symint.node.expr]
531
+
532
+ # check for Dim.DYNAMIC specializations; special case error message on 0/1
533
+ if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
534
+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
535
+ if i is not None:
536
+ path += f".shape[{i}]"
537
+ if (
538
+ trace_vr.is_singleton()
539
+ and hint in (0, 1)
540
+ # PATCHED: line removed
541
+ # and not torch.fx.experimental._config.backed_size_oblivious
542
+ ):
543
+ return None
544
+ # PATCHED: line removed
545
+ # msg = (
546
+ # f"- Received user-specified dim hint "
547
+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
548
+ # f"but export 0/1 specialized due to hint of "
549
+ # f"{hint} for dimension {path}."
550
+ # )
551
+ else:
552
+ msg = (
553
+ f"- Received user-specified dim hint "
554
+ f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
555
+ f"but tracing inferred a static shape of "
556
+ f"{out_vr.lower} for dimension {path}."
557
+ )
558
+ return msg
559
+
560
+ except torch.utils._sympy.value_ranges.ValueRangeError:
561
+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
562
+ if i is not None:
563
+ path += f".shape[{i}]"
564
+ msg = (
565
+ f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
566
+ f"conflicting with the inferred min/max range of "
567
+ f"[{trace_vr.lower}, {trace_vr.upper}], "
568
+ f"for {path}."
569
+ )
570
+ return msg
571
+
572
+ return None
@@ -1,13 +1,20 @@
1
1
  import inspect
2
2
  import math
3
+ import os
3
4
  from dataclasses import dataclass
4
5
  from functools import wraps
5
- from typing import Callable, List, Optional, Tuple
6
+ from typing import Callable, List, Optional, Tuple, Union
6
7
  import packaging.version as pv
7
8
  import torch
8
9
  import transformers
9
10
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
10
11
  from transformers.cache_utils import StaticCache, Cache
12
+ from transformers.generation.utils import (
13
+ GenerateNonBeamOutput,
14
+ GenerationConfig,
15
+ StoppingCriteriaList,
16
+ LogitsProcessorList,
17
+ )
11
18
 
12
19
  try:
13
20
  from transformers.cache_utils import parse_processor_args # noqa: F401
@@ -114,6 +121,7 @@ if patch_masking_utils:
114
121
  """manual patch for function ``transformers.masking_utils.eager_mask``."""
115
122
  # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
116
123
  _ = kwargs.pop("allow_is_causal_skip", None)
124
+ # PATCHED: this line called the patched version of sdpa_mask
117
125
  mask = patched_sdpa_mask_recent_torch(
118
126
  batch_size=batch_size,
119
127
  cache_position=cache_position,
@@ -126,7 +134,7 @@ if patch_masking_utils:
126
134
  **kwargs,
127
135
  )
128
136
  min_dtype = torch.finfo(dtype).min
129
- # The patched line.
137
+ # PATCHED: the following line
130
138
  # we need 0s where the tokens should be taken into account,
131
139
  # and -inf otherwise (mask is already of boolean type)
132
140
  # mask =
@@ -158,6 +166,7 @@ if patch_masking_utils:
158
166
  mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
159
167
  batch_arange = torch.arange(batch_size, device=cache_position.device)
160
168
  head_arange = torch.arange(1, device=cache_position.device)
169
+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
161
170
  causal_mask = patched__vmap_for_bhqkv(mask_function)(
162
171
  batch_arange, head_arange, cache_position, kv_arange
163
172
  )
@@ -214,6 +223,7 @@ if patch_DynamicLayer:
214
223
  self.dtype, self.device = key_states.dtype, key_states.device
215
224
  new_shape = list(key_states.shape)
216
225
  new_shape[-2] = 0
226
+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
217
227
  self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
218
228
  self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
219
229
  if patch_is_initialized:
@@ -248,6 +258,8 @@ def _patch_make_causal_mask(
248
258
  diagonal = past_key_values_length - sliding_window - 1
249
259
 
250
260
  context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
261
+ # PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
262
+ # and used masked_fill instead of masked_fill_
251
263
  # In this case, the current implementation of torch fails (17/12/2024).
252
264
  # Try model Phi-3.5-Mini-Instruct.
253
265
  mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
@@ -455,7 +467,16 @@ class patched_GenerationMixin:
455
467
  _PATCHES_ = [
456
468
  "_cache_dependant_input_preparation",
457
469
  "_cache_dependant_input_preparation_exporting",
458
- "prepare_inputs_for_generation",
470
+ (
471
+ None
472
+ if pv.Version(transformers.__version__) >= pv.Version("4.56")
473
+ else "prepare_inputs_for_generation"
474
+ ),
475
+ (
476
+ "_sample"
477
+ if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
478
+ else None
479
+ ),
459
480
  ]
460
481
  _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
461
482
 
@@ -588,7 +609,7 @@ class patched_GenerationMixin:
588
609
  model_inputs = {}
589
610
  # - some models don't have `Cache` support
590
611
  # (which implies they don't expect `cache_position` in `forward`)
591
- if self._supports_cache_class:
612
+ if getattr(self, "_supports_cache_class", False):
592
613
  model_inputs["cache_position"] = cache_position
593
614
  # - `cache_position` was not a mandatory input in
594
615
  # `prepare_inputs_for_generation` for those models, and this
@@ -728,6 +749,192 @@ class patched_GenerationMixin:
728
749
  model_inputs.pop("labels", None)
729
750
  return model_inputs
730
751
 
752
+ def _sample(
753
+ self,
754
+ input_ids: torch.LongTensor,
755
+ logits_processor: "LogitsProcessorList", # noqa: F821
756
+ stopping_criteria: "StoppingCriteriaList", # noqa: F821
757
+ generation_config: "GenerationConfig", # noqa: F821
758
+ synced_gpus: bool = False,
759
+ streamer: Optional["BaseStreamer"] = None, # noqa: F821
760
+ **model_kwargs,
761
+ ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
762
+ """
763
+ 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
764
+ """
765
+ # init values
766
+ pad_token_id = generation_config._pad_token_tensor
767
+ output_attentions = generation_config.output_attentions
768
+ output_hidden_states = generation_config.output_hidden_states
769
+ output_scores = generation_config.output_scores
770
+ output_logits = generation_config.output_logits
771
+ return_dict_in_generate = generation_config.return_dict_in_generate
772
+ has_eos_stopping_criteria = any(
773
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
774
+ )
775
+ do_sample = generation_config.do_sample
776
+
777
+ # init attention / hidden states / scores tuples
778
+ scores = () if (return_dict_in_generate and output_scores) else None
779
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
780
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
781
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
782
+ decoder_hidden_states = (
783
+ () if (return_dict_in_generate and output_hidden_states) else None
784
+ )
785
+
786
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
787
+ if return_dict_in_generate and self.config.is_encoder_decoder:
788
+ encoder_attentions = (
789
+ model_kwargs["encoder_outputs"].get("attentions")
790
+ if output_attentions
791
+ else None
792
+ )
793
+ encoder_hidden_states = (
794
+ model_kwargs["encoder_outputs"].get("hidden_states")
795
+ if output_hidden_states
796
+ else None
797
+ )
798
+
799
+ # keep track of which sequences are already finished
800
+ batch_size, cur_len = input_ids.shape[:2]
801
+ this_peer_finished = False
802
+ unfinished_sequences = torch.ones(
803
+ batch_size, dtype=torch.long, device=input_ids.device
804
+ )
805
+ model_kwargs = self._get_initial_cache_position(
806
+ cur_len, input_ids.device, model_kwargs
807
+ )
808
+
809
+ model_forward = self.__call__
810
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
811
+ if compile_forward:
812
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
813
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
814
+ if self.config._attn_implementation == "flash_attention_2":
815
+ # only raise warning if the user passed an explicit compile-config
816
+ if (
817
+ generation_config.compile_config is not None
818
+ and generation_config.compile_config.fullgraph
819
+ ):
820
+ generation_config.compile_config.fullgraph = False
821
+ model_forward = self.get_compiled_call(generation_config.compile_config)
822
+
823
+ if generation_config.prefill_chunk_size is not None:
824
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
825
+ is_prefill = False
826
+ else:
827
+ is_prefill = True
828
+
829
+ while self._has_unfinished_sequences(
830
+ this_peer_finished, synced_gpus, device=input_ids.device
831
+ ):
832
+ # prepare model inputs
833
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
834
+
835
+ if is_prefill:
836
+ outputs = self(**model_inputs, return_dict=True)
837
+ is_prefill = False
838
+ else:
839
+ outputs = model_forward(**model_inputs, return_dict=True)
840
+
841
+ model_kwargs = self._update_model_kwargs_for_generation(
842
+ outputs,
843
+ model_kwargs,
844
+ is_encoder_decoder=self.config.is_encoder_decoder,
845
+ )
846
+ if synced_gpus and this_peer_finished:
847
+ continue
848
+
849
+ next_token_logits = outputs.logits[:, -1, :].to(
850
+ copy=True, dtype=torch.float32, device=input_ids.device
851
+ )
852
+
853
+ # pre-process distribution
854
+ next_token_scores = logits_processor(input_ids, next_token_logits)
855
+
856
+ # Store scores, attentions and hidden_states when required
857
+ if return_dict_in_generate:
858
+ if output_scores:
859
+ scores += (next_token_scores,)
860
+ if output_logits:
861
+ raw_logits += (next_token_logits,)
862
+ if output_attentions:
863
+ decoder_attentions += (
864
+ (outputs.decoder_attentions,)
865
+ if self.config.is_encoder_decoder
866
+ else (outputs.attentions,)
867
+ )
868
+ if self.config.is_encoder_decoder:
869
+ cross_attentions += (outputs.cross_attentions,)
870
+
871
+ if output_hidden_states:
872
+ decoder_hidden_states += (
873
+ (outputs.decoder_hidden_states,)
874
+ if self.config.is_encoder_decoder
875
+ else (outputs.hidden_states,)
876
+ )
877
+
878
+ # token selection
879
+ if do_sample:
880
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
881
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
882
+ else:
883
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
884
+
885
+ # finished sentences should have their next token be a padding token
886
+ if has_eos_stopping_criteria:
887
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
888
+ 1 - unfinished_sequences
889
+ )
890
+
891
+ # update generated ids, model inputs, and length for next step
892
+ # PATCHED: the two following lines, next_tokens can 2D already for this model
893
+ next_tokens_2d = (
894
+ next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
895
+ )
896
+ input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
897
+ if streamer is not None:
898
+ streamer.put(next_tokens.cpu())
899
+
900
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
901
+ this_peer_finished = unfinished_sequences.max() == 0
902
+ cur_len += 1
903
+
904
+ # This is needed to properly delete outputs.logits which may be very large
905
+ # for first iteration
906
+ # Otherwise a reference to outputs is kept which keeps
907
+ # the logits alive in the next iteration
908
+ del outputs
909
+
910
+ if streamer is not None:
911
+ streamer.end()
912
+
913
+ if return_dict_in_generate:
914
+ if self.config.is_encoder_decoder:
915
+ return transformers.generation.utils.GenerateEncoderDecoderOutput(
916
+ sequences=input_ids,
917
+ scores=scores,
918
+ logits=raw_logits,
919
+ encoder_attentions=encoder_attentions,
920
+ encoder_hidden_states=encoder_hidden_states,
921
+ decoder_attentions=decoder_attentions,
922
+ cross_attentions=cross_attentions,
923
+ decoder_hidden_states=decoder_hidden_states,
924
+ past_key_values=model_kwargs.get("past_key_values"),
925
+ )
926
+ else:
927
+ return transformers.generation.utils.GenerateDecoderOnlyOutput(
928
+ sequences=input_ids,
929
+ scores=scores,
930
+ logits=raw_logits,
931
+ attentions=decoder_attentions,
932
+ hidden_states=decoder_hidden_states,
933
+ past_key_values=model_kwargs.get("past_key_values"),
934
+ )
935
+ else:
936
+ return input_ids
937
+
731
938
 
732
939
  def patched__compute_dynamic_ntk_parameters(
733
940
  config: Optional[transformers.PretrainedConfig] = None,
@@ -791,6 +998,7 @@ def patched__compute_dynamic_ntk_parameters(
791
998
  if seq_len is None:
792
999
  seq_len = max_position_embeddings
793
1000
  else:
1001
+ # PATCHED: remove the line using max
794
1002
  torch._check(isinstance(seq_len, torch.Tensor))
795
1003
  seq_len = torch.maximum(
796
1004
  seq_len,
@@ -896,6 +1104,7 @@ def patched_dynamic_rope_update(rope_forward):
896
1104
  )
897
1105
  original_inv_freq = self.original_inv_freq.to(device)
898
1106
 
1107
+ # PATCHED: uses torch.cond instead of a test
899
1108
  cond = (seq_len > original_max_position_embeddings).item()
900
1109
  inv_freq = torch.cond(
901
1110
  cond,
@@ -967,6 +1176,7 @@ def patched_dynamic_rope_update(rope_forward):
967
1176
 
968
1177
  original_inv_freq = self.original_inv_freq.to(device)
969
1178
  cond = (seq_len >= self.original_max_seq_len).item()
1179
+ # PATCHED: uses torch.cond instead of a test
970
1180
  inv_freq = torch.cond(
971
1181
  cond,
972
1182
  (lambda x, y: x.clone()),
@@ -1002,6 +1212,7 @@ def common_eager_attention_forward(
1002
1212
 
1003
1213
  attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
1004
1214
  if attention_mask is not None:
1215
+ # PATCHED
1005
1216
  # The two following lines were added.
1006
1217
  if attention_mask is not None and attention_mask.ndim == 4:
1007
1218
  attention_mask = attention_mask[:, :, :, : key.shape[-2]]
@@ -1074,6 +1285,7 @@ def patched_modeling_marian_eager_attention_forward(
1074
1285
  class common_RotaryEmbedding(torch.nn.Module):
1075
1286
  # This may cause some issues.
1076
1287
  # @torch.no_grad()
1288
+ # PATCHED: the decorator
1077
1289
  @patched_dynamic_rope_update
1078
1290
  def forward(self, x, position_ids):
1079
1291
  inv_freq_expanded = (
@@ -1629,3 +1841,56 @@ if patch_qwen3:
1629
1841
  batch_size, sequence_length, hidden_dim
1630
1842
  )
1631
1843
  return final_hidden_states, router_logits
1844
+
1845
+
1846
+ try:
1847
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
1848
+
1849
+ patch_gemma3 = True
1850
+ except ImportError:
1851
+ patch_gemma3 = False
1852
+
1853
+
1854
+ if patch_gemma3:
1855
+
1856
+ class patched_Gemma3Model(torch.nn.Module):
1857
+ _PATCHES_ = ["get_placeholder_mask"]
1858
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
1859
+ _PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
1860
+
1861
+ def get_placeholder_mask(
1862
+ self,
1863
+ input_ids: torch.LongTensor,
1864
+ inputs_embeds: torch.FloatTensor,
1865
+ image_features: torch.FloatTensor,
1866
+ ):
1867
+ if input_ids is None:
1868
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1869
+ torch.tensor(
1870
+ self.config.image_token_id,
1871
+ dtype=torch.long,
1872
+ device=inputs_embeds.device,
1873
+ )
1874
+ )
1875
+ special_image_mask = special_image_mask.all(-1)
1876
+ else:
1877
+ special_image_mask = input_ids == self.config.image_token_id
1878
+
1879
+ n_image_tokens = special_image_mask.sum()
1880
+ special_image_mask = (
1881
+ special_image_mask.unsqueeze(-1)
1882
+ .expand_as(inputs_embeds)
1883
+ .to(inputs_embeds.device)
1884
+ )
1885
+ n_image_features = image_features.shape[0] * image_features.shape[1]
1886
+ # PATCHED: torch._check
1887
+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
1888
+ # raise ValueError( ... )
1889
+ torch._check(
1890
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
1891
+ lambda: (
1892
+ f"Image features and image tokens do not match: tokens: "
1893
+ f"{n_image_tokens}, features {n_image_features}"
1894
+ ),
1895
+ )
1896
+ return special_image_mask
@@ -289,21 +289,17 @@ def task_from_tags(tags: Union[str, List[str]]) -> str:
289
289
 
290
290
  def enumerate_model_list(
291
291
  n: int = 50,
292
- task: Optional[str] = None,
293
- library: Optional[str] = None,
294
- tags: Optional[Union[str, List[str]]] = None,
292
+ pipeline_tag: Optional[str] = None,
295
293
  search: Optional[str] = None,
296
294
  dump: Optional[str] = None,
297
- filter: Optional[str] = None,
295
+ filter: Optional[Union[str, List[str]]] = None,
298
296
  verbose: int = 0,
299
297
  ):
300
298
  """
301
299
  Enumerates models coming from :epkg:`huggingface_hub`.
302
300
 
303
301
  :param n: number of models to retrieve (-1 for all)
304
- :param task: see :meth:`huggingface_hub.HfApi.list_models`
305
- :param tags: see :meth:`huggingface_hub.HfApi.list_models`
306
- :param library: see :meth:`huggingface_hub.HfApi.list_models`
302
+ :param pipeline_tag: see :meth:`huggingface_hub.HfApi.list_models`
307
303
  :param search: see :meth:`huggingface_hub.HfApi.list_models`
308
304
  :param filter: see :meth:`huggingface_hub.HfApi.list_models`
309
305
  :param dump: dumps the result in this csv file
@@ -311,9 +307,7 @@ def enumerate_model_list(
311
307
  """
312
308
  api = HfApi()
313
309
  models = api.list_models(
314
- task=task,
315
- library=library,
316
- tags=tags,
310
+ pipeline_tag=pipeline_tag,
317
311
  search=search,
318
312
  full=True,
319
313
  filter=filter,
@@ -4829,3 +4829,39 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
4829
4829
  "vocab_size": 32064,
4830
4830
  }
4831
4831
  )
4832
+
4833
+
4834
+ def _ccached_google_gemma_3_4b_it_like():
4835
+ "google/gemma-3-4b-it"
4836
+ return transformers.Gemma3Config(
4837
+ **{
4838
+ "architectures": ["Gemma3ForConditionalGeneration"],
4839
+ "boi_token_index": 255999,
4840
+ "eoi_token_index": 256000,
4841
+ "eos_token_id": [1, 106],
4842
+ "image_token_index": 262144,
4843
+ "initializer_range": 0.02,
4844
+ "mm_tokens_per_image": 256,
4845
+ "model_type": "gemma3",
4846
+ "text_config": {
4847
+ "hidden_size": 2560,
4848
+ "intermediate_size": 10240,
4849
+ "model_type": "gemma3_text",
4850
+ "num_hidden_layers": 34,
4851
+ "rope_scaling": {"factor": 8.0, "rope_type": "linear"},
4852
+ "sliding_window": 1024,
4853
+ },
4854
+ "torch_dtype": "bfloat16",
4855
+ "transformers_version": "4.50.0.dev0",
4856
+ "vision_config": {
4857
+ "hidden_size": 1152,
4858
+ "image_size": 896,
4859
+ "intermediate_size": 4304,
4860
+ "model_type": "siglip_vision_model",
4861
+ "num_attention_heads": 16,
4862
+ "num_hidden_layers": 27,
4863
+ "patch_size": 14,
4864
+ "vision_use_head": false,
4865
+ },
4866
+ }
4867
+ )