liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -5,18 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
- from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
8
+ from transformers.utils import can_return_tuple
14
9
 
15
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerQwen2_5_VLCausalLMOutputWithPast
16
13
 
17
14
 
18
- @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
- @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
15
+ @can_return_tuple
20
16
  def lce_forward(
21
17
  self,
22
18
  input_ids: torch.LongTensor = None,
@@ -36,17 +32,26 @@ def lce_forward(
36
32
  rope_deltas: Optional[torch.LongTensor] = None,
37
33
  cache_position: Optional[torch.LongTensor] = None,
38
34
  second_per_grid_ts: Optional[torch.Tensor] = None,
39
- **loss_kwargs,
40
- ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
35
+ skip_logits: Optional[bool] = None,
36
+ **kwargs,
37
+ ) -> Union[Tuple, LigerQwen2_5_VLCausalLMOutputWithPast]:
41
38
  r"""
42
- Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
43
- Args:
44
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
-
49
- Returns:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
44
+ The tensors corresponding to the input videos. Pixel values can be obtained using
45
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
46
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
47
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
48
+ The temporal, height and width of feature shape of each image in LLM.
49
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
50
+ The temporal, height and width of feature shape of each video in LLM.
51
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
52
+ The rope index difference between sequence length and multimodal rope.
53
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
54
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
50
55
 
51
56
  Example:
52
57
 
@@ -78,78 +83,20 @@ def lce_forward(
78
83
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
79
84
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
80
85
  ```"""
86
+
81
87
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
82
88
  output_hidden_states = (
83
89
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
84
90
  )
85
91
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
92
 
87
- if inputs_embeds is None:
88
- inputs_embeds = self.model.embed_tokens(input_ids)
89
- if pixel_values is not None:
90
- pixel_values = pixel_values.type(self.visual.dtype)
91
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
92
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
93
- n_image_features = image_embeds.shape[0]
94
- if n_image_tokens != n_image_features:
95
- raise ValueError(
96
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
97
- )
98
-
99
- mask = input_ids == self.config.image_token_id
100
- mask_unsqueezed = mask.unsqueeze(-1)
101
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
102
- image_mask = mask_expanded.to(inputs_embeds.device)
103
-
104
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
105
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
106
-
107
- if pixel_values_videos is not None:
108
- pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
109
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
110
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
111
- n_video_features = video_embeds.shape[0]
112
- if n_video_tokens != n_video_features:
113
- raise ValueError(
114
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
115
- )
116
-
117
- mask = input_ids == self.config.video_token_id
118
- mask_unsqueezed = mask.unsqueeze(-1)
119
- mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
120
- video_mask = mask_expanded.to(inputs_embeds.device)
121
-
122
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
123
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
124
-
125
- if attention_mask is not None:
126
- attention_mask = attention_mask.to(inputs_embeds.device)
127
-
128
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
129
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
130
- # calculate RoPE index once per generation in the pre-fill stage only
131
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
132
- position_ids, rope_deltas = self.get_rope_index(
133
- input_ids,
134
- image_grid_thw,
135
- video_grid_thw,
136
- second_per_grid_ts,
137
- attention_mask,
138
- )
139
- self.rope_deltas = rope_deltas
140
- # then use the prev pre-calculated rope-deltas to get the correct position ids
141
- else:
142
- batch_size, seq_length, _ = inputs_embeds.shape
143
- delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
144
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
145
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
146
- if cache_position is not None: # otherwise `deltas` is an int `0`
147
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
148
- position_ids = position_ids.add(delta)
149
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
150
-
151
93
  outputs = self.model(
152
- input_ids=None,
94
+ input_ids=input_ids,
95
+ pixel_values=pixel_values,
96
+ pixel_values_videos=pixel_values_videos,
97
+ image_grid_thw=image_grid_thw,
98
+ video_grid_thw=video_grid_thw,
99
+ second_per_grid_ts=second_per_grid_ts,
153
100
  position_ids=position_ids,
154
101
  attention_mask=attention_mask,
155
102
  past_key_values=past_key_values,
@@ -159,46 +106,58 @@ def lce_forward(
159
106
  output_hidden_states=output_hidden_states,
160
107
  return_dict=return_dict,
161
108
  cache_position=cache_position,
109
+ **kwargs,
162
110
  )
163
111
 
164
112
  hidden_states = outputs[0]
165
113
 
114
+ shift_labels = kwargs.pop("shift_labels", None)
166
115
  loss = None
167
116
  logits = None
117
+ token_accuracy = None
168
118
 
169
- if self.training and (labels is not None):
170
- loss = LigerForCausalLMLoss(
119
+ if skip_logits and labels is None and shift_labels is None:
120
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
121
+
122
+ if skip_logits is None:
123
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
124
+
125
+ # Compute loss
126
+ if skip_logits:
127
+ result = LigerForCausalLMLoss(
171
128
  hidden_states=hidden_states,
172
129
  lm_head_weight=self.lm_head.weight,
173
130
  labels=labels,
131
+ shift_labels=shift_labels,
174
132
  hidden_size=self.config.hidden_size,
175
- **loss_kwargs,
133
+ **kwargs,
176
134
  )
135
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
177
136
  else:
178
137
  logits = self.lm_head(hidden_states)
179
- if labels is not None:
180
- # Upcast to float if we need to compute the loss to avoid potential precision issues
181
- logits = logits.float()
182
- # Shift so that tokens < n predict n
183
- shift_logits = logits[..., :-1, :].contiguous()
184
- shift_labels = labels[..., 1:].contiguous()
185
- # Flatten the tokens
186
- loss_fct = CrossEntropyLoss()
187
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
188
- shift_labels = shift_labels.view(-1)
189
- # Enable model parallelism
190
- shift_labels = shift_labels.to(shift_logits.device)
191
- loss = loss_fct(shift_logits, shift_labels)
138
+
139
+ loss = None
140
+ if labels is not None or shift_labels is not None:
141
+ loss = self.loss_function(
142
+ logits=logits,
143
+ labels=labels,
144
+ shift_labels=shift_labels,
145
+ vocab_size=self.config.vocab_size,
146
+ )
192
147
 
193
148
  if not return_dict:
194
- output = (logits,) + outputs[1:]
195
- return (loss,) + output if loss is not None else output
149
+ output_tuple = (logits,) + outputs[1:]
150
+ output = (loss,) + output_tuple if loss is not None else output_tuple
151
+ output = output + (token_accuracy,) if token_accuracy is not None else output
152
+ return output
196
153
 
197
- return Qwen2_5_VLCausalLMOutputWithPast(
154
+ # Return Qwen2.5-VL output with token accuracy
155
+ return LigerQwen2_5_VLCausalLMOutputWithPast(
198
156
  loss=loss,
199
157
  logits=logits,
200
158
  past_key_values=outputs.past_key_values,
201
159
  hidden_states=outputs.hidden_states,
202
160
  attentions=outputs.attentions,
203
- rope_deltas=rope_deltas,
161
+ rope_deltas=outputs.rope_deltas,
162
+ token_accuracy=token_accuracy,
204
163
  )
@@ -5,20 +5,14 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from packaging import version
9
- from torch.nn import CrossEntropyLoss
10
- from transformers import __version__ as transformers_version
11
- from transformers.models.qwen2_vl.modeling_qwen2_vl import _CONFIG_FOR_DOC
12
- from transformers.models.qwen2_vl.modeling_qwen2_vl import QWEN2_VL_INPUTS_DOCSTRING
13
- from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
8
+ from transformers.utils import can_return_tuple
16
9
 
17
10
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerQwen2VLCausalLMOutputWithPast
18
13
 
19
14
 
20
- @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
15
+ @can_return_tuple
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,
@@ -37,18 +31,24 @@ def lce_forward(
37
31
  video_grid_thw: Optional[torch.LongTensor] = None,
38
32
  rope_deltas: Optional[torch.LongTensor] = None,
39
33
  cache_position: Optional[torch.LongTensor] = None,
40
- **loss_kwargs,
41
- ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
34
+ skip_logits: Optional[bool] = None,
35
+ **kwargs,
36
+ ) -> Union[Tuple, LigerQwen2VLCausalLMOutputWithPast]:
42
37
  r"""
43
- Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
44
-
45
- Args:
46
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
47
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
48
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
49
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
50
-
51
- Returns:
38
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43
+ The tensors corresponding to the input videos. Pixel values can be obtained using
44
+ [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
45
+ [`Qwen2VLImageProcessor`] for processing videos.
46
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47
+ The temporal, height and width of feature shape of each image in LLM.
48
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49
+ The temporal, height and width of feature shape of each video in LLM.
50
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51
+ The rope index difference between sequence length and multimodal rope.
52
52
 
53
53
  Example:
54
54
 
@@ -80,80 +80,19 @@ def lce_forward(
80
80
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
81
81
  "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
82
82
  ```"""
83
+
83
84
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
85
  output_hidden_states = (
85
86
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
87
  )
87
88
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
89
 
89
- if inputs_embeds is None:
90
- inputs_embeds = self.model.embed_tokens(input_ids)
91
- if pixel_values is not None:
92
- pixel_values = pixel_values.type(self.visual.get_dtype())
93
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
94
- n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
95
- n_image_features = image_embeds.shape[0]
96
- if n_image_tokens != n_image_features:
97
- raise ValueError(
98
- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
99
- )
100
- image_mask = (
101
- (input_ids == self.config.image_token_id)
102
- .unsqueeze(-1)
103
- .expand_as(inputs_embeds)
104
- .to(inputs_embeds.device)
105
- )
106
- image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
107
- inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
108
-
109
- if pixel_values_videos is not None:
110
- pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
111
- video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
112
- n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
113
- n_video_features = video_embeds.shape[0]
114
- if n_video_tokens != n_video_features:
115
- raise ValueError(
116
- f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
117
- )
118
- video_mask = (
119
- (input_ids == self.config.video_token_id)
120
- .unsqueeze(-1)
121
- .expand_as(inputs_embeds)
122
- .to(inputs_embeds.device)
123
- )
124
- video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
125
- inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
126
-
127
- if attention_mask is not None:
128
- attention_mask = attention_mask.to(inputs_embeds.device)
129
-
130
- if version.parse(transformers_version) > version.parse("4.46.3"):
131
- # NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
132
- # https://github.com/huggingface/transformers/issues/33401
133
- # While correct, this breaks equivalence with past versions of Qwen2-VL from
134
- # transformers and leads to failed tests or users noticing differences in results.
135
- # TODO: remove above conditional when liger drops support for transformers<4.47.0
136
- # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
137
- if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
138
- # calculate RoPE index once per generation in the pre-fill stage only
139
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
140
- position_ids, rope_deltas = self.get_rope_index(
141
- input_ids, image_grid_thw, video_grid_thw, attention_mask
142
- )
143
- self.rope_deltas = rope_deltas
144
- # then use the prev pre-calculated rope-deltas to get the correct position ids
145
- else:
146
- batch_size, seq_length, _ = inputs_embeds.shape
147
- delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
148
- position_ids = torch.arange(seq_length, device=inputs_embeds.device)
149
- position_ids = position_ids.view(1, -1).expand(batch_size, -1)
150
- if cache_position is not None: # otherwise `deltas` is an int `0`
151
- delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
152
- position_ids = position_ids.add(delta)
153
- position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
154
-
155
90
  outputs = self.model(
156
- input_ids=None,
91
+ input_ids=input_ids,
92
+ pixel_values=pixel_values,
93
+ pixel_values_videos=pixel_values_videos,
94
+ image_grid_thw=image_grid_thw,
95
+ video_grid_thw=video_grid_thw,
157
96
  position_ids=position_ids,
158
97
  attention_mask=attention_mask,
159
98
  past_key_values=past_key_values,
@@ -163,46 +102,58 @@ def lce_forward(
163
102
  output_hidden_states=output_hidden_states,
164
103
  return_dict=return_dict,
165
104
  cache_position=cache_position,
105
+ **kwargs,
166
106
  )
167
107
 
168
108
  hidden_states = outputs[0]
169
109
 
110
+ shift_labels = kwargs.pop("shift_labels", None)
170
111
  loss = None
171
112
  logits = None
113
+ token_accuracy = None
114
+
115
+ if skip_logits and labels is None and shift_labels is None:
116
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
117
+
118
+ if skip_logits is None:
119
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
172
120
 
173
- if self.training and (labels is not None):
174
- loss = LigerForCausalLMLoss(
121
+ # Compute loss
122
+ if skip_logits:
123
+ result = LigerForCausalLMLoss(
175
124
  hidden_states=hidden_states,
176
125
  lm_head_weight=self.lm_head.weight,
177
126
  labels=labels,
127
+ shift_labels=shift_labels,
178
128
  hidden_size=self.config.hidden_size,
179
- **loss_kwargs,
129
+ **kwargs,
180
130
  )
131
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
181
132
  else:
182
133
  logits = self.lm_head(hidden_states)
183
- if labels is not None:
184
- # Upcast to float if we need to compute the loss to avoid potential precision issues
185
- logits = logits.float()
186
- # Shift so that tokens < n predict n
187
- shift_logits = logits[..., :-1, :].contiguous()
188
- shift_labels = labels[..., 1:].contiguous()
189
- # Flatten the tokens
190
- loss_fct = CrossEntropyLoss()
191
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
192
- shift_labels = shift_labels.view(-1)
193
- # Enable model parallelism
194
- shift_labels = shift_labels.to(shift_logits.device)
195
- loss = loss_fct(shift_logits, shift_labels)
134
+
135
+ loss = None
136
+ if labels is not None or shift_labels is not None:
137
+ loss = self.loss_function(
138
+ logits=logits,
139
+ labels=labels,
140
+ shift_labels=shift_labels,
141
+ vocab_size=self.config.vocab_size,
142
+ )
196
143
 
197
144
  if not return_dict:
198
- output = (logits,) + outputs[1:]
199
- return (loss,) + output if loss is not None else output
145
+ output_tuple = (logits,) + outputs[1:]
146
+ output = (loss,) + output_tuple if loss is not None else output_tuple
147
+ output = output + (token_accuracy,) if token_accuracy is not None else output
148
+ return output
200
149
 
201
- return Qwen2VLCausalLMOutputWithPast(
150
+ # Return Qwen2VL output with token accuracy
151
+ return LigerQwen2VLCausalLMOutputWithPast(
202
152
  loss=loss,
203
153
  logits=logits,
204
154
  past_key_values=outputs.past_key_values,
205
155
  hidden_states=outputs.hidden_states,
206
156
  attentions=outputs.attentions,
207
- rope_deltas=rope_deltas,
157
+ rope_deltas=outputs.rope_deltas,
158
+ token_accuracy=token_accuracy,
208
159
  )
@@ -0,0 +1,136 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10
+
11
+
12
+ def lce_forward(
13
+ self,
14
+ input_ids: Optional[torch.LongTensor] = None,
15
+ attention_mask: Optional[torch.Tensor] = None,
16
+ position_ids: Optional[torch.LongTensor] = None,
17
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
18
+ inputs_embeds: Optional[torch.FloatTensor] = None,
19
+ labels: Optional[torch.LongTensor] = None,
20
+ use_cache: Optional[bool] = None,
21
+ output_attentions: Optional[bool] = None,
22
+ output_hidden_states: Optional[bool] = None,
23
+ cache_position: Optional[torch.LongTensor] = None,
24
+ logits_to_keep: Union[int, torch.Tensor] = 0,
25
+ skip_logits: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ **kwargs,
28
+ ) -> LigerCausalLMOutputWithPast:
29
+ r"""
30
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34
+
35
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
36
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
41
+
42
+ Returns:
43
+
44
+ Example:
45
+
46
+ ```python
47
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
48
+
49
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
50
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
51
+
52
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
53
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
54
+
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59
+ ```"""
60
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61
+ output_hidden_states = (
62
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63
+ )
64
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
+
66
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67
+ outputs = self.model(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_values=past_key_values,
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=use_cache,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ cache_position=cache_position,
77
+ **kwargs,
78
+ )
79
+
80
+ hidden_states = outputs[0]
81
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83
+ kept_hidden_states = hidden_states[:, slice_indices, :]
84
+
85
+ shift_labels = kwargs.pop("shift_labels", None)
86
+ # Remove output-control parameters that shouldn't be passed to loss functions
87
+ kwargs.pop("return_dict", None)
88
+ logits = None
89
+ loss = None
90
+ token_accuracy = None
91
+
92
+ if skip_logits and labels is None and shift_labels is None:
93
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
94
+
95
+ if skip_logits is None:
96
+ # By default, if in training mode, don't materialize logits
97
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
98
+
99
+ # Compute loss
100
+ if skip_logits:
101
+ result = LigerForCausalLMLoss(
102
+ hidden_states=kept_hidden_states,
103
+ lm_head_weight=self.lm_head.weight,
104
+ labels=labels,
105
+ shift_labels=shift_labels,
106
+ hidden_size=self.config.hidden_size,
107
+ **kwargs,
108
+ )
109
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
110
+
111
+ else:
112
+ logits = self.lm_head(kept_hidden_states)
113
+ if labels is not None or shift_labels is not None:
114
+ loss = self.loss_function(
115
+ logits=logits,
116
+ labels=labels,
117
+ shift_labels=shift_labels,
118
+ vocab_size=self.config.vocab_size,
119
+ **kwargs,
120
+ )
121
+
122
+ if not return_dict:
123
+ output = (logits,) + outputs[1:]
124
+ output = ((loss,) + output) if loss is not None else output
125
+ output = output + (token_accuracy,) if token_accuracy is not None else output
126
+ return output
127
+
128
+ # Return custom output class with accuracy field
129
+ return LigerCausalLMOutputWithPast(
130
+ loss=loss,
131
+ logits=logits,
132
+ past_key_values=outputs.past_key_values,
133
+ hidden_states=outputs.hidden_states,
134
+ attentions=outputs.attentions,
135
+ token_accuracy=token_accuracy,
136
+ )