liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class LigerFusedLinearDistillationBase(torch.autograd.Function):
12
+ @abstractmethod
13
+ def distillation_loss_fn(
14
+ student_logits,
15
+ teacher_logits,
16
+ ):
17
+ """
18
+ Compute distillation loss.
19
+ Args:
20
+ student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
21
+ teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
22
+ Returns:
23
+ torch.Tensor: Sum of distillation losses for the chunk. The class will handle
24
+ converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
25
+ """
26
+ raise NotImplementedError("Distillation loss function must be implemented.")
27
+
28
+ @staticmethod
29
+ def chunk_forward(
30
+ student_input_chunk,
31
+ student_weight,
32
+ teacher_input_chunk,
33
+ teacher_weight,
34
+ target_chunk,
35
+ student_bias=None,
36
+ teacher_bias=None,
37
+ ignore_index=-100,
38
+ compute_ce_loss=True,
39
+ ):
40
+ # Student
41
+ student_logits_chunk = student_input_chunk @ student_weight.t()
42
+ if student_bias is not None:
43
+ student_logits_chunk += student_bias
44
+ student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)
45
+
46
+ # Teacher
47
+ with torch.no_grad():
48
+ teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
49
+ if teacher_bias is not None:
50
+ teacher_logits_chunk += teacher_bias
51
+
52
+ # The hard/task loss
53
+ ce_loss = 0.0
54
+ if compute_ce_loss:
55
+ ce_loss = F.nll_loss(
56
+ student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
57
+ target_chunk.view(-1),
58
+ reduction="sum",
59
+ ignore_index=ignore_index,
60
+ )
61
+
62
+ return student_logits_chunk, teacher_logits_chunk, ce_loss
63
+
64
+ @staticmethod
65
+ def _compute_loss(
66
+ student_input_chunk,
67
+ student_weight,
68
+ teacher_input_chunk,
69
+ teacher_weight,
70
+ target_chunk,
71
+ student_bias=None,
72
+ teacher_bias=None,
73
+ distillation_loss_fn=None,
74
+ full_target=None,
75
+ ignore_index=-100,
76
+ weight_hard_loss=0.5,
77
+ weight_soft_loss=0.5,
78
+ compute_ce_loss=True,
79
+ temperature=1,
80
+ **loss_kwargs,
81
+ ):
82
+ """
83
+ Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
84
+ Args:
85
+ distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
86
+ student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
87
+ student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
88
+ teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
89
+ teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
90
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
91
+ student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
92
+ teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
93
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
94
+ ignore_index (int): Index to ignore for loss computation.
95
+ weight_hard_loss (float): Weight for hard loss.
96
+ weight_soft_loss (float): Weight for soft loss.
97
+ compute_ce_loss (bool): Whether to compute CE loss.
98
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
99
+ loss_kwargs (dict): Additional arguments for the loss function.
100
+ """
101
+ (
102
+ student_logits_chunk,
103
+ teacher_logits_chunk,
104
+ hard_loss,
105
+ ) = LigerFusedLinearDistillationBase.chunk_forward(
106
+ student_input_chunk,
107
+ student_weight,
108
+ teacher_input_chunk,
109
+ teacher_weight,
110
+ target_chunk,
111
+ student_bias=student_bias,
112
+ teacher_bias=teacher_bias,
113
+ ignore_index=ignore_index,
114
+ compute_ce_loss=compute_ce_loss,
115
+ )
116
+
117
+ student_logits_chunk /= temperature
118
+ teacher_logits_chunk /= temperature
119
+
120
+ # If the teacher and student token size is different, pad student logits to match the teacher's.
121
+ # This only applies to cases where they share exactly the same vocab and tokenizer just
122
+ # that teacher logit is padded for some training efficiency such as
123
+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
124
+ teacher_vocab_size = teacher_weight.shape[0]
125
+ student_vocab_size = student_weight.shape[0]
126
+ if teacher_vocab_size > student_vocab_size:
127
+ pad_size = teacher_vocab_size - student_vocab_size
128
+ pad_tensor = torch.zeros(
129
+ (*student_logits_chunk.shape[:-1], pad_size),
130
+ dtype=student_logits_chunk.dtype,
131
+ device=student_logits_chunk.device,
132
+ )
133
+ student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
134
+
135
+ hard_loss /= full_target.shape[0]
136
+
137
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
138
+ soft_loss /= full_target.shape[0]
139
+
140
+ loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
141
+ return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
142
+
143
+ @staticmethod
144
+ def forward(
145
+ cls,
146
+ ctx,
147
+ student_input,
148
+ student_weight,
149
+ teacher_input,
150
+ teacher_weight,
151
+ target,
152
+ student_bias=None,
153
+ teacher_bias=None,
154
+ chunk_size=1024,
155
+ ignore_index=-100,
156
+ weight_hard_loss=0.5,
157
+ weight_soft_loss=0.5,
158
+ beta=0.5,
159
+ compute_ce_loss=True,
160
+ temperature=1.0,
161
+ compiled=True,
162
+ return_soft_hard_loss=False,
163
+ **loss_kwargs,
164
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
165
+ """
166
+ Base class for fused linear layer with distillation loss.
167
+ Only need to compute gradients for student model.
168
+
169
+ Args:
170
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
171
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
172
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
173
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
174
+ target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
175
+ student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
176
+ teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
177
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
178
+ chunk_size (int): Size of a chunk.
179
+ ignore_index (int): Index to ignore for loss computation.
180
+ weight_hard_loss (float): Weight for hard/task loss.
181
+ weight_soft_loss (float): Weight for soft/distillation loss.
182
+ beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
183
+ compute_ce_loss (bool): Whether to compute CE loss.
184
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
185
+ compiled (bool): Whether to use torch compile for chunk accumulation.
186
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
187
+ loss_kwargs (dict): Other possible arguments that a loss function might need
188
+ """
189
+ CHUNK_SIZE = chunk_size
190
+ grad_weight = torch.zeros_like(student_weight)
191
+ grad_inputs = []
192
+ grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
193
+ loss_acc = torch.zeros((), device=student_input.device)
194
+ soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
195
+ hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
196
+
197
+ loss_func_to_call = partial(
198
+ LigerFusedLinearDistillationBase._compute_loss,
199
+ distillation_loss_fn=cls.distillation_loss_fn,
200
+ full_target=target,
201
+ ignore_index=ignore_index,
202
+ weight_hard_loss=weight_hard_loss,
203
+ weight_soft_loss=weight_soft_loss,
204
+ compute_ce_loss=compute_ce_loss,
205
+ temperature=temperature,
206
+ beta=beta,
207
+ **loss_kwargs,
208
+ )
209
+
210
+ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
211
+ if student_bias is not None:
212
+ (
213
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
214
+ (
215
+ chunk_loss,
216
+ (
217
+ chunk_soft_loss,
218
+ chunk_hard_loss,
219
+ chunk_student_logits,
220
+ chunk_teacher_logits,
221
+ ),
222
+ ),
223
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
224
+ student_input_chunk,
225
+ student_weight,
226
+ teacher_input_chunk,
227
+ teacher_weight,
228
+ target_chunk,
229
+ student_bias,
230
+ teacher_bias,
231
+ )
232
+ grad_bias.add_(chunk_grad_bias)
233
+ else:
234
+ (
235
+ (chunk_grad_input, chunk_grad_weight),
236
+ (
237
+ chunk_loss,
238
+ (
239
+ chunk_soft_loss,
240
+ chunk_hard_loss,
241
+ chunk_student_logits,
242
+ chunk_teacher_logits,
243
+ ),
244
+ ),
245
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
246
+ student_input_chunk,
247
+ student_weight,
248
+ teacher_input_chunk,
249
+ teacher_weight,
250
+ target_chunk,
251
+ student_bias,
252
+ teacher_bias,
253
+ )
254
+ grad_weight.add_(chunk_grad_weight)
255
+ loss_acc.add_(chunk_loss)
256
+ if return_soft_hard_loss:
257
+ soft_loss_acc.add_(chunk_soft_loss)
258
+ hard_loss_acc.add_(chunk_hard_loss)
259
+ return chunk_grad_input
260
+
261
+ if compiled:
262
+ accumulate_chunk = torch.compile(accumulate_chunk)
263
+
264
+ num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
265
+ _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
266
+ _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
267
+ _target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)
268
+
269
+ for student_input_chunk, teacher_input_chunk, target_chunk in zip(
270
+ _student_input_chunks, _teacher_input_chunks, _target_chunks
271
+ ):
272
+ grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
273
+ grad_inputs.append(grad_input)
274
+
275
+ ctx.save_for_backward(
276
+ torch.cat(grad_inputs, dim=0),
277
+ grad_weight,
278
+ grad_bias,
279
+ )
280
+ if return_soft_hard_loss:
281
+ return loss_acc, soft_loss_acc, hard_loss_acc
282
+ return loss_acc
283
+
284
+ @staticmethod
285
+ def backward(ctx, grad_output, *args):
286
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
287
+ if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
288
+ grad_input = grad_input * grad_output
289
+ grad_weight = grad_weight * grad_output
290
+ grad_bias = grad_bias * grad_output if grad_bias is not None else None
291
+
292
+ return grad_input, grad_weight, None, None, None, grad_bias
@@ -0,0 +1,366 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch._dynamo.config
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LigerFusedLinearPPOBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def ppo_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("PPO loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ selected_token_ids,
24
+ attention_mask,
25
+ advantages,
26
+ bias=None,
27
+ ref_per_token_logps=None,
28
+ old_per_token_logps=None,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ epsilon_low=0.2,
33
+ epsilon_high=0.2,
34
+ beta=0.04,
35
+ loss_type="dapo",
36
+ max_completion_length=None,
37
+ importance_sampling_level="token",
38
+ temperature=1.0,
39
+ compiled=True,
40
+ use_ref_model=False,
41
+ chunk_size=1,
42
+ ):
43
+ # TODO: check torch compile matmul
44
+ """Chunked forward pass for PPO loss computation.
45
+
46
+ Args:
47
+ cls: The class
48
+ ctx: Context for backward
49
+ _input: Input tensor
50
+ weight: Weight tensor
51
+ selected_token_ids: Selected token ids tensor
52
+ attention_mask: Attention mask tensor
53
+ advantages: Advantages tensor
54
+ bias: Bias tensor
55
+ ref_per_token_logps: Reference model log probs per token tensor
56
+ old_per_token_logps: Old per token log probabilities tensor
57
+ ref_input: Reference model input tensor
58
+ ref_weight: Reference model weight tensor
59
+ ref_bias: Reference model bias tensor
60
+ epsilon_low: Lower bound for clipping the importance sampling ratio
61
+ epsilon_high: Upper bound for clipping the importance sampling ratio
62
+ beta: Weight for the KL penalty
63
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
64
+ max_completion_length: Maximum completion length required for "dr_grpo"
65
+ temperature: Temperature for the logits
66
+ compiled: Whether to use torch compile
67
+ use_ref_model: Whether to use a reference model
68
+ chunk_size: Size of chunks for processing in other loss modules
69
+ """
70
+ if use_ref_model:
71
+ assert ref_per_token_logps is not None or ref_input is not None, (
72
+ "If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
73
+ )
74
+ if ref_per_token_logps is not None and ref_input is not None:
75
+ raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
76
+ if loss_type == "dr_grpo":
77
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
78
+ # Initialize accumulators
79
+ loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
80
+ grad_weight = torch.zeros_like(weight) # [V, H]
81
+ grad_inputs = []
82
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
83
+ aggregated_metrics = []
84
+
85
+ # Create a partial function with fixed arguments
86
+ compute_loss = partial(
87
+ LigerFusedLinearPPOBase._compute_chunk_loss,
88
+ ref_weight=ref_weight,
89
+ ref_bias=ref_bias,
90
+ full_attention_mask=attention_mask,
91
+ epsilon_low=epsilon_low,
92
+ epsilon_high=epsilon_high,
93
+ beta=beta,
94
+ loss_type=loss_type,
95
+ max_completion_length=max_completion_length,
96
+ importance_sampling_level=importance_sampling_level,
97
+ temperature=temperature,
98
+ use_ref_model=use_ref_model,
99
+ ppo_loss_fn=cls.ppo_loss_fn,
100
+ )
101
+
102
+ def fused_fwd_bwd(
103
+ input_chunk,
104
+ selected_token_ids_chunk,
105
+ attention_mask_chunk,
106
+ advantages_chunk,
107
+ ref_per_token_logps_chunk,
108
+ old_per_token_logps_chunk,
109
+ ref_input_chunk,
110
+ ):
111
+ """Fused forward and backward for a chunk."""
112
+ argnums = (0, 1, 5) if bias is not None else (0, 1)
113
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
114
+ input_chunk, # arg 0
115
+ weight, # arg 1
116
+ selected_token_ids_chunk, # arg 2
117
+ attention_mask_chunk, # arg 3
118
+ advantages_chunk, # arg 4
119
+ bias, # arg 5
120
+ ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
121
+ old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
122
+ ref_input_chunk=ref_input_chunk, # arg 8
123
+ )
124
+
125
+ def accumulate_chunk(
126
+ input_chunk,
127
+ selected_token_ids_chunk,
128
+ attention_mask_chunk,
129
+ advantages_chunk,
130
+ ref_per_token_logps_chunk=None,
131
+ old_per_token_logps_chunk=None,
132
+ ref_input_chunk=None,
133
+ ):
134
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
135
+ input_chunk,
136
+ selected_token_ids_chunk,
137
+ attention_mask_chunk,
138
+ advantages_chunk,
139
+ ref_per_token_logps_chunk,
140
+ old_per_token_logps_chunk,
141
+ ref_input_chunk,
142
+ )
143
+ if bias is not None:
144
+ grad_bias.add_(chunk_grad_bias[0])
145
+
146
+ # Accumulate gradients and loss
147
+ grad_weight.add_(chunk_grad_weight)
148
+ grad_inputs.append(chunk_grad_input)
149
+ loss_acc.add_(chunk_loss)
150
+ # Initialize storage for metrics on first chunk
151
+ if len(aggregated_metrics) == 0:
152
+ for metric in chunk_metrics:
153
+ if metric.ndim == 0:
154
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
155
+ else:
156
+ aggregated_metrics.append([])
157
+
158
+ # Accumulate metrics
159
+ for i, metric in enumerate(chunk_metrics):
160
+ if metric.ndim == 0:
161
+ aggregated_metrics[i].add_(metric)
162
+ else:
163
+ aggregated_metrics[i].append(metric)
164
+
165
+ if compiled:
166
+ # TODO: Figure out what is better to compile here
167
+ # accumulate_chunk = torch.compile(accumulate_chunk)
168
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
169
+
170
+ # Process input in chunks based on chunk_size
171
+ chunks = max(1, _input.shape[0] // chunk_size)
172
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
173
+ _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
174
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
175
+ _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
176
+ _ref_per_token_logps_chunks = (
177
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
178
+ if use_ref_model and ref_per_token_logps is not None
179
+ else [None] * chunks
180
+ )
181
+ _old_per_token_logps_chunks = (
182
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
183
+ if old_per_token_logps is not None
184
+ else [None] * chunks
185
+ )
186
+ # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
187
+ _ref_input_chunks = (
188
+ torch.chunk(ref_input, chunks=chunks, dim=0)
189
+ if use_ref_model and ref_per_token_logps is None
190
+ else [None] * chunks
191
+ )
192
+
193
+ for (
194
+ input_chunk,
195
+ selected_token_ids_chunk,
196
+ attention_mask_chunk,
197
+ advantages_chunk,
198
+ ref_per_token_logps_chunk,
199
+ old_per_token_logps_chunk,
200
+ ref_input_chunk,
201
+ ) in zip(
202
+ _input_chunks,
203
+ _selected_token_ids_chunks,
204
+ _attention_mask_chunks,
205
+ _advantages_chunks,
206
+ _ref_per_token_logps_chunks,
207
+ _old_per_token_logps_chunks,
208
+ _ref_input_chunks,
209
+ ):
210
+ # Mark dynamic dimensions
211
+ torch._dynamo.mark_dynamic(input_chunk, 1)
212
+ torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
213
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
214
+ if ref_per_token_logps_chunk is not None:
215
+ torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
216
+ if ref_input_chunk is not None:
217
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
218
+ if old_per_token_logps_chunk is not None:
219
+ torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
220
+
221
+ accumulate_chunk(
222
+ input_chunk,
223
+ selected_token_ids_chunk,
224
+ attention_mask_chunk,
225
+ advantages_chunk,
226
+ ref_per_token_logps_chunk,
227
+ old_per_token_logps_chunk,
228
+ ref_input_chunk,
229
+ )
230
+
231
+ # Combine gradients
232
+ grad_input = torch.cat(grad_inputs, dim=0)
233
+
234
+ # Save for backward
235
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
236
+
237
+ # Finalize metrics
238
+ final_metrics = []
239
+ for metric in aggregated_metrics:
240
+ if isinstance(metric, list):
241
+ final_metrics.append(torch.cat(metric, dim=0))
242
+ else:
243
+ final_metrics.append(metric)
244
+
245
+ return loss_acc, tuple(final_metrics)
246
+
247
+ @staticmethod
248
+ def _compute_dapo_normalizer(attention_mask):
249
+ """Global active tokens averaged per process."""
250
+ normalizer = attention_mask.to(torch.float32).sum()
251
+ world_size = 1
252
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
253
+ import torch.distributed as dist
254
+
255
+ normalizer = normalizer.clone()
256
+ dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
257
+ world_size = dist.get_world_size()
258
+
259
+ normalizer = normalizer / world_size
260
+ return torch.clamp(normalizer, min=1.0)
261
+
262
+ @staticmethod
263
+ def _compute_chunk_loss(
264
+ input_chunk,
265
+ weight,
266
+ selected_token_ids_chunk,
267
+ attention_mask_chunk,
268
+ advantages_chunk,
269
+ bias=None,
270
+ ref_per_token_logps_chunk=None,
271
+ old_per_token_logps_chunk=None,
272
+ ref_input_chunk=None,
273
+ ref_weight=None,
274
+ ref_bias=None,
275
+ full_attention_mask=None,
276
+ epsilon_low=0.2,
277
+ epsilon_high=0.2,
278
+ beta=0.04,
279
+ loss_type="dapo",
280
+ max_completion_length=None,
281
+ importance_sampling_level="token",
282
+ temperature=1.0,
283
+ use_ref_model=False,
284
+ ppo_loss_fn=None,
285
+ ):
286
+ """Compute loss for a single chunk."""
287
+ # Get policy log probabilities using chunk_forward
288
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
289
+
290
+ # Get reference log probabilities if needed
291
+ ref_log_probs = None
292
+ if use_ref_model and ref_per_token_logps_chunk is None:
293
+ with torch.no_grad():
294
+ ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
295
+ ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
296
+ )
297
+
298
+ # Compute chunk loss and metrics using the provided loss function
299
+ chunk_loss, chunk_metrics = ppo_loss_fn(
300
+ log_probs=log_probs,
301
+ selected_token_ids=selected_token_ids_chunk,
302
+ attention_mask=attention_mask_chunk,
303
+ advantages=advantages_chunk,
304
+ full_attention_mask=full_attention_mask,
305
+ ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
306
+ old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
307
+ ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
308
+ epsilon_low=epsilon_low,
309
+ epsilon_high=epsilon_high,
310
+ beta=beta,
311
+ loss_type=loss_type,
312
+ max_completion_length=max_completion_length,
313
+ importance_sampling_level=importance_sampling_level,
314
+ )
315
+
316
+ return chunk_loss, chunk_metrics
317
+
318
+ @staticmethod
319
+ def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
320
+ """Forward pass computation for a single chunk without explicit reshaping."""
321
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
322
+ logits = torch.matmul(input_chunk, weight.t())
323
+ if bias is not None:
324
+ logits = logits + bias # Broadcasts bias to [B, T, V]
325
+ if temperature != 1.0:
326
+ logits = logits / temperature
327
+
328
+ # Compute log probabilities using softmax over the last dimension
329
+ log_probs = F.log_softmax(logits.float(), dim=-1)
330
+
331
+ return log_probs, logits
332
+
333
+ @staticmethod
334
+ def backward(ctx, grad_output, *grad_metrics):
335
+ """Backward pass for PPO loss."""
336
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
337
+
338
+ if grad_output != 1.0:
339
+ grad_input = grad_input * grad_output
340
+ grad_weight = grad_weight * grad_output
341
+ if grad_bias is not None:
342
+ grad_bias = grad_bias * grad_output
343
+
344
+ return (
345
+ grad_input,
346
+ grad_weight,
347
+ None, # grad_selected_token_ids
348
+ None, # grad_attention_mask
349
+ None, # grad_advantages
350
+ grad_bias,
351
+ None, # grad_ref_per_token_logps
352
+ None, # grad_old_per_token_logps
353
+ None, # grad_ref_input
354
+ None, # grad_ref_weight
355
+ None, # grad_ref_bias
356
+ None, # grad_epsilon_low
357
+ None, # grad_epsilon_high
358
+ None, # grad_beta
359
+ None, # grad_loss_type
360
+ None, # grad_max_completion_length
361
+ None, # grad_importance_sampling_level
362
+ None, # grad_temperature
363
+ None, # grad_compiled
364
+ None, # grad_use_ref_model
365
+ None, # grad_chunk_size
366
+ )