liger-kernel-nightly 0.5.6.dev20250411201510__py3-none-any.whl → 0.5.6.dev20250411224032__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
+ loss_type="bnpo",
36
+ max_completion_length=None,
35
37
  temperature=1.0,
36
38
  compiled=True,
37
39
  use_ref_model=False,
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
57
59
  epsilon_low: Lower bound for clipping the importance sampling ratio
58
60
  epsilon_high: Upper bound for clipping the importance sampling ratio
59
61
  beta: Weight for the KL penalty
62
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ max_completion_length: Maximum completion length required for "dr_grpo"
60
64
  temperature: Temperature for the logits
61
65
  compiled: Whether to use torch compile
62
66
  use_ref_model: Whether to use a reference model
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
68
72
  )
69
73
  if ref_per_token_logps is not None and ref_input is not None:
70
74
  raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
75
+ if loss_type == "dr_grpo":
76
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
71
77
  # Initialize accumulators
72
78
  loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
79
  grad_weight = torch.zeros_like(weight) # [V, H]
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
84
90
  epsilon_low=epsilon_low,
85
91
  epsilon_high=epsilon_high,
86
92
  beta=beta,
93
+ loss_type=loss_type,
94
+ max_completion_length=max_completion_length,
87
95
  temperature=temperature,
88
96
  use_ref_model=use_ref_model,
89
97
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
251
259
  epsilon_low=0.2,
252
260
  epsilon_high=0.2,
253
261
  beta=0.04,
262
+ loss_type="bnpo",
263
+ max_completion_length=None,
254
264
  temperature=1.0,
255
265
  use_ref_model=False,
256
266
  ppo_loss_fn=None,
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
280
290
  epsilon_low=epsilon_low,
281
291
  epsilon_high=epsilon_high,
282
292
  beta=beta,
293
+ loss_type=loss_type,
294
+ max_completion_length=max_completion_length,
283
295
  )
284
296
 
285
297
  return chunk_loss, chunk_metrics
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
303
315
  def backward(ctx, grad_output, *grad_metrics):
304
316
  """Backward pass for PPO loss."""
305
317
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
318
+
306
319
  if grad_output != 1.0:
307
320
  grad_input = grad_input * grad_output
308
321
  grad_weight = grad_weight * grad_output
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
328
341
  None, # grad_compiled
329
342
  None, # grad_use_ref_model
330
343
  None, # grad_chunk_size
344
+ None, # grad_loss_type
345
+ None, # grad_max_completion_length
331
346
  )
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
27
27
  epsilon_low=0.2,
28
28
  epsilon_high=0.2,
29
29
  beta=0.04,
30
+ loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
31
+ max_completion_length=None, # Required for dr_grpo
30
32
  **kwargs,
31
33
  ):
32
34
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
61
63
  # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
64
  # and TRL GRPO implementation
63
65
  # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
66
+ if loss_type == "grpo":
67
+ # Average per-sequence loss
68
+ loss = (
69
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
70
+ ).sum() / full_attention_mask.shape[0]
71
+ elif loss_type == "bnpo":
72
+ # Batch Normalized Per-token loss (original implementation)
73
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
74
+ elif loss_type == "dr_grpo":
75
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
76
+ if max_completion_length is None:
77
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
78
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
79
+ else:
80
+ raise ValueError(f"Unknown loss type: {loss_type}")
65
81
 
66
82
  # Calculate metrics
67
83
  metrics = []
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
91
107
  beta=0.04,
92
108
  epsilon_low=0.2,
93
109
  epsilon_high=0.2,
110
+ loss_type="bnpo",
111
+ max_completion_length=None,
94
112
  temperature=1.0,
95
113
  compiled=True,
96
114
  use_ref_model=True,
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
110
128
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
129
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
130
  beta (float): Weight for the KL penalty
131
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
132
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
113
133
  temperature (float): Temperature for the logits
114
134
  compiled (bool): Whether to use torch compile
115
135
  use_ref_model (bool): Whether to use a reference model
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
134
154
  beta=beta,
135
155
  epsilon_low=epsilon_low,
136
156
  epsilon_high=epsilon_high,
157
+ loss_type=loss_type,
158
+ max_completion_length=max_completion_length,
137
159
  temperature=temperature,
138
160
  compiled=compiled,
139
161
  use_ref_model=use_ref_model,
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
161
183
  None, # grad_beta
162
184
  None, # grad_epsilon_low
163
185
  None, # grad_epsilon_high
186
+ None, # grad_loss_type (string, not differentiable)
187
+ None, # grad_max_completion_length (int, not differentiable)
164
188
  None, # grad_temperature
165
189
  None, # grad_compiled
166
190
  None, # grad_use_ref_model
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
179
203
  chunk_size: int = 1,
180
204
  epsilon_low: float = 0.2,
181
205
  epsilon_high: float = 0.2,
206
+ loss_type: str = "bnpo",
207
+ max_completion_length: int | None = None,
182
208
  temperature: float = 1.0,
183
209
  ):
184
210
  """
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
189
215
  chunk_size (int): Size of chunks for processing.
190
216
  epsilon_low (float): Lower bound for the importance sampling ratio.
191
217
  epsilon_high (float): Upper bound for the importance sampling ratio.
218
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
219
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
192
220
  temperature (float): Temperature for the logits.
193
221
  """
194
222
  super().__init__()
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
198
226
  self.chunk_size = chunk_size
199
227
  self.epsilon_low = epsilon_low
200
228
  self.epsilon_high = epsilon_high
229
+ self.loss_type = loss_type
230
+ self.max_completion_length = max_completion_length
201
231
  self.temperature = temperature
