liger-kernel 0.6.1__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  2. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  3. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  4. liger_kernel/chunked_loss/jsd_loss.py +5 -2
  5. liger_kernel/ops/cross_entropy.py +59 -53
  6. liger_kernel/ops/fused_linear_cross_entropy.py +83 -17
  7. liger_kernel/ops/layer_norm.py +4 -6
  8. liger_kernel/ops/llama4_rope.py +225 -0
  9. liger_kernel/ops/poly_norm.py +386 -0
  10. liger_kernel/transformers/__init__.py +32 -0
  11. liger_kernel/transformers/experimental/__init__.py +5 -0
  12. liger_kernel/transformers/functional.py +9 -0
  13. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -1
  14. liger_kernel/transformers/llama4_rope.py +93 -0
  15. liger_kernel/transformers/model/falcon_h1.py +108 -0
  16. liger_kernel/transformers/model/gemma.py +2 -1
  17. liger_kernel/transformers/model/gemma2.py +8 -2
  18. liger_kernel/transformers/model/gemma3.py +27 -2
  19. liger_kernel/transformers/model/glm4.py +2 -1
  20. liger_kernel/transformers/model/glm4v.py +151 -0
  21. liger_kernel/transformers/model/glm4v_moe.py +153 -0
  22. liger_kernel/transformers/model/internvl.py +150 -0
  23. liger_kernel/transformers/model/llama.py +2 -1
  24. liger_kernel/transformers/model/llama4.py +2 -1
  25. liger_kernel/transformers/model/llava.py +6 -2
  26. liger_kernel/transformers/model/loss_utils.py +3 -0
  27. liger_kernel/transformers/model/mistral.py +2 -1
  28. liger_kernel/transformers/model/mixtral.py +8 -2
  29. liger_kernel/transformers/model/mllama.py +6 -3
  30. liger_kernel/transformers/model/olmo2.py +2 -1
  31. liger_kernel/transformers/model/paligemma.py +19 -0
  32. liger_kernel/transformers/model/phi3.py +10 -160
  33. liger_kernel/transformers/model/qwen2.py +2 -1
  34. liger_kernel/transformers/model/qwen2_5_vl.py +7 -2
  35. liger_kernel/transformers/model/qwen2_vl.py +7 -2
  36. liger_kernel/transformers/model/qwen3.py +2 -1
  37. liger_kernel/transformers/model/qwen3_moe.py +8 -2
  38. liger_kernel/transformers/model/qwen3_next.py +134 -0
  39. liger_kernel/transformers/model/smollm3.py +2 -1
  40. liger_kernel/transformers/model/smolvlm.py +158 -0
  41. liger_kernel/transformers/monkey_patch.py +552 -23
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +42 -0
  44. liger_kernel/transformers/rms_norm.py +7 -0
  45. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/METADATA +14 -11
  46. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/RECORD +50 -39
  47. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/WHEEL +0 -0
  48. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/LICENSE +0 -0
  49. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/licenses/NOTICE +0 -0
  50. {liger_kernel-0.6.1.dist-info → liger_kernel-0.6.3.dist-info}/top_level.txt +0 -0
@@ -10,9 +10,16 @@ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinea
10
10
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
11
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
12
12
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
13
+ from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
13
14
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
+ from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
+ from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
+ from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
14
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
15
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
21
+ from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
22
+ from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
16
23
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
17
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
18
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
@@ -24,12 +31,16 @@ if TYPE_CHECKING:
24
31
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
25
32
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
26
33
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
34
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
27
35
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
28
36
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
29
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
30
38
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
31
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
40
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
41
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
32
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
43
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
33
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
34
45
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
35
46
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -44,7 +55,9 @@ if TYPE_CHECKING:
44
55
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
45
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
46
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
58
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
47
59
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
60
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
48
61
 
49
62
 
50
63
  # Check if 'transformers' is installed
