liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (115) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +46 -15
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  15. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  16. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  17. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  18. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  19. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  20. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  21. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  22. liger_kernel/ops/backends/registry.py +61 -0
  23. liger_kernel/ops/cross_entropy.py +134 -65
  24. liger_kernel/ops/dyt.py +115 -180
  25. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  26. liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
  27. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  28. liger_kernel/ops/geglu.py +6 -4
  29. liger_kernel/ops/group_norm.py +7 -7
  30. liger_kernel/ops/grpo_loss.py +312 -0
  31. liger_kernel/ops/jsd.py +2 -1
  32. liger_kernel/ops/kl_div.py +9 -5
  33. liger_kernel/ops/layer_norm.py +146 -78
  34. liger_kernel/ops/llama4_rope.py +225 -0
  35. liger_kernel/ops/multi_token_attention.py +207 -0
  36. liger_kernel/ops/poly_norm.py +390 -0
  37. liger_kernel/ops/rms_norm.py +398 -99
  38. liger_kernel/ops/rope.py +1 -1
  39. liger_kernel/ops/softmax.py +201 -0
  40. liger_kernel/ops/sparsemax.py +179 -0
  41. liger_kernel/ops/swiglu.py +1 -1
  42. liger_kernel/ops/tiled_mlp.py +136 -0
  43. liger_kernel/ops/utils.py +14 -0
  44. liger_kernel/transformers/__init__.py +208 -17
  45. liger_kernel/transformers/auto_model.py +21 -0
  46. liger_kernel/transformers/cross_entropy.py +9 -4
  47. liger_kernel/transformers/dyt.py +6 -4
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -1
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +122 -20
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -1
  57. liger_kernel/transformers/group_norm.py +1 -1
  58. liger_kernel/transformers/grpo_loss.py +153 -0
  59. liger_kernel/transformers/jsd.py +1 -1
  60. liger_kernel/transformers/kl_div.py +1 -1
  61. liger_kernel/transformers/layer_norm.py +1 -1
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/exaone4.py +136 -0
  64. liger_kernel/transformers/model/falcon_h1.py +122 -0
  65. liger_kernel/transformers/model/gemma.py +57 -27
  66. liger_kernel/transformers/model/gemma2.py +65 -28
  67. liger_kernel/transformers/model/gemma3.py +331 -0
  68. liger_kernel/transformers/model/glm4.py +141 -0
  69. liger_kernel/transformers/model/glm4v.py +163 -0
  70. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  71. liger_kernel/transformers/model/gpt_oss.py +211 -0
  72. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  73. liger_kernel/transformers/model/internvl.py +157 -0
  74. liger_kernel/transformers/model/llama.py +109 -27
  75. liger_kernel/transformers/model/llama4.py +121 -0
  76. liger_kernel/transformers/model/llava.py +111 -136
  77. liger_kernel/transformers/model/loss_utils.py +50 -12
  78. liger_kernel/transformers/model/mistral.py +51 -34
  79. liger_kernel/transformers/model/mixtral.py +50 -29
  80. liger_kernel/transformers/model/mllama.py +46 -24
  81. liger_kernel/transformers/model/olmo2.py +47 -22
  82. liger_kernel/transformers/model/olmo3.py +142 -0
  83. liger_kernel/transformers/model/output_classes.py +147 -0
  84. liger_kernel/transformers/model/paligemma.py +50 -14
  85. liger_kernel/transformers/model/phi3.py +47 -172
  86. liger_kernel/transformers/model/qwen2.py +55 -23
  87. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  88. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  89. liger_kernel/transformers/model/qwen3.py +136 -0
  90. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  91. liger_kernel/transformers/model/qwen3_next.py +146 -0
  92. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  93. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  94. liger_kernel/transformers/model/smollm3.py +199 -0
  95. liger_kernel/transformers/model/smolvlm.py +158 -0
  96. liger_kernel/transformers/monkey_patch.py +2018 -244
  97. liger_kernel/transformers/multi_token_attention.py +64 -0
  98. liger_kernel/transformers/poly_norm.py +42 -0
  99. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  100. liger_kernel/transformers/rms_norm.py +54 -6
  101. liger_kernel/transformers/rope.py +45 -1
  102. liger_kernel/transformers/softmax.py +12 -0
  103. liger_kernel/transformers/sparsemax.py +16 -0
  104. liger_kernel/transformers/swiglu.py +39 -1
  105. liger_kernel/transformers/tiled_mlp.py +125 -0
  106. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  107. liger_kernel/transformers/tvd.py +1 -1
  108. liger_kernel/utils.py +63 -0
  109. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
  110. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  111. liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
  112. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