202
232
 
203
233
  def forward(
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
229
259
  self.beta,
230
260
  self.epsilon_low,
231
261
  self.epsilon_high,
262
+ self.loss_type,
263
+ self.max_completion_length,
232
264
  self.temperature,
233
265
  self.compiled,
234
266
  self.use_ref_model,
@@ -222,7 +222,7 @@ def lce_forward(
222
222
  lm_head_weight=self.lm_head.weight,
223
223
  labels=labels,
224
224
  hidden_size=self.config.hidden_size,
225
- softcap=self.config.final_logit_softcapping,
225
+ final_logit_softcapping=self.config.final_logit_softcapping,
226
226
  **loss_kwargs,
227
227
  )
228
228
 
@@ -112,7 +112,7 @@ def causal_forward(
112
112
  lm_head_weight=self.lm_head.weight,
113
113
  labels=labels,
114
114
  hidden_size=self.config.hidden_size,
115
- softcap=self.config.final_logit_softcapping,
115
+ final_logit_softcapping=self.config.final_logit_softcapping,
116
116
  **loss_kwargs,
117
117
  )
118
118
 
@@ -1,14 +1,18 @@
1
+ from typing import Optional
2
+
3
+ import torch
1
4
  import torch.nn as nn
2
5
 
3
6
  import liger_kernel.transformers.functional as F
4
7
 
5
8
 
6
9
  def fixed_fused_linear_cross_entropy(
7
- hidden_states,
8
- lm_head_weight,
9
- target,
10
- num_items_in_batch: int = None,
10
+ hidden_states: torch.Tensor,
11
+ lm_head_weight: torch.Tensor,
12
+ target: torch.Tensor,
13
+ num_items_in_batch: Optional[int] = None,
11
14
  ignore_index: int = -100,
15
+ final_logit_softcapping: Optional[float] = None,
12
16
  **kwargs,
13
17
  ):
14
18
  reduction = "sum" if num_items_in_batch is not None else "mean"
@@ -18,7 +22,7 @@ def fixed_fused_linear_cross_entropy(
18
22
  target,
19
23
  reduction=reduction,
20
24
  ignore_index=ignore_index,
21
- **kwargs,
25
+ softcap=final_logit_softcapping,
22
26
  )
23
27
  if reduction == "sum":
24
28
  loss = loss / num_items_in_batch
@@ -31,15 +35,17 @@ def LigerForCausalLMLoss(
31
35
  lm_head_weight,
32
36
  labels,
33
37
  hidden_size: int,
34
- num_items_in_batch: int = None,
38
+ num_items_in_batch: Optional[int] = None,
35
39
  ignore_index: int = -100,
40
+ shift_labels: Optional[torch.Tensor] = None,
41
+ final_logit_softcapping: Optional[float] = None,
36
42
  **kwargs,
37
43
  ):
38
44
  # Skip upcast since intermediate values for the loss are all fp32 in kernel
39
- labels = labels.to(hidden_states.device)
40
- # Shift so that token < n predict n
41
- labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
42
- shift_labels = labels[..., 1:].contiguous()
45
+ if shift_labels is None:
46
+ # Shift so that token < n predict n
47
+ labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
48
+ shift_labels = labels[..., 1:].contiguous()
43
49
 
44
50
  # Flatten the tokens
45
51
  hidden_states = hidden_states.view(-1, hidden_size)
@@ -52,6 +58,7 @@ def LigerForCausalLMLoss(
52
58
  shift_labels,
53
59
  num_items_in_batch,
54
60
  ignore_index,
61
+ final_logit_softcapping,
55
62
  **kwargs,
56
63
  )
57
64
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250411201510
3
+ Version: 0.5.6.dev20250411224032
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -7,10 +7,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
- liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=-E4AuWY-y2bMo_kAmEQBgQ92UJh3L5IiCRGVcfMJOCE,12731
10
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
11
11
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
13
- liger_kernel/chunked_loss/grpo_loss.py,sha256=6Mb4ZT6MfnOr4Xo681rMR0LKkhzJhInvQp8wp2YVMK0,8913
13
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=eh6mErFUZsSQrgRRefuXdk-LG0gS7Rg2r-U9CtbH3eU,10834
14
14
  liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
15
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
@@ -56,11 +56,11 @@ liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_J
56
56
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
57
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
59
- liger_kernel/transformers/model/gemma2.py,sha256=tLl1v-O8K0NZ7BQcSf1dE3450-xV72RAk4E5oTPcu_s,10907
60
- liger_kernel/transformers/model/gemma3.py,sha256=PjAfFtupT9EW0sb57Hx8UJXcnvq9HFgNndeAE4EqyPw,16086
59
+ liger_kernel/transformers/model/gemma2.py,sha256=n4MZupFGDMvtnvkvkNhRrxXS3ZF341BVfyLjrOXp10g,10923
60
+ liger_kernel/transformers/model/gemma3.py,sha256=ge3JYchiKvX1G1Zp00jX2zmQK2K7ymJoZAxbb2ggslw,16102
61
61
  liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
62
62
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
- liger_kernel/transformers/model/loss_utils.py,sha256=Z-fUrf-cUDUjUIH7Tl9OL2hT8nmtx7ES3kg8syuWKy4,1476
63
+ liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
64
64
  liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
65
65
  liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
66
66
  liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/METADATA,sha256=exdcHfLuKkUQ2NIene0sQ5hEn8mB98YKJ43XfirrGwM,23297
79
- liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/METADATA,sha256=cahu5M3U8t37VfXE0xLqBEH8tLSpXZzbbfwlyS86AuE,23297
79
+ liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/RECORD,,