liger-kernel 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +18 -5
- liger_kernel/ops/cross_entropy.py +59 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +30 -4
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +84 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +19 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/functional.py +24 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +25 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +22 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +398 -20
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/METADATA +4 -1
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/RECORD +55 -48
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/WHEEL +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.4.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import torch.nn.functional as F
|
|
3
6
|
|
|
@@ -41,7 +44,8 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
41
44
|
temperature: float = 1.0,
|
|
42
45
|
compiled: bool = True,
|
|
43
46
|
chunk_size: int = 1024,
|
|
44
|
-
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
45
49
|
return super().forward(
|
|
46
50
|
cls=cls,
|
|
47
51
|
ctx=ctx,
|
|
@@ -59,11 +63,12 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
59
63
|
ignore_index=ignore_index,
|
|
60
64
|
temperature=temperature,
|
|
61
65
|
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
62
67
|
)
|
|
63
68
|
|
|
64
69
|
@staticmethod
|
|
65
|
-
def backward(ctx, grad_output):
|
|
66
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
67
72
|
|
|
68
73
|
return (
|
|
69
74
|
*grads,
|
|
@@ -75,6 +80,7 @@ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase)
|
|
|
75
80
|
None, # temperature
|
|
76
81
|
None, # compiled
|
|
77
82
|
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
78
84
|
)
|
|
79
85
|
|
|
80
86
|
|
|
@@ -88,6 +94,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
88
94
|
temperature: float = 1.0,
|
|
89
95
|
compiled: bool = True,
|
|
90
96
|
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
91
98
|
):
|
|
92
99
|
super().__init__()
|
|
93
100
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -98,6 +105,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
98
105
|
self.compiled = compiled
|
|
99
106
|
self.beta = beta
|
|
100
107
|
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
101
109
|
|
|
102
110
|
def forward(
|
|
103
111
|
self,
|
|
@@ -108,7 +116,7 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
108
116
|
true_labels: torch.LongTensor,
|
|
109
117
|
student_bias: torch.Tensor = None,
|
|
110
118
|
teacher_bias: torch.Tensor = None,
|
|
111
|
-
) -> torch.Tensor:
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
112
120
|
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
113
121
|
student_input,
|
|
114
122
|
student_weight,
|
|
@@ -124,4 +132,5 @@ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
|
124
132
|
self.temperature,
|
|
125
133
|
self.compiled,
|
|
126
134
|
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
127
136
|
)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
159
|
compute_ce_loss=True,
|
|
158
160
|
temperature=1.0,
|
|
159
161
|
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
160
163
|
**loss_kwargs,
|
|
161
|
-
):
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
165
|
"""
|
|
163
166
|
Base class for fused linear layer with distillation loss.
|
|
164
167
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
183
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
184
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
185
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
187
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
188
|
"""
|
|
185
189
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
191
|
grad_inputs = []
|
|
188
192
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
193
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
196
|
|
|
191
197
|
loss_func_to_call = partial(
|
|
192
198
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
253
|
)
|
|
248
254
|
grad_weight.add_(chunk_grad_weight)
|
|
249
255
|
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
259
|
return chunk_grad_input
|
|
251
260
|
|
|
252
261
|
if compiled:
|
|
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
277
|
grad_weight,
|
|
269
278
|
grad_bias,
|
|
270
279
|
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
282
|
return loss_acc
|
|
272
283
|
|
|
273
284
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
275
286
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
287
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
288
|
grad_input = grad_input * grad_output
|
|
@@ -32,7 +32,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
32
32
|
epsilon_low=0.2,
|
|
33
33
|
epsilon_high=0.2,
|
|
34
34
|
beta=0.04,
|
|
35
|
-
loss_type="
|
|
35
|
+
loss_type="dapo",
|
|
36
36
|
max_completion_length=None,
|
|
37
37
|
importance_sampling_level="token",
|
|
38
38
|
temperature=1.0,
|
|
@@ -60,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
60
60
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
61
61
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
62
62
|
beta: Weight for the KL penalty
|
|
63
|
-
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
64
64
|
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
65
65
|
temperature: Temperature for the logits
|
|
66
66
|
compiled: Whether to use torch compile
|
|
@@ -244,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
244
244
|
|
|
245
245
|
return loss_acc, tuple(final_metrics)
|
|
246
246
|
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
247
262
|
@staticmethod
|
|
248
263
|
def _compute_chunk_loss(
|
|
249
264
|
input_chunk,
|
|
@@ -261,7 +276,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
261
276
|
epsilon_low=0.2,
|
|
262
277
|
epsilon_high=0.2,
|
|
263
278
|
beta=0.04,
|
|
264
|
-
loss_type="
|
|
279
|
+
loss_type="dapo",
|
|
265
280
|
max_completion_length=None,
|
|
266
281
|
importance_sampling_level="token",
|
|
267
282
|
temperature=1.0,
|
|
@@ -341,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
341
356
|
None, # grad_epsilon_low
|
|
342
357
|
None, # grad_epsilon_high
|
|
343
358
|
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
344
362
|
None, # grad_temperature
|
|
345
363
|
None, # grad_compiled
|
|
346
364
|
None, # grad_use_ref_model
|
|
347
365
|
None, # grad_chunk_size
|
|
348
|
-
None, # grad_loss_type
|
|
349
|
-
None, # grad_max_completion_length
|
|
350
366
|
)
|
|
@@ -29,7 +29,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
29
29
|
epsilon_low=0.2,
|
|
30
30
|
epsilon_high=0.2,
|
|
31
31
|
beta=0.04,
|
|
32
|
-
loss_type="
|
|
32
|
+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
34
|
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
35
35
|
**kwargs,
|
|
@@ -94,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
94
94
|
if max_completion_length is None:
|
|
95
95
|
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
96
96
|
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
elif loss_type == "dapo":
|
|
98
|
+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
|
|
99
|
+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
|
|
97
100
|
else:
|
|
98
101
|
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
99
102
|
|
|
@@ -135,7 +138,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
135
138
|
beta=0.04,
|
|
136
139
|
epsilon_low=0.2,
|
|
137
140
|
epsilon_high=0.2,
|
|
138
|
-
loss_type="
|
|
141
|
+
loss_type="dapo",
|
|
139
142
|
max_completion_length=None,
|
|
140
143
|
importance_sampling_level="token",
|
|
141
144
|
temperature=1.0,
|
|
@@ -157,7 +160,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
157
160
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
158
161
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
159
162
|
beta (float): Weight for the KL penalty
|
|
160
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
163
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
161
164
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
165
|
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
163
166
|
temperature (float): Temperature for the logits
|
|
@@ -235,7 +238,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
235
238
|
chunk_size: int = 1,
|
|
236
239
|
epsilon_low: float = 0.2,
|
|
237
240
|
epsilon_high: float = 0.2,
|
|
238
|
-
loss_type: str = "
|
|
241
|
+
loss_type: str = "dapo",
|
|
239
242
|
max_completion_length: Optional[int] = None,
|
|
240
243
|
importance_sampling_level: str = "token",
|
|
241
244
|
temperature: float = 1.0,
|
|
@@ -248,7 +251,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
248
251
|
chunk_size (int): Size of chunks for processing.
|
|
249
252
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
250
253
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
251
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
254
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
252
255
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
256
|
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
254
257
|
temperature (float): Temperature for the logits.
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
2
|
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
3
6
|
import torch
|
|
4
7
|
import torch.nn.functional as F
|
|
5
8
|
|
|
@@ -56,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
56
59
|
temperature: float = 1.0,
|
|
57
60
|
compiled: bool = True,
|
|
58
61
|
chunk_size: int = 1024,
|
|
62
|
+
return_soft_hard_loss: bool = False,
|
|
59
63
|
):
|
|
60
64
|
"""
|
|
61
65
|
Fused linear layer with JSD distillation loss.
|
|
@@ -72,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
72
76
|
temperature (float): Temperature for softening/sharpening distributions
|
|
73
77
|
compiled (bool): Whether to use torch compile
|
|
74
78
|
chunk_size (int): Size of chunks for processing.
|
|
79
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
75
80
|
Returns:
|
|
76
|
-
torch.Tensor: Computed loss
|
|
81
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
77
82
|
"""
|
|
78
83
|
return super().forward(
|
|
79
84
|
cls=cls,
|
|
@@ -92,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
92
97
|
ignore_index=ignore_index,
|
|
93
98
|
temperature=temperature,
|
|
94
99
|
compiled=compiled,
|
|
100
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
95
101
|
)
|
|
96
102
|
|
|
97
103
|
@staticmethod
|
|
98
|
-
def backward(ctx, grad_output):
|
|
99
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
104
|
+
def backward(ctx, grad_output, *args):
|
|
105
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
100
106
|
|
|
101
107
|
return (
|
|
102
108
|
*grads,
|
|
@@ -108,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
108
114
|
None, # temperature
|
|
109
115
|
None, # compiled
|
|
110
116
|
None, # chunk_size
|
|
117
|
+
None, # return_soft_hard_loss
|
|
111
118
|
)
|
|
112
119
|
|
|
113
120
|
|
|
@@ -125,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
132
|
temperature: float = 1.0,
|
|
126
133
|
compiled: bool = True,
|
|
127
134
|
chunk_size: int = 1024,
|
|
135
|
+
return_soft_hard_loss: bool = False,
|
|
128
136
|
):
|
|
129
137
|
"""
|
|
130
138
|
Args:
|
|
@@ -135,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
135
143
|
compiled (bool): Whether to use torch compile
|
|
136
144
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
137
145
|
chunk_size (int): Size of chunks for processing.
|
|
146
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
138
147
|
"""
|
|
139
148
|
super().__init__()
|
|
140
149
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -145,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
154
|
self.compiled = compiled
|
|
146
155
|
self.beta = beta
|
|
147
156
|
self.chunk_size = chunk_size
|
|
157
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
148
158
|
|
|
149
159
|
def forward(
|
|
150
160
|
self,
|
|
@@ -155,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
155
165
|
true_labels: torch.LongTensor,
|
|
156
166
|
student_bias: torch.Tensor = None,
|
|
157
167
|
teacher_bias: torch.Tensor = None,
|
|
158
|
-
) -> torch.Tensor:
|
|
168
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
159
169
|
"""
|
|
160
170
|
Compute the JSD distillation loss.
|
|
161
171
|
|
|
@@ -167,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
167
177
|
true_labels (torch.LongTensor): Target labels tensor
|
|
168
178
|
|
|
169
179
|
Returns:
|
|
170
|
-
torch.Tensor
|
|
180
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
182
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
171
183
|
"""
|
|
172
184
|
return LigerFusedLinearJSDFunction.apply(
|
|
173
185
|
student_input,
|
|
@@ -184,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
184
196
|
self.temperature,
|
|
185
197
|
self.compiled,
|
|
186
198
|
self.chunk_size,
|
|
199
|
+
self.return_soft_hard_loss,
|
|
187
200
|
)
|
|
@@ -32,6 +32,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
32
|
loss_ptr,
|
|
33
33
|
z_loss_ptr,
|
|
34
34
|
loss_stride,
|
|
35
|
+
token_accuracy_ptr,
|
|
36
|
+
token_accuracy_stride,
|
|
35
37
|
n_cols,
|
|
36
38
|
n_non_ignore,
|
|
37
39
|
sum_non_ignore_weight,
|
|
@@ -42,6 +44,7 @@ def liger_cross_entropy_kernel(
|
|
|
42
44
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
45
|
softcap,
|
|
44
46
|
RETURN_Z_LOSS: tl.constexpr,
|
|
47
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
48
|
BLOCK_SIZE: tl.constexpr,
|
|
46
49
|
HAS_WEIGHT: tl.constexpr,
|
|
47
50
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
@@ -60,6 +63,8 @@ def liger_cross_entropy_kernel(
|
|
|
60
63
|
loss_ptr: Pointer to tensor to store the loss.
|
|
61
64
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
62
65
|
loss_stride (int): The stride of the loss tensor.
|
|
66
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
67
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
63
68
|
n_cols (int): The number of columns in the input tensor.
|
|
64
69
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
65
70
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -69,7 +74,8 @@ def liger_cross_entropy_kernel(
|
|
|
69
74
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
70
75
|
reduction (str): The string for the reduction to apply
|
|
71
76
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
72
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
77
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
78
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
73
79
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
74
80
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
75
81
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
@@ -92,11 +98,17 @@ def liger_cross_entropy_kernel(
|
|
|
92
98
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
93
99
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
94
100
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
101
|
+
# For ignored tokens, set token accuracy to 0
|
|
102
|
+
if RETURN_TOKEN_ACCURACY:
|
|
103
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
104
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
95
105
|
return
|
|
96
106
|
|
|
97
107
|
loss_ptr += program_id * loss_stride
|
|
98
108
|
if RETURN_Z_LOSS:
|
|
99
109
|
z_loss_ptr += program_id * loss_stride
|
|
110
|
+
if RETURN_TOKEN_ACCURACY:
|
|
111
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
100
112
|
|
|
101
113
|
if HAS_WEIGHT:
|
|
102
114
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -107,6 +119,7 @@ def liger_cross_entropy_kernel(
|
|
|
107
119
|
# 3. [Online softmax] first pass: find max + sum
|
|
108
120
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
109
121
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
122
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
110
123
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
111
124
|
if HAS_SOFTCAPPING:
|
|
112
125
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -127,6 +140,16 @@ def liger_cross_entropy_kernel(
|
|
|
127
140
|
if HAS_SOFTCAPPING:
|
|
128
141
|
X_block = softcap * tanh(X_block / softcap)
|
|
129
142
|
block_max = tl.max(X_block)
|
|
143
|
+
|
|
144
|
+
# Track argmax for accuracy computation
|
|
145
|
+
if RETURN_TOKEN_ACCURACY and block_max > m:
|
|
146
|
+
# Find the index of the maximum value in this block
|
|
147
|
+
is_max_mask = X_block == block_max
|
|
148
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
149
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
150
|
+
# Get the first (smallest) index where max occurs
|
|
151
|
+
argmax_idx = tl.min(masked_offsets)
|
|
152
|
+
|
|
130
153
|
if label_smoothing > 0:
|
|
131
154
|
# scale X beforehand to avoid overflow
|
|
132
155
|
if HAS_WEIGHT:
|
|
@@ -256,6 +279,10 @@ def liger_cross_entropy_kernel(
|
|
|
256
279
|
tl.store(loss_ptr, loss)
|
|
257
280
|
if RETURN_Z_LOSS:
|
|
258
281
|
tl.store(z_loss_ptr, z_loss)
|
|
282
|
+
if RETURN_TOKEN_ACCURACY:
|
|
283
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
284
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
285
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
259
286
|
|
|
260
287
|
|
|
261
288
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -274,8 +301,12 @@ def cross_entropy_forward(
|
|
|
274
301
|
reduction,
|
|
275
302
|
softcap,
|
|
276
303
|
return_z_loss,
|
|
304
|
+
return_token_accuracy=False,
|
|
277
305
|
):
|
|
278
306
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
307
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
308
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
309
|
+
)
|
|
279
310
|
|
|
280
311
|
BT, V = _input.shape
|
|
281
312
|
n_rows = BT
|
|
@@ -285,6 +316,9 @@ def cross_entropy_forward(
|
|
|
285
316
|
# unreduced loss
|
|
286
317
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
287
318
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
319
|
+
token_accuracy_1d = (
|
|
320
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
321
|
+
)
|
|
288
322
|
|
|
289
323
|
target_mask = target != ignore_index
|
|
290
324
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -321,6 +355,10 @@ def cross_entropy_forward(
|
|
|
321
355
|
loss_ptr=loss_1d,
|
|
322
356
|
z_loss_ptr=z_loss_1d,
|
|
323
357
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
358
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
359
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
360
|
+
if return_token_accuracy
|
|
361
|
+
else 0, # always 1 if accuracy is enabled
|
|
324
362
|
n_cols=V,
|
|
325
363
|
n_non_ignore=n_non_ignore,
|
|
326
364
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -331,6 +369,7 @@ def cross_entropy_forward(
|
|
|
331
369
|
reduction=reduction,
|
|
332
370
|
softcap=softcap,
|
|
333
371
|
RETURN_Z_LOSS=return_z_loss,
|
|
372
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
334
373
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
335
374
|
HAS_WEIGHT=True if weight is not None else False,
|
|
336
375
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
@@ -343,11 +382,14 @@ def cross_entropy_forward(
|
|
|
343
382
|
if reduction == "none":
|
|
344
383
|
loss = loss_1d
|
|
345
384
|
z_loss = z_loss_1d if return_z_loss else None
|
|
385
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
346
386
|
else:
|
|
347
387
|
loss = torch.sum(loss_1d)
|
|
348
388
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
389
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
390
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
349
391
|
|
|
350
|
-
return loss, z_loss, _input
|
|
392
|
+
return loss, z_loss, token_accuracy, _input
|
|
351
393
|
|
|
352
394
|
|
|
353
395
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -395,6 +437,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
395
437
|
reduction: str = "mean",
|
|
396
438
|
softcap: Optional[float] = None,
|
|
397
439
|
return_z_loss: bool = False,
|
|
440
|
+
return_token_accuracy: bool = False,
|
|
398
441
|
):
|
|
399
442
|
"""
|
|
400
443
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -409,14 +452,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
409
452
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
410
453
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
411
454
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
412
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
455
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
456
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
413
457
|
|
|
414
458
|
Returns:
|
|
415
|
-
tuple: A tuple with the
|
|
459
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
416
460
|
"""
|
|
417
461
|
input_requires_grad = _input.requires_grad
|
|
418
462
|
|
|
419
|
-
loss, z_loss, _input = cross_entropy_forward(
|
|
463
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
420
464
|
_input,
|
|
421
465
|
target,
|
|
422
466
|
weight,
|
|
@@ -426,6 +470,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
426
470
|
reduction,
|
|
427
471
|
softcap,
|
|
428
472
|
return_z_loss,
|
|
473
|
+
return_token_accuracy,
|
|
429
474
|
)
|
|
430
475
|
# TODO: investigation
|
|
431
476
|
# If we don't detach the _input tensor, the memory will double
|
|
@@ -433,23 +478,27 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
433
478
|
if input_requires_grad:
|
|
434
479
|
ctx.save_for_backward(_input.detach())
|
|
435
480
|
ctx.return_z_loss = return_z_loss
|
|
481
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
436
482
|
|
|
437
|
-
return loss, z_loss
|
|
483
|
+
return loss, z_loss, token_accuracy
|
|
438
484
|
|
|
439
485
|
@staticmethod
|
|
440
|
-
def backward(ctx, grad_output,
|
|
486
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
441
487
|
"""
|
|
442
488
|
The backward pass of the Liger Cross Entropy loss.
|
|
443
489
|
|
|
444
490
|
Parameters:
|
|
445
491
|
ctx : The context object with saved tensors.
|
|
446
492
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
447
|
-
grad_output2 (
|
|
493
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
494
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
448
495
|
Returns:
|
|
449
496
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
450
497
|
"""
|
|
451
498
|
if ctx.return_z_loss:
|
|
452
|
-
del
|
|
499
|
+
del grad_output2 # z_loss is only for logging
|
|
500
|
+
if ctx.return_token_accuracy:
|
|
501
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
453
502
|
|
|
454
503
|
(_input,) = ctx.saved_tensors
|
|
455
504
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -463,4 +512,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
463
512
|
None,
|
|
464
513
|
None,
|
|
465
514
|
None,
|
|
515
|
+
None,
|
|
466
516
|
)
|