@@ -82,12 +95,16 @@ def __getattr__(name: str):
82
95
  monkey_patch_symbols = {
83
96
  "_apply_liger_kernel",
84
97
  "_apply_liger_kernel_to_instance",
98
+ "apply_liger_kernel_to_falcon_h1",
85
99
  "apply_liger_kernel_to_gemma",
86
100
  "apply_liger_kernel_to_gemma2",
87
101
  "apply_liger_kernel_to_gemma3",
88
102
  "apply_liger_kernel_to_gemma3_text",
89
103
  "apply_liger_kernel_to_glm4",
104
+ "apply_liger_kernel_to_glm4v",
105
+ "apply_liger_kernel_to_glm4v_moe",
90
106
  "apply_liger_kernel_to_granite",
107
+ "apply_liger_kernel_to_internvl",
91
108
  "apply_liger_kernel_to_llama",
92
109
  "apply_liger_kernel_to_llava",
93
110
  "apply_liger_kernel_to_llama4",
@@ -102,7 +119,9 @@ def __getattr__(name: str):
102
119
  "apply_liger_kernel_to_qwen2_vl",
103
120
  "apply_liger_kernel_to_qwen3",
104
121
  "apply_liger_kernel_to_qwen3_moe",
122
+ "apply_liger_kernel_to_qwen3_next",
105
123
  "apply_liger_kernel_to_smollm3",
124
+ "apply_liger_kernel_to_smolvlm",
106
125
  }
107
126
 
108
127
  if name in monkey_patch_symbols:
@@ -123,13 +142,20 @@ __all__ = [
123
142
  "LigerJSD",
124
143
  "LigerLayerNorm",
125
144
  "LigerFusedAddRMSNorm",
145
+ "LigerPolyNorm",
126
146
  "LigerRMSNorm",
127
147
  "liger_rotary_pos_emb",
148
+ "liger_llama4_text_rotary_pos_emb",
149
+ "liger_llama4_vision_rotary_pos_emb",
128
150
  "LigerBlockSparseTop2MLP",
129
151
  "LigerPhi3SwiGLUMLP",
130
152
  "LigerQwen3MoeSwiGLUMLP",
131
153
  "LigerSwiGLUMLP",
132
154
  "LigerTVDLoss",
155
+ "LigerKLDIVLoss",
156
+ "LigerMultiTokenAttention",
157
+ "LigerSoftmax",
158
+ "LigerSparsemax",
133
159
  ]
134
160
 
135
161
  # Add transformer-dependent symbols only if available
@@ -139,12 +165,16 @@ if _TRANSFORMERS_AVAILABLE:
139
165
  "AutoLigerKernelForCausalLM",
140
166
  "_apply_liger_kernel",
141
167
  "_apply_liger_kernel_to_instance",
168
+ "apply_liger_kernel_to_falcon_h1",
142
169
  "apply_liger_kernel_to_gemma",
143
170
  "apply_liger_kernel_to_gemma2",
144
171
  "apply_liger_kernel_to_gemma3",
145
172
  "apply_liger_kernel_to_gemma3_text",
146
173
  "apply_liger_kernel_to_glm4",
174
+ "apply_liger_kernel_to_glm4v",
175
+ "apply_liger_kernel_to_glm4v_moe",
147
176
  "apply_liger_kernel_to_granite",
177
+ "apply_liger_kernel_to_internvl",
148
178
  "apply_liger_kernel_to_llama",
149
179
  "apply_liger_kernel_to_llava",
150
180
  "apply_liger_kernel_to_llama4",
@@ -159,6 +189,8 @@ if _TRANSFORMERS_AVAILABLE:
159
189
  "apply_liger_kernel_to_qwen2_vl",
160
190
  "apply_liger_kernel_to_qwen3",
161
191
  "apply_liger_kernel_to_qwen3_moe",
192
+ "apply_liger_kernel_to_qwen3_next",
162
193
  "apply_liger_kernel_to_smollm3",
194
+ "apply_liger_kernel_to_smolvlm",
163
195
  ]
