liger-kernel-nightly 0.6.2.dev20250903164350__py3-none-any.whl → 0.6.2.dev20250903164435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- {liger_kernel_nightly-0.6.2.dev20250903164350.dist-info → liger_kernel_nightly-0.6.2.dev20250903164435.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20250903164350.dist-info → liger_kernel_nightly-0.6.2.dev20250903164435.dist-info}/RECORD +8 -8
- {liger_kernel_nightly-0.6.2.dev20250903164350.dist-info → liger_kernel_nightly-0.6.2.dev20250903164435.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250903164350.dist-info → liger_kernel_nightly-0.6.2.dev20250903164435.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250903164350.dist-info → liger_kernel_nightly-0.6.2.dev20250903164435.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250903164350.dist-info → liger_kernel_nightly-0.6.2.dev20250903164435.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
34
34
|
beta=0.04,
|
35
35
|
loss_type="bnpo",
|
36
36
|
max_completion_length=None,
|
37
|
+
importance_sampling_level="token",
|
37
38
|
temperature=1.0,
|
38
39
|
compiled=True,
|
39
40
|
use_ref_model=False,
|
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
92
93
|
beta=beta,
|
93
94
|
loss_type=loss_type,
|
94
95
|
max_completion_length=max_completion_length,
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
95
97
|
temperature=temperature,
|
96
98
|
use_ref_model=use_ref_model,
|
97
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
261
263
|
beta=0.04,
|
262
264
|
loss_type="bnpo",
|
263
265
|
max_completion_length=None,
|
266
|
+
importance_sampling_level="token",
|
264
267
|
temperature=1.0,
|
265
268
|
use_ref_model=False,
|
266
269
|
ppo_loss_fn=None,
|
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
292
295
|
beta=beta,
|
293
296
|
loss_type=loss_type,
|
294
297
|
max_completion_length=max_completion_length,
|
298
|
+
importance_sampling_level=importance_sampling_level,
|
295
299
|
)
|
296
300
|
|
297
301
|
return chunk_loss, chunk_metrics
|
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
31
31
|
beta=0.04,
|
32
32
|
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
34
35
|
**kwargs,
|
35
36
|
):
|
36
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
50
51
|
|
51
52
|
# Compute policy gradient loss with importance sampling ratio
|
52
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
53
|
-
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
55
|
+
|
56
|
+
if importance_sampling_level == "token":
|
57
|
+
log_importance_weights = log_ratio
|
58
|
+
elif importance_sampling_level == "sequence":
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
61
|
+
else:
|
62
|
+
raise ValueError(
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
64
|
+
"and 'sequence'."
|
65
|
+
)
|
66
|
+
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
54
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
55
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
56
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
85
101
|
metrics = []
|
86
102
|
if beta != 0.0:
|
87
103
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
88
|
-
|
89
|
-
|
90
|
-
|
104
|
+
|
105
|
+
# Adjust clipping metric calculation based on importance sampling level
|
106
|
+
if importance_sampling_level == "token":
|
107
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
108
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
109
|
+
)
|
110
|
+
else: # sequence level
|
111
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
112
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
113
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
114
|
+
)
|
115
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
116
|
+
|
91
117
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
92
118
|
return loss, metrics
|
93
119
|
|
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
111
137
|
epsilon_high=0.2,
|
112
138
|
loss_type="bnpo",
|
113
139
|
max_completion_length=None,
|
140
|
+
importance_sampling_level="token",
|
114
141
|
temperature=1.0,
|
115
142
|
compiled=True,
|
116
143
|
use_ref_model=True,
|
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
132
159
|
beta (float): Weight for the KL penalty
|
133
160
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
134
161
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
162
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
135
163
|
temperature (float): Temperature for the logits
|
136
164
|
compiled (bool): Whether to use torch compile
|
137
165
|
use_ref_model (bool): Whether to use a reference model
|
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
162
190
|
compiled=compiled,
|
163
191
|
use_ref_model=use_ref_model,
|
164
192
|
chunk_size=chunk_size,
|
193
|
+
importance_sampling_level=importance_sampling_level,
|
165
194
|
)
|
166
195
|
|
167
196
|
@staticmethod
|
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
187
216
|
None, # grad_epsilon_high
|
188
217
|
None, # grad_loss_type (string, not differentiable)
|
189
218
|
None, # grad_max_completion_length (int, not differentiable)
|
219
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
190
220
|
None, # grad_temperature
|
191
221
|
None, # grad_compiled
|
192
222
|
None, # grad_use_ref_model
|
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
207
237
|
epsilon_high: float = 0.2,
|
208
238
|
loss_type: str = "bnpo",
|
209
239
|
max_completion_length: Optional[int] = None,
|
240
|
+
importance_sampling_level: str = "token",
|
210
241
|
temperature: float = 1.0,
|
211
242
|
):
|
212
243
|
"""
|
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
219
250
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
220
251
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
221
252
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
253
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
222
254
|
temperature (float): Temperature for the logits.
|
223
255
|
"""
|
224
256
|
super().__init__()
|
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
230
262
|
self.epsilon_high = epsilon_high
|
231
263
|
self.loss_type = loss_type
|
232
264
|
self.max_completion_length = max_completion_length
|
265
|
+
self.importance_sampling_level = importance_sampling_level
|
233
266
|
self.temperature = temperature
|
234
267
|
|
235
268
|
def forward(
|
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
263
296
|
self.epsilon_high,
|
264
297
|
self.loss_type,
|
265
298
|
self.max_completion_length,
|
299
|
+
self.importance_sampling_level,
|
266
300
|
self.temperature,
|
267
301
|
self.compiled,
|
268
302
|
self.use_ref_model,
|
@@ -8,10 +8,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
|
|
8
8
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=I83khNs3QQjuhr8U3NIOAACkbse6DNiBV-TulPZ0lXw,9006
|
9
9
|
liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
|
10
10
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
|
11
|
-
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=
|
11
|
+
liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=ZjpNP5VC-tXXIKb4AckkQ3iWWQeej-JoG4StJq3N0wg,13650
|
12
12
|
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
|
13
13
|
liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
|
14
|
-
liger_kernel/chunked_loss/grpo_loss.py,sha256=
|
14
|
+
liger_kernel/chunked_loss/grpo_loss.py,sha256=SkZuKoW8K94UbWR-OtfopsQkuQ8tFOr_90AGR6_Mhes,12844
|
15
15
|
liger_kernel/chunked_loss/jsd_loss.py,sha256=gRhnmB8xwuz7FcMJi5v5eyBsq01owaCbcyyrF4rYtY0,7133
|
16
16
|
liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
|
17
17
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
@@ -96,9 +96,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
96
96
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
97
97
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
98
98
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
99
|
-
liger_kernel_nightly-0.6.2.
|
100
|
-
liger_kernel_nightly-0.6.2.
|
101
|
-
liger_kernel_nightly-0.6.2.
|
102
|
-
liger_kernel_nightly-0.6.2.
|
103
|
-
liger_kernel_nightly-0.6.2.
|
104
|
-
liger_kernel_nightly-0.6.2.
|
99
|
+
liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
100
|
+
liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/METADATA,sha256=BgiSTSMznb0cvZyFqU68T0sEIAOBcf9hvuO6jIPCcC8,24504
|
101
|
+
liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
102
|
+
liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
103
|
+
liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
104
|
+
liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|