liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def preference_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("Preference loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ target,
24
+ preference_labels,
25
+ bias=None,
26
+ chunk_size=1,
27
+ ignore_index=-100,
28
+ compiled=True,
29
+ use_ref_model=False,
30
+ ref_input=None,
31
+ ref_weight=None,
32
+ ref_bias=None,
33
+ average_log_prob=False,
34
+ **loss_kwargs,
35
+ ):
36
+ """
37
+ Base class for fused linear layer with unpaired preference loss like KTO
38
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
39
+
40
+ The mental model is:
41
+
42
+ forward()
43
+ ├── Loop over chunks
44
+ └── compute_loss()
45
+ ├── chunk_forward() # Compute logits and log probs
46
+ └── prefer_loss() # Calculate preference loss
47
+
48
+ Args:
49
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
50
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
51
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
52
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
53
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
54
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
55
+ ignore_index (int): Index to ignore for loss computation.
56
+ beta (float): Weight for the preference loss.
57
+ compiled (bool): Whether to use torch compile for chunk accumulation.
58
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
59
+ preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
60
+ Shape: (batch_size,).
61
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
62
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
63
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
64
+ loss_kwargs (dict): Other possible arguments that a loss function might need
65
+ """
66
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
67
+ CHUNK_SIZE = chunk_size
68
+
69
+ # Gradients to be accumulated
70
+ grad_inputs = []
71
+ grad_weight = torch.zeros_like(weight)
72
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
73
+
74
+ # Loss to be accumulated
75
+ loss_acc = torch.zeros((), device=_input.device)
76
+
77
+ # Metrics to be recorded
78
+ chosen_logps_sum = torch.zeros((), device=_input.device)
79
+ rejected_logps_sum = torch.zeros((), device=_input.device)
80
+ chosen_logits_sum = torch.zeros((), device=_input.device)
81
+ rejected_logits_sum = torch.zeros((), device=_input.device)
82
+ aggregated_aux_outputs = []
83
+
84
+ compute_loss = partial(
85
+ LigerFusedLinearUnpairedPreferenceBase._compute_loss,
86
+ preference_loss_fn=cls.preference_loss_fn,
87
+ full_target=target,
88
+ ignore_index=ignore_index,
89
+ use_ref_model=use_ref_model,
90
+ ref_weight=ref_weight,
91
+ ref_bias=ref_bias,
92
+ average_log_prob=average_log_prob,
93
+ **loss_kwargs,
94
+ )
95
+
96
+ def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
97
+ """
98
+ Fused forward and backward pass for a chunk of input and target.
99
+ """
100
+ argnums = (0, 1, 4) if bias is not None else (0, 1)
101
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
102
+ input_chunk,
103
+ weight,
104
+ target_chunk,
105
+ preference_labels_chunk,
106
+ bias,
107
+ ref_input_chunk=ref_input_chunk,
108
+ )
109
+
110
+ def accumulate_chunk(
111
+ input_chunk,
112
+ target_chunk,
113
+ preference_labels_chunk=None,
114
+ ref_input_chunk=None,
115
+ ):
116
+ (
117
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
118
+ (
119
+ chunk_loss,
120
+ (
121
+ chunk_chosen_logps_sum,
122
+ chunk_rejected_logps_sum,
123
+ chunk_chosen_logits_sum,
124
+ chunk_rejected_logits_sum,
125
+ *aux_outputs,
126
+ ),
127
+ ),
128
+ ) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
129
+ if bias is not None:
130
+ grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
131
+
132
+ # Accumulate gradients
133
+ grad_weight.add_(chunk_grad_weight)
134
+ grad_inputs.append(chunk_grad_input)
135
+
136
+ # Accumulate loss
137
+ loss_acc.add_(chunk_loss)
138
+
139
+ # Accumulate metrics
140
+ chosen_logps_sum.add_(chunk_chosen_logps_sum)
141
+ rejected_logps_sum.add_(chunk_rejected_logps_sum)
142
+ chosen_logits_sum.add_(chunk_chosen_logits_sum)
143
+ rejected_logits_sum.add_(chunk_rejected_logits_sum)
144
+
145
+ # aux_outputs
146
+ # Initialize storage for aux_outputs
147
+ if len(aggregated_aux_outputs) == 0:
148
+ for aux in aux_outputs:
149
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
150
+
151
+ # Process each aux_output
152
+ for i, aux in enumerate(aux_outputs):
153
+ if aux.ndim == 0:
154
+ aggregated_aux_outputs[i].add_(aux)
155
+
156
+ if compiled:
157
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
158
+
159
+ # When not paired, use labels to separate chosen and rejected
160
+ assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
161
+
162
+ chunks = max(1, _input.shape[0] // CHUNK_SIZE)
163
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
164
+ _target_chunks = torch.chunk(target, chunks=chunks, dim=0)
165
+ _preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
166
+
167
+ if use_ref_model:
168
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
169
+
170
+ for (
171
+ input_chunk,
172
+ target_chunk,
173
+ ref_input_chunk,
174
+ preference_labels_chunk,
175
+ ) in zip(
176
+ _input_chunks,
177
+ _target_chunks,
178
+ (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
179
+ _preference_labels_chunks,
180
+ ):
181
+ # mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
182
+ torch._dynamo.mark_dynamic(input_chunk, 1)
183
+ torch._dynamo.mark_dynamic(target_chunk, 1)
184
+ torch._dynamo.mark_dynamic(target, 1)
185
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
186
+ torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
187
+
188
+ # accumulate loss, gradients, and metrics
189
+ accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
190
+
191
+ # Aggregate aux outputs lists into tensors
192
+ for i, aux in enumerate(aggregated_aux_outputs):
193
+ if isinstance(aux, list):
194
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
195
+
196
+ ctx.save_for_backward(
197
+ torch.cat(grad_inputs, dim=0),
198
+ grad_weight,
199
+ grad_bias,
200
+ )
201
+
202
+ return_vars = (
203
+ chosen_logps_sum,
204
+ rejected_logps_sum,
205
+ chosen_logits_sum,
206
+ rejected_logits_sum,
207
+ )
208
+
209
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
210
+
211
+ @staticmethod
212
+ def backward(ctx, *grad_output):
213
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
214
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
215
+ grad_input = grad_input * grad_output[0][0]
216
+ grad_weight = grad_weight * grad_output[0][0]
217
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
218
+
219
+ return grad_input, grad_weight, None, None, grad_bias
220
+
221
+ @staticmethod
222
+ def chunk_forward(
223
+ input_chunk,
224
+ weight,
225
+ target_chunk,
226
+ preference_labels_chunk,
227
+ bias=None,
228
+ ignore_index=-100,
229
+ average_log_prob=False,
230
+ ):
231
+ logits_chunk = input_chunk @ weight.t()
232
+ if bias is not None:
233
+ logits_chunk = logits_chunk + bias
234
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
235
+ loss_mask_chunk = target_chunk != ignore_index
236
+ label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
237
+
238
+ per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
239
+ if average_log_prob:
240
+ log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
241
+ else:
242
+ log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
243
+
244
+ chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
245
+ rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
246
+
247
+ chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
248
+ rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
249
+
250
+ return (
251
+ log_probs,
252
+ chosen_logps_sum,
253
+ rejected_logps_sum,
254
+ chosen_logits_sum,
255
+ rejected_logits_sum,
256
+ )
257
+
258
+ @staticmethod
259
+ def _compute_loss(
260
+ input_chunk,
261
+ weight,
262
+ target_chunk,
263
+ preference_labels_chunk,
264
+ bias=None,
265
+ preference_loss_fn=None,
266
+ full_target=None,
267
+ ignore_index=-100,
268
+ use_ref_model=False,
269
+ ref_input_chunk=None,
270
+ ref_weight=None,
271
+ ref_bias=None,
272
+ average_log_prob=False,
273
+ **loss_kwargs,
274
+ ):
275
+ """
276
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
277
+ Args:
278
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
279
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
280
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
281
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
282
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
283
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
284
+ ignore_index (int): Index to ignore for loss computation.
285
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
286
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
287
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
288
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
289
+ loss_kwargs (dict): Additional arguments for the loss function.
290
+ """
291
+ (
292
+ log_prob_chunk,
293
+ chosen_logps_sum,
294
+ rejected_logps_sum,
295
+ chosen_logits_sum,
296
+ rejected_logits_sum,
297
+ ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
298
+ input_chunk,
299
+ weight,
300
+ target_chunk,
301
+ preference_labels_chunk,
302
+ bias=bias,
303
+ ignore_index=ignore_index,
304
+ average_log_prob=average_log_prob,
305
+ )
306
+
307
+ if use_ref_model:
308
+ with torch.no_grad():
309
+ (
310
+ ref_log_prob_chunk,
311
+ _,
312
+ _,
313
+ _,
314
+ _,
315
+ ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
316
+ ref_input_chunk,
317
+ ref_weight,
318
+ target_chunk,
319
+ preference_labels_chunk,
320
+ ref_bias,
321
+ ignore_index=ignore_index,
322
+ average_log_prob=average_log_prob,
323
+ )
324
+ loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
325
+
326
+ preference_loss_outputs = preference_loss_fn(
327
+ log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
328
+ )
329
+ if isinstance(preference_loss_outputs, tuple):
330
+ preference_loss_chunk, *aux_outputs = preference_loss_outputs
331
+ else:
332
+ preference_loss_chunk, aux_outputs = preference_loss_outputs, []
333
+
334
+ return_vars = (
335
+ chosen_logps_sum,
336
+ rejected_logps_sum,
337
+ chosen_logits_sum,
338
+ rejected_logits_sum,
339
+ )
340
+
341
+ return preference_loss_chunk, (*return_vars, *aux_outputs)
@@ -0,0 +1,304 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
6
+
7
+
8
+ def k3_loss_fn(log_p, log_q):
9
+ # computes k3 estimate of KL[q, p]
10
+ # ref: http://joschu.net/blog/kl-approx.html
11
+ return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
12
+
13
+
14
+ def clip_coef_fn(coef, epsilon_low, epsilon_high):
15
+ return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
16
+
17
+
18
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
19
+ @staticmethod
20
+ def ppo_loss_fn(
21
+ log_probs,
22
+ selected_token_ids,
23
+ attention_mask,
24
+ advantages,
25
+ full_attention_mask,
26
+ ref_per_token_logps=None, # shape: [chunk_size, seq_len]
27
+ old_per_token_logps=None,
28
+ ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
29
+ epsilon_low=0.2,
30
+ epsilon_high=0.2,
31
+ beta=0.04,
32
+ loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
33
+ max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
35
+ **kwargs,
36
+ ):
37
+ """GRPO Loss Function matching GRPOTrainer implementation."""
38
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
39
+ -1
40
+ ) # (batch_size, seq_len)
41
+
42
+ # Get reference model probabilities
43
+ if ref_per_token_logps is None:
44
+ if ref_log_probs is not None:
45
+ with torch.no_grad():
46
+ ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
47
+ -1
48
+ )
49
+ else:
50
+ ref_per_token_logps = per_token_logps.detach()
51
+
52
+ # Compute policy gradient loss with importance sampling ratio
53
+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
70
+ coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
71
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
72
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
73
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
74
+ if beta != 0.0:
75
+ # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
76
+ kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
77
+ # Combine losses
78
+ per_token_loss = per_token_loss + beta * kl_div
79
+
80
+ # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
81
+ # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
82
+ # and TRL GRPO implementation
83
+ # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
84
+ if loss_type == "grpo":
85
+ # Average per-sequence loss
86
+ loss = (
87
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
88
+ ).sum() / full_attention_mask.shape[0]
89
+ elif loss_type == "bnpo":
90
+ # Batch Normalized Per-token loss (original implementation)
91
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
92
+ elif loss_type == "dr_grpo":
93
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
94
+ if max_completion_length is None:
95
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ else:
98
+ raise ValueError(f"Unknown loss type: {loss_type}")
99
+
100
+ # Calculate metrics
101
+ metrics = []
102
+ if beta != 0.0:
103
+ metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
104
+
105
+ # Adjust clipping metric calculation based on importance sampling level
106
+ if importance_sampling_level == "token":
107
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
108
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
109
+ )
110
+ else: # sequence level
111
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
112
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
113
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
114
+ )
115
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
116
+
117
+ metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
118
+ return loss, metrics
119
+
120
+ @classmethod
121
+ def forward(
122
+ cls,
123
+ ctx,
124
+ _input,
125
+ weight,
126
+ selected_token_ids,
127
+ attention_mask,
128
+ advantages,
129
+ bias=None,
130
+ ref_per_token_logps=None,
131
+ old_per_token_logps=None,
132
+ ref_input=None,
133
+ ref_weight=None,
134
+ ref_bias=None,
135
+ beta=0.04,
136
+ epsilon_low=0.2,
137
+ epsilon_high=0.2,
138
+ loss_type="bnpo",
139
+ max_completion_length=None,
140
+ importance_sampling_level="token",
141
+ temperature=1.0,
142
+ compiled=True,
143
+ use_ref_model=True,
144
+ chunk_size=1,
145
+ ):
146
+ """
147
+ Fused linear layer with GRPO loss.
148
+ Args:
149
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
150
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
151
+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
152
+ attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
153
+ advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
154
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
155
+ ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
156
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
157
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
158
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
159
+ beta (float): Weight for the KL penalty
160
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
161
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
162
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
163
+ temperature (float): Temperature for the logits
164
+ compiled (bool): Whether to use torch compile
165
+ use_ref_model (bool): Whether to use a reference model
166
+ chunk_size (int): Size of chunks for processing.
167
+ Returns:
168
+ torch.Tensor: Computed loss
169
+ """
170
+ return super().forward(
171
+ cls=cls,
172
+ ctx=ctx,
173
+ _input=_input,
174
+ weight=weight,
175
+ selected_token_ids=selected_token_ids,
176
+ attention_mask=attention_mask,
177
+ advantages=advantages,
178
+ bias=bias,
179
+ ref_per_token_logps=ref_per_token_logps,
180
+ old_per_token_logps=old_per_token_logps,
181
+ ref_input=ref_input,
182
+ ref_weight=ref_weight,
183
+ ref_bias=ref_bias,
184
+ beta=beta,
185
+ epsilon_low=epsilon_low,
186
+ epsilon_high=epsilon_high,
187
+ loss_type=loss_type,
188
+ max_completion_length=max_completion_length,
189
+ temperature=temperature,
190
+ compiled=compiled,
191
+ use_ref_model=use_ref_model,
192
+ chunk_size=chunk_size,
193
+ importance_sampling_level=importance_sampling_level,
194
+ )
195
+
196
+ @staticmethod
197
+ def backward(ctx, grad_output, *grad_metrics):
198
+ """Backward pass for GRPO loss.
199
+
200
+ Args:
201
+ grad_output: Gradient of the loss (scalar)
202
+ grad_metrics: Gradients of the metrics (not used in backward computation)
203
+ """
204
+ grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
205
+ return (
206
+ *grads[
207
+ :6
208
+ ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
209
+ None, # grad_ref_per_token_logps
210
+ None, # grad_old_per_token_logps
211
+ None, # grad_ref_input
212
+ None, # grad_ref_weight
213
+ None, # grad_ref_bias
214
+ None, # grad_beta
215
+ None, # grad_epsilon_low
216
+ None, # grad_epsilon_high
217
+ None, # grad_loss_type (string, not differentiable)
218
+ None, # grad_max_completion_length (int, not differentiable)
219
+ None, # grad_importance_sampling_level (string, not differentiable)
220
+ None, # grad_temperature
221
+ None, # grad_compiled
222
+ None, # grad_use_ref_model
223
+ None, # grad_chunk_size
224
+ )
225
+
226
+
227
+ class LigerFusedLinearGRPOLoss(torch.nn.Module):
228
+ """Fused linear layer with GRPO loss."""
229
+
230
+ def __init__(
231
+ self,
232
+ beta: float = 0.04,
233
+ compiled: bool = True,
234
+ use_ref_model: bool = True,
235
+ chunk_size: int = 1,
236
+ epsilon_low: float = 0.2,
237
+ epsilon_high: float = 0.2,
238
+ loss_type: str = "bnpo",
239
+ max_completion_length: Optional[int] = None,
240
+ importance_sampling_level: str = "token",
241
+ temperature: float = 1.0,
242
+ ):
243
+ """
244
+ Args:
245
+ beta (float): Weight for the KL penalty.
246
+ compiled (bool): Whether to use torch compile.
247
+ use_ref_model (bool): Whether to use a reference model.
248
+ chunk_size (int): Size of chunks for processing.
249
+ epsilon_low (float): Lower bound for the importance sampling ratio.
250
+ epsilon_high (float): Upper bound for the importance sampling ratio.
251
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
252
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
253
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
254
+ temperature (float): Temperature for the logits.
255
+ """
256
+ super().__init__()
257
+ self.beta = beta
258
+ self.compiled = compiled
259
+ self.use_ref_model = use_ref_model
260
+ self.chunk_size = chunk_size
261
+ self.epsilon_low = epsilon_low
262
+ self.epsilon_high = epsilon_high
263
+ self.loss_type = loss_type
264
+ self.max_completion_length = max_completion_length
265
+ self.importance_sampling_level = importance_sampling_level
266
+ self.temperature = temperature
267
+
268
+ def forward(
269
+ self,
270
+ _input,
271
+ lin_weight,
272
+ selected_token_ids,
273
+ attention_mask,
274
+ advantages,
275
+ bias=None,
276
+ ref_per_token_logps=None,
277
+ old_per_token_logps=None,
278
+ ref_input=None,
279
+ ref_weight=None,
280
+ ref_bias=None,
281
+ ):
282
+ return LigerFusedLinearGRPOFunction.apply(
283
+ _input,
284
+ lin_weight,
285
+ selected_token_ids,
286
+ attention_mask,
287
+ advantages,
288
+ bias,
289
+ ref_per_token_logps,
290
+ old_per_token_logps,
291
+ ref_input,
292
+ ref_weight,
293
+ ref_bias,
294
+ self.beta,
295
+ self.epsilon_low,
296
+ self.epsilon_high,
297
+ self.loss_type,
298
+ self.max_completion_length,
299
+ self.importance_sampling_level,
300
+ self.temperature,
301
+ self.compiled,
302
+ self.use_ref_model,
303
+ self.chunk_size,
304
+ )