liger-kernel-nightly 0.6.2.dev20250830153353__py3-none-any.whl → 0.6.2.dev20250903164435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
34
34
  beta=0.04,
35
35
  loss_type="bnpo",
36
36
  max_completion_length=None,
37
+ importance_sampling_level="token",
37
38
  temperature=1.0,
38
39
  compiled=True,
39
40
  use_ref_model=False,
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
92
93
  beta=beta,
93
94
  loss_type=loss_type,
94
95
  max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
95
97
  temperature=temperature,
96
98
  use_ref_model=use_ref_model,
97
99
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
261
263
  beta=0.04,
262
264
  loss_type="bnpo",
263
265
  max_completion_length=None,
266
+ importance_sampling_level="token",
264
267
  temperature=1.0,
265
268
  use_ref_model=False,
266
269
  ppo_loss_fn=None,
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
292
295
  beta=beta,
293
296
  loss_type=loss_type,
294
297
  max_completion_length=max_completion_length,
298
+ importance_sampling_level=importance_sampling_level,
295
299
  )
296
300
 
297
301
  return chunk_loss, chunk_metrics
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
31
31
  beta=0.04,
32
32
  loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
33
33
  max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
34
35
  **kwargs,
35
36
  ):
36
37
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
50
51
 
51
52
  # Compute policy gradient loss with importance sampling ratio
52
53
  old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
53
- coef_1 = torch.exp(per_token_logps - old_per_token_logps)
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
54
70
  coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
55
71
  per_token_loss1 = coef_1 * advantages.unsqueeze(1)
56
72
  per_token_loss2 = coef_2 * advantages.unsqueeze(1)
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
85
101
  metrics = []
86
102
  if beta != 0.0:
87
103
  metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
88
- is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
89
- (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
90
- )
104
+
105
+ # Adjust clipping metric calculation based on importance sampling level
106
+ if importance_sampling_level == "token":
107
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
108
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
109
+ )
110
+ else: # sequence level
111
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
112
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
113
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
114
+ )
115
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
116
+
91
117
  metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
92
118
  return loss, metrics
93
119
 
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
111
137
  epsilon_high=0.2,
112
138
  loss_type="bnpo",
113
139
  max_completion_length=None,
140
+ importance_sampling_level="token",
114
141
  temperature=1.0,
115
142
  compiled=True,
116
143
  use_ref_model=True,
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
132
159
  beta (float): Weight for the KL penalty
133
160
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
134
161
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
135
163
  temperature (float): Temperature for the logits
136
164
  compiled (bool): Whether to use torch compile
137
165
  use_ref_model (bool): Whether to use a reference model
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
162
190
  compiled=compiled,
163
191
  use_ref_model=use_ref_model,
164
192
  chunk_size=chunk_size,
193
+ importance_sampling_level=importance_sampling_level,
165
194
  )
166
195
 
167
196
  @staticmethod
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
187
216
  None, # grad_epsilon_high
188
217
  None, # grad_loss_type (string, not differentiable)
189
218
  None, # grad_max_completion_length (int, not differentiable)
219
+ None, # grad_importance_sampling_level (string, not differentiable)
190
220
  None, # grad_temperature
191
221
  None, # grad_compiled
192
222
  None, # grad_use_ref_model
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
207
237
  epsilon_high: float = 0.2,
208
238
  loss_type: str = "bnpo",
209
239
  max_completion_length: Optional[int] = None,
240
+ importance_sampling_level: str = "token",
210
241
  temperature: float = 1.0,
211
242
  ):
212
243
  """
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
219
250
  epsilon_high (float): Upper bound for the importance sampling ratio.
220
251
  loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
221
252
  max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
222
254
  temperature (float): Temperature for the logits.
223
255
  """
224
256
  super().__init__()
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
230
262
  self.epsilon_high = epsilon_high
231
263
  self.loss_type = loss_type
232
264
  self.max_completion_length = max_completion_length
265
+ self.importance_sampling_level = importance_sampling_level
233
266
  self.temperature = temperature
234
267
 
