liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,153 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
+ from liger_kernel.ops import GrpoLossFunction
5
+
6
+
7
+ def triton_grpo_loss(
8
+ logits,
9
+ old_logp,
10
+ ref_logp,
11
+ completion_ids,
12
+ advantages,
13
+ completion_mask=None,
14
+ temperature=0.9,
15
+ beta=0.04,
16
+ eps_low=0.2,
17
+ eps_high=0.4,
18
+ inplace=True,
19
+ loss_type="dapo",
20
+ max_completion_length=None,
21
+ importance_sampling_level="token",
22
+ reduce=False,
23
+ ):
24
+ assert logits is not None and completion_ids is not None and advantages is not None, (
25
+ "must provide logits、completion_ids and advantages"
26
+ )
27
+ if importance_sampling_level != "token":
28
+ raise ValueError(
29
+ f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
30
+ )
31
+
32
+ per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
33
+ logits,
34
+ old_logp,
35
+ ref_logp,
36
+ completion_ids,
37
+ advantages,
38
+ completion_mask,
39
+ temperature,
40
+ beta,
41
+ eps_low,
42
+ eps_high,
43
+ inplace,
44
+ )
45
+ if not reduce:
46
+ return per_token_loss, per_token_kl, is_clipped
47
+
48
+ loss = _reduce_grpo_loss(
49
+ per_token_loss,
50
+ completion_mask,
51
+ loss_type=loss_type,
52
+ max_completion_length=max_completion_length,
53
+ )
54
+
55
+ metrics = []
56
+ if beta != 0.0 and per_token_kl is not None:
57
+ metrics.append(_masked_mean(per_token_kl, completion_mask))
58
+ metrics.append(_masked_mean(is_clipped.float(), completion_mask))
59
+ return loss, metrics
60
+
61
+
62
+ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length):
63
+ mask = completion_mask
64
+ if mask is None:
65
+ mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
66
+ mask = mask.to(per_token_loss.dtype)
67
+
68
+ if loss_type == "grpo":
69
+ per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
70
+ return per_seq.mean()
71
+ if loss_type == "bnpo":
72
+ return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
73
+ if loss_type == "dr_grpo":
74
+ if max_completion_length is None:
75
+ raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'")
76
+ batch = per_token_loss.shape[0]
77
+ return (per_token_loss * mask).sum() / (batch * max_completion_length)
78
+ if loss_type == "dapo":
79
+ normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
80
+ return (per_token_loss * mask).sum() / normalizer
81
+ raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")
82
+
83
+
84
+ def _masked_mean(values, mask):
85
+ if mask is None:
86
+ mask = torch.ones_like(values, dtype=values.dtype, device=values.device)
87
+ mask = mask.to(values.dtype)
88
+ return (values * mask).sum() / mask.sum().clamp(min=1.0)
89
+
90
+
91
+ # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
92
+ """
93
+ import torch
94
+ import trl
95
+ assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
96
+ from trl.extras.profiling import profiling_decorator
97
+
98
+ @profiling_decorator
99
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
100
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
101
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
102
+ return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
103
+
104
+ @profiling_decorator
105
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
106
+ if return_outputs:
107
+ raise ValueError("The GRPOTrainer does not support returning outputs")
108
+ # Compute the per-token log probabilities for the model
109
+
110
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
111
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
112
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
113
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
114
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
115
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
116
+
117
+ ref_per_token_logps = inputs["ref_per_token_logps"]
118
+ advantages = inputs["advantages"]
119
+ old_per_token_logps = inputs["old_per_token_logps"]
120
+
121
+
122
+ per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
123
+ old_per_token_logps,
124
+ ref_per_token_logps,
125
+ completion_ids,
126
+ advantages,
127
+ completion_mask,
128
+ self.temperature,
129
+ self.beta,
130
+ self.epsilon_low,
131
+ self.epsilon_high,)
132
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
133
+
134
+ # Log the metrics
135
+ mode = "eval" if self.control.should_evaluate else "train"
136
+
137
+ if self.beta != 0.0:
138
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
139
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
140
+
141
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
142
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
143
+ return loss
144
+
145
+ trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
146
+ trl.GRPOTrainer.compute_loss = compute_loss
147
+ trigger = None
148
+ """
149
+
150
+ # add this line at the first line of grpo.py in open-r1
151
+ """
152
+ from liger_kernel.transformers.grpo_loss import trigger
153
+ """
@@ -2,7 +2,7 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.jsd import LigerJSDFunction
5
+ from liger_kernel.ops import LigerJSDFunction
6
6
 
