onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,486 @@
1
+ import inspect
2
+ import os
3
+ from typing import Optional, Tuple, Union
4
+ import packaging.version as pv
5
+ import torch
6
+ import transformers
7
+ from transformers.cache_utils import StaticCache, Cache
8
+ from .patch_helper import _is_torchdynamo_exporting
9
+
10
+
11
+ class patched_GenerationMixin:
12
+ """
13
+ Applies modifications implemented in PR
14
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
15
+ """
16
+
17
+ _PATCHES_ = [
18
+ "_cache_dependant_input_preparation",
19
+ "_cache_dependant_input_preparation_exporting",
20
+ (
21
+ None
22
+ if pv.Version(transformers.__version__) >= pv.Version("4.56")
23
+ else "prepare_inputs_for_generation"
24
+ ),
25
+ (
26
+ "_sample"
27
+ if pv.Version(transformers.__version__) == pv.Version("4.57.0.dev0")
28
+ else None
29
+ ),
30
+ ]
31
+ _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin
32
+
33
+ def _cache_dependant_input_preparation(
34
+ self,
35
+ input_ids: torch.LongTensor,
36
+ inputs_embeds: Optional[torch.FloatTensor],
37
+ cache_position: Optional[torch.LongTensor],
38
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
39
+ """
40
+ Generic cache-dependent input preparation
41
+ The code is put in a separate function to allow granular unit testing
42
+ as it needs a different implementation to be exportable.
43
+
44
+ If we have cache: let's slice `input_ids` through `cache_position`,
45
+ to keep only the unprocessed tokens
46
+ - Exception 1: when passing input_embeds,
47
+ input_ids may be missing entries
48
+ - Exception 2: some generation methods do special slicing of input_ids,
49
+ so we don't need to do it here
50
+ - Exception 3: with synced GPUs cache_position may go out of bounds,
51
+ but we only want dummy token in that case.
52
+ - Exception 4: If input_embeds are passed then slice it through
53
+ `cache_position`, to keep only the unprocessed tokens and
54
+ generate the first token for each sequence.
55
+ Later use the generated Input ids for continuation.
56
+
57
+ The current implementation does not rely on ``self`` and could be
58
+ a class method. It is left as a standard method to be easily rewritten.
59
+ """
60
+ if _is_torchdynamo_exporting():
61
+ return self._cache_dependant_input_preparation_exporting(
62
+ input_ids, inputs_embeds, cache_position
63
+ )
64
+ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
65
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
66
+ elif inputs_embeds is not None or ( # Exception 1
67
+ cache_position[-1] >= input_ids.shape[1]
68
+ ): # Exception 3
69
+ input_ids = input_ids[:, -cache_position.shape[0] :]
70
+ elif (
71
+ input_ids.shape[1] != cache_position.shape[0]
72
+ ): # Default case (the "else", a no op, is Exception 2)
73
+ input_ids = input_ids[:, cache_position]
74
+ return inputs_embeds, input_ids
75
+
76
+ def _cache_dependant_input_preparation_exporting(
77
+ self,
78
+ input_ids: torch.LongTensor,
79
+ inputs_embeds: Optional[torch.FloatTensor],
80
+ cache_position: Optional[torch.LongTensor],
81
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
82
+ """
83
+ This method implements method ``_cache_dependant_input_preparation``
84
+ with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
85
+ The code is put in a separate function to allow granular unit testing.
86
+ """
87
+ if inputs_embeds is None:
88
+ input_ids = input_ids[:, cache_position]
89
+ else:
90
+ # This is the code we need to implemented with torch.cond.
91
+ # if input_ids.shape[1] == 0:
92
+ # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
93
+ # else:
94
+ # if cache_position[-1] >= input_ids.shape[1]:
95
+ # input_ids = input_ids[:, -cache_position.shape[0] :]
96
+ # else:
97
+ # if input_ids.shape[1] != cache_position.shape[0]:
98
+ # input_ids = input_ids[:, cache_position]
99
+ def branch_1(inputs_embeds, cache_position):
100
+ return inputs_embeds[:, -cache_position.shape[0] :].clone()
101
+
102
+ def branch_2(input_ids, cache_position):
103
+ return input_ids[:, -cache_position.shape[0] :].clone()
104
+
105
+ def branch_3(input_ids, cache_position):
106
+ return input_ids[:, cache_position].clone()
107
+
108
+ inputs_embeds, input_ids = torch.cond(
109
+ input_ids.shape[1] == 0,
110
+ (
111
+ lambda input_ids, inputs_embeds, cache_position: (
112
+ branch_1(inputs_embeds, cache_position),
113
+ input_ids.clone(),
114
+ )
115
+ ),
116
+ (
117
+ lambda input_ids, inputs_embeds, cache_position: (
118
+ inputs_embeds,
119
+ torch.cond(
120
+ cache_position[-1] >= input_ids.shape[1],
121
+ branch_2,
122
+ lambda input_ids, cache_position: (
123
+ torch.cond(
124
+ input_ids.shape[1] != cache_position.shape[0],
125
+ branch_3,
126
+ (lambda input_ids, cache_position: input_ids),
127
+ [input_ids, cache_position],
128
+ )
129
+ ),
130
+ [input_ids, cache_position],
131
+ ),
132
+ )
133
+ ),
134
+ [input_ids, inputs_embeds, cache_position],
135
+ )
136
+ return inputs_embeds, input_ids
137
+
138
+ def prepare_inputs_for_generation(
139
+ self,
140
+ input_ids: torch.LongTensor,
141
+ past_key_values: Optional[Cache] = None,
142
+ attention_mask: Optional[torch.LongTensor] = None,
143
+ inputs_embeds: Optional[torch.FloatTensor] = None,
144
+ cache_position: Optional[torch.LongTensor] = None,
145
+ **kwargs,
146
+ ):
147
+ """
148
+ Prepare the model inputs for generation.
149
+ In includes operations like computing the 4D attention mask or
150
+ slicing inputs given the existing cache.
151
+
152
+ See the forward pass in the model documentation
153
+ for expected arguments (different models might have different
154
+ requirements for e.g. `past_key_values`).
155
+ This function should work as is for most LLMs.
156
+ """
157
+
158
+ # 1. Handle BC:
159
+ model_inputs = {}
160
+ # - some models don't have `Cache` support
161
+ # (which implies they don't expect `cache_position` in `forward`)
162
+ if getattr(self, "_supports_cache_class", False):
163
+ model_inputs["cache_position"] = cache_position
164
+ # - `cache_position` was not a mandatory input in
165
+ # `prepare_inputs_for_generation` for those models, and this
166
+ # function may be called outside of `generate`.
167
+ # Handle most use cases by creating `cache_position` on the fly
168
+ # (this alternative is not as robust as calling
169
+ # `generate` and letting it create `cache_position`)
170
+ elif cache_position is None:
171
+ past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
172
+ cache_position = torch.arange(
173
+ past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device
174
+ )
175
+
176
+ # 2. Generic cache-dependent input preparation
177
+ if past_key_values is not None:
178
+ model_inputs["past_key_values"] = past_key_values
179
+ inputs_embeds, input_ids = self._cache_dependant_input_preparation(
180
+ input_ids, inputs_embeds, cache_position
181
+ )
182
+
183
+ # 3. Prepare base model inputs
184
+ input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
185
+ # if `inputs_embeds` are passed, we only want
186
+ # to use them in the 1st generation step for every prompt.
187
+ if not self.config.is_encoder_decoder:
188
+ if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
189
+ model_inputs[input_ids_key] = None
190
+ model_inputs["inputs_embeds"] = inputs_embeds
191
+ else:
192
+ # `clone` calls in this function ensure a consistent stride. See #32227
193
+ model_inputs[input_ids_key] = input_ids.clone(
194
+ memory_format=torch.contiguous_format
195
+ )
196
+ model_inputs["inputs_embeds"] = None
197
+ else:
198
+ model_inputs[input_ids_key] = input_ids.clone(
199
+ memory_format=torch.contiguous_format
200
+ )
201
+
202
+ # 4. Create missing `position_ids` on the fly
203
+ encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None
204
+ attention_mask = (
205
+ kwargs.pop("decoder_attention_mask", None)
206
+ if self.config.is_encoder_decoder
207
+ else attention_mask
208
+ )
209
+ attention_mask_key = (
210
+ "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
211
+ )
212
+ position_ids_key = (
213
+ "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
214
+ )
215
+ if (
216
+ attention_mask is not None
217
+ and kwargs.get(position_ids_key) is None
218
+ and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
219
+ ):
220
+ position_ids = attention_mask.long().cumsum(-1) - 1
221
+ position_ids.masked_fill_(attention_mask == 0, 1)
222
+ kwargs[position_ids_key] = (
223
+ position_ids # placed in kwargs for further processing (see below)
224
+ )
225
+
226
+ # 5. Slice model inputs if it's an input
227
+ # that should have the same length as `input_ids`
228
+ for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
229
+ model_input = kwargs.get(model_input_name)
230
+ if model_input is not None:
231
+ if past_key_values is not None:
232
+ current_input_length = (
233
+ model_inputs["inputs_embeds"].shape[1]
234
+ if model_inputs.get("inputs_embeds") is not None
235
+ else model_inputs[input_ids_key].shape[1]
236
+ )
237
+ model_input = model_input[:, -current_input_length:]
238
+ model_input = model_input.clone(memory_format=torch.contiguous_format)
239
+ model_inputs[model_input_name] = model_input
240
+
241
+ # 6. Create 4D attention mask is we are using a
242
+ # `StaticCache` (important for performant compiled forward pass)
243
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
244
+ if model_inputs["inputs_embeds"] is not None:
245
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
246
+ device = model_inputs["inputs_embeds"].device
247
+ else:
248
+ batch_size, sequence_length = model_inputs[input_ids_key].shape
249
+ device = model_inputs[input_ids_key].device
250
+
251
+ # Create the causal mask with fixed shape in advance,
252
+ # to reduce recompilations. If the function to create
253
+ # the 4D causal mask exists,
254
+ # it should be present in the base model (XXXModel class).
255
+ base_model = getattr(self, self.base_model_prefix, None)
256
+ if base_model is None:
257
+ causal_mask_creation_function = getattr(
258
+ self, "_prepare_4d_causal_attention_mask_with_cache_position", None
259
+ )
260
+ else:
261
+ causal_mask_creation_function = getattr(
262
+ base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
263
+ )
264
+ if causal_mask_creation_function is None:
265
+ pass
266
+ # logger.warning_once(
267
+ # f"{self.__class__.__name__} has no "
268
+ # "`_prepare_4d_causal_attention_mask_with_cache_position` method "
269
+ # "defined in its base modeling class. "
270
+ # "Compiled forward passes will be sub-optimal. If you're "
271
+ # "writing code, see Llama for an example implementation. "
272
+ # "If you're a user, please report this "
273
+ # "issue on GitHub."
274
+ # )
275
+ else:
276
+ attention_mask = causal_mask_creation_function(
277
+ attention_mask,
278
+ sequence_length=sequence_length,
279
+ target_length=past_key_values.get_max_cache_shape(),
280
+ dtype=self.dtype,
281
+ device=device,
282
+ cache_position=cache_position,
283
+ batch_size=batch_size,
284
+ config=self.config,
285
+ past_key_values=past_key_values,
286
+ )
287
+ if attention_mask is not None:
288
+ model_inputs[attention_mask_key] = attention_mask
289
+
290
+ if encoder_attention_mask is not None:
291
+ model_inputs["attention_mask"] = encoder_attention_mask
292
+
293
+ # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
294
+ for key, value in kwargs.items():
295
+ if key not in model_inputs:
296
+ model_inputs[key] = value
297
+
298
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
299
+ model_inputs.pop("labels", None)
300
+ return model_inputs
301
+
302
+ def _sample(
303
+ self,
304
+ input_ids: torch.LongTensor,
305
+ logits_processor: "LogitsProcessorList", # noqa: F821
306
+ stopping_criteria: "StoppingCriteriaList", # noqa: F821
307
+ generation_config: "GenerationConfig", # noqa: F821
308
+ synced_gpus: bool = False,
309
+ streamer: Optional["BaseStreamer"] = None, # noqa: F821
310
+ **model_kwargs,
311
+ ) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
312
+ """
313
+ 2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
314
+ """
315
+ # init values
316
+ pad_token_id = generation_config._pad_token_tensor
317
+ output_attentions = generation_config.output_attentions
318
+ output_hidden_states = generation_config.output_hidden_states
319
+ output_scores = generation_config.output_scores
320
+ output_logits = generation_config.output_logits
321
+ return_dict_in_generate = generation_config.return_dict_in_generate
322
+ has_eos_stopping_criteria = any(
323
+ hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
324
+ )
325
+ do_sample = generation_config.do_sample
326
+
327
+ # init attention / hidden states / scores tuples
328
+ scores = () if (return_dict_in_generate and output_scores) else None
329
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
330
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
331
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
332
+ decoder_hidden_states = (
333
+ () if (return_dict_in_generate and output_hidden_states) else None
334
+ )
335
+
336
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
337
+ if return_dict_in_generate and self.config.is_encoder_decoder:
338
+ encoder_attentions = (
339
+ model_kwargs["encoder_outputs"].get("attentions")
340
+ if output_attentions
341
+ else None
342
+ )
343
+ encoder_hidden_states = (
344
+ model_kwargs["encoder_outputs"].get("hidden_states")
345
+ if output_hidden_states
346
+ else None
347
+ )
348
+
349
+ # keep track of which sequences are already finished
350
+ batch_size, cur_len = input_ids.shape[:2]
351
+ this_peer_finished = False
352
+ unfinished_sequences = torch.ones(
353
+ batch_size, dtype=torch.long, device=input_ids.device
354
+ )
355
+ model_kwargs = self._get_initial_cache_position(
356
+ cur_len, input_ids.device, model_kwargs
357
+ )
358
+
359
+ model_forward = self.__call__
360
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
361
+ if compile_forward:
362
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
363
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
364
+ if self.config._attn_implementation == "flash_attention_2":
365
+ # only raise warning if the user passed an explicit compile-config
366
+ if (
367
+ generation_config.compile_config is not None
368
+ and generation_config.compile_config.fullgraph
369
+ ):
370
+ generation_config.compile_config.fullgraph = False
371
+ model_forward = self.get_compiled_call(generation_config.compile_config)
372
+
373
+ if generation_config.prefill_chunk_size is not None:
374
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
375
+ is_prefill = False
376
+ else:
377
+ is_prefill = True
378
+
379
+ while self._has_unfinished_sequences(
380
+ this_peer_finished, synced_gpus, device=input_ids.device
381
+ ):
382
+ # prepare model inputs
383
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
384
+
385
+ if is_prefill:
386
+ outputs = self(**model_inputs, return_dict=True)
387
+ is_prefill = False
388
+ else:
389
+ outputs = model_forward(**model_inputs, return_dict=True)
390
+
391
+ model_kwargs = self._update_model_kwargs_for_generation(
392
+ outputs,
393
+ model_kwargs,
394
+ is_encoder_decoder=self.config.is_encoder_decoder,
395
+ )
396
+ if synced_gpus and this_peer_finished:
397
+ continue
398
+
399
+ next_token_logits = outputs.logits[:, -1, :].to(
400
+ copy=True, dtype=torch.float32, device=input_ids.device
401
+ )
402
+
403
+ # pre-process distribution
404
+ next_token_scores = logits_processor(input_ids, next_token_logits)
405
+
406
+ # Store scores, attentions and hidden_states when required
407
+ if return_dict_in_generate:
408
+ if output_scores:
409
+ scores += (next_token_scores,)
410
+ if output_logits:
411
+ raw_logits += (next_token_logits,)
412
+ if output_attentions:
413
+ decoder_attentions += (
414
+ (outputs.decoder_attentions,)
415
+ if self.config.is_encoder_decoder
416
+ else (outputs.attentions,)
417
+ )
418
+ if self.config.is_encoder_decoder:
419
+ cross_attentions += (outputs.cross_attentions,)
420
+
421
+ if output_hidden_states:
422
+ decoder_hidden_states += (
423
+ (outputs.decoder_hidden_states,)
424
+ if self.config.is_encoder_decoder
425
+ else (outputs.hidden_states,)
426
+ )
427
+
428
+ # token selection
429
+ if do_sample:
430
+ probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
431
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
432
+ else:
433
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
434
+
435
+ # finished sentences should have their next token be a padding token
436
+ if has_eos_stopping_criteria:
437
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
438
+ 1 - unfinished_sequences
439
+ )
440
+
441
+ # update generated ids, model inputs, and length for next step
442
+ # PATCHED: the two following lines, next_tokens can 2D already for this model
443
+ next_tokens_2d = (
444
+ next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
445
+ )
446
+ input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
447
+ if streamer is not None:
448
+ streamer.put(next_tokens.cpu())
449
+
450
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
451
+ this_peer_finished = unfinished_sequences.max() == 0
452
+ cur_len += 1
453
+
454
+ # This is needed to properly delete outputs.logits which may be very large
455
+ # for first iteration
456
+ # Otherwise a reference to outputs is kept which keeps
457
+ # the logits alive in the next iteration
458
+ del outputs
459
+
460
+ if streamer is not None:
461
+ streamer.end()
462
+
463
+ if return_dict_in_generate:
464
+ if self.config.is_encoder_decoder:
465
+ return transformers.generation.utils.GenerateEncoderDecoderOutput(
466
+ sequences=input_ids,
467
+ scores=scores,
468
+ logits=raw_logits,
469
+ encoder_attentions=encoder_attentions,
470
+ encoder_hidden_states=encoder_hidden_states,
471
+ decoder_attentions=decoder_attentions,
472
+ cross_attentions=cross_attentions,
473
+ decoder_hidden_states=decoder_hidden_states,
474
+ past_key_values=model_kwargs.get("past_key_values"),
475
+ )
476
+ else:
477
+ return transformers.generation.utils.GenerateDecoderOnlyOutput(
478
+ sequences=input_ids,
479
+ scores=scores,
480
+ logits=raw_logits,
481
+ attentions=decoder_attentions,
482
+ hidden_states=decoder_hidden_states,
483
+ past_key_values=model_kwargs.get("past_key_values"),
484
+ )
485
+ else:
486
+ return input_ids
@@ -0,0 +1,156 @@
1
+ from typing import Callable, Optional, Tuple
2
+ import packaging.version as pv
3
+ import torch
4
+ import transformers
5
+
6
+
7
+ class patched_IdeficsEmbedding(torch.nn.Module):
8
+ _PATCHES_ = ["forward"]
9
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
10
+
11
+ def forward(self, x, seq_len=None):
12
+ # x: [bs, num_attention_heads, seq_len, head_size]
13
+ # if seq_len > self.max_seq_len_cached:
14
+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
15
+
16
+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
17
+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
18
+ # freqs = torch.einsum("i,j->ij", t, inv_freq)
19
+ freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
20
+ emb = torch.cat((freqs, freqs), dim=-1)
21
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
22
+
23
+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
24
+ torch._check(seq_len.item() <= cos_cached.shape[0])
25
+ co = cos_cached[: seq_len.item()].detach().clone()
26
+ torch._check(seq_len.item() <= sin_cached.shape[0])
27
+ si = sin_cached[: seq_len.item()].detach().clone()
28
+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
29
+
30
+ cos_cached, sin_cached = torch.cond(
31
+ (seq_len > self.max_seq_len_cached).item(),
32
+ _set_cos_sin_cache_then,
33
+ _set_cos_sin_cache_else,
34
+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
35
+ )
36
+ return cos_cached, sin_cached
37
+
38
+
39
+ class patched_IdeficsAttention(torch.nn.Module):
40
+ _PATCHES_ = ["forward"]
41
+ _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsAttention
42
+
43
+ def forward(
44
+ self,
45
+ hidden_states: torch.Tensor,
46
+ key_value_states: Optional[torch.Tensor] = None,
47
+ attention_mask: Optional[torch.Tensor] = None,
48
+ position_ids: Optional[torch.LongTensor] = None,
49
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
50
+ output_attentions: bool = False,
51
+ use_cache: bool = False,
52
+ cache_position: Optional[torch.LongTensor] = None,
53
+ **kwargs,
54
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
55
+ # if key_value_states are provided this layer is used as a cross-attention layer
56
+ is_cross_attention = self.is_cross_attention or key_value_states is not None
57
+
58
+ bsz, q_len, _ = hidden_states.size()
59
+
60
+ query_states = (
61
+ self.q_proj(hidden_states)
62
+ .view(bsz, q_len, self.num_heads, self.head_dim)
63
+ .transpose(1, 2)
64
+ )
65
+ if not is_cross_attention:
66
+ key_states = (
67
+ self.k_proj(hidden_states)
68
+ .view(bsz, q_len, self.num_heads, self.head_dim)
69
+ .transpose(1, 2)
70
+ )
71
+ value_states = (
72
+ self.v_proj(hidden_states)
73
+ .view(bsz, q_len, self.num_heads, self.head_dim)
74
+ .transpose(1, 2)
75
+ )
76
+ else:
77
+ _, kv_len, _ = (
78
+ key_value_states.size()
79
+ ) # Note that, in this case, `kv_len` == `kv_seq_len`
80
+ key_states = (
81
+ self.k_proj(key_value_states)
82
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
83
+ .transpose(1, 2)
84
+ )
85
+ value_states = (
86
+ self.v_proj(key_value_states)
87
+ .view(bsz, kv_len, self.num_heads, self.head_dim)
88
+ .transpose(1, 2)
89
+ )
90
+
91
+ kv_seq_len = key_states.shape[-2]
92
+ if past_key_value is not None:
93
+ kv_seq_len += cache_position[0]
94
+
95
+ if not is_cross_attention:
96
+ rotary_length = torch.maximum(
97
+ torch.tensor(kv_seq_len, dtype=torch.int64),
98
+ torch.tensor(q_len, dtype=torch.int64),
99
+ )
100
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
101
+ query_states, key_states = (
102
+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
103
+ query_states, key_states, cos, sin, position_ids
104
+ )
105
+ )
106
+ # [bsz, nh, t, hd]
107
+
108
+ if past_key_value is not None:
109
+ # sin and cos are specific to RoPE models;
110
+ # cache_position needed for the static cache
111
+ cache_kwargs = {"cache_position": cache_position}
112
+ key_states, value_states = past_key_value.update(
113
+ key_states, value_states, self.layer_idx, cache_kwargs
114
+ )
115
+
116
+ if self.qk_layer_norms:
117
+ query_states = self.q_layer_norm(query_states)
118
+ key_states = self.k_layer_norm(key_states)
119
+
120
+ attention_interface: Callable = (
121
+ transformers.models.idefics.modeling_idefics.eager_attention_forward
122
+ )
123
+
124
+ if self.config._attn_implementation != "eager":
125
+ if self.config._attn_implementation == "sdpa" and output_attentions:
126
+ transformers.models.idefics.modeling_idefics.logger.warning_once(
127
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
128
+ "`output_attentions=True`. Falling back to "
129
+ "eager attention. This warning can be removed using the argument "
130
+ '`attn_implementation="eager"` when loading the model.'
131
+ )
132
+ else:
133
+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
134
+ self.config._attn_implementation
135
+ ]
136
+
137
+ attn_output, attn_weights = attention_interface(
138
+ self,
139
+ query_states,
140
+ key_states,
141
+ value_states,
142
+ attention_mask,
143
+ dropout=0.0 if not self.training else self.dropout,
144
+ scaling=self.scaling,
145
+ **kwargs,
146
+ )
147
+
148
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
149
+ attn_output = self.o_proj(attn_output)
150
+
151
+ if output_attentions:
152
+ attn_weights = None
153
+
154
+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
155
+ return attn_output, attn_weights, past_key_value
156
+ return attn_output, attn_weights