liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def distillation_loss_fn(
|
|
14
|
+
student_logits,
|
|
15
|
+
teacher_logits,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Compute distillation loss.
|
|
19
|
+
Args:
|
|
20
|
+
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
21
|
+
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
22
|
+
Returns:
|
|
23
|
+
torch.Tensor: Sum of distillation losses for the chunk. The class will handle
|
|
24
|
+
converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
|
|
25
|
+
"""
|
|
26
|
+
raise NotImplementedError("Distillation loss function must be implemented.")
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def chunk_forward(
|
|
30
|
+
student_input_chunk,
|
|
31
|
+
student_weight,
|
|
32
|
+
teacher_input_chunk,
|
|
33
|
+
teacher_weight,
|
|
34
|
+
target_chunk,
|
|
35
|
+
student_bias=None,
|
|
36
|
+
teacher_bias=None,
|
|
37
|
+
ignore_index=-100,
|
|
38
|
+
compute_ce_loss=True,
|
|
39
|
+
):
|
|
40
|
+
# Student
|
|
41
|
+
student_logits_chunk = student_input_chunk @ student_weight.t()
|
|
42
|
+
if student_bias is not None:
|
|
43
|
+
student_logits_chunk += student_bias
|
|
44
|
+
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)
|
|
45
|
+
|
|
46
|
+
# Teacher
|
|
47
|
+
with torch.no_grad():
|
|
48
|
+
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
|
|
49
|
+
if teacher_bias is not None:
|
|
50
|
+
teacher_logits_chunk += teacher_bias
|
|
51
|
+
|
|
52
|
+
# The hard/task loss
|
|
53
|
+
ce_loss = 0.0
|
|
54
|
+
if compute_ce_loss:
|
|
55
|
+
ce_loss = F.nll_loss(
|
|
56
|
+
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
|
|
57
|
+
target_chunk.view(-1),
|
|
58
|
+
reduction="sum",
|
|
59
|
+
ignore_index=ignore_index,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return student_logits_chunk, teacher_logits_chunk, ce_loss
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def _compute_loss(
|
|
66
|
+
student_input_chunk,
|
|
67
|
+
student_weight,
|
|
68
|
+
teacher_input_chunk,
|
|
69
|
+
teacher_weight,
|
|
70
|
+
target_chunk,
|
|
71
|
+
student_bias=None,
|
|
72
|
+
teacher_bias=None,
|
|
73
|
+
distillation_loss_fn=None,
|
|
74
|
+
full_target=None,
|
|
75
|
+
ignore_index=-100,
|
|
76
|
+
weight_hard_loss=0.5,
|
|
77
|
+
weight_soft_loss=0.5,
|
|
78
|
+
compute_ce_loss=True,
|
|
79
|
+
temperature=1,
|
|
80
|
+
**loss_kwargs,
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
|
|
84
|
+
Args:
|
|
85
|
+
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
86
|
+
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
|
|
87
|
+
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
|
|
88
|
+
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
|
|
89
|
+
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
|
|
90
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
|
|
91
|
+
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
92
|
+
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
93
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
|
|
94
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
95
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
96
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
97
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
|
98
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
99
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
100
|
+
"""
|
|
101
|
+
(
|
|
102
|
+
student_logits_chunk,
|
|
103
|
+
teacher_logits_chunk,
|
|
104
|
+
hard_loss,
|
|
105
|
+
) = LigerFusedLinearDistillationBase.chunk_forward(
|
|
106
|
+
student_input_chunk,
|
|
107
|
+
student_weight,
|
|
108
|
+
teacher_input_chunk,
|
|
109
|
+
teacher_weight,
|
|
110
|
+
target_chunk,
|
|
111
|
+
student_bias=student_bias,
|
|
112
|
+
teacher_bias=teacher_bias,
|
|
113
|
+
ignore_index=ignore_index,
|
|
114
|
+
compute_ce_loss=compute_ce_loss,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
student_logits_chunk /= temperature
|
|
118
|
+
teacher_logits_chunk /= temperature
|
|
119
|
+
|
|
120
|
+
# If the teacher and student token size is different, pad student logits to match the teacher's.
|
|
121
|
+
# This only applies to cases where they share exactly the same vocab and tokenizer just
|
|
122
|
+
# that teacher logit is padded for some training efficiency such as
|
|
123
|
+
# https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
|
|
124
|
+
teacher_vocab_size = teacher_weight.shape[0]
|
|
125
|
+
student_vocab_size = student_weight.shape[0]
|
|
126
|
+
if teacher_vocab_size > student_vocab_size:
|
|
127
|
+
pad_size = teacher_vocab_size - student_vocab_size
|
|
128
|
+
pad_tensor = torch.zeros(
|
|
129
|
+
(*student_logits_chunk.shape[:-1], pad_size),
|
|
130
|
+
dtype=student_logits_chunk.dtype,
|
|
131
|
+
device=student_logits_chunk.device,
|
|
132
|
+
)
|
|
133
|
+
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
|
|
134
|
+
|
|
135
|
+
hard_loss /= full_target.shape[0]
|
|
136
|
+
|
|
137
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
|
|
138
|
+
soft_loss /= full_target.shape[0]
|
|
139
|
+
|
|
140
|
+
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
141
|
+
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def forward(
|
|
145
|
+
cls,
|
|
146
|
+
ctx,
|
|
147
|
+
student_input,
|
|
148
|
+
student_weight,
|
|
149
|
+
teacher_input,
|
|
150
|
+
teacher_weight,
|
|
151
|
+
target,
|
|
152
|
+
student_bias=None,
|
|
153
|
+
teacher_bias=None,
|
|
154
|
+
chunk_size=1024,
|
|
155
|
+
ignore_index=-100,
|
|
156
|
+
weight_hard_loss=0.5,
|
|
157
|
+
weight_soft_loss=0.5,
|
|
158
|
+
beta=0.5,
|
|
159
|
+
compute_ce_loss=True,
|
|
160
|
+
temperature=1.0,
|
|
161
|
+
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
163
|
+
**loss_kwargs,
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
165
|
+
"""
|
|
166
|
+
Base class for fused linear layer with distillation loss.
|
|
167
|
+
Only need to compute gradients for student model.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
|
|
171
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
|
|
172
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
|
|
173
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
|
|
174
|
+
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
|
|
175
|
+
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
|
|
176
|
+
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
|
|
177
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
178
|
+
chunk_size (int): Size of a chunk.
|
|
179
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
180
|
+
weight_hard_loss (float): Weight for hard/task loss.
|
|
181
|
+
weight_soft_loss (float): Weight for soft/distillation loss.
|
|
182
|
+
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
|
|
183
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
|
184
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
185
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
187
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
188
|
+
"""
|
|
189
|
+
CHUNK_SIZE = chunk_size
|
|
190
|
+
grad_weight = torch.zeros_like(student_weight)
|
|
191
|
+
grad_inputs = []
|
|
192
|
+
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
193
|
+
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
196
|
+
|
|
197
|
+
loss_func_to_call = partial(
|
|
198
|
+
LigerFusedLinearDistillationBase._compute_loss,
|
|
199
|
+
distillation_loss_fn=cls.distillation_loss_fn,
|
|
200
|
+
full_target=target,
|
|
201
|
+
ignore_index=ignore_index,
|
|
202
|
+
weight_hard_loss=weight_hard_loss,
|
|
203
|
+
weight_soft_loss=weight_soft_loss,
|
|
204
|
+
compute_ce_loss=compute_ce_loss,
|
|
205
|
+
temperature=temperature,
|
|
206
|
+
beta=beta,
|
|
207
|
+
**loss_kwargs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
|
|
211
|
+
if student_bias is not None:
|
|
212
|
+
(
|
|
213
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
|
214
|
+
(
|
|
215
|
+
chunk_loss,
|
|
216
|
+
(
|
|
217
|
+
chunk_soft_loss,
|
|
218
|
+
chunk_hard_loss,
|
|
219
|
+
chunk_student_logits,
|
|
220
|
+
chunk_teacher_logits,
|
|
221
|
+
),
|
|
222
|
+
),
|
|
223
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
|
|
224
|
+
student_input_chunk,
|
|
225
|
+
student_weight,
|
|
226
|
+
teacher_input_chunk,
|
|
227
|
+
teacher_weight,
|
|
228
|
+
target_chunk,
|
|
229
|
+
student_bias,
|
|
230
|
+
teacher_bias,
|
|
231
|
+
)
|
|
232
|
+
grad_bias.add_(chunk_grad_bias)
|
|
233
|
+
else:
|
|
234
|
+
(
|
|
235
|
+
(chunk_grad_input, chunk_grad_weight),
|
|
236
|
+
(
|
|
237
|
+
chunk_loss,
|
|
238
|
+
(
|
|
239
|
+
chunk_soft_loss,
|
|
240
|
+
chunk_hard_loss,
|
|
241
|
+
chunk_student_logits,
|
|
242
|
+
chunk_teacher_logits,
|
|
243
|
+
),
|
|
244
|
+
),
|
|
245
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
|
|
246
|
+
student_input_chunk,
|
|
247
|
+
student_weight,
|
|
248
|
+
teacher_input_chunk,
|
|
249
|
+
teacher_weight,
|
|
250
|
+
target_chunk,
|
|
251
|
+
student_bias,
|
|
252
|
+
teacher_bias,
|
|
253
|
+
)
|
|
254
|
+
grad_weight.add_(chunk_grad_weight)
|
|
255
|
+
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
259
|
+
return chunk_grad_input
|
|
260
|
+
|
|
261
|
+
if compiled:
|
|
262
|
+
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
263
|
+
|
|
264
|
+
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
|
|
265
|
+
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
|
|
266
|
+
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
|
|
267
|
+
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)
|
|
268
|
+
|
|
269
|
+
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
|
|
270
|
+
_student_input_chunks, _teacher_input_chunks, _target_chunks
|
|
271
|
+
):
|
|
272
|
+
grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
|
|
273
|
+
grad_inputs.append(grad_input)
|
|
274
|
+
|
|
275
|
+
ctx.save_for_backward(
|
|
276
|
+
torch.cat(grad_inputs, dim=0),
|
|
277
|
+
grad_weight,
|
|
278
|
+
grad_bias,
|
|
279
|
+
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
282
|
+
return loss_acc
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
286
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
287
|
+
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
288
|
+
grad_input = grad_input * grad_output
|
|
289
|
+
grad_weight = grad_weight * grad_output
|
|
290
|
+
grad_bias = grad_bias * grad_output if grad_bias is not None else None
|
|
291
|
+
|
|
292
|
+
return grad_input, grad_weight, None, None, None, grad_bias
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch._dynamo.config
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def ppo_loss_fn(*args, **kwargs):
|
|
12
|
+
"""
|
|
13
|
+
To be extended by subclasses.
|
|
14
|
+
"""
|
|
15
|
+
raise NotImplementedError("PPO loss function must be implemented.")
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def forward(
|
|
19
|
+
cls,
|
|
20
|
+
ctx,
|
|
21
|
+
_input,
|
|
22
|
+
weight,
|
|
23
|
+
selected_token_ids,
|
|
24
|
+
attention_mask,
|
|
25
|
+
advantages,
|
|
26
|
+
bias=None,
|
|
27
|
+
ref_per_token_logps=None,
|
|
28
|
+
old_per_token_logps=None,
|
|
29
|
+
ref_input=None,
|
|
30
|
+
ref_weight=None,
|
|
31
|
+
ref_bias=None,
|
|
32
|
+
epsilon_low=0.2,
|
|
33
|
+
epsilon_high=0.2,
|
|
34
|
+
beta=0.04,
|
|
35
|
+
loss_type="dapo",
|
|
36
|
+
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
38
|
+
temperature=1.0,
|
|
39
|
+
compiled=True,
|
|
40
|
+
use_ref_model=False,
|
|
41
|
+
chunk_size=1,
|
|
42
|
+
):
|
|
43
|
+
# TODO: check torch compile matmul
|
|
44
|
+
"""Chunked forward pass for PPO loss computation.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
cls: The class
|
|
48
|
+
ctx: Context for backward
|
|
49
|
+
_input: Input tensor
|
|
50
|
+
weight: Weight tensor
|
|
51
|
+
selected_token_ids: Selected token ids tensor
|
|
52
|
+
attention_mask: Attention mask tensor
|
|
53
|
+
advantages: Advantages tensor
|
|
54
|
+
bias: Bias tensor
|
|
55
|
+
ref_per_token_logps: Reference model log probs per token tensor
|
|
56
|
+
old_per_token_logps: Old per token log probabilities tensor
|
|
57
|
+
ref_input: Reference model input tensor
|
|
58
|
+
ref_weight: Reference model weight tensor
|
|
59
|
+
ref_bias: Reference model bias tensor
|
|
60
|
+
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
61
|
+
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
62
|
+
beta: Weight for the KL penalty
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
64
|
+
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
65
|
+
temperature: Temperature for the logits
|
|
66
|
+
compiled: Whether to use torch compile
|
|
67
|
+
use_ref_model: Whether to use a reference model
|
|
68
|
+
chunk_size: Size of chunks for processing in other loss modules
|
|
69
|
+
"""
|
|
70
|
+
if use_ref_model:
|
|
71
|
+
assert ref_per_token_logps is not None or ref_input is not None, (
|
|
72
|
+
"If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
|
|
73
|
+
)
|
|
74
|
+
if ref_per_token_logps is not None and ref_input is not None:
|
|
75
|
+
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
|
76
|
+
if loss_type == "dr_grpo":
|
|
77
|
+
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
|
|
78
|
+
# Initialize accumulators
|
|
79
|
+
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
|
80
|
+
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
81
|
+
grad_inputs = []
|
|
82
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
|
|
83
|
+
aggregated_metrics = []
|
|
84
|
+
|
|
85
|
+
# Create a partial function with fixed arguments
|
|
86
|
+
compute_loss = partial(
|
|
87
|
+
LigerFusedLinearPPOBase._compute_chunk_loss,
|
|
88
|
+
ref_weight=ref_weight,
|
|
89
|
+
ref_bias=ref_bias,
|
|
90
|
+
full_attention_mask=attention_mask,
|
|
91
|
+
epsilon_low=epsilon_low,
|
|
92
|
+
epsilon_high=epsilon_high,
|
|
93
|
+
beta=beta,
|
|
94
|
+
loss_type=loss_type,
|
|
95
|
+
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
97
|
+
temperature=temperature,
|
|
98
|
+
use_ref_model=use_ref_model,
|
|
99
|
+
ppo_loss_fn=cls.ppo_loss_fn,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def fused_fwd_bwd(
|
|
103
|
+
input_chunk,
|
|
104
|
+
selected_token_ids_chunk,
|
|
105
|
+
attention_mask_chunk,
|
|
106
|
+
advantages_chunk,
|
|
107
|
+
ref_per_token_logps_chunk,
|
|
108
|
+
old_per_token_logps_chunk,
|
|
109
|
+
ref_input_chunk,
|
|
110
|
+
):
|
|
111
|
+
"""Fused forward and backward for a chunk."""
|
|
112
|
+
argnums = (0, 1, 5) if bias is not None else (0, 1)
|
|
113
|
+
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
|
|
114
|
+
input_chunk, # arg 0
|
|
115
|
+
weight, # arg 1
|
|
116
|
+
selected_token_ids_chunk, # arg 2
|
|
117
|
+
attention_mask_chunk, # arg 3
|
|
118
|
+
advantages_chunk, # arg 4
|
|
119
|
+
bias, # arg 5
|
|
120
|
+
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
|
|
121
|
+
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
|
|
122
|
+
ref_input_chunk=ref_input_chunk, # arg 8
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def accumulate_chunk(
|
|
126
|
+
input_chunk,
|
|
127
|
+
selected_token_ids_chunk,
|
|
128
|
+
attention_mask_chunk,
|
|
129
|
+
advantages_chunk,
|
|
130
|
+
ref_per_token_logps_chunk=None,
|
|
131
|
+
old_per_token_logps_chunk=None,
|
|
132
|
+
ref_input_chunk=None,
|
|
133
|
+
):
|
|
134
|
+
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
|
|
135
|
+
input_chunk,
|
|
136
|
+
selected_token_ids_chunk,
|
|
137
|
+
attention_mask_chunk,
|
|
138
|
+
advantages_chunk,
|
|
139
|
+
ref_per_token_logps_chunk,
|
|
140
|
+
old_per_token_logps_chunk,
|
|
141
|
+
ref_input_chunk,
|
|
142
|
+
)
|
|
143
|
+
if bias is not None:
|
|
144
|
+
grad_bias.add_(chunk_grad_bias[0])
|
|
145
|
+
|
|
146
|
+
# Accumulate gradients and loss
|
|
147
|
+
grad_weight.add_(chunk_grad_weight)
|
|
148
|
+
grad_inputs.append(chunk_grad_input)
|
|
149
|
+
loss_acc.add_(chunk_loss)
|
|
150
|
+
# Initialize storage for metrics on first chunk
|
|
151
|
+
if len(aggregated_metrics) == 0:
|
|
152
|
+
for metric in chunk_metrics:
|
|
153
|
+
if metric.ndim == 0:
|
|
154
|
+
aggregated_metrics.append(torch.zeros((), device=metric.device))
|
|
155
|
+
else:
|
|
156
|
+
aggregated_metrics.append([])
|
|
157
|
+
|
|
158
|
+
# Accumulate metrics
|
|
159
|
+
for i, metric in enumerate(chunk_metrics):
|
|
160
|
+
if metric.ndim == 0:
|
|
161
|
+
aggregated_metrics[i].add_(metric)
|
|
162
|
+
else:
|
|
163
|
+
aggregated_metrics[i].append(metric)
|
|
164
|
+
|
|
165
|
+
if compiled:
|
|
166
|
+
# TODO: Figure out what is better to compile here
|
|
167
|
+
# accumulate_chunk = torch.compile(accumulate_chunk)
|
|
168
|
+
fused_fwd_bwd = torch.compile(fused_fwd_bwd)
|
|
169
|
+
|
|
170
|
+
# Process input in chunks based on chunk_size
|
|
171
|
+
chunks = max(1, _input.shape[0] // chunk_size)
|
|
172
|
+
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
|
|
173
|
+
_selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
|
|
174
|
+
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
|
|
175
|
+
_advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
|
|
176
|
+
_ref_per_token_logps_chunks = (
|
|
177
|
+
torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
|
|
178
|
+
if use_ref_model and ref_per_token_logps is not None
|
|
179
|
+
else [None] * chunks
|
|
180
|
+
)
|
|
181
|
+
_old_per_token_logps_chunks = (
|
|
182
|
+
torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
|
|
183
|
+
if old_per_token_logps is not None
|
|
184
|
+
else [None] * chunks
|
|
185
|
+
)
|
|
186
|
+
# if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
|
|
187
|
+
_ref_input_chunks = (
|
|
188
|
+
torch.chunk(ref_input, chunks=chunks, dim=0)
|
|
189
|
+
if use_ref_model and ref_per_token_logps is None
|
|
190
|
+
else [None] * chunks
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
for (
|
|
194
|
+
input_chunk,
|
|
195
|
+
selected_token_ids_chunk,
|
|
196
|
+
attention_mask_chunk,
|
|
197
|
+
advantages_chunk,
|
|
198
|
+
ref_per_token_logps_chunk,
|
|
199
|
+
old_per_token_logps_chunk,
|
|
200
|
+
ref_input_chunk,
|
|
201
|
+
) in zip(
|
|
202
|
+
_input_chunks,
|
|
203
|
+
_selected_token_ids_chunks,
|
|
204
|
+
_attention_mask_chunks,
|
|
205
|
+
_advantages_chunks,
|
|
206
|
+
_ref_per_token_logps_chunks,
|
|
207
|
+
_old_per_token_logps_chunks,
|
|
208
|
+
_ref_input_chunks,
|
|
209
|
+
):
|
|
210
|
+
# Mark dynamic dimensions
|
|
211
|
+
torch._dynamo.mark_dynamic(input_chunk, 1)
|
|
212
|
+
torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
|
|
213
|
+
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
|
|
214
|
+
if ref_per_token_logps_chunk is not None:
|
|
215
|
+
torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
|
|
216
|
+
if ref_input_chunk is not None:
|
|
217
|
+
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
|
|
218
|
+
if old_per_token_logps_chunk is not None:
|
|
219
|
+
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
|
|
220
|
+
|
|
221
|
+
accumulate_chunk(
|
|
222
|
+
input_chunk,
|
|
223
|
+
selected_token_ids_chunk,
|
|
224
|
+
attention_mask_chunk,
|
|
225
|
+
advantages_chunk,
|
|
226
|
+
ref_per_token_logps_chunk,
|
|
227
|
+
old_per_token_logps_chunk,
|
|
228
|
+
ref_input_chunk,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Combine gradients
|
|
232
|
+
grad_input = torch.cat(grad_inputs, dim=0)
|
|
233
|
+
|
|
234
|
+
# Save for backward
|
|
235
|
+
ctx.save_for_backward(grad_input, grad_weight, grad_bias)
|
|
236
|
+
|
|
237
|
+
# Finalize metrics
|
|
238
|
+
final_metrics = []
|
|
239
|
+
for metric in aggregated_metrics:
|
|
240
|
+
if isinstance(metric, list):
|
|
241
|
+
final_metrics.append(torch.cat(metric, dim=0))
|
|
242
|
+
else:
|
|
243
|
+
final_metrics.append(metric)
|
|
244
|
+
|
|
245
|
+
return loss_acc, tuple(final_metrics)
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
262
|
+
@staticmethod
|
|
263
|
+
def _compute_chunk_loss(
|
|
264
|
+
input_chunk,
|
|
265
|
+
weight,
|
|
266
|
+
selected_token_ids_chunk,
|
|
267
|
+
attention_mask_chunk,
|
|
268
|
+
advantages_chunk,
|
|
269
|
+
bias=None,
|
|
270
|
+
ref_per_token_logps_chunk=None,
|
|
271
|
+
old_per_token_logps_chunk=None,
|
|
272
|
+
ref_input_chunk=None,
|
|
273
|
+
ref_weight=None,
|
|
274
|
+
ref_bias=None,
|
|
275
|
+
full_attention_mask=None,
|
|
276
|
+
epsilon_low=0.2,
|
|
277
|
+
epsilon_high=0.2,
|
|
278
|
+
beta=0.04,
|
|
279
|
+
loss_type="dapo",
|
|
280
|
+
max_completion_length=None,
|
|
281
|
+
importance_sampling_level="token",
|
|
282
|
+
temperature=1.0,
|
|
283
|
+
use_ref_model=False,
|
|
284
|
+
ppo_loss_fn=None,
|
|
285
|
+
):
|
|
286
|
+
"""Compute loss for a single chunk."""
|
|
287
|
+
# Get policy log probabilities using chunk_forward
|
|
288
|
+
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
|
|
289
|
+
|
|
290
|
+
# Get reference log probabilities if needed
|
|
291
|
+
ref_log_probs = None
|
|
292
|
+
if use_ref_model and ref_per_token_logps_chunk is None:
|
|
293
|
+
with torch.no_grad():
|
|
294
|
+
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
|
|
295
|
+
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Compute chunk loss and metrics using the provided loss function
|
|
299
|
+
chunk_loss, chunk_metrics = ppo_loss_fn(
|
|
300
|
+
log_probs=log_probs,
|
|
301
|
+
selected_token_ids=selected_token_ids_chunk,
|
|
302
|
+
attention_mask=attention_mask_chunk,
|
|
303
|
+
advantages=advantages_chunk,
|
|
304
|
+
full_attention_mask=full_attention_mask,
|
|
305
|
+
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
|
|
306
|
+
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
|
|
307
|
+
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
|
|
308
|
+
epsilon_low=epsilon_low,
|
|
309
|
+
epsilon_high=epsilon_high,
|
|
310
|
+
beta=beta,
|
|
311
|
+
loss_type=loss_type,
|
|
312
|
+
max_completion_length=max_completion_length,
|
|
313
|
+
importance_sampling_level=importance_sampling_level,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
return chunk_loss, chunk_metrics
|
|
317
|
+
|
|
318
|
+
@staticmethod
|
|
319
|
+
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
|
|
320
|
+
"""Forward pass computation for a single chunk without explicit reshaping."""
|
|
321
|
+
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
|
|
322
|
+
logits = torch.matmul(input_chunk, weight.t())
|
|
323
|
+
if bias is not None:
|
|
324
|
+
logits = logits + bias # Broadcasts bias to [B, T, V]
|
|
325
|
+
if temperature != 1.0:
|
|
326
|
+
logits = logits / temperature
|
|
327
|
+
|
|
328
|
+
# Compute log probabilities using softmax over the last dimension
|
|
329
|
+
log_probs = F.log_softmax(logits.float(), dim=-1)
|
|
330
|
+
|
|
331
|
+
return log_probs, logits
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def backward(ctx, grad_output, *grad_metrics):
|
|
335
|
+
"""Backward pass for PPO loss."""
|
|
336
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
337
|
+
|
|
338
|
+
if grad_output != 1.0:
|
|
339
|
+
grad_input = grad_input * grad_output
|
|
340
|
+
grad_weight = grad_weight * grad_output
|
|
341
|
+
if grad_bias is not None:
|
|
342
|
+
grad_bias = grad_bias * grad_output
|
|
343
|
+
|
|
344
|
+
return (
|
|
345
|
+
grad_input,
|
|
346
|
+
grad_weight,
|
|
347
|
+
None, # grad_selected_token_ids
|
|
348
|
+
None, # grad_attention_mask
|
|
349
|
+
None, # grad_advantages
|
|
350
|
+
grad_bias,
|
|
351
|
+
None, # grad_ref_per_token_logps
|
|
352
|
+
None, # grad_old_per_token_logps
|
|
353
|
+
None, # grad_ref_input
|
|
354
|
+
None, # grad_ref_weight
|
|
355
|
+
None, # grad_ref_bias
|
|
356
|
+
None, # grad_epsilon_low
|
|
357
|
+
None, # grad_epsilon_high
|
|
358
|
+
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
362
|
+
None, # grad_temperature
|
|
363
|
+
None, # grad_compiled
|
|
364
|
+
None, # grad_use_ref_model
|
|
365
|
+
None, # grad_chunk_size
|
|
366
|
+
)
|