liger-kernel-nightly 0.5.6.dev20250408223717__py3-none-any.whl → 0.5.6.dev20250411210855__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
+ loss_type="bnpo",
36
+ max_completion_length=None,
35
37
  temperature=1.0,
36
38
  compiled=True,
37
39
  use_ref_model=False,
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
57
59
  epsilon_low: Lower bound for clipping the importance sampling ratio
58
60
  epsilon_high: Upper bound for clipping the importance sampling ratio
59
61
  beta: Weight for the KL penalty
62
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ max_completion_length: Maximum completion length required for "dr_grpo"
60
64
  temperature: Temperature for the logits
61
65
  compiled: Whether to use torch compile
62
66
  use_ref_model: Whether to use a reference model
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
68
72
  )
69
73
  if ref_per_token_logps is not None and ref_input is not None:
70
74
  raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
75
+ if loss_type == "dr_grpo":
76
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
71
77
  # Initialize accumulators
72
78
  loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
79
  grad_weight = torch.zeros_like(weight) # [V, H]
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
84
90
  epsilon_low=epsilon_low,
85
91
  epsilon_high=epsilon_high,
86
92
  beta=beta,
93
+ loss_type=loss_type,
94
+ max_completion_length=max_completion_length,
87
95
  temperature=temperature,
88
96
  use_ref_model=use_ref_model,
89
97
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
251
259
  epsilon_low=0.2,
252
260
  epsilon_high=0.2,
253
261
  beta=0.04,
262
+ loss_type="bnpo",
263
+ max_completion_length=None,
254
264
  temperature=1.0,
255
265
  use_ref_model=False,
256
266
  ppo_loss_fn=None,
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
280
290
  epsilon_low=epsilon_low,
281
291
  epsilon_high=epsilon_high,
282
292
  beta=beta,
293
+ loss_type=loss_type,
294
+ max_completion_length=max_completion_length,
283
295
  )
284
296
 
285
297
  return chunk_loss, chunk_metrics
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
303
315
  def backward(ctx, grad_output, *grad_metrics):
304
316
  """Backward pass for PPO loss."""
305
317
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
318
+
306
319
  if grad_output != 1.0:
307
320
  grad_input = grad_input * grad_output
308
321
  grad_weight = grad_weight * grad_output
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
328
341
  None, # grad_compiled
329
342
  None, # grad_use_ref_model
330
343
  None, # grad_chunk_size
344
+ None, # grad_loss_type
345
+ None, # grad_max_completion_length
331
346
  )
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
27
27
  epsilon_low=0.2,
28
28
  epsilon_high=0.2,
29
29
  beta=0.04,
30
+ loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
31
+ max_completion_length=None, # Required for dr_grpo
30
32
  **kwargs,
31
33
  ):
32
34
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
61
63
  # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
64
  # and TRL GRPO implementation
63
65
  # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
66
+ if loss_type == "grpo":
67
+ # Average per-sequence loss
68
+ loss = (
69
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
70
+ ).sum() / full_attention_mask.shape[0]
71
+ elif loss_type == "bnpo":
72
+ # Batch Normalized Per-token loss (original implementation)
73
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
74
+ elif loss_type == "dr_grpo":
75
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
76
+ if max_completion_length is None:
77
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
78
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
79
+ else:
80
+ raise ValueError(f"Unknown loss type: {loss_type}")
65
81
 
66
82
  # Calculate metrics
67
83
  metrics = []
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
91
107
  beta=0.04,
92
108
  epsilon_low=0.2,
93
109
  epsilon_high=0.2,
110
+ loss_type="bnpo",
111
+ max_completion_length=None,
94
112
  temperature=1.0,
95
113
  compiled=True,
96
114
  use_ref_model=True,
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
110
128
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
129
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
130
  beta (float): Weight for the KL penalty
131
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
132
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
113
133
  temperature (float): Temperature for the logits
114
134
  compiled (bool): Whether to use torch compile
115
135
  use_ref_model (bool): Whether to use a reference model
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
134
154
  beta=beta,
135
155
  epsilon_low=epsilon_low,