7
7
 
8
8
  class LigerJSD(torch.nn.Module):
@@ -1,6 +1,6 @@
1
1
  import torch.nn as nn
2
2
 
3
- from liger_kernel.ops.kl_div import LigerKLDivLossFunction
3
+ from liger_kernel.ops import LigerKLDivLossFunction
4
4
 
5
5
 
6
6
  class LigerKLDIVLoss(nn.KLDivLoss):
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.layer_norm import LigerLayerNormFunction
4
+ from liger_kernel.ops import LigerLayerNormFunction
5
5
 
6
6
 
7
7
  class LigerLayerNorm(nn.Module):
@@ -0,0 +1,93 @@
1
+ """
2
+ Liger Kernel implementation of Llama4 Rotary Position Embedding (RoPE).
3
+ Supports both text and vision RoPE variants with fused operations for optimal performance.
4
+ """
5
+
6
+ import torch
7
+
8
+ from liger_kernel.ops import LigerLlama4RopeFunction
9
+
10
+
11
+ def liger_llama4_text_rotary_pos_emb(
12
+ xq: torch.Tensor,
13
+ xk: torch.Tensor,
14
+ freqs_cis: torch.Tensor,
15
+ ) -> tuple[torch.Tensor, torch.Tensor]:
16
+ """
17
+ Liger-optimized implementation of Llama4 text rotary position embedding.
18
+
19
+ This implementation uses a fused Triton kernel for complex multiplication,
20
+ providing significant performance improvements over the original PyTorch implementation.
21
+
22
+ Args:
23
+ xq (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
24
+ xk (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
25
+ freqs_cis (torch.Tensor): Complex frequency tensor from Llama4TextRotaryEmbedding
26
+
27
+ Returns:
28
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
29
+ """
30
+ # Use fused Triton kernel for complex RoPE
31
+ return LigerLlama4RopeFunction.apply(xq, xk, freqs_cis)
32
+
33
+
34
+ def liger_llama4_vision_rotary_pos_emb(
35
+ query: torch.Tensor,
36
+ key: torch.Tensor,
37
+ freqs_ci: torch.Tensor,
38
+ ) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Liger-optimized implementation of Llama4 vision rotary position embedding.
41
+
42
+ This implementation uses the same fused Triton kernel as text RoPE,
43
+ providing performance improvements for vision transformer attention.
44
+
45
+ Args:
46
+ query (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
47
+ key (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim)
48
+ freqs_ci (torch.Tensor): Complex frequency tensor for 2D positions
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors
52
+ """
53
+ # Handle broadcasting for vision RoPE
54
+ if freqs_ci.dim() == 3:
55
+ try:
56
+ # Try the regular 3D expansion
57
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
58
+ except RuntimeError as e:
59
+ if "expand" in str(e) and "4" in str(e):
60
+ # The tensor is actually 4D internally, handle it differently
61
+ freqs_ci = freqs_ci.squeeze(1) # Remove the middle dimension
62
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
63
+ else:
64
+ raise e
65
+ elif freqs_ci.dim() == 4: # (1, seq_len, 1, head_dim//2) - already properly shaped
66
+ # Squeeze the middle dimension to get (1, seq_len, head_dim//2)
67
+ freqs_ci = freqs_ci.squeeze(2)
68
+ elif freqs_ci.dim() == 2: # (seq_len, head_dim//2) - needs expansion
69
+ freqs_ci = freqs_ci.unsqueeze(0).expand(query.shape[0], -1, -1)
70
+ else:
71
+ raise ValueError(f"Unexpected freqs_ci shape: {freqs_ci.shape}")
72
+
73
+ # Use the same fused kernel as text RoPE
74
+ return LigerLlama4RopeFunction.apply(query, key, freqs_ci)
75
+
76
+
77
+ # Note: We only patch the functions, not the classes
78
+ # The original Llama4TextRotaryEmbedding and Llama4VisionRotaryEmbedding classes remain unchanged
79
+
80
+
81
+ # Convenience functions for monkey patching
82
+ def apply_liger_llama4_rope_full(modeling_module):
83
+ """
84
+ Apply Liger optimizations to Llama4 RoPE functions.
85
+
86
+ Args:
87
+ modeling_module: The transformers modeling module to patch
88
+ """
89
+ # Replace the text RoPE function
90
+ modeling_module.apply_rotary_emb = liger_llama4_text_rotary_pos_emb
91
+
92
+ # Replace the vision RoPE function
93
+ modeling_module.vision_apply_rotary_emb = liger_llama4_vision_rotary_pos_emb
@@ -0,0 +1,122 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ if TYPE_CHECKING:
8
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
32
+ r"""
33
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
34
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
35
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37
+
38
+ Example:
39
+
40
+ ```python
41
+ >>> from transformers import AutoTokenizer, FalconH1ForCausalLM
42
+
43
+ >>> model = FalconH1ForCausalLM.from_pretrained("...")
44
+ >>> tokenizer = AutoTokenizer.from_pretrained("...")
45
+
46
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
47
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
48
+
49
+ >>> # Generate
50
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
51
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
52
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
53
+ ```"""
54
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
55
+ output_hidden_states = (
56
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
57
+ )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+
60
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
61
+ outputs = self.model(
62
+ input_ids=input_ids,
63
+ attention_mask=attention_mask,
64
+ position_ids=position_ids,
65
+ past_key_values=past_key_values,
66
+ inputs_embeds=inputs_embeds,
67
+ use_cache=use_cache,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ cache_position=cache_position,
71
+ **kwargs,
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
76
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77
+ kept_hidden_states = hidden_states[:, slice_indices, :]
78
+
79
+ shift_labels = kwargs.pop("shift_labels", None)
80
+ logits = None
81
+ loss = None
82
+ token_accuracy = None
83
+
84
+ # if in training mode, don't materialize logits
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ # Compute loss
93
+ if skip_logits:
94
+ result = LigerForCausalLMLoss(
95
+ hidden_states=kept_hidden_states,
96
+ lm_head_weight=self.lm_head.weight,
97
+ labels=labels,
98
+ shift_labels=shift_labels,
99
+ hidden_size=self.config.hidden_size,
100
+ **kwargs,
101
+ )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
103
+ else:
104
+ logits = self.lm_head(kept_hidden_states)
105
+ if labels is not None or shift_labels is not None:
106
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
116
+ loss=loss,
117
+ logits=logits,
118
+ past_key_values=outputs.past_key_values,
119
+ hidden_states=outputs.hidden_states,
120
+ attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
122
+ )
@@ -8,18 +8,14 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
- from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
16
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
19
17
 
20
18
 
21
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
19
  def lce_forward_deprecated(
24
20
  self,
25
21
  input_ids: torch.LongTensor = None,
@@ -33,6 +29,7 @@ def lce_forward_deprecated(
33
29
  output_hidden_states: Optional[bool] = None,
34
30
  return_dict: Optional[bool] = None,
35
31
  cache_position: Optional[torch.LongTensor] = None,
32
+ skip_logits: Optional[bool] = None,
36
33
  ) -> Union[Tuple, CausalLMOutputWithPast]:
37
34
  r"""
38
35
 
@@ -87,7 +84,14 @@ def lce_forward_deprecated(
87
84
  loss = None
88
85
  logits = None
89
86
 
90
- if self.training and (labels is not None):
87
+ if skip_logits and labels is None:
88
+ raise ValueError("skip_logits is True, but labels is None")
89
+
90
+ if skip_logits is None:
91
+ # By default, if in training mode, don't materialize logits
92
+ skip_logits = self.training and labels is not None
93
+
94
+ if skip_logits:
91
95
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
92
96
  shift_labels = labels[..., 1:].contiguous()
93
97
 
@@ -129,8 +133,6 @@ def lce_forward_deprecated(
129
133
 
130
134
 
131
135
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
132
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
133
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
134
136
  def lce_forward(
135
137
  self,
136
138
  input_ids: torch.LongTensor = None,
@@ -145,8 +147,9 @@ def lce_forward(
145
147
  return_dict: Optional[bool] = None,
146
148
  cache_position: Optional[torch.LongTensor] = None,
147
149
  logits_to_keep: Union[int, torch.Tensor] = 0,
148
- **loss_kwargs,
149
- ) -> Union[Tuple, CausalLMOutputWithPast]:
150
+ skip_logits: Optional[bool] = None,
151
+ **kwargs,
152
+ ) -> Union[Tuple, LigerCausalLMOutputWithPast]:
150
153
  r"""
151
154
  Args:
152
155
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -197,40 +200,62 @@ def lce_forward(
197
200
  output_hidden_states=output_hidden_states,
198
201
  return_dict=return_dict,
199
202
  cache_position=cache_position,
203
+ **kwargs,
200
204
  )
201
205
 
202
206
  hidden_states = outputs[0]
207
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
208
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
209
+ kept_hidden_states = hidden_states[:, slice_indices, :]
203
210
 
211
+ shift_labels = kwargs.pop("shift_labels", None)
204
212
  logits = None
205
213
  loss = None
206
- # if in training mode, don't materialize logits
207
- if self.training and (labels is not None):
208
- loss = LigerForCausalLMLoss(
209
- hidden_states=hidden_states,
214
+ token_accuracy = None
215
+
216
+ if skip_logits and labels is None and shift_labels is None:
217
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
218
+
219
+ if skip_logits is None:
220
+ # By default, if in training mode, don't materialize logits
221
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
222
+
223
+ # Compute loss
224
+ if skip_logits:
225
+ result = LigerForCausalLMLoss(
226
+ hidden_states=kept_hidden_states,
210
227
  lm_head_weight=self.lm_head.weight,
211
228
  labels=labels,
229
+ shift_labels=shift_labels,
212
230
  hidden_size=self.config.hidden_size,
213
- **loss_kwargs,
231
+ **kwargs,
214
232
  )
215
- else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
218
- if labels is not None:
233
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
234
+ else:
235
+ logits = self.lm_head(kept_hidden_states)
236
+ if labels is not None or shift_labels is not None:
219
237
  loss = self.loss_function(
220
238
  logits=logits,
221
239
  labels=labels,
240
+ shift_labels=shift_labels,
222
241
  vocab_size=self.config.vocab_size,
223
- **loss_kwargs,
242
+ **kwargs,
224
243
  )
225
244
 
226
245
  if not return_dict:
227
- output = (logits,) + outputs[1:]
228
- return (loss,) + output if loss is not None else output
229
-
230
- return CausalLMOutputWithPast(
246
+ output_tuple = (logits,) + outputs[1:]
247
+ if loss is not None:
248
+ output_tuple = (loss,) + output_tuple
249
+ if token_accuracy is not None:
250
+ output_tuple = output_tuple + (token_accuracy,)
251
+ return output_tuple
252
+
253
+ # Return custom output class with token_accuracy field
254
+ return LigerCausalLMOutputWithPast(
231
255
  loss=loss,
232
256
  logits=logits,
233
257
  past_key_values=outputs.past_key_values,
234
258
  hidden_states=outputs.hidden_states,
235
259
  attentions=outputs.attentions,
260
+ token_accuracy=token_accuracy,
236
261
  )