liger-kernel-nightly 0.5.6.dev20250411201510__py3-none-any.whl → 0.5.6.dev20250411224032__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_ppo.py +15 -0
- liger_kernel/chunked_loss/grpo_loss.py +33 -1
- liger_kernel/transformers/model/gemma2.py +1 -1
- liger_kernel/transformers/model/gemma3.py +1 -1
- liger_kernel/transformers/model/loss_utils.py +17 -10
- {liger_kernel_nightly-0.5.6.dev20250411201510.dist-info → liger_kernel_nightly-0.5.6.dev20250411224032.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.6.dev20250411201510.dist-info → liger_kernel_nightly-0.5.6.dev20250411224032.dist-info}/RECORD +11 -11
- {liger_kernel_nightly-0.5.6.dev20250411201510.dist-info → liger_kernel_nightly-0.5.6.dev20250411224032.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510.dist-info → liger_kernel_nightly-0.5.6.dev20250411224032.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510.dist-info → liger_kernel_nightly-0.5.6.dev20250411224032.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250411201510.dist-info → liger_kernel_nightly-0.5.6.dev20250411224032.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
32
32
|
epsilon_low=0.2,
|
33
33
|
epsilon_high=0.2,
|
34
34
|
beta=0.04,
|
35
|
+
loss_type="bnpo",
|
36
|
+
max_completion_length=None,
|
35
37
|
temperature=1.0,
|
36
38
|
compiled=True,
|
37
39
|
use_ref_model=False,
|
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
57
59
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
58
60
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
59
61
|
beta: Weight for the KL penalty
|
62
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
63
|
+
max_completion_length: Maximum completion length required for "dr_grpo"
|
60
64
|
temperature: Temperature for the logits
|
61
65
|
compiled: Whether to use torch compile
|
62
66
|
use_ref_model: Whether to use a reference model
|
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
68
72
|
)
|
69
73
|
if ref_per_token_logps is not None and ref_input is not None:
|
70
74
|
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
75
|
+
if loss_type == "dr_grpo":
|
76
|
+
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
|
71
77
|
# Initialize accumulators
|
72
78
|
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
73
79
|
grad_weight = torch.zeros_like(weight) # [V, H]
|
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
84
90
|
epsilon_low=epsilon_low,
|
85
91
|
epsilon_high=epsilon_high,
|
86
92
|
beta=beta,
|
93
|
+
loss_type=loss_type,
|
94
|
+
max_completion_length=max_completion_length,
|
87
95
|
temperature=temperature,
|
88
96
|
use_ref_model=use_ref_model,
|
89
97
|
ppo_loss_fn=cls.ppo_loss_fn,
|
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
251
259
|
epsilon_low=0.2,
|
252
260
|
epsilon_high=0.2,
|
253
261
|
beta=0.04,
|
262
|
+
loss_type="bnpo",
|
263
|
+
max_completion_length=None,
|
254
264
|
temperature=1.0,
|
255
265
|
use_ref_model=False,
|
256
266
|
ppo_loss_fn=None,
|
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
280
290
|
epsilon_low=epsilon_low,
|
281
291
|
epsilon_high=epsilon_high,
|
282
292
|
beta=beta,
|
293
|
+
loss_type=loss_type,
|
294
|
+
max_completion_length=max_completion_length,
|
283
295
|
)
|
284
296
|
|
285
297
|
return chunk_loss, chunk_metrics
|
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
303
315
|
def backward(ctx, grad_output, *grad_metrics):
|
304
316
|
"""Backward pass for PPO loss."""
|
305
317
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
318
|
+
|
306
319
|
if grad_output != 1.0:
|
307
320
|
grad_input = grad_input * grad_output
|
308
321
|
grad_weight = grad_weight * grad_output
|
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
328
341
|
None, # grad_compiled
|
329
342
|
None, # grad_use_ref_model
|
330
343
|
None, # grad_chunk_size
|
344
|
+
None, # grad_loss_type
|
345
|
+
None, # grad_max_completion_length
|
331
346
|
)
|
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
27
27
|
epsilon_low=0.2,
|
28
28
|
epsilon_high=0.2,
|
29
29
|
beta=0.04,
|
30
|
+
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
31
|
+
max_completion_length=None, # Required for dr_grpo
|
30
32
|
**kwargs,
|
31
33
|
):
|
32
34
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
61
63
|
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
62
64
|
# and TRL GRPO implementation
|
63
65
|
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
64
|
-
|
66
|
+
if loss_type == "grpo":
|
67
|
+
# Average per-sequence loss
|
68
|
+
loss = (
|
69
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
70
|
+
).sum() / full_attention_mask.shape[0]
|
71
|
+
elif loss_type == "bnpo":
|
72
|
+
# Batch Normalized Per-token loss (original implementation)
|
73
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
74
|
+
elif loss_type == "dr_grpo":
|
75
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
76
|
+
if max_completion_length is None:
|
77
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
78
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
79
|
+
else:
|
80
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
65
81
|
|
66
82
|
# Calculate metrics
|
67
83
|
metrics = []
|
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
91
107
|
beta=0.04,
|
92
108
|
epsilon_low=0.2,
|
93
109
|
epsilon_high=0.2,
|
110
|
+
loss_type="bnpo",
|
111
|
+
max_completion_length=None,
|
94
112
|
temperature=1.0,
|
95
113
|
compiled=True,
|
96
114
|
use_ref_model=True,
|
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
110
128
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
111
129
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
112
130
|
beta (float): Weight for the KL penalty
|
131
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
132
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
113
133
|
temperature (float): Temperature for the logits
|
114
134
|
compiled (bool): Whether to use torch compile
|
115
135
|
use_ref_model (bool): Whether to use a reference model
|
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
134
154
|
beta=beta,
|
135
155
|
epsilon_low=epsilon_low,
|
136
156
|
epsilon_high=epsilon_high,
|
157
|
+
loss_type=loss_type,
|
158
|
+
max_completion_length=max_completion_length,
|
137
159
|
temperature=temperature,
|
138
160
|
compiled=compiled,
|
139
161
|
use_ref_model=use_ref_model,
|
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
161
183
|
None, # grad_beta
|
162
184
|
None, # grad_epsilon_low
|
163
185
|
None, # grad_epsilon_high
|
186
|
+
None, # grad_loss_type (string, not differentiable)
|
187
|
+
None, # grad_max_completion_length (int, not differentiable)
|
164
188
|
None, # grad_temperature
|
165
189
|
None, # grad_compiled
|
166
190
|
None, # grad_use_ref_model
|
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
179
203
|
chunk_size: int = 1,
|
180
204
|
epsilon_low: float = 0.2,
|
181
205
|
epsilon_high: float = 0.2,
|
206
|
+
loss_type: str = "bnpo",
|
207
|
+
max_completion_length: int | None = None,
|
182
208
|
temperature: float = 1.0,
|
183
209
|
):
|
184
210
|
"""
|
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
189
215
|
chunk_size (int): Size of chunks for processing.
|
190
216
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
191
217
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
218
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
219
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
192
220
|
temperature (float): Temperature for the logits.
|
193
221
|
"""
|
194
222
|
super().__init__()
|
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
198
226
|
self.chunk_size = chunk_size
|
199
227
|
self.epsilon_low = epsilon_low
|
200
228
|
self.epsilon_high = epsilon_high
|
229
|
+
self.loss_type = loss_type
|
230
|
+
self.max_completion_length = max_completion_length
|
201
231
|
self.temperature = temperature
|
202
232
|
|
203
233
|
def forward(
|
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
229
259
|
self.beta,
|
230
260
|
self.epsilon_low,
|
231
261
|
self.epsilon_high,
|
262
|
+
self.loss_type,
|
263
|
+
self.max_completion_length,
|
232
264
|
self.temperature,
|
233
265
|
self.compiled,
|
234
266
|
self.use_ref_model,
|
@@ -222,7 +222,7 @@ def lce_forward(
|
|
222
222
|
lm_head_weight=self.lm_head.weight,
|
223
223
|
labels=labels,
|
224
224
|
hidden_size=self.config.hidden_size,
|
225
|
-
|
225
|
+
final_logit_softcapping=self.config.final_logit_softcapping,
|
226
226
|
**loss_kwargs,
|
227
227
|
)
|
228
228
|
|
@@ -112,7 +112,7 @@ def causal_forward(
|
|
112
112
|
lm_head_weight=self.lm_head.weight,
|
113
113
|
labels=labels,
|
114
114
|
hidden_size=self.config.hidden_size,
|
115
|
-
|
115
|
+
final_logit_softcapping=self.config.final_logit_softcapping,
|
116
116
|
**loss_kwargs,
|
117
117
|
)
|
118
118
|
|
@@ -1,14 +1,18 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
1
4
|
import torch.nn as nn
|
2
5
|
|
3
6
|
import liger_kernel.transformers.functional as F
|
4
7
|
|
5
8
|
|
6
9
|
def fixed_fused_linear_cross_entropy(
|
7
|
-
hidden_states,
|
8
|
-
lm_head_weight,
|
9
|
-
target,
|
10
|
-
num_items_in_batch: int = None,
|
10
|
+
hidden_states: torch.Tensor,
|
11
|
+
lm_head_weight: torch.Tensor,
|
12
|
+
target: torch.Tensor,
|
13
|
+
num_items_in_batch: Optional[int] = None,
|
11
14
|
ignore_index: int = -100,
|
15
|
+
final_logit_softcapping: Optional[float] = None,
|
12
16
|
**kwargs,
|
13
17
|
):
|
14
18
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
@@ -18,7 +22,7 @@ def fixed_fused_linear_cross_entropy(
|
|
18
22
|
target,
|
19
23
|
reduction=reduction,
|
20
24
|
ignore_index=ignore_index,
|
21
|
-
|
25
|
+
softcap=final_logit_softcapping,
|
22
26
|
)
|
23
27
|
if reduction == "sum":
|
24
28
|
loss = loss / num_items_in_batch
|
@@ -31,15 +35,17 @@ def LigerForCausalLMLoss(
|
|
31
35
|
lm_head_weight,
|
32
36
|
labels,
|
33
37
|
hidden_size: int,
|
34
|
-
num_items_in_batch: int = None,
|
38
|
+
num_items_in_batch: Optional[int] = None,
|
35
39
|
ignore_index: int = -100,
|
40
|
+
shift_labels: Optional[torch.Tensor] = None,
|
41
|
+
final_logit_softcapping: Optional[float] = None,
|
36
42
|
**kwargs,
|
37
43
|
):
|
38
44
|
# Skip upcast since intermediate values for the loss are all fp32 in kernel
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
45
|
+
if shift_labels is None:
|
46
|
+
# Shift so that token < n predict n
|
47
|
+
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
48
|
+
shift_labels = labels[..., 1:].contiguous()
|
43
49
|
|
44
50
|
# Flatten the tokens
|
45
51
|
hidden_states = hidden_states.view(-1, hidden_size)
|
@@ -52,6 +58,7 @@ def LigerForCausalLMLoss(
|
|
52
58
|
shift_labels,
|
53
59
|
num_items_in_batch,
|
54
60
|
ignore_index,
|
61
|
+
final_logit_softcapping,
|
55
62
|
**kwargs,
|
56
63
|
)
|
57
64
|
return loss
|
@@ -7,10 +7,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
|
|
7
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=xZwGqS04si9zXyob95SAdalC-hajZg8fWINqiqffN8k,5855
|
8
8
|
liger_kernel/chunked_loss/functional.py,sha256=9G3nKm-Bi7uoZRFkL8wwGMl6juDl4bSzDvTa5GHZPzg,955
|
9
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
10
|
-
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256
|
10
|
+
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
|
11
11
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ojB42jYPu0c4ki96Ft-hy7Sf6fh_WikG-aWNrlZzSio,18362
|
12
12
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
13
|
-
liger_kernel/chunked_loss/grpo_loss.py,sha256=
|
13
|
+
liger_kernel/chunked_loss/grpo_loss.py,sha256=eh6mErFUZsSQrgRRefuXdk-LG0gS7Rg2r-U9CtbH3eU,10834
|
14
14
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=u2ahkuHsbhpNaKcpBCz5gCMDk9ou-P04DHji592dIBo,7067
|
15
15
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
@@ -56,11 +56,11 @@ liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_J
|
|
56
56
|
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
57
57
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
|
59
|
-
liger_kernel/transformers/model/gemma2.py,sha256=
|
60
|
-
liger_kernel/transformers/model/gemma3.py,sha256=
|
59
|
+
liger_kernel/transformers/model/gemma2.py,sha256=n4MZupFGDMvtnvkvkNhRrxXS3ZF341BVfyLjrOXp10g,10923
|
60
|
+
liger_kernel/transformers/model/gemma3.py,sha256=ge3JYchiKvX1G1Zp00jX2zmQK2K7ymJoZAxbb2ggslw,16102
|
61
61
|
liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
|
62
62
|
liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
|
63
|
-
liger_kernel/transformers/model/loss_utils.py,sha256=
|
63
|
+
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
64
64
|
liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
|
65
65
|
liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
|
66
66
|
liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
|
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
74
74
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
75
75
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
76
76
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
77
|
-
liger_kernel_nightly-0.5.6.
|
78
|
-
liger_kernel_nightly-0.5.6.
|
79
|
-
liger_kernel_nightly-0.5.6.
|
80
|
-
liger_kernel_nightly-0.5.6.
|
81
|
-
liger_kernel_nightly-0.5.6.
|
82
|
-
liger_kernel_nightly-0.5.6.
|
77
|
+
liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
78
|
+
liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/METADATA,sha256=cahu5M3U8t37VfXE0xLqBEH8tLSpXZzbbfwlyS86AuE,23297
|
79
|
+
liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
80
|
+
liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
81
|
+
liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
82
|
+
liger_kernel_nightly-0.5.6.dev20250411224032.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|