136
156
  epsilon_high=epsilon_high,
157
+ loss_type=loss_type,
158
+ max_completion_length=max_completion_length,
137
159
  temperature=temperature,
138
160
  compiled=compiled,
139
161
  use_ref_model=use_ref_model,
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
161
183
  None, # grad_beta
162
184
  None, # grad_epsilon_low
163
185
  None, # grad_epsilon_high
186
+ None, # grad_loss_type (string, not differentiable)
187
+ None, # grad_max_completion_length (int, not differentiable)
164
188
  None, # grad_temperature
165
189
  None, # grad_compiled
166
190
  None, # grad_use_ref_model
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
179
203
  chunk_size: int = 1,
180
204
  epsilon_low: float = 0.2,
181
205
  epsilon_high: float = 0.2,
206
+ loss_type: str = "bnpo",
207
+ max_completion_length: int | None = None,
182
208
  temperature: float = 1.0,
183
209
  ):
184
210
  """
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
189
215
  chunk_size (int): Size of chunks for processing.
190
216
  epsilon_low (float): Lower bound for the importance sampling ratio.
191
217
  epsilon_high (float): Upper bound for the importance sampling ratio.
218
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
219
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
192
220
  temperature (float): Temperature for the logits.
193
221
  """
194
222
  super().__init__()
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
198
226
  self.chunk_size = chunk_size
199
227
  self.epsilon_low = epsilon_low
200
228
  self.epsilon_high = epsilon_high
229
+ self.loss_type = loss_type
230
+ self.max_completion_length = max_completion_length
201
231
  self.temperature = temperature
202
232
 
203
233
  def forward(
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
229
259
  self.beta,
230
260
  self.epsilon_low,
231
261
  self.epsilon_high,
262
+ self.loss_type,
263
+ self.max_completion_length,
232
264
  self.temperature,
233
265
  self.compiled,
234
266
  self.use_ref_model,
@@ -6,6 +6,7 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
 
11
12
  def get_num_warps(BLOCK_SIZE):
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
115
116
 
116
117
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
117
118
  BT, V = y_pred.shape
118
-
119
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
120
- num_warps = get_num_warps(BLOCK_SIZE)
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
121
125
 
122
126
  grid = (BT,)
123
127
  reduction = _str_to_reduction_mode[reduction]
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
155
159
 
156
160
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
157
161
  BT, V = target.shape
158
-
159
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
160
- num_warps = get_num_warps(BLOCK_SIZE)
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
161
168
 
162
169
  grid = (BT,)
163
170
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250408223717
3
+ Version: 0.5.6.dev20250411210855
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -7,10 +7,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
7
7
  liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
8
8
  liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
9
9
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
10
- liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=-E4AuWY-y2bMo_kAmEQBgQ92UJh3L5IiCRGVcfMJOCE,12731
10
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
11
11
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
12
12
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
13
- liger_kernel/chunked_loss/grpo_loss.py,sha256=6Mb4ZT6MfnOr4Xo681rMR0LKkhzJhInvQp8wp2YVMK0,8913
13
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=eh6mErFUZsSQrgRRefuXdk-LG0gS7Rg2r-U9CtbH3eU,10834
14
14
  liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
15
15
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
16
16
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
@@ -23,7 +23,7 @@ liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHu
23
23
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
24
24
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
25
25
  liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
26
- liger_kernel/ops/kl_div.py,sha256=NkG7D6_DnPBzr-ohhYiQbRBnq_fbGmpn5UU7y0UBKQo,8420
26
+ liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
27
27
  liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
28
28
  liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
29
29
  liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.6.dev20250408223717.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.6.dev20250408223717.dist-info/METADATA,sha256=ZSAGbY1ejoXoRQzTkkCjTwZd-OQxWdTV1IukEftepgU,23297
79
- liger_kernel_nightly-0.5.6.dev20250408223717.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.6.dev20250408223717.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.6.dev20250408223717.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.6.dev20250408223717.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/METADATA,sha256=mX6Na52mRBO2g2I7Qqj34QGM17tMQAZLNjE7XX0g9fA,23297
79
+ liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/RECORD,,