@@ -7,17 +7,14 @@ import torch
7
7
 
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
- from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
10
+ from transformers.utils.deprecation import deprecate_kwarg
14
11
 
15
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
16
13
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
17
16
 
18
17
 
19
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
18
  def lce_forward_deprecated(
22
19
  self,
23
20
  input_ids: torch.LongTensor = None,
@@ -31,6 +28,7 @@ def lce_forward_deprecated(
31
28
  output_hidden_states: Optional[bool] = None,
32
29
  return_dict: Optional[bool] = None,
33
30
  cache_position: Optional[torch.LongTensor] = None,
31
+ skip_logits: Optional[bool] = None,
34
32
  ) -> Union[Tuple, CausalLMOutputWithPast]:
35
33
  r"""
36
34
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -85,6 +83,13 @@ def lce_forward_deprecated(
85
83
  loss = None
86
84
  logits = None
87
85
 
86
+ if skip_logits and labels is None:
87
+ raise ValueError("skip_logits is True, but labels is None")
88
+
89
+ if skip_logits is None:
90
+ # By default, if in training mode, don't materialize logits
91
+ skip_logits = self.training and labels is not None
92
+
88
93
  if self.training and (labels is not None):
89
94
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
90
95
  shift_labels = labels[..., 1:].contiguous()
@@ -125,8 +130,7 @@ def lce_forward_deprecated(
125
130
  )
126
131
 
127
132
 
128
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
129
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
133
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
134
  def lce_forward(
131
135
  self,
132
136
  input_ids: torch.LongTensor = None,
@@ -140,9 +144,10 @@ def lce_forward(
140
144
  output_hidden_states: Optional[bool] = None,
141
145
  return_dict: Optional[bool] = None,
142
146
  cache_position: Optional[torch.LongTensor] = None,
143
- num_logits_to_keep: int = 0,
144
- **loss_kwargs,
145
- ) -> Union[Tuple, CausalLMOutputWithPast]:
147
+ logits_to_keep: Union[int, torch.Tensor] = 0,
148
+ skip_logits: Optional[bool] = None,
149
+ **kwargs,
150
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
146
151
  r"""
147
152
  Args:
148
153
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -150,10 +155,12 @@ def lce_forward(
150
155
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
151
156
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
152
157
 
153
- num_logits_to_keep (`int`, *optional*):
154
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
158
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
159
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
155
160
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
156
161
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
162
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
163
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
157
164
 
158
165
  Returns:
159
166
 
@@ -192,36 +199,61 @@ def lce_forward(
192
199
  output_hidden_states=output_hidden_states,
193
200
  return_dict=return_dict,
194
201
  cache_position=cache_position,
202
+ **kwargs,
195
203
  )
196
204
 
197
205
  hidden_states = outputs[0]
206
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
207
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
208
+ kept_hidden_states = hidden_states[:, slice_indices, :]
198
209
 
210
+ shift_labels = kwargs.pop("shift_labels", None)
199
211
  logits = None
200
212
  loss = None
201
- # if in training mode, don't materialize logits
202
- if self.training and (labels is not None):
203
- loss = LigerForCausalLMLoss(
204
- hidden_states=hidden_states,
213
+ token_accuracy = None
214
+
215
+ if skip_logits and labels is None and shift_labels is None:
216
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
217
+
218
+ if skip_logits is None:
219
+ # By default, if in training mode, don't materialize logits
220
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
221
+
222
+ # Compute loss
223
+ if skip_logits:
224
+ result = LigerForCausalLMLoss(
225
+ hidden_states=kept_hidden_states,
205
226
  lm_head_weight=self.lm_head.weight,
206
227
  labels=labels,
228
+ shift_labels=shift_labels,
207
229
  hidden_size=self.config.hidden_size,
208
- **loss_kwargs,
230
+ **kwargs,
209
231
  )
232
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
210
233
 
211
- else: # if in inference mode materialize logits
212
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
213
- if labels is not None:
234
+ else:
235
+ logits = self.lm_head(kept_hidden_states)
236
+ if labels is not None or shift_labels is not None:
214
237
  loss = self.loss_function(
215
238
  logits=logits,
216
239
  labels=labels,
240
+ shift_labels=shift_labels,
217
241
  vocab_size=self.config.vocab_size,
218
- **loss_kwargs,
242
+ **kwargs,
219
243
  )
220
244
 
221
- return CausalLMOutputWithPast(
245
+ if not return_dict:
246
+ output_tuple = (logits,) + outputs[1:]
247
+ output = (loss,) + output_tuple if loss is not None else output_tuple
248
+ output = output + (token_accuracy,) if token_accuracy is not None else output
249
+ return output
250
+
251
+ # Return custom output class with token accuracy field
252
+ return LigerCausalLMOutputWithPast(
222
253
  loss=loss,
223
254
  logits=logits,
224
255
  past_key_values=outputs.past_key_values,
225
256
  hidden_states=outputs.hidden_states,
226
257
  attentions=outputs.attentions,
258
+ token_accuracy=token_accuracy,
227
259
  )
@@ -5,18 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
8
+ from transformers.utils import can_return_tuple
14
9
 
15
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerQwen2_5_VLCausalLMOutputWithPast
16
13
 
17
14
 
18
- @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
15
+ @can_return_tuple
20
16
  def lce_forward(
21
17
  self,
22
18
  input_ids: torch.LongTensor = None,
@@ -36,17 +32,26 @@ def lce_forward(
36
32
  rope_deltas: Optional[torch.LongTensor] = None,
37
33
  cache_position: Optional[torch.LongTensor] = None,
38
34
  second_per_grid_ts: Optional[torch.Tensor] = None,
39
- **loss_kwargs,
40
- ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
35
+ skip_logits: Optional[bool] = None,
36
+ **kwargs,
37
+ ) -> Union[Tuple, LigerQwen2_5_VLCausalLMOutputWithPast]:
41
38
  r"""
42
- Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
43
- Args:
44
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
-
49
- Returns:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
44
+ The tensors corresponding to the input videos. Pixel values can be obtained using
45
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
46
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
47
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
48
+ The temporal, height and width of feature shape of each image in LLM.
49
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
50
+ The temporal, height and width of feature shape of each video in LLM.
51
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
52
+ The rope index difference between sequence length and multimodal rope.
53
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
54
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
50
55
 
51
56
  Example:
52
57
 
@@ -78,78 +83,20 @@ def lce_forward(
78
83
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
79
84
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
80
85
  ```"""
86
+
81
87
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
82
88
  output_hidden_states = (
83
89
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
84
90
  )
85
91
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
92
 
87
- if inputs_embeds is None:
88
- inputs_embeds = self.model.embed_tokens(input_ids)
89
- if pixel_values is not None:
90
- pixel_values = pixel_values.type(self.visual.dtype)
91
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
92
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
93
- n_image_features = image_embeds.shape[0]
94
- if n_image_tokens != n_image_features:
95
- raise ValueError(
96
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
97
- )
98
-
99
- mask = input_ids == self.config.image_token_id
100
- mask_unsqueezed = mask.unsqueeze(-1)
101
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
102
- image_mask = mask_expanded.to(inputs_embeds.device)
103
-
104
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
105
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
106
-
107
- if pixel_values_videos is not None:
108
- pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
109
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
110
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
111
- n_video_features = video_embeds.shape[0]
112
- if n_video_tokens != n_video_features:
113
- raise ValueError(
114
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
115
- )
116
-
117
- mask = input_ids == self.config.video_token_id
118
- mask_unsqueezed = mask.unsqueeze(-1)
119
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
120
- video_mask = mask_expanded.to(inputs_embeds.device)
121
-
122
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
123
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
124
-
125
- if attention_mask is not None:
126
- attention_mask = attention_mask.to(inputs_embeds.device)
127
-
128
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
129
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
130
- # calculate RoPE index once per generation in the pre-fill stage only
131
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
132
- position_ids, rope_deltas = self.get_rope_index(
133
- input_ids,
134
- image_grid_thw,
135
- video_grid_thw,
136
- second_per_grid_ts,
137
- attention_mask,
138
- )
139
- self.rope_deltas = rope_deltas
140
- # then use the prev pre-calculated rope-deltas to get the correct position ids
141
- else:
142
- batch_size, seq_length, _ = inputs_embeds.shape
143
- delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
144
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
145
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
146
- if cache_position is not None: # otherwise `deltas` is an int `0`
147
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
148
- position_ids = position_ids.add(delta)
149
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
150
-
151
93
  outputs = self.model(
152
- input_ids=None,
94
+ input_ids=input_ids,
95
+ pixel_values=pixel_values,
96
+ pixel_values_videos=pixel_values_videos,
97
+ image_grid_thw=image_grid_thw,
98
+ video_grid_thw=video_grid_thw,
99
+ second_per_grid_ts=second_per_grid_ts,
153
100
  position_ids=position_ids,
154
101
  attention_mask=attention_mask,
155
102
  past_key_values=past_key_values,
@@ -159,46 +106,58 @@ def lce_forward(
159
106
  output_hidden_states=output_hidden_states,
160
107
  return_dict=return_dict,
161
108
  cache_position=cache_position,
109
+ **kwargs,
162
110
  )
163
111
 
164
112
  hidden_states = outputs[0]
165
113
 
114
+ shift_labels = kwargs.pop("shift_labels", None)
166
115
  loss = None
167
116
  logits = None
117
+ token_accuracy = None
168
118
 
169
- if self.training and (labels is not None):
170
- loss = LigerForCausalLMLoss(
119
+ if skip_logits and labels is None and shift_labels is None:
120
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
121
+
122
+ if skip_logits is None:
123
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
124
+
125
+ # Compute loss
126
+ if skip_logits:
127
+ result = LigerForCausalLMLoss(
171
128
  hidden_states=hidden_states,
172
129
  lm_head_weight=self.lm_head.weight,
173
130
  labels=labels,
131
+ shift_labels=shift_labels,
174
132
  hidden_size=self.config.hidden_size,
175
- **loss_kwargs,
133
+ **kwargs,
176
134
  )
135
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
177
136
  else:
178
137
  logits = self.lm_head(hidden_states)
179
- if labels is not None:
180
- # Upcast to float if we need to compute the loss to avoid potential precision issues
181
- logits = logits.float()
182
- # Shift so that tokens < n predict n
183
- shift_logits = logits[..., :-1, :].contiguous()
184
- shift_labels = labels[..., 1:].contiguous()
185
- # Flatten the tokens
186
- loss_fct = CrossEntropyLoss()
187
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
188
- shift_labels = shift_labels.view(-1)
189
- # Enable model parallelism
190
- shift_labels = shift_labels.to(shift_logits.device)
191
- loss = loss_fct(shift_logits, shift_labels)
138
+
139
+ loss = None
140
+ if labels is not None or shift_labels is not None:
141
+ loss = self.loss_function(
142
+ logits=logits,
143
+ labels=labels,
144
+ shift_labels=shift_labels,
145
+ vocab_size=self.config.vocab_size,
146
+ )
192
147
 
193
148
  if not return_dict:
194
- output = (logits,) + outputs[1:]
195
- return (loss,) + output if loss is not None else output
149
+ output_tuple = (logits,) + outputs[1:]
150
+ output = (loss,) + output_tuple if loss is not None else output_tuple
151
+ output = output + (token_accuracy,) if token_accuracy is not None else output
152
+ return output
196
153
 
197
- return Qwen2_5_VLCausalLMOutputWithPast(
154
+ # Return Qwen2.5-VL output with token accuracy
155
+ return LigerQwen2_5_VLCausalLMOutputWithPast(
198
156
  loss=loss,
199
157
  logits=logits,
200
158
  past_key_values=outputs.past_key_values,
201
159
  hidden_states=outputs.hidden_states,
202
160
  attentions=outputs.attentions,
203
- rope_deltas=rope_deltas,
161
+ rope_deltas=outputs.rope_deltas,
162
+ token_accuracy=token_accuracy,
204
163
  )
@@ -5,20 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from packaging import version
9
- from torch.nn import CrossEntropyLoss
10
- from transformers import __version__ as transformers_version
11
- from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
- from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
- from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
8
+ from transformers.utils import can_return_tuple
16
9
 
17
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerQwen2VLCausalLMOutputWithPast
18
13
 
19
14
 
20
- @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
15
+ @can_return_tuple
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -37,18 +31,24 @@ def lce_forward(
37
31
  video_grid_thw: Optional[torch.LongTensor] = None,
38
32
  rope_deltas: Optional[torch.LongTensor] = None,
39
33
  cache_position: Optional[torch.LongTensor] = None,
40
- **loss_kwargs,
41
- ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
34
+ skip_logits: Optional[bool] = None,
35
+ **kwargs,
36
+ ) -> Union[Tuple, LigerQwen2VLCausalLMOutputWithPast]:
42
37
  r"""
43
- Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
44
-
45
- Args:
46
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
47
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
48
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
49
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
50
-
51
- Returns:
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
45
+ [`Qwen2VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
52
 
53
53
  Example:
54
54
 
@@ -80,80 +80,19 @@ def lce_forward(
80
80
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
81
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
82
  ```"""
83
+
83
84
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
85
  output_hidden_states = (
85
86
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
87
  )
87
88
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
89
 
89
- if inputs_embeds is None:
90
- inputs_embeds = self.model.embed_tokens(input_ids)
91
- if pixel_values is not None:
92
- pixel_values = pixel_values.type(self.visual.get_dtype())
93
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
94
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
95
- n_image_features = image_embeds.shape[0]
96
- if n_image_tokens != n_image_features:
97
- raise ValueError(
98
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
99
- )
100
- image_mask = (
101
- (input_ids == self.config.image_token_id)
102
- .unsqueeze(-1)
103
- .expand_as(inputs_embeds)
104
- .to(inputs_embeds.device)
105
- )
106
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
107
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
108
-
109
- if pixel_values_videos is not None:
110
- pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
111
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
112
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
113
- n_video_features = video_embeds.shape[0]
114
- if n_video_tokens != n_video_features:
115
- raise ValueError(
116
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
117
- )
118
- video_mask = (
119
- (input_ids == self.config.video_token_id)
120
- .unsqueeze(-1)
121
- .expand_as(inputs_embeds)
122
- .to(inputs_embeds.device)
123
- )
124
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
125
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
126
-
127
- if attention_mask is not None:
128
- attention_mask = attention_mask.to(inputs_embeds.device)
129
-
130
- if version.parse(transformers_version) > version.parse("4.46.3"):
131
- # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
132
- # https://github.com/huggingface/transformers/issues/33401
133
- # While correct, this breaks equivalence with past versions of Qwen2-VL from
134
- # transformers and leads to failed tests or users noticing differences in results.
135
- # TODO: remove above conditional when liger drops support for transformers<4.47.0
136
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
137
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
138
- # calculate RoPE index once per generation in the pre-fill stage only
139
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
140
- position_ids, rope_deltas = self.get_rope_index(
141
- input_ids, image_grid_thw, video_grid_thw, attention_mask
142
- )
143
- self.rope_deltas = rope_deltas
144
- # then use the prev pre-calculated rope-deltas to get the correct position ids
145
- else:
146
- batch_size, seq_length, _ = inputs_embeds.shape
147
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
148
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
149
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
150
- if cache_position is not None: # otherwise `deltas` is an int `0`
151
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
152
- position_ids = position_ids.add(delta)
153
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
154
-
155
90
  outputs = self.model(
156
- input_ids=None,
91
+ input_ids=input_ids,
92
+ pixel_values=pixel_values,
93
+ pixel_values_videos=pixel_values_videos,
94
+ image_grid_thw=image_grid_thw,
95
+ video_grid_thw=video_grid_thw,
157
96
  position_ids=position_ids,
158
97
  attention_mask=attention_mask,
159
98
  past_key_values=past_key_values,
@@ -163,46 +102,58 @@ def lce_forward(
163
102
  output_hidden_states=output_hidden_states,
164
103
  return_dict=return_dict,
165
104
  cache_position=cache_position,
105
+ **kwargs,
166
106
  )
167
107
 
168
108
  hidden_states = outputs[0]
169
109
 
110
+ shift_labels = kwargs.pop("shift_labels", None)
170
111
  loss = None
171
112
  logits = None
113
+ token_accuracy = None
114
+
115
+ if skip_logits and labels is None and shift_labels is None:
116
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
117
+
118
+ if skip_logits is None:
119
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
172
120
 
173
- if self.training and (labels is not None):
174
- loss = LigerForCausalLMLoss(
121
+ # Compute loss
122
+ if skip_logits:
123
+ result = LigerForCausalLMLoss(
175
124
  hidden_states=hidden_states,
176
125
  lm_head_weight=self.lm_head.weight,
177
126
  labels=labels,
127
+ shift_labels=shift_labels,
178
128
  hidden_size=self.config.hidden_size,
179
- **loss_kwargs,
129
+ **kwargs,
180
130
  )
131
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
181
132
  else:
182
133
  logits = self.lm_head(hidden_states)
183
- if labels is not None:
184
- # Upcast to float if we need to compute the loss to avoid potential precision issues
185
- logits = logits.float()
186
- # Shift so that tokens < n predict n
187
- shift_logits = logits[..., :-1, :].contiguous()
188
- shift_labels = labels[..., 1:].contiguous()
189
- # Flatten the tokens
190
- loss_fct = CrossEntropyLoss()
191
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
192
- shift_labels = shift_labels.view(-1)
193
- # Enable model parallelism
194
- shift_labels = shift_labels.to(shift_logits.device)
195
- loss = loss_fct(shift_logits, shift_labels)
134
+
135
+ loss = None
136
+ if labels is not None or shift_labels is not None:
137
+ loss = self.loss_function(
138
+ logits=logits,
139
+ labels=labels,
140
+ shift_labels=shift_labels,
141
+ vocab_size=self.config.vocab_size,
142
+ )
196
143
 
197
144
  if not return_dict:
198
- output = (logits,) + outputs[1:]
199
- return (loss,) + output if loss is not None else output
145
+ output_tuple = (logits,) + outputs[1:]
146
+ output = (loss,) + output_tuple if loss is not None else output_tuple
147
+ output = output + (token_accuracy,) if token_accuracy is not None else output
148
+ return output
200
149
 
201
- return Qwen2VLCausalLMOutputWithPast(
150
+ # Return Qwen2VL output with token accuracy
151
+ return LigerQwen2VLCausalLMOutputWithPast(
202
152
  loss=loss,
203
153
  logits=logits,
204
154
  past_key_values=outputs.past_key_values,
205
155
  hidden_states=outputs.hidden_states,
206
156
  attentions=outputs.attentions,
207
- rope_deltas=rope_deltas,
157
+ rope_deltas=outputs.rope_deltas,
158
+ token_accuracy=token_accuracy,
208
159
  )