liger-kernel-nightly 0.5.6.dev20250408223717__py3-none-any.whl → 0.5.6.dev20250411210855__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_ppo.py +15 -0
- liger_kernel/chunked_loss/grpo_loss.py +33 -1
- liger_kernel/ops/kl_div.py +13 -6
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411210855.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411210855.dist-info}/RECORD +9 -9
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411210855.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411210855.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411210855.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411210855.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
32
32
|
epsilon_low=0.2,
|
33
33
|
epsilon_high=0.2,
|
34
34
|
beta=0.04,
|
35
|
+
loss_type="bnpo",
|
36
|
+
max_completion_length=None,
|
35
37
|
temperature=1.0,
|
36
38
|
compiled=True,
|
37
39
|
use_ref_model=False,
|
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
57
59
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
58
60
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
59
61
|
beta: Weight for the KL penalty
|
62
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
63
|
+
max_completion_length: Maximum completion length required for "dr_grpo"
|
60
64
|
temperature: Temperature for the logits
|
61
65
|
compiled: Whether to use torch compile
|
62
66
|
use_ref_model: Whether to use a reference model
|
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
68
72
|
)
|
69
73
|
if ref_per_token_logps is not None and ref_input is not None:
|
70
74
|
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
75
|
+
if loss_type == "dr_grpo":
|
76
|
+
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
|
71
77
|
# Initialize accumulators
|
72
78
|
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
73
79
|
grad_weight = torch.zeros_like(weight) # [V, H]
|
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
84
90
|
epsilon_low=epsilon_low,
|
85
91
|
epsilon_high=epsilon_high,
|
86
92
|
beta=beta,
|
93
|
+
loss_type=loss_type,
|
94
|
+
max_completion_length=max_completion_length,
|
87
95
|
temperature=temperature,
|
88
96
|
use_ref_model=use_ref_model,
|
89
97
|
ppo_loss_fn=cls.ppo_loss_fn,
|
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
251
259
|
epsilon_low=0.2,
|
252
260
|
epsilon_high=0.2,
|
253
261
|
beta=0.04,
|
262
|
+
loss_type="bnpo",
|
263
|
+
max_completion_length=None,
|
254
264
|
temperature=1.0,
|
255
265
|
use_ref_model=False,
|
256
266
|
ppo_loss_fn=None,
|
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
280
290
|
epsilon_low=epsilon_low,
|
281
291
|
epsilon_high=epsilon_high,
|
282
292
|
beta=beta,
|
293
|
+
loss_type=loss_type,
|
294
|
+
max_completion_length=max_completion_length,
|
283
295
|
)
|
284
296
|
|
285
297
|
return chunk_loss, chunk_metrics
|
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
303
315
|
def backward(ctx, grad_output, *grad_metrics):
|
304
316
|
"""Backward pass for PPO loss."""
|
305
317
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
318
|
+
|
306
319
|
if grad_output != 1.0:
|
307
320
|
grad_input = grad_input * grad_output
|
308
321
|
grad_weight = grad_weight * grad_output
|
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
328
341
|
None, # grad_compiled
|
329
342
|
None, # grad_use_ref_model
|
330
343
|
None, # grad_chunk_size
|
344
|
+
None, # grad_loss_type
|
345
|
+
None, # grad_max_completion_length
|
331
346
|
)
|
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
27
27
|
epsilon_low=0.2,
|
28
28
|
epsilon_high=0.2,
|
29
29
|
beta=0.04,
|
30
|
+
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
31
|
+
max_completion_length=None, # Required for dr_grpo
|
30
32
|
**kwargs,
|
31
33
|
):
|
32
34
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
61
63
|
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
62
64
|
# and TRL GRPO implementation
|
63
65
|
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
64
|
-
|
66
|
+
if loss_type == "grpo":
|
67
|
+
# Average per-sequence loss
|
68
|
+
loss = (
|
69
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
70
|
+
).sum() / full_attention_mask.shape[0]
|
71
|
+
elif loss_type == "bnpo":
|
72
|
+
# Batch Normalized Per-token loss (original implementation)
|
73
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
74
|
+
elif loss_type == "dr_grpo":
|
75
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
76
|
+
if max_completion_length is None:
|
77
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
78
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
79
|
+
else:
|
80
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
65
81
|
|
66
82
|
# Calculate metrics
|
67
83
|
metrics = []
|
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
91
107
|
beta=0.04,
|
92
108
|
epsilon_low=0.2,
|
93
109
|
epsilon_high=0.2,
|
110
|
+
loss_type="bnpo",
|
111
|
+
max_completion_length=None,
|
94
112
|
temperature=1.0,
|
95
113
|
compiled=True,
|
96
114
|
use_ref_model=True,
|
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
110
128
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
111
129
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
112
130
|
beta (float): Weight for the KL penalty
|
131
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
132
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
113
133
|
temperature (float): Temperature for the logits
|
114
134
|
compiled (bool): Whether to use torch compile
|
115
135
|
use_ref_model (bool): Whether to use a reference model
|
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
134
154
|
beta=beta,
|
135
155
|
epsilon_low=epsilon_low,
|
136
156
|
epsilon_high=epsilon_high,
|
157
|
+
loss_type=loss_type,
|
158
|
+
max_completion_length=max_completion_length,
|
137
159
|
temperature=temperature,
|
138
160
|
compiled=compiled,
|
139
161
|
use_ref_model=use_ref_model,
|
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
161
183
|
None, # grad_beta
|
162
184
|
None, # grad_epsilon_low
|
163
185
|
None, # grad_epsilon_high
|
186
|
+
None, # grad_loss_type (string, not differentiable)
|
187
|
+
None, # grad_max_completion_length (int, not differentiable)
|
164
188
|
None, # grad_temperature
|
165
189
|
None, # grad_compiled
|
166
190
|
None, # grad_use_ref_model
|
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
179
203
|
chunk_size: int = 1,
|
180
204
|
epsilon_low: float = 0.2,
|
181
205
|
epsilon_high: float = 0.2,
|
206
|
+
loss_type: str = "bnpo",
|
207
|
+
max_completion_length: int | None = None,
|
182
208
|
temperature: float = 1.0,
|
183
209
|
):
|
184
210
|
"""
|
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
189
215
|
chunk_size (int): Size of chunks for processing.
|
190
216
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
191
217
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
218
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
219
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
192
220
|
temperature (float): Temperature for the logits.
|
193
221
|
"""
|
194
222
|
super().__init__()
|
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
198
226
|
self.chunk_size = chunk_size
|
199
227
|
self.epsilon_low = epsilon_low
|
200
228
|
self.epsilon_high = epsilon_high
|
229
|
+
self.loss_type = loss_type
|
230
|
+
self.max_completion_length = max_completion_length
|
201
231
|
self.temperature = temperature
|
202
232
|
|
203
233
|
def forward(
|
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
229
259
|
self.beta,
|
230
260
|
self.epsilon_low,
|
231
261
|
self.epsilon_high,
|
262
|
+
self.loss_type,
|
263
|
+
self.max_completion_length,
|
232
264
|
self.temperature,
|
233
265
|
self.compiled,
|
234
266
|
self.use_ref_model,
|
liger_kernel/ops/kl_div.py
CHANGED
@@ -6,6 +6,7 @@ import triton.language as tl
|
|
6
6
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
9
|
+
from liger_kernel.utils import infer_device
|
9
10
|
|
10
11
|
|
11
12
|
def get_num_warps(BLOCK_SIZE):
|
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
|
|
115
116
|
|
116
117
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
117
118
|
BT, V = y_pred.shape
|
118
|
-
|
119
|
-
|
120
|
-
|
119
|
+
BLOCK_SIZE = (
|
120
|
+
min(8192, triton.next_power_of_2(V))
|
121
|
+
if infer_device() == "xpu"
|
122
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
123
|
+
)
|
124
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
121
125
|
|
122
126
|
grid = (BT,)
|
123
127
|
reduction = _str_to_reduction_mode[reduction]
|
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
155
159
|
|
156
160
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
157
161
|
BT, V = target.shape
|
158
|
-
|
159
|
-
|
160
|
-
|
162
|
+
BLOCK_SIZE = (
|
163
|
+
min(8192, triton.next_power_of_2(V))
|
164
|
+
if infer_device() == "xpu"
|
165
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
166
|
+
)
|
167
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
161
168
|
|
162
169
|
grid = (BT,)
|
163
170
|
|
@@ -7,10 +7,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
|
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
10
|
-
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256
|
10
|
+
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
|
11
11
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
|
12
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
13
|
-
liger_kernel/chunked_loss/grpo_loss.py,sha256=
|
13
|
+
liger_kernel/chunked_loss/grpo_loss.py,sha256=eh6mErFUZsSQrgRRefuXdk-LG0gS7Rg2r-U9CtbH3eU,10834
|
14
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
|
15
15
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
@@ -23,7 +23,7 @@ liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHu
|
|
23
23
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
24
24
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
25
25
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
26
|
-
liger_kernel/ops/kl_div.py,sha256=
|
26
|
+
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
27
27
|
liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
|
28
28
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
29
29
|
liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
|
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
74
74
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
75
75
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
76
76
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
77
|
-
liger_kernel_nightly-0.5.6.
|
78
|
-
liger_kernel_nightly-0.5.6.
|
79
|
-
liger_kernel_nightly-0.5.6.
|
80
|
-
liger_kernel_nightly-0.5.6.
|
81
|
-
liger_kernel_nightly-0.5.6.
|
82
|
-
liger_kernel_nightly-0.5.6.
|
77
|
+
liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
78
|
+
liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/METADATA,sha256=mX6Na52mRBO2g2I7Qqj34QGM17tMQAZLNjE7XX0g9fA,23297
|
79
|
+
liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
80
|
+
liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
81
|
+
liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
82
|
+
liger_kernel_nightly-0.5.6.dev20250411210855.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|