164
196
  )
@@ -0,0 +1,5 @@
1
+ from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
2
+
3
+ __all__ = [
4
+ "LigerEmbedding",
5
+ ]
@@ -12,6 +12,7 @@ from liger_kernel.ops.jsd import LigerJSDFunction
12
12
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
13
13
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
14
14
  from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
15
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
15
16
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
16
17
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
17
18
  from liger_kernel.ops.rope import LigerRopeFunction
@@ -64,6 +65,8 @@ def liger_fused_linear_cross_entropy(
64
65
  reduction: str = "mean",
65
66
  softcap: Optional[float] = None,
66
67
  return_z_loss: bool = False,
68
+ accum_dtype=None,
69
+ use_token_scaling: bool = False,
67
70
  ):
68
71
  loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
69
72
  input,
@@ -77,6 +80,8 @@ def liger_fused_linear_cross_entropy(
77
80
  reduction,
78
81
  softcap,
79
82
  return_z_loss,
83
+ accum_dtype,
84
+ use_token_scaling,
80
85
  )
81
86
  if not return_z_loss:
82
87
  return loss
@@ -254,6 +259,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
254
259
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
255
260
 
256
261
 
262
+ def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
263
+ return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
264
+
265
+
257
266
  def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
258
267
  return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
259
268
 
@@ -15,6 +15,8 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
15
15
  reduction: str = "mean",
16
16
  softcap: Optional[float] = None,
17
17
  return_z_loss: bool = False,
18
+ accum_dtype: Optional[torch.dtype] = None,
19
+ use_token_scaling: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -23,7 +25,8 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
23
25
  assert reduction in {
24
26
  "mean",
25
27
  "sum",
26
- }, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
28
+ "none",
29
+ }, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}"
27
30
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
28
31
  self.ce_weight = ce_weight
29
32
  self.ignore_index = ignore_index
@@ -32,6 +35,8 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
32
35
  self.reduction = reduction
33
36
  self.softcap = softcap
34
37
  self.return_z_loss = return_z_loss
38
+ self.accum_dtype = accum_dtype
39
+ self.use_token_scaling = use_token_scaling
35
40
 
36
41
  def forward(self, lin_weight, _input, target, bias=None):
37
42
  loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
@@ -46,6 +51,8 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
46
51
  self.reduction,
47
52
  self.softcap,
48
53
  self.return_z_loss,
54
+ self.accum_dtype,
55
+ self.use_token_scaling,
49
56
  )
50
57
  if not self.return_z_loss:
51
58
  return loss