235
268
  def forward(
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
263
296
  self.epsilon_high,
264
297
  self.loss_type,
265
298
  self.max_completion_length,
299
+ self.importance_sampling_level,
266
300
  self.temperature,
267
301
  self.compiled,
268
302
  self.use_ref_model,
@@ -9,7 +9,7 @@ from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunct
9
9
 
10
10
 
11
11
  class LigerMultiTokenAttention(nn.Module):
12
- """
12
+ r"""
13
13
  Multi-Token Attention:
14
14
  out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
15
15
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20250830153353
3
+ Version: 0.6.2.dev20250903164435
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -8,10 +8,10 @@ liger_kernel/chunked_loss/cpo_loss.py,sha256=Gzz1eU4kgcbdubFVRy55e8A1Cr-r45UgNic
8
8
  liger_kernel/chunked_loss/dpo_loss.py,sha256=I83khNs3QQjuhr8U3NIOAACkbse6DNiBV-TulPZ0lXw,9006
9
9
  liger_kernel/chunked_loss/functional.py,sha256=-XPDbLml9dHmvoSU2VNTUrBDFehuzvuAGPikVetBMtI,1132
10
10
  liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=ooR-qnZCyWJN935oHCSWLaKKKyaYERyhNczRGi1VOiw,11935
11
- liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=AA19cpv6D8mo5RbSK5GRCcZoOSnpxV_Z1eJlAsC5eic,13434
11
+ liger_kernel/chunked_loss/fused_linear_ppo.py,sha256=ZjpNP5VC-tXXIKb4AckkQ3iWWQeej-JoG4StJq3N0wg,13650
12
12
  liger_kernel/chunked_loss/fused_linear_preference.py,sha256=FIH85uUXAOgYx5Ax8MjFhJHVu-2pKtY7wSegd0zSyyY,18336
13
13
  liger_kernel/chunked_loss/fused_linear_unpaired_preference.py,sha256=RiuK3UtRwH9T6jZ36sA8Urj-TVuOLOO2syLg_JOQapY,13437
14
- liger_kernel/chunked_loss/grpo_loss.py,sha256=kuqHkYV383sUxqJN-DMsfADHi2hxHVyKx5S24TNc8bQ,10866
14
+ liger_kernel/chunked_loss/grpo_loss.py,sha256=SkZuKoW8K94UbWR-OtfopsQkuQ8tFOr_90AGR6_Mhes,12844
15
15
  liger_kernel/chunked_loss/jsd_loss.py,sha256=gRhnmB8xwuz7FcMJi5v5eyBsq01owaCbcyyrF4rYtY0,7133
16
16
  liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsmSbQyqwQY,7529
17
17
  liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
@@ -59,7 +59,7 @@ liger_kernel/transformers/kl_div.py,sha256=WLffFbh1EExD2Eb1F7lN11fo9JJC-0751WJjZ
59
59
  liger_kernel/transformers/layer_norm.py,sha256=c9pk3PEasOKYR0rhe5e5nNrnYKVCEW4VC8S6LpCq9EQ,906
60
60
  liger_kernel/transformers/llama4_rope.py,sha256=kS6PSHEwf3dS7hD7C7p8S0geugx2EMCiP0h0F7LsUoY,3639
61
61
  liger_kernel/transformers/monkey_patch.py,sha256=pG3Yf0fMg4_0pAncc2wLtpdfXvmC5CROpNJ43-MmElM,93075
62
- liger_kernel/transformers/multi_token_attention.py,sha256=l9VDICK0dfmifUDW668hGscP8AHq2rYcM2oGUa3baRQ,1751
62
+ liger_kernel/transformers/multi_token_attention.py,sha256=K3NIY9_5TPgZ4_Rahn0xnkMXxD_fmlJHK4CWGYvGQp0,1752
63
63
  liger_kernel/transformers/qwen2vl_mrope.py,sha256=5EwSqrMdsL9MYspeBMXBsNJKvH0MOmRrtJXAJlnnlOI,1047
64
64
  liger_kernel/transformers/rms_norm.py,sha256=vkekcvTeWY8vL4H6hg3t0XeY0Ew_3OFMPHuzqlxPPVw,2719
65
65
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
@@ -96,9 +96,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
96
96
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
97
97
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
98
98
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
99
- liger_kernel_nightly-0.6.2.dev20250830153353.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
100
- liger_kernel_nightly-0.6.2.dev20250830153353.dist-info/METADATA,sha256=pdvNhCMdDJLC-ipmXC0fO7Nw_8EP9e0oNfbnU_TCPVg,24504
101
- liger_kernel_nightly-0.6.2.dev20250830153353.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
102
- liger_kernel_nightly-0.6.2.dev20250830153353.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
103
- liger_kernel_nightly-0.6.2.dev20250830153353.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
104
- liger_kernel_nightly-0.6.2.dev20250830153353.dist-info/RECORD,,
99
+ liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
100
+ liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/METADATA,sha256=BgiSTSMznb0cvZyFqU68T0sEIAOBcf9hvuO6jIPCcC8,24504
101
+ liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
102
+ liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
103
+ liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
104
+ liger_kernel_nightly-0.6.2.dev20250903164435.dist-info/RECORD,,