liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +65 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +43 -13
- liger_kernel/ops/geglu.py +2 -1
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +7 -2
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +27 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +29 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +594 -19
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/utils.py +25 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +4 -1
- liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import torch.nn.functional as F
|
|
3
6
|
|
|
@@ -41,7 +44,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
41
44
|
temperature: float = 1.0,
|
|
42
45
|
compiled: bool = True,
|
|
43
46
|
chunk_size: int = 1024,
|
|
44
|
-
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
45
49
|
return super().forward(
|
|
46
50
|
cls=cls,
|
|
47
51
|
ctx=ctx,
|
|
@@ -59,11 +63,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
59
63
|
ignore_index=ignore_index,
|
|
60
64
|
temperature=temperature,
|
|
61
65
|
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
62
67
|
)
|
|
63
68
|
|
|
64
69
|
@staticmethod
|
|
65
|
-
def backward(ctx, grad_output):
|
|
66
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
67
72
|
|
|
68
73
|
return (
|
|
69
74
|
*grads,
|
|
@@ -75,6 +80,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
75
80
|
None, # temperature
|
|
76
81
|
None, # compiled
|
|
77
82
|
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
78
84
|
)
|
|
79
85
|
|
|
80
86
|
|
|
@@ -88,6 +94,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
88
94
|
temperature: float = 1.0,
|
|
89
95
|
compiled: bool = True,
|
|
90
96
|
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
91
98
|
):
|
|
92
99
|
super().__init__()
|
|
93
100
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -98,6 +105,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
98
105
|
self.compiled = compiled
|
|
99
106
|
self.beta = beta
|
|
100
107
|
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
101
109
|
|
|
102
110
|
def forward(
|
|
103
111
|
self,
|
|
@@ -108,7 +116,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
108
116
|
true_labels: torch.LongTensor,
|
|
109
117
|
student_bias: torch.Tensor = None,
|
|
110
118
|
teacher_bias: torch.Tensor = None,
|
|
111
|
-
) -> torch.Tensor:
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
112
120
|
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
113
121
|
student_input,
|
|
114
122
|
student_weight,
|
|
@@ -124,4 +132,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
124
132
|
self.temperature,
|
|
125
133
|
self.compiled,
|
|
126
134
|
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
127
136
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
159
|
compute_ce_loss=True,
|
|
158
160
|
temperature=1.0,
|
|
159
161
|
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
160
163
|
**loss_kwargs,
|
|
161
|
-
):
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
165
|
"""
|
|
163
166
|
Base class for fused linear layer with distillation loss.
|
|
164
167
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
183
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
184
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
185
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
187
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
188
|
"""
|
|
185
189
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
191
|
grad_inputs = []
|
|
188
192
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
193
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
196
|
|
|
191
197
|
loss_func_to_call = partial(
|
|
192
198
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
253
|
)
|
|
248
254
|
grad_weight.add_(chunk_grad_weight)
|
|
249
255
|
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
259
|
return chunk_grad_input
|
|
251
260
|
|
|
252
261
|
if compiled:
|
|
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
277
|
grad_weight,
|
|
269
278
|
grad_bias,
|
|
270
279
|
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
282
|
return loss_acc
|
|
272
283
|
|
|
273
284
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
275
286
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
287
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
288
|
grad_input = grad_input * grad_output
|
|
@@ -32,7 +32,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
32
32
|
epsilon_low=0.2,
|
|
33
33
|
epsilon_high=0.2,
|
|
34
34
|
beta=0.04,
|
|
35
|
-
loss_type="
|
|
35
|
+
loss_type="dapo",
|
|
36
36
|
max_completion_length=None,
|
|
37
37
|
importance_sampling_level="token",
|
|
38
38
|
temperature=1.0,
|
|
@@ -60,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
60
60
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
61
61
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
62
62
|
beta: Weight for the KL penalty
|
|
63
|
-
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
64
64
|
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
65
65
|
temperature: Temperature for the logits
|
|
66
66
|
compiled: Whether to use torch compile
|
|
@@ -244,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
244
244
|
|
|
245
245
|
return loss_acc, tuple(final_metrics)
|
|
246
246
|
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
247
262
|
@staticmethod
|
|
248
263
|
def _compute_chunk_loss(
|
|
249
264
|
input_chunk,
|
|
@@ -261,7 +276,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
261
276
|
epsilon_low=0.2,
|
|
262
277
|
epsilon_high=0.2,
|
|
263
278
|
beta=0.04,
|
|
264
|
-
loss_type="
|
|
279
|
+
loss_type="dapo",
|
|
265
280
|
max_completion_length=None,
|
|
266
281
|
importance_sampling_level="token",
|
|
267
282
|
temperature=1.0,
|
|
@@ -341,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
341
356
|
None, # grad_epsilon_low
|
|
342
357
|
None, # grad_epsilon_high
|
|
343
358
|
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
344
362
|
None, # grad_temperature
|
|
345
363
|
None, # grad_compiled
|
|
346
364
|
None, # grad_use_ref_model
|
|
347
365
|
None, # grad_chunk_size
|
|
348
|
-
None, # grad_loss_type
|
|
349
|
-
None, # grad_max_completion_length
|
|
350
366
|
)
|
|
@@ -29,7 +29,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
29
29
|
epsilon_low=0.2,
|
|
30
30
|
epsilon_high=0.2,
|
|
31
31
|
beta=0.04,
|
|
32
|
-
loss_type="
|
|
32
|
+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
34
|
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
35
35
|
**kwargs,
|
|
@@ -94,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
94
94
|
if max_completion_length is None:
|
|
95
95
|
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
96
96
|
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
elif loss_type == "dapo":
|
|
98
|
+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
|
|
99
|
+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
|
|
97
100
|
else:
|
|
98
101
|
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
99
102
|
|
|
@@ -135,7 +138,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
135
138
|
beta=0.04,
|
|
136
139
|
epsilon_low=0.2,
|
|
137
140
|
epsilon_high=0.2,
|
|
138
|
-
loss_type="
|
|
141
|
+
loss_type="dapo",
|
|
139
142
|
max_completion_length=None,
|
|
140
143
|
importance_sampling_level="token",
|
|
141
144
|
temperature=1.0,
|
|
@@ -157,7 +160,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
157
160
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
158
161
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
159
162
|
beta (float): Weight for the KL penalty
|
|
160
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
163
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
161
164
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
165
|
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
163
166
|
temperature (float): Temperature for the logits
|
|
@@ -235,7 +238,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
235
238
|
chunk_size: int = 1,
|
|
236
239
|
epsilon_low: float = 0.2,
|
|
237
240
|
epsilon_high: float = 0.2,
|
|
238
|
-
loss_type: str = "
|
|
241
|
+
loss_type: str = "dapo",
|
|
239
242
|
max_completion_length: Optional[int] = None,
|
|
240
243
|
importance_sampling_level: str = "token",
|
|
241
244
|
temperature: float = 1.0,
|
|
@@ -248,7 +251,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
248
251
|
chunk_size (int): Size of chunks for processing.
|
|
249
252
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
250
253
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
251
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
254
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
252
255
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
256
|
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
254
257
|
temperature (float): Temperature for the logits.
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
2
|
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
3
6
|
import torch
|
|
4
7
|
import torch.nn.functional as F
|
|
5
8
|
|
|
@@ -56,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
56
59
|
temperature: float = 1.0,
|
|
57
60
|
compiled: bool = True,
|
|
58
61
|
chunk_size: int = 1024,
|
|
62
|
+
return_soft_hard_loss: bool = False,
|
|
59
63
|
):
|
|
60
64
|
"""
|
|
61
65
|
Fused linear layer with JSD distillation loss.
|
|
@@ -72,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
72
76
|
temperature (float): Temperature for softening/sharpening distributions
|
|
73
77
|
compiled (bool): Whether to use torch compile
|
|
74
78
|
chunk_size (int): Size of chunks for processing.
|
|
79
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
75
80
|
Returns:
|
|
76
|
-
torch.Tensor: Computed loss
|
|
81
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
77
82
|
"""
|
|
78
83
|
return super().forward(
|
|
79
84
|
cls=cls,
|
|
@@ -92,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
92
97
|
ignore_index=ignore_index,
|
|
93
98
|
temperature=temperature,
|
|
94
99
|
compiled=compiled,
|
|
100
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
95
101
|
)
|
|
96
102
|
|
|
97
103
|
@staticmethod
|
|
98
|
-
def backward(ctx, grad_output):
|
|
99
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
104
|
+
def backward(ctx, grad_output, *args):
|
|
105
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
100
106
|
|
|
101
107
|
return (
|
|
102
108
|
*grads,
|
|
@@ -108,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
108
114
|
None, # temperature
|
|
109
115
|
None, # compiled
|
|
110
116
|
None, # chunk_size
|
|
117
|
+
None, # return_soft_hard_loss
|
|
111
118
|
)
|
|
112
119
|
|
|
113
120
|
|
|
@@ -125,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
132
|
temperature: float = 1.0,
|
|
126
133
|
compiled: bool = True,
|
|
127
134
|
chunk_size: int = 1024,
|
|
135
|
+
return_soft_hard_loss: bool = False,
|
|
128
136
|
):
|
|
129
137
|
"""
|
|
130
138
|
Args:
|
|
@@ -135,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
135
143
|
compiled (bool): Whether to use torch compile
|
|
136
144
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
137
145
|
chunk_size (int): Size of chunks for processing.
|
|
146
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
138
147
|
"""
|
|
139
148
|
super().__init__()
|
|
140
149
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -145,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
154
|
self.compiled = compiled
|
|
146
155
|
self.beta = beta
|
|
147
156
|
self.chunk_size = chunk_size
|
|
157
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
148
158
|
|
|
149
159
|
def forward(
|
|
150
160
|
self,
|
|
@@ -155,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
155
165
|
true_labels: torch.LongTensor,
|
|
156
166
|
student_bias: torch.Tensor = None,
|
|
157
167
|
teacher_bias: torch.Tensor = None,
|
|
158
|
-
) -> torch.Tensor:
|
|
168
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
159
169
|
"""
|
|
160
170
|
Compute the JSD distillation loss.
|
|
161
171
|
|
|
@@ -167,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
167
177
|
true_labels (torch.LongTensor): Target labels tensor
|
|
168
178
|
|
|
169
179
|
Returns:
|
|
170
|
-
torch.Tensor
|
|
180
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
182
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
171
183
|
"""
|
|
172
184
|
return LigerFusedLinearJSDFunction.apply(
|
|
173
185
|
student_input,
|
|
@@ -184,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
184
196
|
self.temperature,
|
|
185
197
|
self.compiled,
|
|
186
198
|
self.chunk_size,
|
|
199
|
+
self.return_soft_hard_loss,
|
|
187
200
|
)
|
|
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
|
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
13
14
|
|
|
14
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
15
16
|
try:
|
|
16
17
|
# typical import path with dispatch available
|
|
17
18
|
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
33
|
loss_ptr,
|
|
33
34
|
z_loss_ptr,
|
|
34
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
35
38
|
n_cols,
|
|
36
39
|
n_non_ignore,
|
|
37
40
|
sum_non_ignore_weight,
|
|
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
|
42
45
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
46
|
softcap,
|
|
44
47
|
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
49
|
BLOCK_SIZE: tl.constexpr,
|
|
46
50
|
HAS_WEIGHT: tl.constexpr,
|
|
47
51
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
@@ -60,6 +64,8 @@ def liger_cross_entropy_kernel(
|
|
|
60
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
61
65
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
62
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
63
69
|
n_cols (int): The number of columns in the input tensor.
|
|
64
70
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
65
71
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -69,7 +75,8 @@ def liger_cross_entropy_kernel(
|
|
|
69
75
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
70
76
|
reduction (str): The string for the reduction to apply
|
|
71
77
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
72
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
73
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
74
81
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
75
82
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
@@ -92,11 +99,17 @@ def liger_cross_entropy_kernel(
|
|
|
92
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
93
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
94
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
95
106
|
return
|
|
96
107
|
|
|
97
108
|
loss_ptr += program_id * loss_stride
|
|
98
109
|
if RETURN_Z_LOSS:
|
|
99
110
|
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
100
113
|
|
|
101
114
|
if HAS_WEIGHT:
|
|
102
115
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -107,6 +120,7 @@ def liger_cross_entropy_kernel(
|
|
|
107
120
|
# 3. [Online softmax] first pass: find max + sum
|
|
108
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
109
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
110
124
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
111
125
|
if HAS_SOFTCAPPING:
|
|
112
126
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -127,6 +141,16 @@ def liger_cross_entropy_kernel(
|
|
|
127
141
|
if HAS_SOFTCAPPING:
|
|
128
142
|
X_block = softcap * tanh(X_block / softcap)
|
|
129
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
130
154
|
if label_smoothing > 0:
|
|
131
155
|
# scale X beforehand to avoid overflow
|
|
132
156
|
if HAS_WEIGHT:
|
|
@@ -256,6 +280,10 @@ def liger_cross_entropy_kernel(
|
|
|
256
280
|
tl.store(loss_ptr, loss)
|
|
257
281
|
if RETURN_Z_LOSS:
|
|
258
282
|
tl.store(z_loss_ptr, z_loss)
|
|
283
|
+
if RETURN_TOKEN_ACCURACY:
|
|
284
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
285
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
286
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
259
287
|
|
|
260
288
|
|
|
261
289
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -274,8 +302,12 @@ def cross_entropy_forward(
|
|
|
274
302
|
reduction,
|
|
275
303
|
softcap,
|
|
276
304
|
return_z_loss,
|
|
305
|
+
return_token_accuracy=False,
|
|
277
306
|
):
|
|
278
307
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
308
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
309
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
310
|
+
)
|
|
279
311
|
|
|
280
312
|
BT, V = _input.shape
|
|
281
313
|
n_rows = BT
|
|
@@ -285,6 +317,9 @@ def cross_entropy_forward(
|
|
|
285
317
|
# unreduced loss
|
|
286
318
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
287
319
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
320
|
+
token_accuracy_1d = (
|
|
321
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
322
|
+
)
|
|
288
323
|
|
|
289
324
|
target_mask = target != ignore_index
|
|
290
325
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -321,6 +356,10 @@ def cross_entropy_forward(
|
|
|
321
356
|
loss_ptr=loss_1d,
|
|
322
357
|
z_loss_ptr=z_loss_1d,
|
|
323
358
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
359
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
360
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
361
|
+
if return_token_accuracy
|
|
362
|
+
else 0, # always 1 if accuracy is enabled
|
|
324
363
|
n_cols=V,
|
|
325
364
|
n_non_ignore=n_non_ignore,
|
|
326
365
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -331,6 +370,7 @@ def cross_entropy_forward(
|
|
|
331
370
|
reduction=reduction,
|
|
332
371
|
softcap=softcap,
|
|
333
372
|
RETURN_Z_LOSS=return_z_loss,
|
|
373
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
334
374
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
335
375
|
HAS_WEIGHT=True if weight is not None else False,
|
|
336
376
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
@@ -343,11 +383,14 @@ def cross_entropy_forward(
|
|
|
343
383
|
if reduction == "none":
|
|
344
384
|
loss = loss_1d
|
|
345
385
|
z_loss = z_loss_1d if return_z_loss else None
|
|
386
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
346
387
|
else:
|
|
347
388
|
loss = torch.sum(loss_1d)
|
|
348
389
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
390
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
391
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
349
392
|
|
|
350
|
-
return loss, z_loss, _input
|
|
393
|
+
return loss, z_loss, token_accuracy, _input
|
|
351
394
|
|
|
352
395
|
|
|
353
396
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -395,6 +438,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
395
438
|
reduction: str = "mean",
|
|
396
439
|
softcap: Optional[float] = None,
|
|
397
440
|
return_z_loss: bool = False,
|
|
441
|
+
return_token_accuracy: bool = False,
|
|
398
442
|
):
|
|
399
443
|
"""
|
|
400
444
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -409,12 +453,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
409
453
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
410
454
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
411
455
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
412
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
456
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
457
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
413
458
|
|
|
414
459
|
Returns:
|
|
415
|
-
tuple: A tuple with the
|
|
460
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
416
461
|
"""
|
|
417
|
-
|
|
462
|
+
input_requires_grad = _input.requires_grad
|
|
463
|
+
|
|
464
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
418
465
|
_input,
|
|
419
466
|
target,
|
|
420
467
|
weight,
|
|
@@ -424,29 +471,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
424
471
|
reduction,
|
|
425
472
|
softcap,
|
|
426
473
|
return_z_loss,
|
|
474
|
+
return_token_accuracy,
|
|
427
475
|
)
|
|
428
476
|
# TODO: investigation
|
|
429
477
|
# If we don't detach the _input tensor, the memory will double
|
|
430
478
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
431
|
-
|
|
479
|
+
if input_requires_grad:
|
|
480
|
+
ctx.save_for_backward(_input.detach())
|
|
432
481
|
ctx.return_z_loss = return_z_loss
|
|
482
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
433
483
|
|
|
434
|
-
return loss, z_loss
|
|
484
|
+
return loss, z_loss, token_accuracy
|
|
435
485
|
|
|
436
486
|
@staticmethod
|
|
437
|
-
def backward(ctx, grad_output,
|
|
487
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
438
488
|
"""
|
|
439
489
|
The backward pass of the Liger Cross Entropy loss.
|
|
440
490
|
|
|
441
491
|
Parameters:
|
|
442
492
|
ctx : The context object with saved tensors.
|
|
443
493
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
444
|
-
grad_output2 (
|
|
494
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
495
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
445
496
|
Returns:
|
|
446
497
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
447
498
|
"""
|
|
448
499
|
if ctx.return_z_loss:
|
|
449
|
-
del
|
|
500
|
+
del grad_output2 # z_loss is only for logging
|
|
501
|
+
if ctx.return_token_accuracy:
|
|
502
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
450
503
|
|
|
451
504
|
(_input,) = ctx.saved_tensors
|
|
452
505
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -460,4 +513,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
460
513
|
None,
|
|
461
514
|
None,
|
|
462
515
|
None,
|
|
516
|
+
None,
|
|
463
517
|
)
|
liger_kernel/ops/dyt.py
CHANGED
|
@@ -7,8 +7,10 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
9
|
from liger_kernel.ops.utils import infer_device
|
|
10
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import tanh
|
|
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
|
125
127
|
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
126
128
|
elif device == "xpu":
|
|
127
129
|
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
128
|
-
|
|
130
|
+
elif device == "npu":
|
|
131
|
+
NUM_SMS = get_npu_multi_processor_count()
|
|
129
132
|
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
130
133
|
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
131
134
|
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|
|
@@ -9,8 +9,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
|
10
10
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
11
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
12
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
12
14
|
|
|
13
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
16
|
try:
|
|
15
17
|
# typical import path with dispatch available
|
|
16
18
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
|
|
|
293
295
|
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
|
|
294
296
|
elif S.device.type == "xpu":
|
|
295
297
|
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
|
|
298
|
+
elif S.device.type == "npu":
|
|
299
|
+
sm_count = get_npu_multi_processor_count()
|
|
296
300
|
|
|
297
301
|
# fp32 for numerical stability especially.
|
|
298
302
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|