liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,10 @@ from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
11
11
  from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
14
15
 
15
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
18
 
17
19
 
18
20
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@@ -124,6 +126,7 @@ def lce_forward_deprecated(
124
126
  )
125
127
 
126
128
 
129
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
127
130
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
128
131
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
129
132
  def lce_forward(
@@ -139,7 +142,7 @@ def lce_forward(
139
142
  output_hidden_states: Optional[bool] = None,
140
143
  return_dict: Optional[bool] = None,
141
144
  cache_position: Optional[torch.LongTensor] = None,
142
- num_logits_to_keep: int = 0,
145
+ logits_to_keep: Union[int, torch.Tensor] = 0,
143
146
  **loss_kwargs,
144
147
  ) -> Union[Tuple, CausalLMOutputWithPast]:
145
148
  r"""
@@ -149,10 +152,12 @@ def lce_forward(
149
152
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
150
153
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
151
154
 
152
- num_logits_to_keep (`int`, *optional*):
153
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
155
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
156
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
154
157
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
155
158
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
159
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
160
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
156
161
 
157
162
  Returns:
158
163
 
@@ -199,24 +204,17 @@ def lce_forward(
199
204
  loss = None
200
205
  # if in training mode, don't materialize logits
201
206
  if self.training and (labels is not None):
202
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
203
-
204
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
205
- shift_labels = labels[..., 1:].contiguous()
206
-
207
- # flatten tokens
208
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
209
- shift_labels = shift_labels.view(-1)
210
-
211
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
212
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
213
-
214
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
215
- if reduction == "sum":
216
- loss /= loss_kwargs["num_items_in_batch"]
207
+ loss = LigerForCausalLMLoss(
208
+ hidden_states=hidden_states,
209
+ lm_head_weight=self.lm_head.weight,
210
+ labels=labels,
211
+ hidden_size=self.config.hidden_size,
212
+ **loss_kwargs,
213
+ )
217
214
 
218
215
  else: # if in inference mode materialize logits
219
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
220
218
  if labels is not None:
221
219
  loss = self.loss_function(
222
220
  logits=logits,
@@ -12,7 +12,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalL
12
12
  from transformers.utils import add_start_docstrings_to_model_forward
13
13
  from transformers.utils import replace_return_docstrings
14
14
 
15
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
16
 
17
17
 
18
18
  @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
@@ -36,6 +36,7 @@ def lce_forward(
36
36
  rope_deltas: Optional[torch.LongTensor] = None,
37
37
  cache_position: Optional[torch.LongTensor] = None,
38
38
  second_per_grid_ts: Optional[torch.Tensor] = None,
39
+ **loss_kwargs,
39
40
  ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
40
41
  r"""
41
42
  Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -166,15 +167,13 @@ def lce_forward(
166
167
  logits = None
167
168
 
168
169
  if self.training and (labels is not None):
169
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
170
- shift_labels = labels[..., 1:].contiguous()
171
-
172
- # Flatten tokens
173
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
174
- shift_labels = shift_labels.view(-1)
175
-
176
- lce = LigerFusedLinearCrossEntropyLoss()
177
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
170
+ loss = LigerForCausalLMLoss(
171
+ hidden_states=hidden_states,
172
+ lm_head_weight=self.lm_head.weight,
173
+ labels=labels,
174
+ hidden_size=self.config.hidden_size,
175
+ **loss_kwargs,
176
+ )
178
177
  else:
179
178
  logits = self.lm_head(hidden_states)
180
179
  if labels is not None:
@@ -14,7 +14,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutput
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
16
 
17
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
18
 
19
19
 
20
20
  @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@@ -37,6 +37,7 @@ def lce_forward(
37
37
  video_grid_thw: Optional[torch.LongTensor] = None,
38
38
  rope_deltas: Optional[torch.LongTensor] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ **loss_kwargs,
40
41
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -170,15 +171,13 @@ def lce_forward(
170
171
  logits = None
171
172
 
172
173
  if self.training and (labels is not None):
173
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
174
- shift_labels = labels[..., 1:].contiguous()
175
-
176
- # Flatten tokens
177
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
178
- shift_labels = shift_labels.view(-1)
179
-
180
- lce = LigerFusedLinearCrossEntropyLoss()
181
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
174
+ loss = LigerForCausalLMLoss(
175
+ hidden_states=hidden_states,
176
+ lm_head_weight=self.lm_head.weight,
177
+ labels=labels,
178
+ hidden_size=self.config.hidden_size,
179
+ **loss_kwargs,
180
+ )
182
181
  else:
183
182
  logits = self.lm_head(hidden_states)
184
183
  if labels is not None: