liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +75 -12
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +131 -49
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +30 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +48 -25
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +26 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +702 -48
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import torch.nn.functional as F
|
|
3
6
|
|
|
@@ -6,7 +9,13 @@ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinear
|
|
|
6
9
|
|
|
7
10
|
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
8
11
|
@staticmethod
|
|
9
|
-
def distillation_loss_fn(
|
|
12
|
+
def distillation_loss_fn(
|
|
13
|
+
student_logits,
|
|
14
|
+
teacher_logits,
|
|
15
|
+
target=None,
|
|
16
|
+
ignore_index=None,
|
|
17
|
+
beta=1.0,
|
|
18
|
+
):
|
|
10
19
|
"""
|
|
11
20
|
Compute Cosine loss (Cosine Similarity Loss).
|
|
12
21
|
Args:
|
|
@@ -41,7 +50,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
41
50
|
temperature: float = 1.0,
|
|
42
51
|
compiled: bool = True,
|
|
43
52
|
chunk_size: int = 1024,
|
|
44
|
-
|
|
53
|
+
return_soft_hard_loss: bool = False,
|
|
54
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
45
55
|
return super().forward(
|
|
46
56
|
cls=cls,
|
|
47
57
|
ctx=ctx,
|
|
@@ -59,11 +69,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
59
69
|
ignore_index=ignore_index,
|
|
60
70
|
temperature=temperature,
|
|
61
71
|
compiled=compiled,
|
|
72
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
62
73
|
)
|
|
63
74
|
|
|
64
75
|
@staticmethod
|
|
65
|
-
def backward(ctx, grad_output):
|
|
66
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
76
|
+
def backward(ctx, grad_output, *args):
|
|
77
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
67
78
|
|
|
68
79
|
return (
|
|
69
80
|
*grads,
|
|
@@ -75,6 +86,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
75
86
|
None, # temperature
|
|
76
87
|
None, # compiled
|
|
77
88
|
None, # chunk_size
|
|
89
|
+
None, # return_soft_hard_loss
|
|
78
90
|
)
|
|
79
91
|
|
|
80
92
|
|
|
@@ -88,6 +100,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
88
100
|
temperature: float = 1.0,
|
|
89
101
|
compiled: bool = True,
|
|
90
102
|
chunk_size: int = 1024,
|
|
103
|
+
return_soft_hard_loss: bool = False,
|
|
91
104
|
):
|
|
92
105
|
super().__init__()
|
|
93
106
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -98,6 +111,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
98
111
|
self.compiled = compiled
|
|
99
112
|
self.beta = beta
|
|
100
113
|
self.chunk_size = chunk_size
|
|
114
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
101
115
|
|
|
102
116
|
def forward(
|
|
103
117
|
self,
|
|
@@ -108,7 +122,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
108
122
|
true_labels: torch.LongTensor,
|
|
109
123
|
student_bias: torch.Tensor = None,
|
|
110
124
|
teacher_bias: torch.Tensor = None,
|
|
111
|
-
) -> torch.Tensor:
|
|
125
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
112
126
|
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
113
127
|
student_input,
|
|
114
128
|
student_weight,
|
|
@@ -124,4 +138,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
124
138
|
self.temperature,
|
|
125
139
|
self.compiled,
|
|
126
140
|
self.chunk_size,
|
|
141
|
+
self.return_soft_hard_loss,
|
|
127
142
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -11,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
11
13
|
def distillation_loss_fn(
|
|
12
14
|
student_logits,
|
|
13
15
|
teacher_logits,
|
|
16
|
+
target=None,
|
|
17
|
+
ignore_index=None,
|
|
14
18
|
):
|
|
15
19
|
"""
|
|
16
20
|
Compute distillation loss.
|
|
@@ -130,10 +134,15 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
130
134
|
)
|
|
131
135
|
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
|
|
132
136
|
|
|
133
|
-
|
|
137
|
+
num_valid_tokens = (full_target != ignore_index).sum()
|
|
138
|
+
num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
|
|
134
139
|
|
|
135
|
-
|
|
136
|
-
|
|
140
|
+
hard_loss /= num_valid_tokens
|
|
141
|
+
|
|
142
|
+
soft_loss = distillation_loss_fn(
|
|
143
|
+
student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
|
|
144
|
+
)
|
|
145
|
+
soft_loss /= num_valid_tokens
|
|
137
146
|
|
|
138
147
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
139
148
|
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
|
|
@@ -157,8 +166,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
166
|
compute_ce_loss=True,
|
|
158
167
|
temperature=1.0,
|
|
159
168
|
compiled=True,
|
|
169
|
+
return_soft_hard_loss=False,
|
|
160
170
|
**loss_kwargs,
|
|
161
|
-
):
|
|
171
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
172
|
"""
|
|
163
173
|
Base class for fused linear layer with distillation loss.
|
|
164
174
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +190,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
190
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
191
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
192
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
193
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
194
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
195
|
"""
|
|
185
196
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +198,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
198
|
grad_inputs = []
|
|
188
199
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
200
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
201
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
202
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
203
|
|
|
191
204
|
loss_func_to_call = partial(
|
|
192
205
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +260,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
260
|
)
|
|
248
261
|
grad_weight.add_(chunk_grad_weight)
|
|
249
262
|
loss_acc.add_(chunk_loss)
|
|
263
|
+
if return_soft_hard_loss:
|
|
264
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
265
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
266
|
return chunk_grad_input
|
|
251
267
|
|
|
252
268
|
if compiled:
|
|
@@ -268,10 +284,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
284
|
grad_weight,
|
|
269
285
|
grad_bias,
|
|
270
286
|
)
|
|
287
|
+
if return_soft_hard_loss:
|
|
288
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
289
|
return loss_acc
|
|
272
290
|
|
|
273
291
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
292
|
+
def backward(ctx, grad_output, *args):
|
|
275
293
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
294
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
295
|
grad_input = grad_input * grad_output
|
|
@@ -32,7 +32,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
32
32
|
epsilon_low=0.2,
|
|
33
33
|
epsilon_high=0.2,
|
|
34
34
|
beta=0.04,
|
|
35
|
-
loss_type="
|
|
35
|
+
loss_type="dapo",
|
|
36
36
|
max_completion_length=None,
|
|
37
37
|
importance_sampling_level="token",
|
|
38
38
|
temperature=1.0,
|
|
@@ -60,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
60
60
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
61
61
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
62
62
|
beta: Weight for the KL penalty
|
|
63
|
-
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
64
64
|
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
65
65
|
temperature: Temperature for the logits
|
|
66
66
|
compiled: Whether to use torch compile
|
|
@@ -244,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
244
244
|
|
|
245
245
|
return loss_acc, tuple(final_metrics)
|
|
246
246
|
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
247
262
|
@staticmethod
|
|
248
263
|
def _compute_chunk_loss(
|
|
249
264
|
input_chunk,
|
|
@@ -261,7 +276,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
261
276
|
epsilon_low=0.2,
|
|
262
277
|
epsilon_high=0.2,
|
|
263
278
|
beta=0.04,
|
|
264
|
-
loss_type="
|
|
279
|
+
loss_type="dapo",
|
|
265
280
|
max_completion_length=None,
|
|
266
281
|
importance_sampling_level="token",
|
|
267
282
|
temperature=1.0,
|
|
@@ -341,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
341
356
|
None, # grad_epsilon_low
|
|
342
357
|
None, # grad_epsilon_high
|
|
343
358
|
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
344
362
|
None, # grad_temperature
|
|
345
363
|
None, # grad_compiled
|
|
346
364
|
None, # grad_use_ref_model
|
|
347
365
|
None, # grad_chunk_size
|
|
348
|
-
None, # grad_loss_type
|
|
349
|
-
None, # grad_max_completion_length
|
|
350
366
|
)
|
|
@@ -29,7 +29,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
29
29
|
epsilon_low=0.2,
|
|
30
30
|
epsilon_high=0.2,
|
|
31
31
|
beta=0.04,
|
|
32
|
-
loss_type="
|
|
32
|
+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
34
|
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
35
35
|
**kwargs,
|
|
@@ -94,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
94
94
|
if max_completion_length is None:
|
|
95
95
|
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
96
96
|
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
elif loss_type == "dapo":
|
|
98
|
+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
|
|
99
|
+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
|
|
97
100
|
else:
|
|
98
101
|
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
99
102
|
|
|
@@ -135,7 +138,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
135
138
|
beta=0.04,
|
|
136
139
|
epsilon_low=0.2,
|
|
137
140
|
epsilon_high=0.2,
|
|
138
|
-
loss_type="
|
|
141
|
+
loss_type="dapo",
|
|
139
142
|
max_completion_length=None,
|
|
140
143
|
importance_sampling_level="token",
|
|
141
144
|
temperature=1.0,
|
|
@@ -157,7 +160,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
157
160
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
158
161
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
159
162
|
beta (float): Weight for the KL penalty
|
|
160
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
163
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
161
164
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
165
|
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
163
166
|
temperature (float): Temperature for the logits
|
|
@@ -235,7 +238,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
235
238
|
chunk_size: int = 1,
|
|
236
239
|
epsilon_low: float = 0.2,
|
|
237
240
|
epsilon_high: float = 0.2,
|
|
238
|
-
loss_type: str = "
|
|
241
|
+
loss_type: str = "dapo",
|
|
239
242
|
max_completion_length: Optional[int] = None,
|
|
240
243
|
importance_sampling_level: str = "token",
|
|
241
244
|
temperature: float = 1.0,
|
|
@@ -248,7 +251,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
248
251
|
chunk_size (int): Size of chunks for processing.
|
|
249
252
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
250
253
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
251
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
254
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
252
255
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
256
|
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
254
257
|
temperature (float): Temperature for the logits.
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
2
|
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
3
6
|
import torch
|
|
4
7
|
import torch.nn.functional as F
|
|
5
8
|
|
|
@@ -8,35 +11,50 @@ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinear
|
|
|
8
11
|
|
|
9
12
|
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
10
13
|
@staticmethod
|
|
11
|
-
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
|
14
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None, ignore_index=-100):
|
|
12
15
|
"""
|
|
13
16
|
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
|
14
17
|
Args:
|
|
15
18
|
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
16
19
|
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
17
20
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
21
|
+
target (torch.Tensor): Target labels for masking. Shape: (chunk_size,).
|
|
22
|
+
ignore_index (int): Index to ignore in loss computation.
|
|
18
23
|
Returns:
|
|
19
24
|
torch.Tensor: Jensen-Shannon Divergence loss
|
|
25
|
+
Note:
|
|
26
|
+
- Uses reduction="none" to preserve per-token losses for masking
|
|
27
|
+
- KL divergence requires summing over vocab dimension (not mean)
|
|
28
|
+
- Masking excludes padding/prompt tokens from loss computation
|
|
20
29
|
"""
|
|
21
30
|
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
22
31
|
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
23
32
|
|
|
24
33
|
if beta == 0:
|
|
25
|
-
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="
|
|
34
|
+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
|
26
35
|
elif beta == 1:
|
|
27
|
-
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="
|
|
36
|
+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
|
28
37
|
else:
|
|
29
38
|
# Compute probabilities (only required for mean calculation)
|
|
30
39
|
log_mean_probs = torch.logsumexp(
|
|
31
40
|
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
32
41
|
)
|
|
33
42
|
|
|
34
|
-
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="
|
|
35
|
-
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="
|
|
43
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="none", log_target=True)
|
|
44
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="none", log_target=True)
|
|
36
45
|
|
|
37
46
|
# JSD is the weighted average of the KL divergences
|
|
38
47
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
39
|
-
|
|
48
|
+
|
|
49
|
+
# Sum over vocab dimension (KL divergence definition)
|
|
50
|
+
jsd_loss = jsd_loss.sum(dim=-1) # (chunk_size,)
|
|
51
|
+
|
|
52
|
+
# Apply ignore_index mask
|
|
53
|
+
if target is not None:
|
|
54
|
+
mask = target != ignore_index
|
|
55
|
+
jsd_loss = jsd_loss.masked_fill(~mask, 0.0)
|
|
56
|
+
|
|
57
|
+
return jsd_loss.sum()
|
|
40
58
|
|
|
41
59
|
@classmethod
|
|
42
60
|
def forward(
|
|
@@ -56,6 +74,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
56
74
|
temperature: float = 1.0,
|
|
57
75
|
compiled: bool = True,
|
|
58
76
|
chunk_size: int = 1024,
|
|
77
|
+
return_soft_hard_loss: bool = False,
|
|
59
78
|
):
|
|
60
79
|
"""
|
|
61
80
|
Fused linear layer with JSD distillation loss.
|
|
@@ -72,8 +91,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
72
91
|
temperature (float): Temperature for softening/sharpening distributions
|
|
73
92
|
compiled (bool): Whether to use torch compile
|
|
74
93
|
chunk_size (int): Size of chunks for processing.
|
|
94
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
75
95
|
Returns:
|
|
76
|
-
torch.Tensor: Computed loss
|
|
96
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
77
97
|
"""
|
|
78
98
|
return super().forward(
|
|
79
99
|
cls=cls,
|
|
@@ -92,11 +112,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
92
112
|
ignore_index=ignore_index,
|
|
93
113
|
temperature=temperature,
|
|
94
114
|
compiled=compiled,
|
|
115
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
95
116
|
)
|
|
96
117
|
|
|
97
118
|
@staticmethod
|
|
98
|
-
def backward(ctx, grad_output):
|
|
99
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
119
|
+
def backward(ctx, grad_output, *args):
|
|
120
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
100
121
|
|
|
101
122
|
return (
|
|
102
123
|
*grads,
|
|
@@ -108,6 +129,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
108
129
|
None, # temperature
|
|
109
130
|
None, # compiled
|
|
110
131
|
None, # chunk_size
|
|
132
|
+
None, # return_soft_hard_loss
|
|
111
133
|
)
|
|
112
134
|
|
|
113
135
|
|
|
@@ -125,6 +147,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
147
|
temperature: float = 1.0,
|
|
126
148
|
compiled: bool = True,
|
|
127
149
|
chunk_size: int = 1024,
|
|
150
|
+
return_soft_hard_loss: bool = False,
|
|
128
151
|
):
|
|
129
152
|
"""
|
|
130
153
|
Args:
|
|
@@ -135,6 +158,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
135
158
|
compiled (bool): Whether to use torch compile
|
|
136
159
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
137
160
|
chunk_size (int): Size of chunks for processing.
|
|
161
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
138
162
|
"""
|
|
139
163
|
super().__init__()
|
|
140
164
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -145,6 +169,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
169
|
self.compiled = compiled
|
|
146
170
|
self.beta = beta
|
|
147
171
|
self.chunk_size = chunk_size
|
|
172
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
148
173
|
|
|
149
174
|
def forward(
|
|
150
175
|
self,
|
|
@@ -155,7 +180,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
155
180
|
true_labels: torch.LongTensor,
|
|
156
181
|
student_bias: torch.Tensor = None,
|
|
157
182
|
teacher_bias: torch.Tensor = None,
|
|
158
|
-
) -> torch.Tensor:
|
|
183
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
159
184
|
"""
|
|
160
185
|
Compute the JSD distillation loss.
|
|
161
186
|
|
|
@@ -167,7 +192,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
167
192
|
true_labels (torch.LongTensor): Target labels tensor
|
|
168
193
|
|
|
169
194
|
Returns:
|
|
170
|
-
torch.Tensor
|
|
195
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
196
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
197
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
171
198
|
"""
|
|
172
199
|
return LigerFusedLinearJSDFunction.apply(
|
|
173
200
|
student_input,
|
|
@@ -184,4 +211,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
184
211
|
self.temperature,
|
|
185
212
|
self.compiled,
|
|
186
213
|
self.chunk_size,
|
|
214
|
+
self.return_soft_hard_loss,
|
|
187
215
|
)
|
liger_kernel/ops/__init__.py
CHANGED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Liger-Kernel operators with automatic vendor-specific replacement.
|
|
3
|
+
|
|
4
|
+
This module provides two ways to import operators:
|
|
5
|
+
|
|
6
|
+
1. Import from this package (recommended for Function classes):
|
|
7
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
8
|
+
|
|
9
|
+
This automatically uses vendor-specific implementation if available.
|
|
10
|
+
|
|
11
|
+
2. Import from submodules (for kernel functions or specific access):
|
|
12
|
+
from liger_kernel.ops.geglu import geglu_forward, geglu_backward
|
|
13
|
+
|
|
14
|
+
This always uses the default implementation (no auto-replacement).
|
|
15
|
+
|
|
16
|
+
The replacement mechanism:
|
|
17
|
+
1. Default implementations are imported from individual modules (e.g., geglu.py)
|
|
18
|
+
2. On module load, device is detected via infer_device()
|
|
19
|
+
3. If running on a supported vendor device (npu, xpu, etc.), the default
|
|
20
|
+
implementations are replaced with vendor-specific ones
|
|
21
|
+
4. All subsequent imports from this package get the replaced versions
|
|
22
|
+
|
|
23
|
+
Note: Direct imports from submodules (e.g., from liger_kernel.ops.geglu import ...)
|
|
24
|
+
are NOT affected by the replacement mechanism.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# =============================================================================
|
|
28
|
+
# Import default implementations
|
|
29
|
+
# Both Function classes and kernel functions are imported here.
|
|
30
|
+
# All of these can be replaced by vendor-specific implementations.
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction # noqa: F401
|
|
34
|
+
from liger_kernel.ops.cross_entropy import cross_entropy_backward # noqa: F401
|
|
35
|
+
from liger_kernel.ops.cross_entropy import cross_entropy_forward # noqa: F401
|
|
36
|
+
from liger_kernel.ops.dyt import LigerDyTFunction # noqa: F401
|
|
37
|
+
from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction # noqa: F401
|
|
38
|
+
from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction # noqa: F401
|
|
39
|
+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_backward # noqa: F401
|
|
40
|
+
from liger_kernel.ops.fused_add_rms_norm import fused_add_rms_norm_forward # noqa: F401
|
|
41
|
+
from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction # noqa: F401
|
|
42
|
+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward # noqa: F401
|
|
43
|
+
from liger_kernel.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward # noqa: F401
|
|
44
|
+
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction # noqa: F401
|
|
45
|
+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_backward # noqa: F401
|
|
46
|
+
from liger_kernel.ops.fused_linear_jsd import fused_linear_jsd_forward # noqa: F401
|
|
47
|
+
from liger_kernel.ops.fused_neighborhood_attention import LigerFusedNeighborhoodAttentionFunction # noqa: F401
|
|
48
|
+
from liger_kernel.ops.geglu import LigerGELUMulFunction # noqa: F401
|
|
49
|
+
from liger_kernel.ops.geglu import geglu_backward # noqa: F401
|
|
50
|
+
from liger_kernel.ops.geglu import geglu_forward # noqa: F401
|
|
51
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction # noqa: F401
|
|
52
|
+
from liger_kernel.ops.group_norm import group_norm_backward # noqa: F401
|
|
53
|
+
from liger_kernel.ops.group_norm import group_norm_forward # noqa: F401
|
|
54
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction # noqa: F401
|
|
55
|
+
from liger_kernel.ops.jsd import LigerJSDFunction # noqa: F401
|
|
56
|
+
from liger_kernel.ops.jsd import jsd_backward # noqa: F401
|
|
57
|
+
from liger_kernel.ops.jsd import jsd_forward # noqa: F401
|
|
58
|
+
from liger_kernel.ops.kl_div import LigerKLDivLossFunction # noqa: F401
|
|
59
|
+
from liger_kernel.ops.layer_norm import LigerLayerNormFunction # noqa: F401
|
|
60
|
+
from liger_kernel.ops.layer_norm import layer_norm_backward # noqa: F401
|
|
61
|
+
from liger_kernel.ops.layer_norm import layer_norm_forward # noqa: F401
|
|
62
|
+
from liger_kernel.ops.llama4_rope import LigerLlama4RopeFunction # noqa: F401
|
|
63
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction # noqa: F401
|
|
64
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction # noqa: F401
|
|
65
|
+
from liger_kernel.ops.poly_norm import poly_norm_backward # noqa: F401
|
|
66
|
+
from liger_kernel.ops.poly_norm import poly_norm_forward # noqa: F401
|
|
67
|
+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction # noqa: F401
|
|
68
|
+
from liger_kernel.ops.rms_norm import LigerRMSNormFunction # noqa: F401
|
|
69
|
+
from liger_kernel.ops.rms_norm import rms_norm_backward # noqa: F401
|
|
70
|
+
from liger_kernel.ops.rms_norm import rms_norm_forward # noqa: F401
|
|
71
|
+
from liger_kernel.ops.rope import LigerRopeFunction # noqa: F401
|
|
72
|
+
from liger_kernel.ops.rope import rope_backward # noqa: F401
|
|
73
|
+
from liger_kernel.ops.rope import rope_forward # noqa: F401
|
|
74
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction # noqa: F401
|
|
75
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction # noqa: F401
|
|
76
|
+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction # noqa: F401
|
|
77
|
+
from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
|
|
78
|
+
from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
|
|
79
|
+
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
|
|
80
|
+
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
|
|
81
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401
|
|
82
|
+
|
|
83
|
+
# NOTE: __all__ is intentionally NOT defined.
|
|
84
|
+
# - Import from this package (liger_kernel.ops) -> subject to vendor replacement
|
|
85
|
+
# - Import from submodules (liger_kernel.ops.geglu) -> always use default implementation
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# =============================================================================
|
|
89
|
+
# Vendor-specific replacement logic
|
|
90
|
+
# =============================================================================
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _replace_with_vendor_ops():
|
|
94
|
+
"""
|
|
95
|
+
Replace/add vendor-specific operator implementations.
|
|
96
|
+
|
|
97
|
+
This function is called automatically on module load. It:
|
|
98
|
+
1. Detects the current device (cuda, npu, xpu, etc.)
|
|
99
|
+
2. Looks up the vendor for that device via VENDOR_REGISTRY
|
|
100
|
+
3. Loads and applies vendor-specific implementations
|
|
101
|
+
|
|
102
|
+
Vendor implementations should be placed in:
|
|
103
|
+
liger_kernel/ops/backends/_<vendor>/ops/
|
|
104
|
+
|
|
105
|
+
If the vendor module defines __all__, only those symbols are exported.
|
|
106
|
+
Otherwise, all public symbols (not starting with _) are auto-discovered.
|
|
107
|
+
|
|
108
|
+
Note: Vendor can both override existing ops AND add new vendor-specific ops.
|
|
109
|
+
"""
|
|
110
|
+
from liger_kernel.ops.backends import get_vendor_for_device
|
|
111
|
+
from liger_kernel.utils import infer_device
|
|
112
|
+
|
|
113
|
+
device = infer_device()
|
|
114
|
+
|
|
115
|
+
# Look up vendor info for this device
|
|
116
|
+
vendor_info = get_vendor_for_device(device)
|
|
117
|
+
if vendor_info is None:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
import importlib
|
|
122
|
+
|
|
123
|
+
vendor_ops = importlib.import_module(vendor_info.module_path)
|
|
124
|
+
|
|
125
|
+
# Get names to export: use __all__ if defined, otherwise auto-discover
|
|
126
|
+
names_to_export = getattr(vendor_ops, "__all__", None)
|
|
127
|
+
|
|
128
|
+
if names_to_export is None:
|
|
129
|
+
# Auto-discover: find all public symbols (classes and functions)
|
|
130
|
+
names_to_export = [name for name in dir(vendor_ops) if not name.startswith("_")]
|
|
131
|
+
|
|
132
|
+
# Replace or add to this module's globals
|
|
133
|
+
for name in names_to_export:
|
|
134
|
+
globals()[name] = getattr(vendor_ops, name)
|
|
135
|
+
|
|
136
|
+
except ImportError:
|
|
137
|
+
# Vendor module not available, use default implementations
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
_replace_with_vendor_ops()
|