liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def preference_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("Preference loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ target,
24
+ preference_labels,
25
+ bias=None,
26
+ chunk_size=1,
27
+ ignore_index=-100,
28
+ compiled=True,
29
+ use_ref_model=False,
30
+ ref_input=None,
31
+ ref_weight=None,
32
+ ref_bias=None,
33
+ average_log_prob=False,
34
+ **loss_kwargs,
35
+ ):
36
+ """
37
+ Base class for fused linear layer with unpaired preference loss like KTO
38
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
39
+
40
+ The mental model is:
41
+
42
+ forward()
43
+ ├── Loop over chunks
44
+ └── compute_loss()
45
+ ├── chunk_forward() # Compute logits and log probs
46
+ └── prefer_loss() # Calculate preference loss
47
+
48
+ Args:
49
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
50
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
51
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
52
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
53
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
54
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
55
+ ignore_index (int): Index to ignore for loss computation.
56
+ beta (float): Weight for the preference loss.
57
+ compiled (bool): Whether to use torch compile for chunk accumulation.
58
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
59
+ preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
60
+ Shape: (batch_size,).
61
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
62
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
63
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
64
+ loss_kwargs (dict): Other possible arguments that a loss function might need
65
+ """
66
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
67
+ CHUNK_SIZE = chunk_size
68
+
69
+ # Gradients to be accumulated
70
+ grad_inputs = []
71
+ grad_weight = torch.zeros_like(weight)
72
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
73
+
74
+ # Loss to be accumulated
75
+ loss_acc = torch.zeros((), device=_input.device)
76
+
77
+ # Metrics to be recorded
78
+ chosen_logps_sum = torch.zeros((), device=_input.device)
79
+ rejected_logps_sum = torch.zeros((), device=_input.device)
80
+ chosen_logits_sum = torch.zeros((), device=_input.device)
81
+ rejected_logits_sum = torch.zeros((), device=_input.device)
82
+ aggregated_aux_outputs = []
83
+
84
+ compute_loss = partial(
85
+ LigerFusedLinearUnpairedPreferenceBase._compute_loss,
86
+ preference_loss_fn=cls.preference_loss_fn,
87
+ full_target=target,
88
+ ignore_index=ignore_index,
89
+ use_ref_model=use_ref_model,
90
+ ref_weight=ref_weight,
91
+ ref_bias=ref_bias,
92
+ average_log_prob=average_log_prob,
93
+ **loss_kwargs,
94
+ )
95
+
96
+ def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
97
+ """
98
+ Fused forward and backward pass for a chunk of input and target.
99
+ """
100
+ argnums = (0, 1, 4) if bias is not None else (0, 1)
101
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
102
+ input_chunk,
103
+ weight,
104
+ target_chunk,
105
+ preference_labels_chunk,
106
+ bias,
107
+ ref_input_chunk=ref_input_chunk,
108
+ )
109
+
110
+ def accumulate_chunk(
111
+ input_chunk,
112
+ target_chunk,
113
+ preference_labels_chunk=None,
114
+ ref_input_chunk=None,
115
+ ):
116
+ (
117
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias),
118
+ (
119
+ chunk_loss,
120
+ (
121
+ chunk_chosen_logps_sum,
122
+ chunk_rejected_logps_sum,
123
+ chunk_chosen_logits_sum,
124
+ chunk_rejected_logits_sum,
125
+ *aux_outputs,
126
+ ),
127
+ ),
128
+ ) = fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
129
+ if bias is not None:
130
+ grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
131
+
132
+ # Accumulate gradients
133
+ grad_weight.add_(chunk_grad_weight)
134
+ grad_inputs.append(chunk_grad_input)
135
+
136
+ # Accumulate loss
137
+ loss_acc.add_(chunk_loss)
138
+
139
+ # Accumulate metrics
140
+ chosen_logps_sum.add_(chunk_chosen_logps_sum)
141
+ rejected_logps_sum.add_(chunk_rejected_logps_sum)
142
+ chosen_logits_sum.add_(chunk_chosen_logits_sum)
143
+ rejected_logits_sum.add_(chunk_rejected_logits_sum)
144
+
145
+ # aux_outputs
146
+ # Initialize storage for aux_outputs
147
+ if len(aggregated_aux_outputs) == 0:
148
+ for aux in aux_outputs:
149
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
150
+
151
+ # Process each aux_output
152
+ for i, aux in enumerate(aux_outputs):
153
+ if aux.ndim == 0:
154
+ aggregated_aux_outputs[i].add_(aux)
155
+
156
+ if compiled:
157
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
158
+
159
+ # When not paired, use labels to separate chosen and rejected
160
+ assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
161
+
162
+ chunks = max(1, _input.shape[0] // CHUNK_SIZE)
163
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
164
+ _target_chunks = torch.chunk(target, chunks=chunks, dim=0)
165
+ _preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
166
+
167
+ if use_ref_model:
168
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
169
+
170
+ for (
171
+ input_chunk,
172
+ target_chunk,
173
+ ref_input_chunk,
174
+ preference_labels_chunk,
175
+ ) in zip(
176
+ _input_chunks,
177
+ _target_chunks,
178
+ (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
179
+ _preference_labels_chunks,
180
+ ):
181
+ # mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
182
+ torch._dynamo.mark_dynamic(input_chunk, 1)
183
+ torch._dynamo.mark_dynamic(target_chunk, 1)
184
+ torch._dynamo.mark_dynamic(target, 1)
185
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
186
+ torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
187
+
188
+ # accumulate loss, gradients, and metrics
189
+ accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
190
+
191
+ # Aggregate aux outputs lists into tensors
192
+ for i, aux in enumerate(aggregated_aux_outputs):
193
+ if isinstance(aux, list):
194
+ aggregated_aux_outputs[i] = torch.cat(aux, dim=0)
195
+
196
+ ctx.save_for_backward(
197
+ torch.cat(grad_inputs, dim=0),
198
+ grad_weight,
199
+ grad_bias,
200
+ )
201
+
202
+ return_vars = (
203
+ chosen_logps_sum,
204
+ rejected_logps_sum,
205
+ chosen_logits_sum,
206
+ rejected_logits_sum,
207
+ )
208
+
209
+ return loss_acc, (*return_vars, *aggregated_aux_outputs)
210
+
211
+ @staticmethod
212
+ def backward(ctx, *grad_output):
213
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
214
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
215
+ grad_input = grad_input * grad_output[0][0]
216
+ grad_weight = grad_weight * grad_output[0][0]
217
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
218
+
219
+ return grad_input, grad_weight, None, None, grad_bias
220
+
221
+ @staticmethod
222
+ def chunk_forward(
223
+ input_chunk,
224
+ weight,
225
+ target_chunk,
226
+ preference_labels_chunk,
227
+ bias=None,
228
+ ignore_index=-100,
229
+ average_log_prob=False,
230
+ ):
231
+ logits_chunk = input_chunk @ weight.t()
232
+ if bias is not None:
233
+ logits_chunk = logits_chunk + bias
234
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
235
+ loss_mask_chunk = target_chunk != ignore_index
236
+ label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
237
+
238
+ per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
239
+ if average_log_prob:
240
+ log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
241
+ else:
242
+ log_probs = (per_token_logps_chunk * loss_mask_chunk).sum(-1)
243
+
244
+ chosen_logps_sum = (log_probs * preference_labels_chunk.unsqueeze(1)).sum()
245
+ rejected_logps_sum = (log_probs * (~preference_labels_chunk).unsqueeze(1)).sum()
246
+
247
+ chosen_logits_sum = (logits_chunk * preference_labels_chunk.unsqueeze(1)).sum()
248
+ rejected_logits_sum = (logits_chunk * (~preference_labels_chunk).unsqueeze(1)).sum()
249
+
250
+ return (
251
+ log_probs,
252
+ chosen_logps_sum,
253
+ rejected_logps_sum,
254
+ chosen_logits_sum,
255
+ rejected_logits_sum,
256
+ )
257
+
258
+ @staticmethod
259
+ def _compute_loss(
260
+ input_chunk,
261
+ weight,
262
+ target_chunk,
263
+ preference_labels_chunk,
264
+ bias=None,
265
+ preference_loss_fn=None,
266
+ full_target=None,
267
+ ignore_index=-100,
268
+ use_ref_model=False,
269
+ ref_input_chunk=None,
270
+ ref_weight=None,
271
+ ref_bias=None,
272
+ average_log_prob=False,
273
+ **loss_kwargs,
274
+ ):
275
+ """
276
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
277
+ Args:
278
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
279
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
280
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
281
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
282
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
283
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
284
+ ignore_index (int): Index to ignore for loss computation.
285
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
286
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
287
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
288
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
289
+ loss_kwargs (dict): Additional arguments for the loss function.
290
+ """
291
+ (
292
+ log_prob_chunk,
293
+ chosen_logps_sum,
294
+ rejected_logps_sum,
295
+ chosen_logits_sum,
296
+ rejected_logits_sum,
297
+ ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
298
+ input_chunk,
299
+ weight,
300
+ target_chunk,
301
+ preference_labels_chunk,
302
+ bias=bias,
303
+ ignore_index=ignore_index,
304
+ average_log_prob=average_log_prob,
305
+ )
306
+
307
+ if use_ref_model:
308
+ with torch.no_grad():
309
+ (
310
+ ref_log_prob_chunk,
311
+ _,
312
+ _,
313
+ _,
314
+ _,
315
+ ) = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
316
+ ref_input_chunk,
317
+ ref_weight,
318
+ target_chunk,
319
+ preference_labels_chunk,
320
+ ref_bias,
321
+ ignore_index=ignore_index,
322
+ average_log_prob=average_log_prob,
323
+ )
324
+ loss_kwargs["ref_log_prob_chunk"] = ref_log_prob_chunk
325
+
326
+ preference_loss_outputs = preference_loss_fn(
327
+ log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
328
+ )
329
+ if isinstance(preference_loss_outputs, tuple):
330
+ preference_loss_chunk, *aux_outputs = preference_loss_outputs
331
+ else:
332
+ preference_loss_chunk, aux_outputs = preference_loss_outputs, []
333
+
334
+ return_vars = (
335
+ chosen_logps_sum,
336
+ rejected_logps_sum,
337
+ chosen_logits_sum,
338
+ rejected_logits_sum,
339
+ )
340
+
341
+ return preference_loss_chunk, (*return_vars, *aux_outputs)
@@ -0,0 +1,307 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
6
+
7
+
8
+ def k3_loss_fn(log_p, log_q):
9
+ # computes k3 estimate of KL[q, p]
10
+ # ref: http://joschu.net/blog/kl-approx.html
11
+ return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
12
+
13
+
14
+ def clip_coef_fn(coef, epsilon_low, epsilon_high):
15
+ return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
16
+
17
+
18
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
19
+ @staticmethod
20
+ def ppo_loss_fn(
21
+ log_probs,
22
+ selected_token_ids,
23
+ attention_mask,
24
+ advantages,
25
+ full_attention_mask,
26
+ ref_per_token_logps=None, # shape: [chunk_size, seq_len]
27
+ old_per_token_logps=None,
28
+ ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
29
+ epsilon_low=0.2,
30
+ epsilon_high=0.2,
31
+ beta=0.04,
32
+ loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
33
+ max_completion_length=None, # Required for dr_grpo
34
+ importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
35
+ **kwargs,
36
+ ):
37
+ """GRPO Loss Function matching GRPOTrainer implementation."""
38
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
39
+ -1
40
+ ) # (batch_size, seq_len)
41
+
42
+ # Get reference model probabilities
43
+ if ref_per_token_logps is None:
44
+ if ref_log_probs is not None:
45
+ with torch.no_grad():
46
+ ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
47
+ -1
48
+ )
49
+ else:
50
+ ref_per_token_logps = per_token_logps.detach()
51
+
52
+ # Compute policy gradient loss with importance sampling ratio
53
+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
54
+ log_ratio = per_token_logps - old_per_token_logps
55
+
56
+ if importance_sampling_level == "token":
57
+ log_importance_weights = log_ratio
58
+ elif importance_sampling_level == "sequence":
59
+ log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
60
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
61
+ else:
62
+ raise ValueError(
63
+ f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
64
+ "and 'sequence'."
65
+ )
66
+
67
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
68
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
69
+ coef_1 = torch.exp(log_importance_weights)
70
+ coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
71
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
72
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
73
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
74
+ if beta != 0.0:
75
+ # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
76
+ kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
77
+ # Combine losses
78
+ per_token_loss = per_token_loss + beta * kl_div
79
+
80
+ # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
81
+ # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
82
+ # and TRL GRPO implementation
83
+ # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
84
+ if loss_type == "grpo":
85
+ # Average per-sequence loss
86
+ loss = (
87
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
88
+ ).sum() / full_attention_mask.shape[0]
89
+ elif loss_type == "bnpo":
90
+ # Batch Normalized Per-token loss (original implementation)
91
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
92
+ elif loss_type == "dr_grpo":
93
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
94
+ if max_completion_length is None:
95
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
96
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
97
+ elif loss_type == "dapo":
98
+ loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
99
+ loss = (per_token_loss * attention_mask).sum() / loss_normalizer
100
+ else:
101
+ raise ValueError(f"Unknown loss type: {loss_type}")
102
+
103
+ # Calculate metrics
104
+ metrics = []
105
+ if beta != 0.0:
106
+ metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
107
+
108
+ # Adjust clipping metric calculation based on importance sampling level
109
+ if importance_sampling_level == "token":
110
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
111
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
112
+ )
113
+ else: # sequence level
114
+ # For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
115
+ is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
116
+ (coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
117
+ )
118
+ is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
119
+
120
+ metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
121
+ return loss, metrics
122
+
123
+ @classmethod
124
+ def forward(
125
+ cls,
126
+ ctx,
127
+ _input,
128
+ weight,
129
+ selected_token_ids,
130
+ attention_mask,
131
+ advantages,
132
+ bias=None,
133
+ ref_per_token_logps=None,
134
+ old_per_token_logps=None,
135
+ ref_input=None,
136
+ ref_weight=None,
137
+ ref_bias=None,
138
+ beta=0.04,
139
+ epsilon_low=0.2,
140
+ epsilon_high=0.2,
141
+ loss_type="dapo",
142
+ max_completion_length=None,
143
+ importance_sampling_level="token",
144
+ temperature=1.0,
145
+ compiled=True,
146
+ use_ref_model=True,
147
+ chunk_size=1,
148
+ ):
149
+ """
150
+ Fused linear layer with GRPO loss.
151
+ Args:
152
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
153
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
154
+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
155
+ attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
156
+ advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
157
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
158
+ ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
159
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
160
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
161
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
162
+ beta (float): Weight for the KL penalty
163
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
164
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
165
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
166
+ temperature (float): Temperature for the logits
167
+ compiled (bool): Whether to use torch compile
168
+ use_ref_model (bool): Whether to use a reference model
169
+ chunk_size (int): Size of chunks for processing.
170
+ Returns:
171
+ torch.Tensor: Computed loss
172
+ """
173
+ return super().forward(
174
+ cls=cls,
175
+ ctx=ctx,
176
+ _input=_input,
177
+ weight=weight,
178
+ selected_token_ids=selected_token_ids,
179
+ attention_mask=attention_mask,
180
+ advantages=advantages,
181
+ bias=bias,
182
+ ref_per_token_logps=ref_per_token_logps,
183
+ old_per_token_logps=old_per_token_logps,
184
+ ref_input=ref_input,
185
+ ref_weight=ref_weight,
186
+ ref_bias=ref_bias,
187
+ beta=beta,
188
+ epsilon_low=epsilon_low,
189
+ epsilon_high=epsilon_high,
190
+ loss_type=loss_type,
191
+ max_completion_length=max_completion_length,
192
+ temperature=temperature,
193
+ compiled=compiled,
194
+ use_ref_model=use_ref_model,
195
+ chunk_size=chunk_size,
196
+ importance_sampling_level=importance_sampling_level,
197
+ )
198
+
199
+ @staticmethod
200
+ def backward(ctx, grad_output, *grad_metrics):
201
+ """Backward pass for GRPO loss.
202
+
203
+ Args:
204
+ grad_output: Gradient of the loss (scalar)
205
+ grad_metrics: Gradients of the metrics (not used in backward computation)
206
+ """
207
+ grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
208
+ return (
209
+ *grads[
210
+ :6
211
+ ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
212
+ None, # grad_ref_per_token_logps
213
+ None, # grad_old_per_token_logps
214
+ None, # grad_ref_input
215
+ None, # grad_ref_weight
216
+ None, # grad_ref_bias
217
+ None, # grad_beta
218
+ None, # grad_epsilon_low
219
+ None, # grad_epsilon_high
220
+ None, # grad_loss_type (string, not differentiable)
221
+ None, # grad_max_completion_length (int, not differentiable)
222
+ None, # grad_importance_sampling_level (string, not differentiable)
223
+ None, # grad_temperature
224
+ None, # grad_compiled
225
+ None, # grad_use_ref_model
226
+ None, # grad_chunk_size
227
+ )
228
+
229
+
230
+ class LigerFusedLinearGRPOLoss(torch.nn.Module):
231
+ """Fused linear layer with GRPO loss."""
232
+
233
+ def __init__(
234
+ self,
235
+ beta: float = 0.04,
236
+ compiled: bool = True,
237
+ use_ref_model: bool = True,
238
+ chunk_size: int = 1,
239
+ epsilon_low: float = 0.2,
240
+ epsilon_high: float = 0.2,
241
+ loss_type: str = "dapo",
242
+ max_completion_length: Optional[int] = None,
243
+ importance_sampling_level: str = "token",
244
+ temperature: float = 1.0,
245
+ ):
246
+ """
247
+ Args:
248
+ beta (float): Weight for the KL penalty.
249
+ compiled (bool): Whether to use torch compile.
250
+ use_ref_model (bool): Whether to use a reference model.
251
+ chunk_size (int): Size of chunks for processing.
252
+ epsilon_low (float): Lower bound for the importance sampling ratio.
253
+ epsilon_high (float): Upper bound for the importance sampling ratio.
254
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
255
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
256
+ importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
257
+ temperature (float): Temperature for the logits.
258
+ """
259
+ super().__init__()
260
+ self.beta = beta
261
+ self.compiled = compiled
262
+ self.use_ref_model = use_ref_model
263
+ self.chunk_size = chunk_size
264
+ self.epsilon_low = epsilon_low
265
+ self.epsilon_high = epsilon_high
266
+ self.loss_type = loss_type
267
+ self.max_completion_length = max_completion_length
268
+ self.importance_sampling_level = importance_sampling_level
269
+ self.temperature = temperature
270
+
271
+ def forward(
272
+ self,
273
+ _input,
274
+ lin_weight,
275
+ selected_token_ids,
276
+ attention_mask,
277
+ advantages,
278
+ bias=None,
279
+ ref_per_token_logps=None,
280
+ old_per_token_logps=None,
281
+ ref_input=None,
282
+ ref_weight=None,
283
+ ref_bias=None,
284
+ ):
285
+ return LigerFusedLinearGRPOFunction.apply(
286
+ _input,
287
+ lin_weight,
288
+ selected_token_ids,
289
+ attention_mask,
290
+ advantages,
291
+ bias,
292
+ ref_per_token_logps,
293
+ old_per_token_logps,
294
+ ref_input,
295
+ ref_weight,
296
+ ref_bias,
297
+ self.beta,
298
+ self.epsilon_low,
299
+ self.epsilon_high,
300
+ self.loss_type,
301
+ self.max_completion_length,
302
+ self.importance_sampling_level,
303
+ self.temperature,
304
+ self.compiled,
305
+ self.use_ref_model,
306
+ self.chunk_size,
307
+ )