@@ -0,0 +1,93 @@
1
+ """
2
+ Liger Kernel implementation of Llama4 Rotary Position Embedding (RoPE).
3
+ Supports both text and vision RoPE variants with fused operations for optimal performance.
4
+ """
5
+
6
+ import torch
7
+
8
+ from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction
9
+
10
+
11
+ def liger_llama4_text_rotary_pos_emb(
12
+ xq: torch.Tensor,
13
+ xk: torch.Tensor,
14
+ freqs_cis: torch.Tensor,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
+ """
17
+ Liger-optimized implementation of Llama4 text rotary position embedding.
18
+
19
+ This implementation uses a fused Triton kernel for complex multiplication,
20
+ providing significant performance improvements over the original PyTorch implementation.
21
+
22
+ Args:
23
+ xq (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
24
+ xk (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
25
+ freqs_cis (torch.Tensor): Complex frequency tensor from Llama4TextRotaryEmbedding
26
+
27
+ Returns:
28
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
29
+ """
30
+ # Use fused Triton kernel for complex RoPE
31
+ return LigerLlama4RopeFunction.apply(xq, xk, freqs_cis)
32
+
33
+
34
+ def liger_llama4_vision_rotary_pos_emb(
35
+ query: torch.Tensor,
36
+ key: torch.Tensor,
37
+ freqs_ci: torch.Tensor,
38
+ ) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Liger-optimized implementation of Llama4 vision rotary position embedding.
41
+
42
+ This implementation uses the same fused Triton kernel as text RoPE,
43
+ providing performance improvements for vision transformer attention.
44
+
45
+ Args:
46
+ query (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
47
+ key (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
48
+ freqs_ci (torch.Tensor): Complex frequency tensor for 2D positions
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
52
+ """
53
+ # Handle broadcasting for vision RoPE
54
+ if freqs_ci.dim() == 3:
55
+ try:
56
+ # Try the regular 3D expansion
57
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
58
+ except RuntimeError as e:
59
+ if "expand" in str(e) and "4" in str(e):
60
+ # The tensor is actually 4D internally, handle it differently
61
+ freqs_ci = freqs_ci.squeeze(1) # Remove the middle dimension
62
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
63
+ else:
64
+ raise e
65
+ elif freqs_ci.dim() == 4: # (1, seq_len, 1, head_dim//2) - already properly shaped
66
+ # Squeeze the middle dimension to get (1, seq_len, head_dim//2)
67
+ freqs_ci = freqs_ci.squeeze(2)
68
+ elif freqs_ci.dim() == 2: # (seq_len, head_dim//2) - needs expansion
69
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
70
+ else:
71
+ raise ValueError(f"Unexpected freqs_ci shape: {freqs_ci.shape}")
72
+
73
+ # Use the same fused kernel as text RoPE
74
+ return LigerLlama4RopeFunction.apply(query, key, freqs_ci)
75
+
76
+
77
+ # Note: We only patch the functions, not the classes
78
+ # The original Llama4TextRotaryEmbedding and Llama4VisionRotaryEmbedding classes remain unchanged
79
+
80
+
81
+ # Convenience functions for monkey patching
82
+ def apply_liger_llama4_rope_full(modeling_module):
83
+ """
84
+ Apply Liger optimizations to Llama4 RoPE functions.
85
+
86
+ Args:
87
+ modeling_module: The transformers modeling module to patch
88
+ """
89
+ # Replace the text RoPE function
90
+ modeling_module.apply_rotary_emb = liger_llama4_text_rotary_pos_emb
91
+
92
+ # Replace the vision RoPE function
93
+ modeling_module.vision_apply_rotary_emb = liger_llama4_vision_rotary_pos_emb
@@ -0,0 +1,108 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+
9
+ if TYPE_CHECKING:
10
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
11
+
12
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ skip_logits: Optional[bool] = None,
29
+ **kwargs,
30
+ ) -> Union[tuple, CausalLMOutputWithPast]:
31
+ r"""
32
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
33
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
34
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
35
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
36
+
37
+ Example:
38
+
39
+ ```python
40
+ >>> from transformers import AutoTokenizer, FalconH1ForCausalLM
41
+
42
+ >>> model = FalconH1ForCausalLM.from_pretrained("...")
43
+ >>> tokenizer = AutoTokenizer.from_pretrained("...")
44
+
45
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
46
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
47
+
48
+ >>> # Generate
49
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
50
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
51
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
52
+ ```"""
53
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
54
+ output_hidden_states = (
55
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
56
+ )
57
+
58
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59
+ outputs = self.model(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ position_ids=position_ids,
63
+ past_key_values=past_key_values,
64
+ inputs_embeds=inputs_embeds,
65
+ use_cache=use_cache,
66
+ output_attentions=output_attentions,
67
+ output_hidden_states=output_hidden_states,
68
+ cache_position=cache_position,
69
+ **kwargs,
70
+ )
71
+
72
+ hidden_states = outputs[0]
73
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
74
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
75
+ kept_hidden_states = hidden_states[:, slice_indices, :]
76
+
77
+ shift_labels = kwargs.pop("shift_labels", None)
78
+ logits = None
79
+ loss = None
80
+ # if in training mode, don't materialize logits
81
+ if skip_logits and labels is None:
82
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
83
+
84
+ if skip_logits is None:
85
+ # By default, if in training mode, don't materialize logits
86
+ skip_logits = self.training and labels is not None
87
+
88
+ if skip_logits:
89
+ loss = LigerForCausalLMLoss(
90
+ hidden_states=kept_hidden_states,
91
+ lm_head_weight=self.lm_head.weight,
92
+ labels=labels,
93
+ shift_labels=shift_labels,
94
+ hidden_size=self.config.hidden_size,
95
+ **kwargs,
96
+ )
97
+ else:
98
+ logits = self.lm_head(kept_hidden_states)
99
+ if labels is not None or shift_labels is not None:
100
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
101
+
102
+ return CausalLMOutputWithPast(
103
+ loss=loss,
104
+ logits=logits,
105
+ past_key_values=outputs.past_key_values,
106
+ hidden_states=outputs.hidden_states,
107
+ attentions=outputs.attentions,
108
+ )
@@ -228,10 +228,11 @@ def lce_forward(
228
228
  )
229
229
  else:
230
230
  logits = self.lm_head(kept_hidden_states)
231
- if labels is not None:
231
+ if labels is not None or shift_labels is not None:
232
232
  loss = self.loss_function(
233
233
  logits=logits,
234
234
  labels=labels,
235
+ shift_labels=shift_labels,
235
236
  vocab_size=self.config.vocab_size,
236
237
  **kwargs,
237
238
  )
@@ -252,8 +252,14 @@ def lce_forward(
252
252
  logits = logits * self.config.final_logit_softcapping
253
253
 
254
254
  loss = None
255
- if labels is not None:
256
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
255
+ if labels is not None or shift_labels is not None:
256
+ loss = self.loss_function(
257
+ logits=logits,
258
+ labels=labels,
259
+ shift_labels=shift_labels,
260
+ vocab_size=self.vocab_size,
261
+ **kwargs,
262
+ )
257
263
 
258
264
  if not return_dict:
259
265
  output = (logits,) + outputs[1:]
@@ -119,8 +119,14 @@ def causal_forward(
119
119
  logits = logits / self.config.final_logit_softcapping
120
120
  logits = torch.tanh(logits)
121
121
  logits = logits * self.config.final_logit_softcapping
122
- if labels is not None:
123
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
122
+ if labels is not None or shift_labels is not None:
123
+ loss = self.loss_function(
124
+ logits=logits,
125
+ labels=labels,
126
+ shift_labels=shift_labels,
127
+ vocab_size=self.vocab_size,
128
+ **loss_kwargs,
129
+ )
124
130
 
125
131
  if not return_dict:
126
132
  output = (logits,) + outputs[1:]
@@ -275,6 +281,25 @@ def multimodal_forward(
275
281
  # Flatten the tokens
276
282
  loss_fct = nn.CrossEntropyLoss()
277
283
 
284
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
285
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
286
+ loss = loss_fct(flat_logits, flat_labels)
287
+ elif shift_labels is not None:
288
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
289
+ logits = logits.float()
290
+ shift_logits = logits[..., :-1, :]
291
+ if attention_mask is not None:
292
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
293
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
294
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
295
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
296
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
297
+ else:
298
+ shift_logits = shift_logits.contiguous()
299
+ shift_labels = shift_labels.contiguous()
300
+ # Flatten the tokens
301
+ loss_fct = nn.CrossEntropyLoss()
302
+
278
303
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
279
304
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
280
305
  loss = loss_fct(flat_logits, flat_labels)
@@ -111,10 +111,11 @@ def lce_forward(
111
111
 
112
112
  else:
113
113
  logits = self.lm_head(kept_hidden_states)
114
- if labels is not None:
114
+ if labels is not None or shift_labels is not None:
115
115
  loss = self.loss_function(
116
116
  logits=logits,
117
117
  labels=labels,
118
+ shift_labels=shift_labels,
118
119
  vocab_size=self.config.vocab_size,
119
120
  **kwargs,
120
121
  )
@@ -0,0 +1,151 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.utils.deprecation import deprecate_kwarg
10
+
11
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12
+
13
+
14
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ return_dict: Optional[bool] = None,
27
+ cache_position: Optional[torch.LongTensor] = None,
28
+ logits_to_keep: Union[int, torch.Tensor] = 0,
29
+ skip_logits: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
32
+ r"""
33
+ Args:
34
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
35
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
36
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
37
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
38
+
39
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
40
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
41
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
42
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
43
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
44
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
45
+
46
+ Returns:
47
+
48
+ Example:
49
+
50
+ ```python
51
+ >>> from PIL import Image
52
+ >>> from transformers import AutoTokenizer, Glm4vForConditionalGeneration
53
+
54
+ >>> MODEL_PATH = "THUDM/GLM-4.1V-9B-Thinking"
55
+ >>> messages = [
56
+ {
57
+ "role": "user",
58
+ "content": [
59
+ {
60
+ "type": "image",
61
+ "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
62
+ },
63
+ {
64
+ "type": "text",
65
+ "text": "describe this image"
66
+ }
67
+ ],
68
+ }
69
+ ]
70
+ >>> processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
71
+ >>> model = Glm4vForConditionalGeneration.from_pretrained(
72
+ pretrained_model_name_or_path=MODEL_PATH,
73
+ dtype=torch.bfloat16,
74
+ device_map="auto",
75
+ )
76
+ >>> inputs = processor.apply_chat_template(
77
+ messages,
78
+ tokenize=True,
79
+ add_generation_prompt=True,
80
+ return_dict=True,
81
+ return_tensors="pt"
82
+ ).to(model.device)
83
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=8192)
84
+ output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
85
+ <think>Got it, let's describe the image. First, there's a vintage car, specifically a Volkswagen Beetle
86
+ ```"""
87
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
88
+ output_hidden_states = (
89
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
90
+ )
91
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
92
+
93
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
94
+ outputs = self.model(
95
+ input_ids=input_ids,
96
+ attention_mask=attention_mask,
97
+ position_ids=position_ids,
98
+ past_key_values=past_key_values,
99
+ inputs_embeds=inputs_embeds,
100
+ use_cache=use_cache,
101
+ output_attentions=output_attentions,
102
+ output_hidden_states=output_hidden_states,
103
+ return_dict=return_dict,
104
+ cache_position=cache_position,
105
+ **kwargs,
106
+ )
107
+
108
+ hidden_states = outputs[0]
109
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
110
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
111
+ kept_hidden_states = hidden_states[:, slice_indices, :]
112
+
113
+ shift_labels = kwargs.pop("shift_labels", None)
114
+ logits = None
115
+ loss = None
116
+
117
+ if skip_logits and labels is None and shift_labels is None:
118
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
119
+
120
+ if skip_logits is None:
121
+ # By default, if in training mode, don't materialize logits
122
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
123
+
124
+ if skip_logits:
125
+ loss = LigerForCausalLMLoss(
126
+ hidden_states=kept_hidden_states,
127
+ lm_head_weight=self.lm_head.weight,
128
+ labels=labels,
129
+ shift_labels=shift_labels,
130
+ hidden_size=self.config.hidden_size,
131
+ **kwargs,
132
+ )
133
+
134
+ else:
135
+ logits = self.lm_head(kept_hidden_states)
136
+ if labels is not None or shift_labels is not None:
137
+ loss = self.loss_function(
138
+ logits=logits,
139
+ labels=labels,
140
+ shift_labels=shift_labels,
141
+ vocab_size=self.config.vocab_size,
142
+ **kwargs,
143
+ )
144
+
145
+ return CausalLMOutputWithPast(
146
+ loss=loss,
147
+ logits=logits,
148
+ past_key_values=outputs.past_key_values,
149
+ hidden_states=outputs.hidden_states,
150
+ attentions=outputs.attentions,
151
+ )