onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional, Tuple, Union
|
|
4
|
+
import packaging.version as pv
|
|
5
|
+
import torch
|
|
6
|
+
import transformers
|
|
7
|
+
from transformers.cache_utils import StaticCache, Cache
|
|
8
|
+
from .patch_helper import _is_torchdynamo_exporting
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class patched_GenerationMixin:
|
|
12
|
+
"""
|
|
13
|
+
Applies modifications implemented in PR
|
|
14
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
_PATCHES_ = [
|
|
18
|
+
"_cache_dependant_input_preparation",
|
|
19
|
+
"_cache_dependant_input_preparation_exporting",
|
|
20
|
+
(
|
|
21
|
+
None
|
|
22
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
23
|
+
else "prepare_inputs_for_generation"
|
|
24
|
+
),
|
|
25
|
+
(
|
|
26
|
+
"_sample"
|
|
27
|
+
if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
|
|
28
|
+
else None
|
|
29
|
+
),
|
|
30
|
+
]
|
|
31
|
+
_PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
|
|
32
|
+
|
|
33
|
+
def _cache_dependant_input_preparation(
|
|
34
|
+
self,
|
|
35
|
+
input_ids: torch.LongTensor,
|
|
36
|
+
inputs_embeds: Optional[torch.FloatTensor],
|
|
37
|
+
cache_position: Optional[torch.LongTensor],
|
|
38
|
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
|
39
|
+
"""
|
|
40
|
+
Generic cache-dependent input preparation
|
|
41
|
+
The code is put in a separate function to allow granular unit testing
|
|
42
|
+
as it needs a different implementation to be exportable.
|
|
43
|
+
|
|
44
|
+
If we have cache: let's slice `input_ids` through `cache_position`,
|
|
45
|
+
to keep only the unprocessed tokens
|
|
46
|
+
- Exception 1: when passing input_embeds,
|
|
47
|
+
input_ids may be missing entries
|
|
48
|
+
- Exception 2: some generation methods do special slicing of input_ids,
|
|
49
|
+
so we don't need to do it here
|
|
50
|
+
- Exception 3: with synced GPUs cache_position may go out of bounds,
|
|
51
|
+
but we only want dummy token in that case.
|
|
52
|
+
- Exception 4: If input_embeds are passed then slice it through
|
|
53
|
+
`cache_position`, to keep only the unprocessed tokens and
|
|
54
|
+
generate the first token for each sequence.
|
|
55
|
+
Later use the generated Input ids for continuation.
|
|
56
|
+
|
|
57
|
+
The current implementation does not rely on ``self`` and could be
|
|
58
|
+
a class method. It is left as a standard method to be easily rewritten.
|
|
59
|
+
"""
|
|
60
|
+
if _is_torchdynamo_exporting():
|
|
61
|
+
return self._cache_dependant_input_preparation_exporting(
|
|
62
|
+
input_ids, inputs_embeds, cache_position
|
|
63
|
+
)
|
|
64
|
+
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
|
65
|
+
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
|
66
|
+
elif inputs_embeds is not None or ( # Exception 1
|
|
67
|
+
cache_position[-1] >= input_ids.shape[1]
|
|
68
|
+
): # Exception 3
|
|
69
|
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
70
|
+
elif (
|
|
71
|
+
input_ids.shape[1] != cache_position.shape[0]
|
|
72
|
+
): # Default case (the "else", a no op, is Exception 2)
|
|
73
|
+
input_ids = input_ids[:, cache_position]
|
|
74
|
+
return inputs_embeds, input_ids
|
|
75
|
+
|
|
76
|
+
def _cache_dependant_input_preparation_exporting(
|
|
77
|
+
self,
|
|
78
|
+
input_ids: torch.LongTensor,
|
|
79
|
+
inputs_embeds: Optional[torch.FloatTensor],
|
|
80
|
+
cache_position: Optional[torch.LongTensor],
|
|
81
|
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
|
82
|
+
"""
|
|
83
|
+
This method implements method ``_cache_dependant_input_preparation``
|
|
84
|
+
with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
|
|
85
|
+
The code is put in a separate function to allow granular unit testing.
|
|
86
|
+
"""
|
|
87
|
+
if inputs_embeds is None:
|
|
88
|
+
input_ids = input_ids[:, cache_position]
|
|
89
|
+
else:
|
|
90
|
+
# This is the code we need to implemented with torch.cond.
|
|
91
|
+
# if input_ids.shape[1] == 0:
|
|
92
|
+
# inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
|
93
|
+
# else:
|
|
94
|
+
# if cache_position[-1] >= input_ids.shape[1]:
|
|
95
|
+
# input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
96
|
+
# else:
|
|
97
|
+
# if input_ids.shape[1] != cache_position.shape[0]:
|
|
98
|
+
# input_ids = input_ids[:, cache_position]
|
|
99
|
+
def branch_1(inputs_embeds, cache_position):
|
|
100
|
+
return inputs_embeds[:, -cache_position.shape[0] :].clone()
|
|
101
|
+
|
|
102
|
+
def branch_2(input_ids, cache_position):
|
|
103
|
+
return input_ids[:, -cache_position.shape[0] :].clone()
|
|
104
|
+
|
|
105
|
+
def branch_3(input_ids, cache_position):
|
|
106
|
+
return input_ids[:, cache_position].clone()
|
|
107
|
+
|
|
108
|
+
inputs_embeds, input_ids = torch.cond(
|
|
109
|
+
input_ids.shape[1] == 0,
|
|
110
|
+
(
|
|
111
|
+
lambda input_ids, inputs_embeds, cache_position: (
|
|
112
|
+
branch_1(inputs_embeds, cache_position),
|
|
113
|
+
input_ids.clone(),
|
|
114
|
+
)
|
|
115
|
+
),
|
|
116
|
+
(
|
|
117
|
+
lambda input_ids, inputs_embeds, cache_position: (
|
|
118
|
+
inputs_embeds,
|
|
119
|
+
torch.cond(
|
|
120
|
+
cache_position[-1] >= input_ids.shape[1],
|
|
121
|
+
branch_2,
|
|
122
|
+
lambda input_ids, cache_position: (
|
|
123
|
+
torch.cond(
|
|
124
|
+
input_ids.shape[1] != cache_position.shape[0],
|
|
125
|
+
branch_3,
|
|
126
|
+
(lambda input_ids, cache_position: input_ids),
|
|
127
|
+
[input_ids, cache_position],
|
|
128
|
+
)
|
|
129
|
+
),
|
|
130
|
+
[input_ids, cache_position],
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
),
|
|
134
|
+
[input_ids, inputs_embeds, cache_position],
|
|
135
|
+
)
|
|
136
|
+
return inputs_embeds, input_ids
|
|
137
|
+
|
|
138
|
+
def prepare_inputs_for_generation(
|
|
139
|
+
self,
|
|
140
|
+
input_ids: torch.LongTensor,
|
|
141
|
+
past_key_values: Optional[Cache] = None,
|
|
142
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
143
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
144
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
145
|
+
**kwargs,
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
Prepare the model inputs for generation.
|
|
149
|
+
In includes operations like computing the 4D attention mask or
|
|
150
|
+
slicing inputs given the existing cache.
|
|
151
|
+
|
|
152
|
+
See the forward pass in the model documentation
|
|
153
|
+
for expected arguments (different models might have different
|
|
154
|
+
requirements for e.g. `past_key_values`).
|
|
155
|
+
This function should work as is for most LLMs.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
# 1. Handle BC:
|
|
159
|
+
model_inputs = {}
|
|
160
|
+
# - some models don't have `Cache` support
|
|
161
|
+
# (which implies they don't expect `cache_position` in `forward`)
|
|
162
|
+
if getattr(self, "_supports_cache_class", False):
|
|
163
|
+
model_inputs["cache_position"] = cache_position
|
|
164
|
+
# - `cache_position` was not a mandatory input in
|
|
165
|
+
# `prepare_inputs_for_generation` for those models, and this
|
|
166
|
+
# function may be called outside of `generate`.
|
|
167
|
+
# Handle most use cases by creating `cache_position` on the fly
|
|
168
|
+
# (this alternative is not as robust as calling
|
|
169
|
+
# `generate` and letting it create `cache_position`)
|
|
170
|
+
elif cache_position is None:
|
|
171
|
+
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
172
|
+
cache_position = torch.arange(
|
|
173
|
+
past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# 2. Generic cache-dependent input preparation
|
|
177
|
+
if past_key_values is not None:
|
|
178
|
+
model_inputs["past_key_values"] = past_key_values
|
|
179
|
+
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
|
|
180
|
+
input_ids, inputs_embeds, cache_position
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# 3. Prepare base model inputs
|
|
184
|
+
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
|
185
|
+
# if `inputs_embeds` are passed, we only want
|
|
186
|
+
# to use them in the 1st generation step for every prompt.
|
|
187
|
+
if not self.config.is_encoder_decoder:
|
|
188
|
+
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
|
189
|
+
model_inputs[input_ids_key] = None
|
|
190
|
+
model_inputs["inputs_embeds"] = inputs_embeds
|
|
191
|
+
else:
|
|
192
|
+
# `clone` calls in this function ensure a consistent stride. See #32227
|
|
193
|
+
model_inputs[input_ids_key] = input_ids.clone(
|
|
194
|
+
memory_format=torch.contiguous_format
|
|
195
|
+
)
|
|
196
|
+
model_inputs["inputs_embeds"] = None
|
|
197
|
+
else:
|
|
198
|
+
model_inputs[input_ids_key] = input_ids.clone(
|
|
199
|
+
memory_format=torch.contiguous_format
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# 4. Create missing `position_ids` on the fly
|
|
203
|
+
encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
|
|
204
|
+
attention_mask = (
|
|
205
|
+
kwargs.pop("decoder_attention_mask", None)
|
|
206
|
+
if self.config.is_encoder_decoder
|
|
207
|
+
else attention_mask
|
|
208
|
+
)
|
|
209
|
+
attention_mask_key = (
|
|
210
|
+
"decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
|
|
211
|
+
)
|
|
212
|
+
position_ids_key = (
|
|
213
|
+
"decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
|
|
214
|
+
)
|
|
215
|
+
if (
|
|
216
|
+
attention_mask is not None
|
|
217
|
+
and kwargs.get(position_ids_key) is None
|
|
218
|
+
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
|
|
219
|
+
):
|
|
220
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
221
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
222
|
+
kwargs[position_ids_key] = (
|
|
223
|
+
position_ids # placed in kwargs for further processing (see below)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# 5. Slice model inputs if it's an input
|
|
227
|
+
# that should have the same length as `input_ids`
|
|
228
|
+
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
|
|
229
|
+
model_input = kwargs.get(model_input_name)
|
|
230
|
+
if model_input is not None:
|
|
231
|
+
if past_key_values is not None:
|
|
232
|
+
current_input_length = (
|
|
233
|
+
model_inputs["inputs_embeds"].shape[1]
|
|
234
|
+
if model_inputs.get("inputs_embeds") is not None
|
|
235
|
+
else model_inputs[input_ids_key].shape[1]
|
|
236
|
+
)
|
|
237
|
+
model_input = model_input[:, -current_input_length:]
|
|
238
|
+
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
|
239
|
+
model_inputs[model_input_name] = model_input
|
|
240
|
+
|
|
241
|
+
# 6. Create 4D attention mask is we are using a
|
|
242
|
+
# `StaticCache` (important for performant compiled forward pass)
|
|
243
|
+
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
|
244
|
+
if model_inputs["inputs_embeds"] is not None:
|
|
245
|
+
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
|
246
|
+
device = model_inputs["inputs_embeds"].device
|
|
247
|
+
else:
|
|
248
|
+
batch_size, sequence_length = model_inputs[input_ids_key].shape
|
|
249
|
+
device = model_inputs[input_ids_key].device
|
|
250
|
+
|
|
251
|
+
# Create the causal mask with fixed shape in advance,
|
|
252
|
+
# to reduce recompilations. If the function to create
|
|
253
|
+
# the 4D causal mask exists,
|
|
254
|
+
# it should be present in the base model (XXXModel class).
|
|
255
|
+
base_model = getattr(self, self.base_model_prefix, None)
|
|
256
|
+
if base_model is None:
|
|
257
|
+
causal_mask_creation_function = getattr(
|
|
258
|
+
self, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
causal_mask_creation_function = getattr(
|
|
262
|
+
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
|
263
|
+
)
|
|
264
|
+
if causal_mask_creation_function is None:
|
|
265
|
+
pass
|
|
266
|
+
# logger.warning_once(
|
|
267
|
+
# f"{self.__class__.__name__} has no "
|
|
268
|
+
# "`_prepare_4d_causal_attention_mask_with_cache_position` method "
|
|
269
|
+
# "defined in its base modeling class. "
|
|
270
|
+
# "Compiled forward passes will be sub-optimal. If you're "
|
|
271
|
+
# "writing code, see Llama for an example implementation. "
|
|
272
|
+
# "If you're a user, please report this "
|
|
273
|
+
# "issue on GitHub."
|
|
274
|
+
# )
|
|
275
|
+
else:
|
|
276
|
+
attention_mask = causal_mask_creation_function(
|
|
277
|
+
attention_mask,
|
|
278
|
+
sequence_length=sequence_length,
|
|
279
|
+
target_length=past_key_values.get_max_cache_shape(),
|
|
280
|
+
dtype=self.dtype,
|
|
281
|
+
device=device,
|
|
282
|
+
cache_position=cache_position,
|
|
283
|
+
batch_size=batch_size,
|
|
284
|
+
config=self.config,
|
|
285
|
+
past_key_values=past_key_values,
|
|
286
|
+
)
|
|
287
|
+
if attention_mask is not None:
|
|
288
|
+
model_inputs[attention_mask_key] = attention_mask
|
|
289
|
+
|
|
290
|
+
if encoder_attention_mask is not None:
|
|
291
|
+
model_inputs["attention_mask"] = encoder_attention_mask
|
|
292
|
+
|
|
293
|
+
# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
|
|
294
|
+
for key, value in kwargs.items():
|
|
295
|
+
if key not in model_inputs:
|
|
296
|
+
model_inputs[key] = value
|
|
297
|
+
|
|
298
|
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
|
299
|
+
model_inputs.pop("labels", None)
|
|
300
|
+
return model_inputs
|
|
301
|
+
|
|
302
|
+
def _sample(
|
|
303
|
+
self,
|
|
304
|
+
input_ids: torch.LongTensor,
|
|
305
|
+
logits_processor: "LogitsProcessorList", # noqa: F821
|
|
306
|
+
stopping_criteria: "StoppingCriteriaList", # noqa: F821
|
|
307
|
+
generation_config: "GenerationConfig", # noqa: F821
|
|
308
|
+
synced_gpus: bool = False,
|
|
309
|
+
streamer: Optional["BaseStreamer"] = None, # noqa: F821
|
|
310
|
+
**model_kwargs,
|
|
311
|
+
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
|
|
312
|
+
"""
|
|
313
|
+
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
|
|
314
|
+
"""
|
|
315
|
+
# init values
|
|
316
|
+
pad_token_id = generation_config._pad_token_tensor
|
|
317
|
+
output_attentions = generation_config.output_attentions
|
|
318
|
+
output_hidden_states = generation_config.output_hidden_states
|
|
319
|
+
output_scores = generation_config.output_scores
|
|
320
|
+
output_logits = generation_config.output_logits
|
|
321
|
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
|
322
|
+
has_eos_stopping_criteria = any(
|
|
323
|
+
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
|
|
324
|
+
)
|
|
325
|
+
do_sample = generation_config.do_sample
|
|
326
|
+
|
|
327
|
+
# init attention / hidden states / scores tuples
|
|
328
|
+
scores = () if (return_dict_in_generate and output_scores) else None
|
|
329
|
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
|
330
|
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
331
|
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
|
332
|
+
decoder_hidden_states = (
|
|
333
|
+
() if (return_dict_in_generate and output_hidden_states) else None
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
337
|
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
338
|
+
encoder_attentions = (
|
|
339
|
+
model_kwargs["encoder_outputs"].get("attentions")
|
|
340
|
+
if output_attentions
|
|
341
|
+
else None
|
|
342
|
+
)
|
|
343
|
+
encoder_hidden_states = (
|
|
344
|
+
model_kwargs["encoder_outputs"].get("hidden_states")
|
|
345
|
+
if output_hidden_states
|
|
346
|
+
else None
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# keep track of which sequences are already finished
|
|
350
|
+
batch_size, cur_len = input_ids.shape[:2]
|
|
351
|
+
this_peer_finished = False
|
|
352
|
+
unfinished_sequences = torch.ones(
|
|
353
|
+
batch_size, dtype=torch.long, device=input_ids.device
|
|
354
|
+
)
|
|
355
|
+
model_kwargs = self._get_initial_cache_position(
|
|
356
|
+
cur_len, input_ids.device, model_kwargs
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
model_forward = self.__call__
|
|
360
|
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
|
361
|
+
if compile_forward:
|
|
362
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
|
363
|
+
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
|
364
|
+
if self.config._attn_implementation == "flash_attention_2":
|
|
365
|
+
# only raise warning if the user passed an explicit compile-config
|
|
366
|
+
if (
|
|
367
|
+
generation_config.compile_config is not None
|
|
368
|
+
and generation_config.compile_config.fullgraph
|
|
369
|
+
):
|
|
370
|
+
generation_config.compile_config.fullgraph = False
|
|
371
|
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
|
372
|
+
|
|
373
|
+
if generation_config.prefill_chunk_size is not None:
|
|
374
|
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
|
375
|
+
is_prefill = False
|
|
376
|
+
else:
|
|
377
|
+
is_prefill = True
|
|
378
|
+
|
|
379
|
+
while self._has_unfinished_sequences(
|
|
380
|
+
this_peer_finished, synced_gpus, device=input_ids.device
|
|
381
|
+
):
|
|
382
|
+
# prepare model inputs
|
|
383
|
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
384
|
+
|
|
385
|
+
if is_prefill:
|
|
386
|
+
outputs = self(**model_inputs, return_dict=True)
|
|
387
|
+
is_prefill = False
|
|
388
|
+
else:
|
|
389
|
+
outputs = model_forward(**model_inputs, return_dict=True)
|
|
390
|
+
|
|
391
|
+
model_kwargs = self._update_model_kwargs_for_generation(
|
|
392
|
+
outputs,
|
|
393
|
+
model_kwargs,
|
|
394
|
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
395
|
+
)
|
|
396
|
+
if synced_gpus and this_peer_finished:
|
|
397
|
+
continue
|
|
398
|
+
|
|
399
|
+
next_token_logits = outputs.logits[:, -1, :].to(
|
|
400
|
+
copy=True, dtype=torch.float32, device=input_ids.device
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# pre-process distribution
|
|
404
|
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
|
405
|
+
|
|
406
|
+
# Store scores, attentions and hidden_states when required
|
|
407
|
+
if return_dict_in_generate:
|
|
408
|
+
if output_scores:
|
|
409
|
+
scores += (next_token_scores,)
|
|
410
|
+
if output_logits:
|
|
411
|
+
raw_logits += (next_token_logits,)
|
|
412
|
+
if output_attentions:
|
|
413
|
+
decoder_attentions += (
|
|
414
|
+
(outputs.decoder_attentions,)
|
|
415
|
+
if self.config.is_encoder_decoder
|
|
416
|
+
else (outputs.attentions,)
|
|
417
|
+
)
|
|
418
|
+
if self.config.is_encoder_decoder:
|
|
419
|
+
cross_attentions += (outputs.cross_attentions,)
|
|
420
|
+
|
|
421
|
+
if output_hidden_states:
|
|
422
|
+
decoder_hidden_states += (
|
|
423
|
+
(outputs.decoder_hidden_states,)
|
|
424
|
+
if self.config.is_encoder_decoder
|
|
425
|
+
else (outputs.hidden_states,)
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# token selection
|
|
429
|
+
if do_sample:
|
|
430
|
+
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
|
431
|
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
432
|
+
else:
|
|
433
|
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
|
434
|
+
|
|
435
|
+
# finished sentences should have their next token be a padding token
|
|
436
|
+
if has_eos_stopping_criteria:
|
|
437
|
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
438
|
+
1 - unfinished_sequences
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# update generated ids, model inputs, and length for next step
|
|
442
|
+
# PATCHED: the two following lines, next_tokens can 2D already for this model
|
|
443
|
+
next_tokens_2d = (
|
|
444
|
+
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
|
|
445
|
+
)
|
|
446
|
+
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
|
|
447
|
+
if streamer is not None:
|
|
448
|
+
streamer.put(next_tokens.cpu())
|
|
449
|
+
|
|
450
|
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
451
|
+
this_peer_finished = unfinished_sequences.max() == 0
|
|
452
|
+
cur_len += 1
|
|
453
|
+
|
|
454
|
+
# This is needed to properly delete outputs.logits which may be very large
|
|
455
|
+
# for first iteration
|
|
456
|
+
# Otherwise a reference to outputs is kept which keeps
|
|
457
|
+
# the logits alive in the next iteration
|
|
458
|
+
del outputs
|
|
459
|
+
|
|
460
|
+
if streamer is not None:
|
|
461
|
+
streamer.end()
|
|
462
|
+
|
|
463
|
+
if return_dict_in_generate:
|
|
464
|
+
if self.config.is_encoder_decoder:
|
|
465
|
+
return transformers.generation.utils.GenerateEncoderDecoderOutput(
|
|
466
|
+
sequences=input_ids,
|
|
467
|
+
scores=scores,
|
|
468
|
+
logits=raw_logits,
|
|
469
|
+
encoder_attentions=encoder_attentions,
|
|
470
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
471
|
+
decoder_attentions=decoder_attentions,
|
|
472
|
+
cross_attentions=cross_attentions,
|
|
473
|
+
decoder_hidden_states=decoder_hidden_states,
|
|
474
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
return transformers.generation.utils.GenerateDecoderOnlyOutput(
|
|
478
|
+
sequences=input_ids,
|
|
479
|
+
scores=scores,
|
|
480
|
+
logits=raw_logits,
|
|
481
|
+
attentions=decoder_attentions,
|
|
482
|
+
hidden_states=decoder_hidden_states,
|
|
483
|
+
past_key_values=model_kwargs.get("past_key_values"),
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
return input_ids
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from typing import Callable, Optional, Tuple
|
|
2
|
+
import packaging.version as pv
|
|
3
|
+
import torch
|
|
4
|
+
import transformers
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
8
|
+
_PATCHES_ = ["forward"]
|
|
9
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|
|
10
|
+
|
|
11
|
+
def forward(self, x, seq_len=None):
|
|
12
|
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
13
|
+
# if seq_len > self.max_seq_len_cached:
|
|
14
|
+
# self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
15
|
+
|
|
16
|
+
def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
|
|
17
|
+
t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
|
|
18
|
+
# freqs = torch.einsum("i,j->ij", t, inv_freq)
|
|
19
|
+
freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
|
|
20
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
21
|
+
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
|
|
22
|
+
|
|
23
|
+
def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
|
|
24
|
+
torch._check(seq_len.item() <= cos_cached.shape[0])
|
|
25
|
+
co = cos_cached[: seq_len.item()].detach().clone()
|
|
26
|
+
torch._check(seq_len.item() <= sin_cached.shape[0])
|
|
27
|
+
si = sin_cached[: seq_len.item()].detach().clone()
|
|
28
|
+
return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
|
|
29
|
+
|
|
30
|
+
cos_cached, sin_cached = torch.cond(
|
|
31
|
+
(seq_len > self.max_seq_len_cached).item(),
|
|
32
|
+
_set_cos_sin_cache_then,
|
|
33
|
+
_set_cos_sin_cache_else,
|
|
34
|
+
[x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
|
|
35
|
+
)
|
|
36
|
+
return cos_cached, sin_cached
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class patched_IdeficsAttention(torch.nn.Module):
|
|
40
|
+
_PATCHES_ = ["forward"]
|
|
41
|
+
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
|
|
42
|
+
|
|
43
|
+
def forward(
|
|
44
|
+
self,
|
|
45
|
+
hidden_states: torch.Tensor,
|
|
46
|
+
key_value_states: Optional[torch.Tensor] = None,
|
|
47
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
48
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
49
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
50
|
+
output_attentions: bool = False,
|
|
51
|
+
use_cache: bool = False,
|
|
52
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
53
|
+
**kwargs,
|
|
54
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
55
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
56
|
+
is_cross_attention = self.is_cross_attention or key_value_states is not None
|
|
57
|
+
|
|
58
|
+
bsz, q_len, _ = hidden_states.size()
|
|
59
|
+
|
|
60
|
+
query_states = (
|
|
61
|
+
self.q_proj(hidden_states)
|
|
62
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
63
|
+
.transpose(1, 2)
|
|
64
|
+
)
|
|
65
|
+
if not is_cross_attention:
|
|
66
|
+
key_states = (
|
|
67
|
+
self.k_proj(hidden_states)
|
|
68
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
69
|
+
.transpose(1, 2)
|
|
70
|
+
)
|
|
71
|
+
value_states = (
|
|
72
|
+
self.v_proj(hidden_states)
|
|
73
|
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
74
|
+
.transpose(1, 2)
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
_, kv_len, _ = (
|
|
78
|
+
key_value_states.size()
|
|
79
|
+
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
|
80
|
+
key_states = (
|
|
81
|
+
self.k_proj(key_value_states)
|
|
82
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
83
|
+
.transpose(1, 2)
|
|
84
|
+
)
|
|
85
|
+
value_states = (
|
|
86
|
+
self.v_proj(key_value_states)
|
|
87
|
+
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
|
88
|
+
.transpose(1, 2)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
kv_seq_len = key_states.shape[-2]
|
|
92
|
+
if past_key_value is not None:
|
|
93
|
+
kv_seq_len += cache_position[0]
|
|
94
|
+
|
|
95
|
+
if not is_cross_attention:
|
|
96
|
+
rotary_length = torch.maximum(
|
|
97
|
+
torch.tensor(kv_seq_len, dtype=torch.int64),
|
|
98
|
+
torch.tensor(q_len, dtype=torch.int64),
|
|
99
|
+
)
|
|
100
|
+
cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
|
|
101
|
+
query_states, key_states = (
|
|
102
|
+
transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
|
|
103
|
+
query_states, key_states, cos, sin, position_ids
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
# [bsz, nh, t, hd]
|
|
107
|
+
|
|
108
|
+
if past_key_value is not None:
|
|
109
|
+
# sin and cos are specific to RoPE models;
|
|
110
|
+
# cache_position needed for the static cache
|
|
111
|
+
cache_kwargs = {"cache_position": cache_position}
|
|
112
|
+
key_states, value_states = past_key_value.update(
|
|
113
|
+
key_states, value_states, self.layer_idx, cache_kwargs
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if self.qk_layer_norms:
|
|
117
|
+
query_states = self.q_layer_norm(query_states)
|
|
118
|
+
key_states = self.k_layer_norm(key_states)
|
|
119
|
+
|
|
120
|
+
attention_interface: Callable = (
|
|
121
|
+
transformers.models.idefics.modeling_idefics.eager_attention_forward
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if self.config._attn_implementation != "eager":
|
|
125
|
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
|
126
|
+
transformers.models.idefics.modeling_idefics.logger.warning_once(
|
|
127
|
+
"`torch.nn.functional.scaled_dot_product_attention` does not support "
|
|
128
|
+
"`output_attentions=True`. Falling back to "
|
|
129
|
+
"eager attention. This warning can be removed using the argument "
|
|
130
|
+
'`attn_implementation="eager"` when loading the model.'
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
|
|
134
|
+
self.config._attn_implementation
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
attn_output, attn_weights = attention_interface(
|
|
138
|
+
self,
|
|
139
|
+
query_states,
|
|
140
|
+
key_states,
|
|
141
|
+
value_states,
|
|
142
|
+
attention_mask,
|
|
143
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
144
|
+
scaling=self.scaling,
|
|
145
|
+
**kwargs,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
149
|
+
attn_output = self.o_proj(attn_output)
|
|
150
|
+
|
|
151
|
+
if output_attentions:
|
|
152
|
+
attn_weights = None
|
|
153
|
+
|
|
154
|
+
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
|
|
155
|
+
return attn_output, attn_weights, past_key_value
|
|
156
|
+
return attn_output, attn_weights
|