liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (68) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
  7. liger_kernel/chunked_loss/grpo_loss.py +38 -4
  8. liger_kernel/chunked_loss/jsd_loss.py +23 -7
  9. liger_kernel/ops/cross_entropy.py +118 -62
  10. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  11. liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
  12. liger_kernel/ops/geglu.py +1 -1
  13. liger_kernel/ops/layer_norm.py +124 -89
  14. liger_kernel/ops/llama4_rope.py +225 -0
  15. liger_kernel/ops/poly_norm.py +386 -0
  16. liger_kernel/ops/rms_norm.py +2 -2
  17. liger_kernel/ops/rope.py +1 -1
  18. liger_kernel/ops/swiglu.py +1 -1
  19. liger_kernel/ops/tiled_mlp.py +136 -0
  20. liger_kernel/transformers/__init__.py +50 -0
  21. liger_kernel/transformers/cross_entropy.py +8 -3
  22. liger_kernel/transformers/experimental/__init__.py +5 -0
  23. liger_kernel/transformers/functional.py +38 -6
  24. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  25. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
  26. liger_kernel/transformers/llama4_rope.py +93 -0
  27. liger_kernel/transformers/model/falcon_h1.py +122 -0
  28. liger_kernel/transformers/model/gemma.py +28 -8
  29. liger_kernel/transformers/model/gemma2.py +31 -8
  30. liger_kernel/transformers/model/gemma3.py +100 -110
  31. liger_kernel/transformers/model/glm4.py +18 -5
  32. liger_kernel/transformers/model/glm4v.py +163 -0
  33. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  34. liger_kernel/transformers/model/internvl.py +157 -0
  35. liger_kernel/transformers/model/llama.py +26 -7
  36. liger_kernel/transformers/model/llama4.py +121 -0
  37. liger_kernel/transformers/model/llava.py +18 -6
  38. liger_kernel/transformers/model/loss_utils.py +34 -3
  39. liger_kernel/transformers/model/mistral.py +17 -10
  40. liger_kernel/transformers/model/mixtral.py +24 -9
  41. liger_kernel/transformers/model/mllama.py +18 -7
  42. liger_kernel/transformers/model/olmo2.py +18 -5
  43. liger_kernel/transformers/model/output_classes.py +147 -0
  44. liger_kernel/transformers/model/paligemma.py +41 -5
  45. liger_kernel/transformers/model/phi3.py +24 -159
  46. liger_kernel/transformers/model/qwen2.py +26 -4
  47. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  48. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  49. liger_kernel/transformers/model/qwen3.py +22 -6
  50. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  51. liger_kernel/transformers/model/qwen3_next.py +146 -0
  52. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  53. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  54. liger_kernel/transformers/model/smollm3.py +199 -0
  55. liger_kernel/transformers/model/smolvlm.py +158 -0
  56. liger_kernel/transformers/monkey_patch.py +1090 -116
  57. liger_kernel/transformers/multi_token_attention.py +1 -1
  58. liger_kernel/transformers/poly_norm.py +42 -0
  59. liger_kernel/transformers/rms_norm.py +7 -0
  60. liger_kernel/transformers/rope.py +43 -0
  61. liger_kernel/transformers/tiled_mlp.py +133 -0
  62. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
  63. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  64. liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
  65. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  66. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  67. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
  68. {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
33
35
  self.reduction = reduction
34
36
  self.softcap = softcap
35
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
36
39
 
37
40
  def forward(self, _input: torch.Tensor, target: torch.Tensor):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
42
  _input,
40
43
  target,
41
44
  self.weight,
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
45
48
  self.reduction,
46
49
  self.softcap,
47
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
48
52
  )
49
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
50
54
  return loss
51
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -0,0 +1,5 @@
1
+ from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
2
+
3
+ __all__ = [
4
+ "LigerEmbedding",
5
+ ]
@@ -1,7 +1,11 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Optional
2
3
 
4
+ import torch
5
+
3
6
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
7
  from liger_kernel.ops.dyt import LigerDyTFunction
8
+ from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
5
9
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
10
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
7
11
  from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction
@@ -11,6 +15,7 @@ from liger_kernel.ops.jsd import LigerJSDFunction
11
15
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
16
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
17
  from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
18
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
14
19
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
15
20
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
16
21
  from liger_kernel.ops.rope import LigerRopeFunction
@@ -20,6 +25,13 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
20
25
  from liger_kernel.ops.tvd import LigerTVDLossFunction
21
26
 
22
27
 
