onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +2 -1
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +15 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +29 -14
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +31 -3
- onnx_diagnostic/torch_models/validate.py +114 -36
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +27 -25
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import os
|
|
3
3
|
import traceback
|
|
4
|
-
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
5
5
|
import torch
|
|
6
6
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
7
7
|
|
|
@@ -65,6 +65,8 @@ def patch__check_input_constraints_for_graph(
|
|
|
65
65
|
verbose: int = 0,
|
|
66
66
|
) -> None:
|
|
67
67
|
try:
|
|
68
|
+
# PATCHED: catches exception and prints out the information instead of
|
|
69
|
+
# stopping the conversion.
|
|
68
70
|
return previous_function(input_placeholders, flat_args_with_path, range_constraints)
|
|
69
71
|
except Exception as e:
|
|
70
72
|
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
|
|
@@ -122,8 +124,7 @@ def patched_infer_size(a, b):
|
|
|
122
124
|
if b1 or b2 or b3:
|
|
123
125
|
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
|
|
124
126
|
else:
|
|
125
|
-
#
|
|
126
|
-
# Try model SmolLM.
|
|
127
|
+
# PATCHED: generic case, the dimension is known, no need to assert
|
|
127
128
|
expandedSizes[i] = torch.sym_max(sizeA, sizeB)
|
|
128
129
|
return tuple(expandedSizes)
|
|
129
130
|
|
|
@@ -132,7 +133,11 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
132
133
|
"""Patches ``torch._refs._broadcast_shapes``."""
|
|
133
134
|
from functools import reduce
|
|
134
135
|
from torch._prims_common import IntLike
|
|
135
|
-
from torch.fx.experimental.symbolic_shapes import
|
|
136
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
137
|
+
guard_size_oblivious,
|
|
138
|
+
guard_or_false,
|
|
139
|
+
is_nested_int,
|
|
140
|
+
)
|
|
136
141
|
|
|
137
142
|
shapes = tuple(
|
|
138
143
|
(x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
|
|
@@ -142,17 +147,30 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
142
147
|
if len(shapes) == 0:
|
|
143
148
|
return None
|
|
144
149
|
|
|
145
|
-
# Type checking
|
|
146
|
-
# TODO: make common validations available as utils
|
|
147
150
|
for shape in shapes:
|
|
148
|
-
|
|
151
|
+
if not isinstance(shape, Sequence):
|
|
152
|
+
raise RuntimeError(
|
|
153
|
+
"Input shapes should be of type ints, a tuple of ints, "
|
|
154
|
+
"or a list of ints, got ",
|
|
155
|
+
shape,
|
|
156
|
+
)
|
|
149
157
|
|
|
150
158
|
# Computes common shape
|
|
151
|
-
common_shape = [
|
|
152
|
-
1,
|
|
153
|
-
] * reduce(max, (len(shape) for shape in shapes))
|
|
159
|
+
common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
|
|
154
160
|
for _arg_idx, shape in enumerate(shapes):
|
|
155
161
|
for idx in range(-1, -1 - len(shape), -1):
|
|
162
|
+
if is_nested_int(shape[idx]):
|
|
163
|
+
# Broadcasting is allowed for (j0, 1) or (j0, j0);
|
|
164
|
+
# not (j0, j1), (j0, 5), etc.
|
|
165
|
+
if is_nested_int(common_shape[idx]) and guard_or_false(
|
|
166
|
+
shape[idx] == common_shape[idx]
|
|
167
|
+
):
|
|
168
|
+
continue
|
|
169
|
+
else:
|
|
170
|
+
if guard_or_false(shape[idx] == common_shape[idx]):
|
|
171
|
+
continue
|
|
172
|
+
# PATCHED: two cases, if == for sure, no broadcast,
|
|
173
|
+
# otherwise maybe broadcast with max(dimensions)
|
|
156
174
|
if guard_size_oblivious(common_shape[idx] == 1):
|
|
157
175
|
if shape[idx] < 0:
|
|
158
176
|
raise ValueError(
|
|
@@ -172,6 +190,7 @@ class patched_ShapeEnv:
|
|
|
172
190
|
) -> None:
|
|
173
191
|
if self.frozen:
|
|
174
192
|
self.counter["ignored_backward_guard"] += 1
|
|
193
|
+
# PATCHED: raised an exception instead of logging.
|
|
175
194
|
raise AssertionError(
|
|
176
195
|
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
177
196
|
f"this could result in accuracy problems"
|
|
@@ -338,11 +357,13 @@ class patched_ShapeEnv:
|
|
|
338
357
|
},
|
|
339
358
|
)
|
|
340
359
|
|
|
360
|
+
# PATCHED: removed lines
|
|
341
361
|
# if config.print_specializations:
|
|
342
362
|
# self.log.warning(
|
|
343
363
|
# "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
|
|
344
364
|
# )
|
|
345
365
|
# self.log.debug("SPECIALIZATION", stack_info=True)
|
|
366
|
+
# PATCHED: replaces logging by raising an exception
|
|
346
367
|
assert msg != "range_refined_to_singleton", (
|
|
347
368
|
f"patched_ShapeEnv: A dynamic dimension becomes static! "
|
|
348
369
|
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
|
|
@@ -364,6 +385,7 @@ class patched_ShapeEnv:
|
|
|
364
385
|
self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
|
|
365
386
|
) -> None:
|
|
366
387
|
self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
|
|
388
|
+
# PATCHED: removed
|
|
367
389
|
# It happens too often to be relevant.
|
|
368
390
|
# sloc, _maybe_extra_debug = self._get_stack_summary(True)
|
|
369
391
|
# warnings.warn(
|
|
@@ -464,3 +486,87 @@ def patched_vmap(func, in_dims=0, out_dims=0):
|
|
|
464
486
|
return results
|
|
465
487
|
|
|
466
488
|
return wrapped
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def patched__constrain_user_specified_dimhint_range(
|
|
492
|
+
symint: torch.SymInt,
|
|
493
|
+
hint: int,
|
|
494
|
+
dim: "_DimHint", # noqa: F821
|
|
495
|
+
range_constraints,
|
|
496
|
+
shape_env,
|
|
497
|
+
keypath: "KeyPath", # noqa: F821
|
|
498
|
+
i: Optional[int] = None,
|
|
499
|
+
) -> Optional[str]:
|
|
500
|
+
"""Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
|
|
501
|
+
from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
|
|
502
|
+
|
|
503
|
+
trace_vr = (
|
|
504
|
+
range_constraints[symint.node.expr]
|
|
505
|
+
if not is_int(symint)
|
|
506
|
+
else ValueRanges(int(symint), int(symint))
|
|
507
|
+
)
|
|
508
|
+
# warn on 0/1 specialization for Dim.AUTO; not an actual error
|
|
509
|
+
# PATCHED: remove logging
|
|
510
|
+
# if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
|
|
511
|
+
# pathstr = f"inputs{pytree.keystr(keypath)}"
|
|
512
|
+
# if i is not None:
|
|
513
|
+
# pathstr += f".shape[{i}]"
|
|
514
|
+
# msg = (
|
|
515
|
+
# f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
|
|
516
|
+
# f"with a sample input with hint = {hint}."
|
|
517
|
+
# )
|
|
518
|
+
# log.warning(msg)
|
|
519
|
+
|
|
520
|
+
try:
|
|
521
|
+
user_vr = ValueRanges(
|
|
522
|
+
lower=0 if dim.min is None else dim.min,
|
|
523
|
+
upper=int_oo if dim.max is None else dim.max,
|
|
524
|
+
)
|
|
525
|
+
if is_int(symint):
|
|
526
|
+
out_vr = trace_vr & user_vr
|
|
527
|
+
else:
|
|
528
|
+
range_constraints[symint.node.expr] &= user_vr
|
|
529
|
+
shape_env.var_to_range[symint.node._expr] &= user_vr
|
|
530
|
+
out_vr = range_constraints[symint.node.expr]
|
|
531
|
+
|
|
532
|
+
# check for Dim.DYNAMIC specializations; special case error message on 0/1
|
|
533
|
+
if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
|
|
534
|
+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
|
|
535
|
+
if i is not None:
|
|
536
|
+
path += f".shape[{i}]"
|
|
537
|
+
if (
|
|
538
|
+
trace_vr.is_singleton()
|
|
539
|
+
and hint in (0, 1)
|
|
540
|
+
# PATCHED: line removed
|
|
541
|
+
# and not torch.fx.experimental._config.backed_size_oblivious
|
|
542
|
+
):
|
|
543
|
+
return None
|
|
544
|
+
# PATCHED: line removed
|
|
545
|
+
# msg = (
|
|
546
|
+
# f"- Received user-specified dim hint "
|
|
547
|
+
# f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
|
|
548
|
+
# f"but export 0/1 specialized due to hint of "
|
|
549
|
+
# f"{hint} for dimension {path}."
|
|
550
|
+
# )
|
|
551
|
+
else:
|
|
552
|
+
msg = (
|
|
553
|
+
f"- Received user-specified dim hint "
|
|
554
|
+
f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
|
|
555
|
+
f"but tracing inferred a static shape of "
|
|
556
|
+
f"{out_vr.lower} for dimension {path}."
|
|
557
|
+
)
|
|
558
|
+
return msg
|
|
559
|
+
|
|
560
|
+
except torch.utils._sympy.value_ranges.ValueRangeError:
|
|
561
|
+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
|
|
562
|
+
if i is not None:
|
|
563
|
+
path += f".shape[{i}]"
|
|
564
|
+
msg = (
|
|
565
|
+
f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
|
|
566
|
+
f"conflicting with the inferred min/max range of "
|
|
567
|
+
f"[{trace_vr.lower}, {trace_vr.upper}], "
|
|
568
|
+
f"for {path}."
|
|
569
|
+
)
|
|
570
|
+
return msg
|
|
571
|
+
|
|
572
|
+
return None
|
|
@@ -1,13 +1,20 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import math
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from functools import wraps
|
|
5
|
-
from typing import Callable, List, Optional, Tuple
|
|
6
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
6
7
|
import packaging.version as pv
|
|
7
8
|
import torch
|
|
8
9
|
import transformers
|
|
9
10
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
10
11
|
from transformers.cache_utils import StaticCache, Cache
|
|
12
|
+
from transformers.generation.utils import (
|
|
13
|
+
GenerateNonBeamOutput,
|
|
14
|
+
GenerationConfig,
|
|
15
|
+
StoppingCriteriaList,
|
|
16
|
+
LogitsProcessorList,
|
|
17
|
+
)
|
|
11
18
|
|
|
12
19
|
try:
|
|
13
20
|
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
@@ -114,6 +121,7 @@ if patch_masking_utils:
|
|
|
114
121
|
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
|
|
115
122
|
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
|
|
116
123
|
_ = kwargs.pop("allow_is_causal_skip", None)
|
|
124
|
+
# PATCHED: this line called the patched version of sdpa_mask
|
|
117
125
|
mask = patched_sdpa_mask_recent_torch(
|
|
118
126
|
batch_size=batch_size,
|
|
119
127
|
cache_position=cache_position,
|
|
@@ -126,7 +134,7 @@ if patch_masking_utils:
|
|
|
126
134
|
**kwargs,
|
|
127
135
|
)
|
|
128
136
|
min_dtype = torch.finfo(dtype).min
|
|
129
|
-
#
|
|
137
|
+
# PATCHED: the following line
|
|
130
138
|
# we need 0s where the tokens should be taken into account,
|
|
131
139
|
# and -inf otherwise (mask is already of boolean type)
|
|
132
140
|
# mask =
|
|
@@ -158,6 +166,7 @@ if patch_masking_utils:
|
|
|
158
166
|
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
|
|
159
167
|
batch_arange = torch.arange(batch_size, device=cache_position.device)
|
|
160
168
|
head_arange = torch.arange(1, device=cache_position.device)
|
|
169
|
+
# PATCHED: this line calls the patched version of vmap_for_bhqkv
|
|
161
170
|
causal_mask = patched__vmap_for_bhqkv(mask_function)(
|
|
162
171
|
batch_arange, head_arange, cache_position, kv_arange
|
|
163
172
|
)
|
|
@@ -214,6 +223,7 @@ if patch_DynamicLayer:
|
|
|
214
223
|
self.dtype, self.device = key_states.dtype, key_states.device
|
|
215
224
|
new_shape = list(key_states.shape)
|
|
216
225
|
new_shape[-2] = 0
|
|
226
|
+
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
217
227
|
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
218
228
|
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
219
229
|
if patch_is_initialized:
|
|
@@ -248,6 +258,8 @@ def _patch_make_causal_mask(
|
|
|
248
258
|
diagonal = past_key_values_length - sliding_window - 1
|
|
249
259
|
|
|
250
260
|
context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
|
|
261
|
+
# PATCHED: removed if is_torchdynamo_compiling(): mask = mask.clone()
|
|
262
|
+
# and used masked_fill instead of masked_fill_
|
|
251
263
|
# In this case, the current implementation of torch fails (17/12/2024).
|
|
252
264
|
# Try model Phi-3.5-Mini-Instruct.
|
|
253
265
|
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
|
|
@@ -455,7 +467,16 @@ class patched_GenerationMixin:
|
|
|
455
467
|
_PATCHES_ = [
|
|
456
468
|
"_cache_dependant_input_preparation",
|
|
457
469
|
"_cache_dependant_input_preparation_exporting",
|
|
458
|
-
|
|
470
|
+
(
|
|
471
|
+
None
|
|
472
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
473
|
+
else "prepare_inputs_for_generation"
|
|
474
|
+
),
|
|
475
|
+
(
|
|
476
|
+
"_sample"
|
|
477
|
+
if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
|
|
478
|
+
else None
|
|
479
|
+
),
|
|
459
480
|
]
|
|
460
481
|
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
|
|
461
482
|
|
|
@@ -588,7 +609,7 @@ class patched_GenerationMixin:
|
|
|
588
609
|
model_inputs = {}
|
|
589
610
|
# - some models don't have `Cache` support
|
|
590
611
|
# (which implies they don't expect `cache_position` in `forward`)
|
|
591
|
-
if self
|
|
612
|
+
if getattr(self, "_supports_cache_class", False):
|
|
592
613
|
model_inputs["cache_position"] = cache_position
|
|
593
614
|
# - `cache_position` was not a mandatory input in
|
|
594
615
|
# `prepare_inputs_for_generation` for those models, and this
|
|
@@ -728,6 +749,192 @@ class patched_GenerationMixin:
|
|
|
728
749
|
model_inputs.pop("labels", None)
|
|
729
750
|
return model_inputs
|
|
730
751
|
|
|
752
|
+
def _sample(
|
|
753
|
+
self,
|
|
754
|
+
input_ids: torch.LongTensor,
|
|
755
|
+
logits_processor: "LogitsProcessorList", # noqa: F821
|
|
756
|
+
stopping_criteria: "StoppingCriteriaList", # noqa: F821
|
|
757
|
+
generation_config: "GenerationConfig", # noqa: F821
|
|
758
|
+
synced_gpus: bool = False,
|
|
759
|
+
streamer: Optional["BaseStreamer"] = None, # noqa: F821
|
|
760
|
+
**model_kwargs,
|
|
761
|
+
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
|
|
762
|
+
"""
|
|
763
|
+
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
|
|
764
|
+
"""
|
|
765
|
+
# init values
|
|
766
|
+
pad_token_id = generation_config._pad_token_tensor
|
|
767
|
+
output_attentions = generation_config.output_attentions
|
|
768
|
+
output_hidden_states = generation_config.output_hidden_states
|
|
769
|
+
output_scores = generation_config.output_scores
|
|
770
|
+
output_logits = generation_config.output_logits
|
|
771
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
772
|
+
has_eos_stopping_criteria = any(
|
|
773
|
+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
|
|
774
|
+
)
|
|
775
|
+
do_sample = generation_config.do_sample
|
|
776
|
+
|
|
777
|
+
# init attention / hidden states / scores tuples
|
|
778
|
+
scores = () if (return_dict_in_generate and output_scores) else None
|
|
779
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
780
|
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
781
|
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
782
|
+
decoder_hidden_states = (
|
|
783
|
+
() if (return_dict_in_generate and output_hidden_states) else None
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
787
|
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
788
|
+
encoder_attentions = (
|
|
789
|
+
model_kwargs["encoder_outputs"].get("attentions")
|
|
790
|
+
if output_attentions
|
|
791
|
+
else None
|
|
792
|
+
)
|
|
793
|
+
encoder_hidden_states = (
|
|
794
|
+
model_kwargs["encoder_outputs"].get("hidden_states")
|
|
795
|
+
if output_hidden_states
|
|
796
|
+
else None
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# keep track of which sequences are already finished
|
|
800
|
+
batch_size, cur_len = input_ids.shape[:2]
|
|
801
|
+
this_peer_finished = False
|
|
802
|
+
unfinished_sequences = torch.ones(
|
|
803
|
+
batch_size, dtype=torch.long, device=input_ids.device
|
|
804
|
+
)
|
|
805
|
+
model_kwargs = self._get_initial_cache_position(
|
|
806
|
+
cur_len, input_ids.device, model_kwargs
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
model_forward = self.__call__
|
|
810
|
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
|
811
|
+
if compile_forward:
|
|
812
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
813
|
+
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
|
814
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
815
|
+
# only raise warning if the user passed an explicit compile-config
|
|
816
|
+
if (
|
|
817
|
+
generation_config.compile_config is not None
|
|
818
|
+
and generation_config.compile_config.fullgraph
|
|
819
|
+
):
|
|
820
|
+
generation_config.compile_config.fullgraph = False
|
|
821
|
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
|
822
|
+
|
|
823
|
+
if generation_config.prefill_chunk_size is not None:
|
|
824
|
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
|
825
|
+
is_prefill = False
|
|
826
|
+
else:
|
|
827
|
+
is_prefill = True
|
|
828
|
+
|
|
829
|
+
while self._has_unfinished_sequences(
|
|
830
|
+
this_peer_finished, synced_gpus, device=input_ids.device
|
|
831
|
+
):
|
|
832
|
+
# prepare model inputs
|
|
833
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
834
|
+
|
|
835
|
+
if is_prefill:
|
|
836
|
+
outputs = self(**model_inputs, return_dict=True)
|
|
837
|
+
is_prefill = False
|
|
838
|
+
else:
|
|
839
|
+
outputs = model_forward(**model_inputs, return_dict=True)
|
|
840
|
+
|
|
841
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
842
|
+
outputs,
|
|
843
|
+
model_kwargs,
|
|
844
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
845
|
+
)
|
|
846
|
+
if synced_gpus and this_peer_finished:
|
|
847
|
+
continue
|
|
848
|
+
|
|
849
|
+
next_token_logits = outputs.logits[:, -1, :].to(
|
|
850
|
+
copy=True, dtype=torch.float32, device=input_ids.device
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
# pre-process distribution
|
|
854
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
855
|
+
|
|
856
|
+
# Store scores, attentions and hidden_states when required
|
|
857
|
+
if return_dict_in_generate:
|
|
858
|
+
if output_scores:
|
|
859
|
+
scores += (next_token_scores,)
|
|
860
|
+
if output_logits:
|
|
861
|
+
raw_logits += (next_token_logits,)
|
|
862
|
+
if output_attentions:
|
|
863
|
+
decoder_attentions += (
|
|
864
|
+
(outputs.decoder_attentions,)
|
|
865
|
+
if self.config.is_encoder_decoder
|
|
866
|
+
else (outputs.attentions,)
|
|
867
|
+
)
|
|
868
|
+
if self.config.is_encoder_decoder:
|
|
869
|
+
cross_attentions += (outputs.cross_attentions,)
|
|
870
|
+
|
|
871
|
+
if output_hidden_states:
|
|
872
|
+
decoder_hidden_states += (
|
|
873
|
+
(outputs.decoder_hidden_states,)
|
|
874
|
+
if self.config.is_encoder_decoder
|
|
875
|
+
else (outputs.hidden_states,)
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
# token selection
|
|
879
|
+
if do_sample:
|
|
880
|
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
|
881
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
882
|
+
else:
|
|
883
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
884
|
+
|
|
885
|
+
# finished sentences should have their next token be a padding token
|
|
886
|
+
if has_eos_stopping_criteria:
|
|
887
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
888
|
+
1 - unfinished_sequences
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
# update generated ids, model inputs, and length for next step
|
|
892
|
+
# PATCHED: the two following lines, next_tokens can 2D already for this model
|
|
893
|
+
next_tokens_2d = (
|
|
894
|
+
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
|
|
895
|
+
)
|
|
896
|
+
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
|
|
897
|
+
if streamer is not None:
|
|
898
|
+
streamer.put(next_tokens.cpu())
|
|
899
|
+
|
|
900
|
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
901
|
+
this_peer_finished = unfinished_sequences.max() == 0
|
|
902
|
+
cur_len += 1
|
|
903
|
+
|
|
904
|
+
# This is needed to properly delete outputs.logits which may be very large
|
|
905
|
+
# for first iteration
|
|
906
|
+
# Otherwise a reference to outputs is kept which keeps
|
|
907
|
+
# the logits alive in the next iteration
|
|
908
|
+
del outputs
|
|
909
|
+
|
|
910
|
+
if streamer is not None:
|
|
911
|
+
streamer.end()
|
|
912
|
+
|
|
913
|
+
if return_dict_in_generate:
|
|
914
|
+
if self.config.is_encoder_decoder:
|
|
915
|
+
return transformers.generation.utils.GenerateEncoderDecoderOutput(
|
|
916
|
+
sequences=input_ids,
|
|
917
|
+
scores=scores,
|
|
918
|
+
logits=raw_logits,
|
|
919
|
+
encoder_attentions=encoder_attentions,
|
|
920
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
921
|
+
decoder_attentions=decoder_attentions,
|
|
922
|
+
cross_attentions=cross_attentions,
|
|
923
|
+
decoder_hidden_states=decoder_hidden_states,
|
|
924
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
925
|
+
)
|
|
926
|
+
else:
|
|
927
|
+
return transformers.generation.utils.GenerateDecoderOnlyOutput(
|
|
928
|
+
sequences=input_ids,
|
|
929
|
+
scores=scores,
|
|
930
|
+
logits=raw_logits,
|
|
931
|
+
attentions=decoder_attentions,
|
|
932
|
+
hidden_states=decoder_hidden_states,
|
|
933
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
934
|
+
)
|
|
935
|
+
else:
|
|
936
|
+
return input_ids
|
|
937
|
+
|
|
731
938
|
|
|
732
939
|
def patched__compute_dynamic_ntk_parameters(
|
|
733
940
|
config: Optional[transformers.PretrainedConfig] = None,
|
|
@@ -791,6 +998,7 @@ def patched__compute_dynamic_ntk_parameters(
|
|
|
791
998
|
if seq_len is None:
|
|
792
999
|
seq_len = max_position_embeddings
|
|
793
1000
|
else:
|
|
1001
|
+
# PATCHED: remove the line using max
|
|
794
1002
|
torch._check(isinstance(seq_len, torch.Tensor))
|
|
795
1003
|
seq_len = torch.maximum(
|
|
796
1004
|
seq_len,
|
|
@@ -896,6 +1104,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
896
1104
|
)
|
|
897
1105
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
898
1106
|
|
|
1107
|
+
# PATCHED: uses torch.cond instead of a test
|
|
899
1108
|
cond = (seq_len > original_max_position_embeddings).item()
|
|
900
1109
|
inv_freq = torch.cond(
|
|
901
1110
|
cond,
|
|
@@ -967,6 +1176,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
967
1176
|
|
|
968
1177
|
original_inv_freq = self.original_inv_freq.to(device)
|
|
969
1178
|
cond = (seq_len >= self.original_max_seq_len).item()
|
|
1179
|
+
# PATCHED: uses torch.cond instead of a test
|
|
970
1180
|
inv_freq = torch.cond(
|
|
971
1181
|
cond,
|
|
972
1182
|
(lambda x, y: x.clone()),
|
|
@@ -1002,6 +1212,7 @@ def common_eager_attention_forward(
|
|
|
1002
1212
|
|
|
1003
1213
|
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
1004
1214
|
if attention_mask is not None:
|
|
1215
|
+
# PATCHED
|
|
1005
1216
|
# The two following lines were added.
|
|
1006
1217
|
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1007
1218
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
@@ -1074,6 +1285,7 @@ def patched_modeling_marian_eager_attention_forward(
|
|
|
1074
1285
|
class common_RotaryEmbedding(torch.nn.Module):
|
|
1075
1286
|
# This may cause some issues.
|
|
1076
1287
|
# @torch.no_grad()
|
|
1288
|
+
# PATCHED: the decorator
|
|
1077
1289
|
@patched_dynamic_rope_update
|
|
1078
1290
|
def forward(self, x, position_ids):
|
|
1079
1291
|
inv_freq_expanded = (
|
|
@@ -1629,3 +1841,56 @@ if patch_qwen3:
|
|
|
1629
1841
|
batch_size, sequence_length, hidden_dim
|
|
1630
1842
|
)
|
|
1631
1843
|
return final_hidden_states, router_logits
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
try:
|
|
1847
|
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3Model # noqa: F401
|
|
1848
|
+
|
|
1849
|
+
patch_gemma3 = True
|
|
1850
|
+
except ImportError:
|
|
1851
|
+
patch_gemma3 = False
|
|
1852
|
+
|
|
1853
|
+
|
|
1854
|
+
if patch_gemma3:
|
|
1855
|
+
|
|
1856
|
+
class patched_Gemma3Model(torch.nn.Module):
|
|
1857
|
+
_PATCHES_ = ["get_placeholder_mask"]
|
|
1858
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
|
|
1859
|
+
_PATCHED_PR_ = "https://github.com/huggingface/transformers/pull/41319"
|
|
1860
|
+
|
|
1861
|
+
def get_placeholder_mask(
|
|
1862
|
+
self,
|
|
1863
|
+
input_ids: torch.LongTensor,
|
|
1864
|
+
inputs_embeds: torch.FloatTensor,
|
|
1865
|
+
image_features: torch.FloatTensor,
|
|
1866
|
+
):
|
|
1867
|
+
if input_ids is None:
|
|
1868
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
1869
|
+
torch.tensor(
|
|
1870
|
+
self.config.image_token_id,
|
|
1871
|
+
dtype=torch.long,
|
|
1872
|
+
device=inputs_embeds.device,
|
|
1873
|
+
)
|
|
1874
|
+
)
|
|
1875
|
+
special_image_mask = special_image_mask.all(-1)
|
|
1876
|
+
else:
|
|
1877
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
1878
|
+
|
|
1879
|
+
n_image_tokens = special_image_mask.sum()
|
|
1880
|
+
special_image_mask = (
|
|
1881
|
+
special_image_mask.unsqueeze(-1)
|
|
1882
|
+
.expand_as(inputs_embeds)
|
|
1883
|
+
.to(inputs_embeds.device)
|
|
1884
|
+
)
|
|
1885
|
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
1886
|
+
# PATCHED: torch._check
|
|
1887
|
+
# if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
1888
|
+
# raise ValueError( ... )
|
|
1889
|
+
torch._check(
|
|
1890
|
+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
|
1891
|
+
lambda: (
|
|
1892
|
+
f"Image features and image tokens do not match: tokens: "
|
|
1893
|
+
f"{n_image_tokens}, features {n_image_features}"
|
|
1894
|
+
),
|
|
1895
|
+
)
|
|
1896
|
+
return special_image_mask
|
|
@@ -4829,3 +4829,39 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
|
|
|
4829
4829
|
"vocab_size": 32064,
|
|
4830
4830
|
}
|
|
4831
4831
|
)
|
|
4832
|
+
|
|
4833
|
+
|
|
4834
|
+
def _ccached_google_gemma_3_4b_it_like():
|
|
4835
|
+
"google/gemma-3-4b-it"
|
|
4836
|
+
return transformers.Gemma3Config(
|
|
4837
|
+
**{
|
|
4838
|
+
"architectures": ["Gemma3ForConditionalGeneration"],
|
|
4839
|
+
"boi_token_index": 255999,
|
|
4840
|
+
"eoi_token_index": 256000,
|
|
4841
|
+
"eos_token_id": [1, 106],
|
|
4842
|
+
"image_token_index": 262144,
|
|
4843
|
+
"initializer_range": 0.02,
|
|
4844
|
+
"mm_tokens_per_image": 256,
|
|
4845
|
+
"model_type": "gemma3",
|
|
4846
|
+
"text_config": {
|
|
4847
|
+
"hidden_size": 2560,
|
|
4848
|
+
"intermediate_size": 10240,
|
|
4849
|
+
"model_type": "gemma3_text",
|
|
4850
|
+
"num_hidden_layers": 34,
|
|
4851
|
+
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
|
|
4852
|
+
"sliding_window": 1024,
|
|
4853
|
+
},
|
|
4854
|
+
"torch_dtype": "bfloat16",
|
|
4855
|
+
"transformers_version": "4.50.0.dev0",
|
|
4856
|
+
"vision_config": {
|
|
4857
|
+
"hidden_size": 1152,
|
|
4858
|
+
"image_size": 896,
|
|
4859
|
+
"intermediate_size": 4304,
|
|
4860
|
+
"model_type": "siglip_vision_model",
|
|
4861
|
+
"num_attention_heads": 16,
|
|
4862
|
+
"num_hidden_layers": 27,
|
|
4863
|
+
"patch_size": 14,
|
|
4864
|
+
"vision_use_head": false,
|
|
4865
|
+
},
|
|
4866
|
+
}
|
|
4867
|
+
)
|
|
@@ -57,7 +57,7 @@ def get_untrained_model_with_inputs(
|
|
|
57
57
|
to get a smaller model
|
|
58
58
|
:param use_pretrained: download the pretrained weights as well
|
|
59
59
|
:param use_preinstalled: use preinstalled configurations
|
|
60
|
-
:param add_second_input: provides
|
|
60
|
+
:param add_second_input: provides others inputs to check a model
|
|
61
61
|
supports different shapes
|
|
62
62
|
:param subfolder: subfolder to use for this model id
|
|
63
63
|
:param use_only_preinstalled: use only preinstalled version
|
|
@@ -193,7 +193,7 @@ def get_untrained_model_with_inputs(
|
|
|
193
193
|
)
|
|
194
194
|
if verbose:
|
|
195
195
|
print(
|
|
196
|
-
f"[get_untrained_model_with_inputs] -- done in "
|
|
196
|
+
f"[get_untrained_model_with_inputs] -- done(1) in "
|
|
197
197
|
f"{time.perf_counter() - begin}s"
|
|
198
198
|
)
|
|
199
199
|
else:
|
|
@@ -250,14 +250,36 @@ def get_untrained_model_with_inputs(
|
|
|
250
250
|
)
|
|
251
251
|
if verbose:
|
|
252
252
|
print(
|
|
253
|
-
f"[get_untrained_model_with_inputs] -- done in "
|
|
253
|
+
f"[get_untrained_model_with_inputs] -- done(2) in "
|
|
254
254
|
f"{time.perf_counter() - begin}s"
|
|
255
255
|
)
|
|
256
256
|
|
|
257
257
|
seed = int(os.environ.get("SEED", "17"))
|
|
258
258
|
torch.manual_seed(seed)
|
|
259
|
+
|
|
260
|
+
if verbose:
|
|
261
|
+
begin = time.perf_counter()
|
|
262
|
+
print(
|
|
263
|
+
f"[get_untrained_model_with_inputs] "
|
|
264
|
+
f"instantiate_specific_model {cls_model}"
|
|
265
|
+
)
|
|
266
|
+
|
|
259
267
|
model = instantiate_specific_model(cls_model, config)
|
|
268
|
+
|
|
269
|
+
if verbose:
|
|
270
|
+
print(
|
|
271
|
+
f"[get_untrained_model_with_inputs] -- done(3) in "
|
|
272
|
+
f"{time.perf_counter() - begin}s (model is {type(model)})"
|
|
273
|
+
)
|
|
274
|
+
|
|
260
275
|
if model is None:
|
|
276
|
+
|
|
277
|
+
if verbose:
|
|
278
|
+
print(
|
|
279
|
+
f"[get_untrained_model_with_inputs] "
|
|
280
|
+
f"instantiate_specific_model(2) {cls_model}"
|
|
281
|
+
)
|
|
282
|
+
|
|
261
283
|
try:
|
|
262
284
|
if type(config) is dict:
|
|
263
285
|
model = cls_model(**config)
|
|
@@ -268,6 +290,12 @@ def get_untrained_model_with_inputs(
|
|
|
268
290
|
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
|
|
269
291
|
) from e
|
|
270
292
|
|
|
293
|
+
if verbose:
|
|
294
|
+
print(
|
|
295
|
+
f"[get_untrained_model_with_inputs] -- done(4) in "
|
|
296
|
+
f"{time.perf_counter() - begin}s (model is {type(model)})"
|
|
297
|
+
)
|
|
298
|
+
|
|
271
299
|
# input kwargs
|
|
272
300
|
seed = int(os.environ.get("SEED", "17")) + 1
|
|
273
301
|
torch.manual_seed(seed)
|