28
+ @dataclass
29
+ class CrossEntropyOutput:
30
+ loss: torch.Tensor
31
+ z_loss: Optional[torch.Tensor] = None
32
+ token_accuracy: Optional[torch.Tensor] = None
33
+
34
+
23
35
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
24
36
  # `weight` and `size_average` are placeholders and not implemented yet
25
37
  def liger_cross_entropy(
@@ -34,8 +46,9 @@ def liger_cross_entropy(
34
46
  lse_square_scale: float = 0.0,
35
47
  softcap: Optional[float] = None,
36
48
  return_z_loss: bool = False,
49
+ return_token_accuracy: bool = False,
37
50
  ):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
51
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
52
  input,
40
53
  target,
41
54
  weight,
@@ -45,10 +58,13 @@ def liger_cross_entropy(
45
58
  reduction,
46
59
  softcap,
47
60
  return_z_loss,
61
+ return_token_accuracy,
48
62
  )
49
- if not return_z_loss:
63
+
64
+ if not return_z_loss and not return_token_accuracy:
50
65
  return loss
51
- return loss, z_loss
66
+
67
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
52
68
 
53
69
 
54
70
  def liger_fused_linear_cross_entropy(
@@ -63,8 +79,11 @@ def liger_fused_linear_cross_entropy(
63
79
  reduction: str = "mean",
64
80
  softcap: Optional[float] = None,
65
81
  return_z_loss: bool = False,
82
+ accum_dtype=None,
83
+ use_token_scaling: bool = False,
84
+ return_token_accuracy: bool = False,
66
85
  ):
67
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
86
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
68
87
  input,
69
88
  weight,
70
89
  target,
@@ -76,10 +95,15 @@ def liger_fused_linear_cross_entropy(
76
95
  reduction,
77
96
  softcap,
78
97
  return_z_loss,
98
+ accum_dtype,
99
+ use_token_scaling,
100
+ return_token_accuracy,
79
101
  )
80
- if not return_z_loss:
102
+
103
+ if not return_z_loss and not return_token_accuracy:
81
104
  return loss
82
- return loss, z_loss
105
+
106
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
83
107
 
84
108
 
85
109
  def liger_fused_linear_jsd(
@@ -253,6 +277,14 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
253
277
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
254
278
 
255
279
 
280
+ def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
281
+ return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
282
+
283
+
284
+ def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
285
+ return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
286
+
287
+
256
288
  def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
257
289
  return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
258
290
 
@@ -0,0 +1,39 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
5
+
6
+
7
+ class LigerFusedAddRMSNorm(nn.Module):
8
+ def __init__(
9
+ self,
10
+ hidden_size,
11
+ eps=1e-6,
12
+ offset=0.0,
13
+ casting_mode="llama",
14
+ init_fn="ones",
15
+ in_place=False,
16
+ ):
17
+ super().__init__()
18
+ assert init_fn in [
19
+ "ones",
20
+ "zeros",
21
+ ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
22
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
23
+ self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (eps, offset, casting_mode, in_place)
24
+
25
+ def forward(self, hidden_states, residual):
26
+ return LigerFusedAddRMSNormFunction.apply(
27
+ hidden_states,
28
+ residual,
29
+ self.weight,
30
+ self.variance_epsilon,
31
+ self.offset,
32
+ self.casting_mode,
33
+ self.in_place,
34
+ )
35
+
36
+ def extra_repr(self):
37
+ return (
38
+ f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
39
+ )
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,9 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ accum_dtype: Optional[torch.dtype] = None,
20
+ use_token_scaling: bool = False,
21
+ return_token_accuracy: bool = False,
18
22
  ):
19
23
  super().__init__()
20
24
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -23,7 +27,8 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
23
27
  assert reduction in {
24
28
  "mean",
25
29
  "sum",
26
- }, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
30
+ "none",
31
+ }, f"reduction must be 'mean' or 'sum' or 'none'. Got: {reduction}"
27
32
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
28
33
  self.ce_weight = ce_weight
29
34
  self.ignore_index = ignore_index
@@ -32,9 +37,12 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
32
37
  self.reduction = reduction
33
38
  self.softcap = softcap
34
39
  self.return_z_loss = return_z_loss
40
+ self.accum_dtype = accum_dtype
41
+ self.use_token_scaling = use_token_scaling
42
+ self.return_token_accuracy = return_token_accuracy
35
43
 
36
44
  def forward(self, lin_weight, _input, target, bias=None):
37
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
45
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
38
46
  _input,
39
47
  lin_weight,
40
48
  target,
@@ -46,7 +54,11 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
46
54
  self.reduction,
47
55
  self.softcap,
48
56
  self.return_z_loss,
57
+ self.accum_dtype,
58
+ self.use_token_scaling,
59
+ self.return_token_accuracy,
49
60
  )
50
- if not self.return_z_loss:
61
+ if not self.return_z_loss and not self.return_token_accuracy:
51
62
  return loss
52
- return loss, z_loss
63
+
64
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -0,0 +1,93 @@
1
+ """
2
+ Liger Kernel implementation of Llama4 Rotary Position Embedding (RoPE).
3
+ Supports both text and vision RoPE variants with fused operations for optimal performance.
4
+ """
5
+
6
+ import torch
7
+
8
+ from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction
9
+
10
+
11
+ def liger_llama4_text_rotary_pos_emb(
12
+ xq: torch.Tensor,
13
+ xk: torch.Tensor,
14
+ freqs_cis: torch.Tensor,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
+ """
17
+ Liger-optimized implementation of Llama4 text rotary position embedding.
18
+
19
+ This implementation uses a fused Triton kernel for complex multiplication,
20
+ providing significant performance improvements over the original PyTorch implementation.
21
+
22
+ Args:
23
+ xq (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
24
+ xk (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
25
+ freqs_cis (torch.Tensor): Complex frequency tensor from Llama4TextRotaryEmbedding
26
+
27
+ Returns:
28
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
29
+ """
30
+ # Use fused Triton kernel for complex RoPE
31
+ return LigerLlama4RopeFunction.apply(xq, xk, freqs_cis)
32
+
33
+
34
+ def liger_llama4_vision_rotary_pos_emb(
35
+ query: torch.Tensor,
36
+ key: torch.Tensor,
37
+ freqs_ci: torch.Tensor,
38
+ ) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Liger-optimized implementation of Llama4 vision rotary position embedding.
41
+
42
+ This implementation uses the same fused Triton kernel as text RoPE,
43
+ providing performance improvements for vision transformer attention.
44
+
45
+ Args:
46
+ query (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
47
+ key (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
48
+ freqs_ci (torch.Tensor): Complex frequency tensor for 2D positions
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
52
+ """
53
+ # Handle broadcasting for vision RoPE
54
+ if freqs_ci.dim() == 3:
55
+ try:
56
+ # Try the regular 3D expansion
57
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
58
+ except RuntimeError as e:
59
+ if "expand" in str(e) and "4" in str(e):
60
+ # The tensor is actually 4D internally, handle it differently
61
+ freqs_ci = freqs_ci.squeeze(1) # Remove the middle dimension
62
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
63
+ else:
64
+ raise e
65
+ elif freqs_ci.dim() == 4: # (1, seq_len, 1, head_dim//2) - already properly shaped
66
+ # Squeeze the middle dimension to get (1, seq_len, head_dim//2)
67
+ freqs_ci = freqs_ci.squeeze(2)
68
+ elif freqs_ci.dim() == 2: # (seq_len, head_dim//2) - needs expansion
69
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
70
+ else:
71
+ raise ValueError(f"Unexpected freqs_ci shape: {freqs_ci.shape}")
72
+
73
+ # Use the same fused kernel as text RoPE
74
+ return LigerLlama4RopeFunction.apply(query, key, freqs_ci)
75
+
76
+
77
+ # Note: We only patch the functions, not the classes
78
+ # The original Llama4TextRotaryEmbedding and Llama4VisionRotaryEmbedding classes remain unchanged
79
+
80
+
81
+ # Convenience functions for monkey patching
82
+ def apply_liger_llama4_rope_full(modeling_module):
83
+ """
84
+ Apply Liger optimizations to Llama4 RoPE functions.
85
+
86
+ Args:
87
+ modeling_module: The transformers modeling module to patch
88
+ """
89
+ # Replace the text RoPE function
90
+ modeling_module.apply_rotary_emb = liger_llama4_text_rotary_pos_emb
91
+
92
+ # Replace the vision RoPE function
93
+ modeling_module.vision_apply_rotary_emb = liger_llama4_vision_rotary_pos_emb
@@ -0,0 +1,122 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ if TYPE_CHECKING:
8
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
32
+ r"""
33
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
34
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
35
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37
+
38
+ Example:
39
+
40
+ ```python
41
+ >>> from transformers import AutoTokenizer, FalconH1ForCausalLM
42
+
43
+ >>> model = FalconH1ForCausalLM.from_pretrained("...")
44
+ >>> tokenizer = AutoTokenizer.from_pretrained("...")
45
+
46
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
47
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
48
+
49
+ >>> # Generate
50
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
51
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
52
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
53
+ ```"""
54
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
55
+ output_hidden_states = (
56
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
57
+ )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+
60
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
61
+ outputs = self.model(
62
+ input_ids=input_ids,
63
+ attention_mask=attention_mask,
64
+ position_ids=position_ids,
65
+ past_key_values=past_key_values,
66
+ inputs_embeds=inputs_embeds,
67
+ use_cache=use_cache,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ cache_position=cache_position,
71
+ **kwargs,
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
76
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77
+ kept_hidden_states = hidden_states[:, slice_indices, :]
78
+
79
+ shift_labels = kwargs.pop("shift_labels", None)
80
+ logits = None
81
+ loss = None
82
+ token_accuracy = None
83
+
84
+ # if in training mode, don't materialize logits
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ # Compute loss
93
+ if skip_logits:
94
+ result = LigerForCausalLMLoss(
95
+ hidden_states=kept_hidden_states,
96
+ lm_head_weight=self.lm_head.weight,
97
+ labels=labels,
98
+ shift_labels=shift_labels,
99
+ hidden_size=self.config.hidden_size,
100
+ **kwargs,
101
+ )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
103
+ else:
104
+ logits = self.lm_head(kept_hidden_states)
105
+ if labels is not None or shift_labels is not None:
106
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
116
+ loss=loss,
117
+ logits=logits,
118
+ past_key_values=outputs.past_key_values,
119
+ hidden_states=outputs.hidden_states,
120
+ attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
122
+ )
@@ -12,6 +12,8 @@ from transformers.utils.deprecation import deprecate_kwarg
12
12
 
13
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
14
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
15
17
 
16
18
 
17
19
  def lce_forward_deprecated(
@@ -27,6 +29,7 @@ def lce_forward_deprecated(
27
29
  output_hidden_states: Optional[bool] = None,
28
30
  return_dict: Optional[bool] = None,
29
31
  cache_position: Optional[torch.LongTensor] = None,
32
+ skip_logits: Optional[bool] = None,
30
33
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
34
  r"""
32
35
 
@@ -81,7 +84,14 @@ def lce_forward_deprecated(
81
84
  loss = None
82
85
  logits = None
83
86
 
84
- if self.training and (labels is not None):
87
+ if skip_logits and labels is None:
88
+ raise ValueError("skip_logits is True, but labels is None")
89
+
90
+ if skip_logits is None:
91
+ # By default, if in training mode, don't materialize logits
92
+ skip_logits = self.training and labels is not None
93
+
94
+ if skip_logits:
85
95
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
86
96
  shift_labels = labels[..., 1:].contiguous()
87
97
 
@@ -139,7 +149,7 @@ def lce_forward(
139
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
140
150
  skip_logits: Optional[bool] = None,
141
151
  **kwargs,
142
- ) -> Union[Tuple, CausalLMOutputWithPast]:
152
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
143
153
  r"""
144
154
  Args:
145
155
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -201,6 +211,7 @@ def lce_forward(
201
211
  shift_labels = kwargs.pop("shift_labels", None)
202
212
  logits = None
203
213
  loss = None
214
+ token_accuracy = None
204
215
 
205
216
  if skip_logits and labels is None and shift_labels is None:
206
217
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -209,8 +220,9 @@ def lce_forward(
209
220
  # By default, if in training mode, don't materialize logits
210
221
  skip_logits = self.training and (labels is not None or shift_labels is not None)
211
222
 
223
+ # Compute loss
212
224
  if skip_logits:
213
- loss = LigerForCausalLMLoss(
225
+ result = LigerForCausalLMLoss(
214
226
  hidden_states=kept_hidden_states,
215
227
  lm_head_weight=self.lm_head.weight,
216
228
  labels=labels,
@@ -218,24 +230,32 @@ def lce_forward(
218
230
  hidden_size=self.config.hidden_size,
219
231
  **kwargs,
220
232
  )
233
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
221
234
  else:
222
235
  logits = self.lm_head(kept_hidden_states)
223
- if labels is not None:
236
+ if labels is not None or shift_labels is not None:
224
237
  loss = self.loss_function(
225
238
  logits=logits,
226
239
  labels=labels,
240
+ shift_labels=shift_labels,
227
241
  vocab_size=self.config.vocab_size,
228
242
  **kwargs,
229
243
  )
230
244
 
231
245
  if not return_dict:
232
- output = (logits,) + outputs[1:]
233
- return (loss,) + output if loss is not None else output
234
-
235
- return CausalLMOutputWithPast(
246
+ output_tuple = (logits,) + outputs[1:]
247
+ if loss is not None:
248
+ output_tuple = (loss,) + output_tuple
249
+ if token_accuracy is not None:
250
+ output_tuple = output_tuple + (token_accuracy,)
251
+ return output_tuple
252
+
253
+ # Return custom output class with token_accuracy field
254
+ return LigerCausalLMOutputWithPast(
236
255
  loss=loss,
237
256
  logits=logits,
238
257
  past_key_values=outputs.past_key_values,
239
258
  hidden_states=outputs.hidden_states,
240
259
  attentions=outputs.attentions,
260
+ token_accuracy=token_accuracy,
241
261
  )
@@ -13,6 +13,8 @@ from transformers.utils.deprecation import deprecate_kwarg
13
13
 
14
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
15
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
17
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
@@ -30,6 +32,7 @@ def lce_forward_deprecated(
30
32
  output_hidden_states: Optional[bool] = None,
31
33
  return_dict: Optional[bool] = None,
32
34
  cache_position: Optional[torch.LongTensor] = None,
35
+ skip_logits: Optional[bool] = None,
33
36
  **kwargs,
34
37
  ) -> Union[Tuple, CausalLMOutputWithPast]:
35
38
  r"""
@@ -85,7 +88,14 @@ def lce_forward_deprecated(
85
88
  loss = None
86
89
  logits = None
87
90
 
88
- if self.training and (labels is not None):
91
+ if skip_logits and labels is None:
92
+ raise ValueError("skip_logits is True, but labels is None")
93
+
94
+ if skip_logits is None:
95
+ # By default, if in training mode, don't materialize logits
96
+ skip_logits = self.training and labels is not None
97
+
98
+ if skip_logits:
89
99
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
90
100
  shift_labels = labels[..., 1:].contiguous()
91
101
 
@@ -150,7 +160,7 @@ def lce_forward(
150
160
  logits_to_keep: Union[int, torch.Tensor] = 0,
151
161
  skip_logits: Optional[bool] = None,
152
162
  **kwargs,
153
- ) -> Union[Tuple, CausalLMOutputWithPast]:
163
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
154
164
  r"""
155
165
  Args:
156
166
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -217,6 +227,7 @@ def lce_forward(
217
227
  shift_labels = kwargs.pop("shift_labels", None)
218
228
  logits = None
219
229
  loss = None
230
+ token_accuracy = None
220
231
 
221
232
  if skip_logits and labels is None and shift_labels is None:
222
233
  raise ValueError("skip_logits is True, but labels and shift_labels are None")
@@ -225,8 +236,9 @@ def lce_forward(
225
236
  # By default, if in training mode, don't materialize logits
226
237
  skip_logits = self.training and (labels is not None or shift_labels is not None)
227
238
 
239
+ # Compute loss
228
240
  if skip_logits:
229
- loss = LigerForCausalLMLoss(
241
+ result = LigerForCausalLMLoss(
230
242
  hidden_states=kept_hidden_states,
231
243
  lm_head_weight=self.lm_head.weight,
232
244
  labels=labels,
@@ -235,6 +247,7 @@ def lce_forward(
235
247
  final_logit_softcapping=self.config.final_logit_softcapping,
236
248
  **kwargs,
237
249
  )
250
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
238
251
 
239
252
  else:
240
253
  logits = self.lm_head(kept_hidden_states)
@@ -244,17 +257,27 @@ def lce_forward(
244
257
  logits = logits * self.config.final_logit_softcapping
245
258
 
246
259
  loss = None
247
- if labels is not None:
248
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
260
+ if labels is not None or shift_labels is not None:
261
+ loss = self.loss_function(
262
+ logits=logits,
263
+ labels=labels,
264
+ shift_labels=shift_labels,
265
+ vocab_size=self.vocab_size,
266
+ **kwargs,
267
+ )
249
268
 
250
269
  if not return_dict:
251
- output = (logits,) + outputs[1:]
252
- return (loss,) + output if loss is not None else output
270
+ output_tuple = (logits,) + outputs[1:]
271
+ output_tuple = (loss,) + output_tuple if loss is not None else output_tuple
272
+ output_tuple = output_tuple + (token_accuracy,) if token_accuracy is not None else output_tuple
273
+ return output_tuple
253
274
 
254
- return CausalLMOutputWithPast(
275
+ # Return custom output class with token_accuracy field
276
+ return LigerCausalLMOutputWithPast(
255
277
  loss=loss,
256
278
  logits=logits,
257
279
  past_key_values=outputs.past_key_values,
258
280
  hidden_states=outputs.hidden_states,
259
281
  attentions=outputs.attentions,
282
+ token_accuracy=token_accuracy,
260
283
  )