liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.4.dev20251121224847__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +23 -7
- liger_kernel/ops/cross_entropy.py +118 -62
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +133 -79
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/rms_norm.py +2 -2
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +59 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +38 -6
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
- liger_kernel/transformers/grpo_loss.py +56 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +31 -8
- liger_kernel/transformers/model/gemma3.py +100 -110
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1278 -116
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/swiglu.py +17 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/METADATA +29 -24
- liger_kernel_nightly-0.6.4.dev20251121224847.dist-info/RECORD +118 -0
- liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.4.dev20251121224847.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
|
|
13
|
+
"""
|
|
14
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
|
15
|
+
Args:
|
|
16
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
17
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
18
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
|
19
|
+
Returns:
|
|
20
|
+
torch.Tensor: cosine similarity loss
|
|
21
|
+
"""
|
|
22
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
|
23
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
|
24
|
+
|
|
25
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
|
26
|
+
loss = beta * (1 - cosine_sim)
|
|
27
|
+
return loss.sum()
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def forward(
|
|
31
|
+
cls,
|
|
32
|
+
ctx,
|
|
33
|
+
student_input: torch.Tensor,
|
|
34
|
+
student_weight: torch.Tensor,
|
|
35
|
+
teacher_input: torch.Tensor,
|
|
36
|
+
teacher_weight: torch.Tensor,
|
|
37
|
+
true_labels: torch.LongTensor,
|
|
38
|
+
student_bias: torch.Tensor,
|
|
39
|
+
teacher_bias: torch.Tensor,
|
|
40
|
+
weight_hard_loss: float = 0.5,
|
|
41
|
+
weight_soft_loss: float = 0.5,
|
|
42
|
+
beta: float = 0.5,
|
|
43
|
+
ignore_index: int = -100,
|
|
44
|
+
temperature: float = 1.0,
|
|
45
|
+
compiled: bool = True,
|
|
46
|
+
chunk_size: int = 1024,
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
49
|
+
return super().forward(
|
|
50
|
+
cls=cls,
|
|
51
|
+
ctx=ctx,
|
|
52
|
+
student_input=student_input,
|
|
53
|
+
student_weight=student_weight,
|
|
54
|
+
teacher_input=teacher_input,
|
|
55
|
+
teacher_weight=teacher_weight,
|
|
56
|
+
target=true_labels,
|
|
57
|
+
student_bias=student_bias,
|
|
58
|
+
teacher_bias=teacher_bias,
|
|
59
|
+
chunk_size=chunk_size,
|
|
60
|
+
weight_hard_loss=weight_hard_loss,
|
|
61
|
+
weight_soft_loss=weight_soft_loss,
|
|
62
|
+
beta=beta,
|
|
63
|
+
ignore_index=ignore_index,
|
|
64
|
+
temperature=temperature,
|
|
65
|
+
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
72
|
+
|
|
73
|
+
return (
|
|
74
|
+
*grads,
|
|
75
|
+
None, # teacher_bias
|
|
76
|
+
None, # weight_hard_loss
|
|
77
|
+
None, # weight_soft_loss
|
|
78
|
+
None, # beta
|
|
79
|
+
None, # ignore_index
|
|
80
|
+
None, # temperature
|
|
81
|
+
None, # compiled
|
|
82
|
+
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
weight_hard_loss: float = 0.5,
|
|
91
|
+
weight_soft_loss: float = 0.5,
|
|
92
|
+
beta: float = 0.5,
|
|
93
|
+
ignore_index: int = -100,
|
|
94
|
+
temperature: float = 1.0,
|
|
95
|
+
compiled: bool = True,
|
|
96
|
+
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
101
|
+
self.weight_hard_loss = weight_hard_loss
|
|
102
|
+
self.weight_soft_loss = weight_soft_loss
|
|
103
|
+
self.ignore_index = ignore_index
|
|
104
|
+
self.temperature = temperature
|
|
105
|
+
self.compiled = compiled
|
|
106
|
+
self.beta = beta
|
|
107
|
+
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
109
|
+
|
|
110
|
+
def forward(
|
|
111
|
+
self,
|
|
112
|
+
student_input: torch.Tensor,
|
|
113
|
+
student_weight: torch.Tensor,
|
|
114
|
+
teacher_input: torch.Tensor,
|
|
115
|
+
teacher_weight: torch.Tensor,
|
|
116
|
+
true_labels: torch.LongTensor,
|
|
117
|
+
student_bias: torch.Tensor = None,
|
|
118
|
+
teacher_bias: torch.Tensor = None,
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
120
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
121
|
+
student_input,
|
|
122
|
+
student_weight,
|
|
123
|
+
teacher_input,
|
|
124
|
+
teacher_weight,
|
|
125
|
+
true_labels,
|
|
126
|
+
student_bias,
|
|
127
|
+
teacher_bias,
|
|
128
|
+
self.weight_hard_loss,
|
|
129
|
+
self.weight_soft_loss,
|
|
130
|
+
self.beta,
|
|
131
|
+
self.ignore_index,
|
|
132
|
+
self.temperature,
|
|
133
|
+
self.compiled,
|
|
134
|
+
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
136
|
+
)
|
|
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
13
13
|
ref_chosen_logps=None,
|
|
14
14
|
ref_rejected_logps=None,
|
|
15
15
|
beta=0.1,
|
|
16
|
+
loss_type="sigmoid",
|
|
16
17
|
):
|
|
17
18
|
"""
|
|
18
19
|
Paper: https://arxiv.org/pdf/2305.18290
|
|
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
49
|
chosen_rewards = beta * chosen_logratios
|
|
49
50
|
rejected_rewards = beta * rejected_logratios
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
if loss_type == "sigmoid":
|
|
53
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
54
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
55
|
+
|
|
56
|
+
elif loss_type == "apo_zero":
|
|
57
|
+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
58
|
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
|
59
|
+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
|
|
60
|
+
losses_rejected = F.sigmoid(beta * rejected_logratios)
|
|
61
|
+
losses = losses_chosen + losses_rejected
|
|
62
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
63
|
+
|
|
64
|
+
elif loss_type == "apo_down":
|
|
65
|
+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
66
|
+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
|
|
67
|
+
# Decrease chosen likelihood and decrease rejected likelihood more
|
|
68
|
+
losses_chosen = F.sigmoid(beta * chosen_logratios)
|
|
69
|
+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
|
|
70
|
+
losses = losses_chosen + losses_rejected
|
|
71
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
72
|
+
|
|
73
|
+
elif loss_type == "sppo_hard":
|
|
74
|
+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
|
|
75
|
+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
|
|
76
|
+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
|
|
77
|
+
# set to 1 for the winner and 0 for the loser.
|
|
78
|
+
a = chosen_logps - ref_chosen_logps
|
|
79
|
+
b = rejected_logps - ref_rejected_logps
|
|
80
|
+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
|
|
81
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
82
|
+
|
|
83
|
+
elif loss_type == "nca_pair":
|
|
84
|
+
losses = (
|
|
85
|
+
-F.logsigmoid(chosen_rewards)
|
|
86
|
+
- 0.5 * F.logsigmoid(-chosen_rewards)
|
|
87
|
+
- 0.5 * F.logsigmoid(-rejected_rewards)
|
|
88
|
+
)
|
|
89
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
90
|
+
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
|
|
94
|
+
)
|
|
95
|
+
|
|
53
96
|
return loss, chosen_rewards, rejected_rewards
|
|
54
97
|
|
|
55
98
|
@classmethod
|
|
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
70
113
|
use_ref_model=True,
|
|
71
114
|
average_log_prob=False,
|
|
72
115
|
chunk_size=1,
|
|
116
|
+
loss_type="sigmoid",
|
|
73
117
|
):
|
|
74
118
|
"""
|
|
75
119
|
Fused linear layer with DPO loss.
|
|
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
108
152
|
ref_bias=ref_bias,
|
|
109
153
|
average_log_prob=average_log_prob,
|
|
110
154
|
chunk_size=chunk_size,
|
|
155
|
+
loss_type=loss_type,
|
|
111
156
|
)
|
|
112
157
|
|
|
113
158
|
@staticmethod
|
|
114
159
|
def backward(ctx, *grad_output):
|
|
115
160
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
116
|
-
return *grads, None, None, None, None, None, None, None, None, None, None
|
|
161
|
+
return *grads, None, None, None, None, None, None, None, None, None, None, None
|
|
117
162
|
|
|
118
163
|
|
|
119
164
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
130
175
|
use_ref_model: bool = True,
|
|
131
176
|
average_log_prob: bool = False,
|
|
132
177
|
chunk_size: int = 1,
|
|
178
|
+
loss_type: str = "sigmoid",
|
|
133
179
|
):
|
|
134
180
|
"""
|
|
135
181
|
Args:
|
|
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
149
195
|
self.use_ref_model = use_ref_model
|
|
150
196
|
self.average_log_prob = average_log_prob
|
|
151
197
|
self.chunk_size = chunk_size
|
|
198
|
+
self.loss_type = loss_type
|
|
199
|
+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
|
|
200
|
+
if self.loss_type not in supported_loss_types:
|
|
201
|
+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
|
|
152
202
|
|
|
153
203
|
def forward(
|
|
154
204
|
self,
|
|
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
175
225
|
self.use_ref_model,
|
|
176
226
|
self.average_log_prob,
|
|
177
227
|
self.chunk_size,
|
|
228
|
+
self.loss_type,
|
|
178
229
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
|
9
10
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
10
11
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
11
12
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
|
12
14
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
13
15
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
14
16
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
159
|
compute_ce_loss=True,
|
|
158
160
|
temperature=1.0,
|
|
159
161
|
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
160
163
|
**loss_kwargs,
|
|
161
|
-
):
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
165
|
"""
|
|
163
166
|
Base class for fused linear layer with distillation loss.
|
|
164
167
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
183
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
184
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
185
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
187
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
188
|
"""
|
|
185
189
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
191
|
grad_inputs = []
|
|
188
192
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
193
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
196
|
|
|
191
197
|
loss_func_to_call = partial(
|
|
192
198
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
253
|
)
|
|
248
254
|
grad_weight.add_(chunk_grad_weight)
|
|
249
255
|
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
259
|
return chunk_grad_input
|
|
251
260
|
|
|
252
261
|
if compiled:
|
|
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
277
|
grad_weight,
|
|
269
278
|
grad_bias,
|
|
270
279
|
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
282
|
return loss_acc
|
|
272
283
|
|
|
273
284
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
275
286
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
287
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
288
|
grad_input = grad_input * grad_output
|
|
@@ -32,8 +32,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
32
32
|
epsilon_low=0.2,
|
|
33
33
|
epsilon_high=0.2,
|
|
34
34
|
beta=0.04,
|
|
35
|
-
loss_type="
|
|
35
|
+
loss_type="dapo",
|
|
36
36
|
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
37
38
|
temperature=1.0,
|
|
38
39
|
compiled=True,
|
|
39
40
|
use_ref_model=False,
|
|
@@ -59,7 +60,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
59
60
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
60
61
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
61
62
|
beta: Weight for the KL penalty
|
|
62
|
-
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
63
64
|
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
64
65
|
temperature: Temperature for the logits
|
|
65
66
|
compiled: Whether to use torch compile
|
|
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
92
93
|
beta=beta,
|
|
93
94
|
loss_type=loss_type,
|
|
94
95
|
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
95
97
|
temperature=temperature,
|
|
96
98
|
use_ref_model=use_ref_model,
|
|
97
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
|
@@ -242,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
242
244
|
|
|
243
245
|
return loss_acc, tuple(final_metrics)
|
|
244
246
|
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
245
262
|
@staticmethod
|
|
246
263
|
def _compute_chunk_loss(
|
|
247
264
|
input_chunk,
|
|
@@ -259,8 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
259
276
|
epsilon_low=0.2,
|
|
260
277
|
epsilon_high=0.2,
|
|
261
278
|
beta=0.04,
|
|
262
|
-
loss_type="
|
|
279
|
+
loss_type="dapo",
|
|
263
280
|
max_completion_length=None,
|
|
281
|
+
importance_sampling_level="token",
|
|
264
282
|
temperature=1.0,
|
|
265
283
|
use_ref_model=False,
|
|
266
284
|
ppo_loss_fn=None,
|
|
@@ -292,6 +310,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
292
310
|
beta=beta,
|
|
293
311
|
loss_type=loss_type,
|
|
294
312
|
max_completion_length=max_completion_length,
|
|
313
|
+
importance_sampling_level=importance_sampling_level,
|
|
295
314
|
)
|
|
296
315
|
|
|
297
316
|
return chunk_loss, chunk_metrics
|
|
@@ -337,10 +356,11 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
337
356
|
None, # grad_epsilon_low
|
|
338
357
|
None, # grad_epsilon_high
|
|
339
358
|
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
340
362
|
None, # grad_temperature
|
|
341
363
|
None, # grad_compiled
|
|
342
364
|
None, # grad_use_ref_model
|
|
343
365
|
None, # grad_chunk_size
|
|
344
|
-
None, # grad_loss_type
|
|
345
|
-
None, # grad_max_completion_length
|
|
346
366
|
)
|
|
@@ -29,8 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
29
29
|
epsilon_low=0.2,
|
|
30
30
|
epsilon_high=0.2,
|
|
31
31
|
beta=0.04,
|
|
32
|
-
loss_type="
|
|
32
|
+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
34
35
|
**kwargs,
|
|
35
36
|
):
|
|
36
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
50
51
|
|
|
51
52
|
# Compute policy gradient loss with importance sampling ratio
|
|
52
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
53
|
-
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
54
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
55
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
56
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
@@ -78,6 +94,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
78
94
|
if max_completion_length is None:
|
|
79
95
|
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
80
96
|
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
elif loss_type == "dapo":
|
|
98
|
+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
|
|
99
|
+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
|
|
81
100
|
else:
|
|
82
101
|
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
83
102
|
|
|
@@ -85,9 +104,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
85
104
|
metrics = []
|
|
86
105
|
if beta != 0.0:
|
|
87
106
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
107
|
+
|
|
108
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
109
|
+
if importance_sampling_level == "token":
|
|
110
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
111
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
112
|
+
)
|
|
113
|
+
else: # sequence level
|
|
114
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
115
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
116
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
117
|
+
)
|
|
118
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
119
|
+
|
|
91
120
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
92
121
|
return loss, metrics
|
|
93
122
|
|
|
@@ -109,8 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
109
138
|
beta=0.04,
|
|
110
139
|
epsilon_low=0.2,
|
|
111
140
|
epsilon_high=0.2,
|
|
112
|
-
loss_type="
|
|
141
|
+
loss_type="dapo",
|
|
113
142
|
max_completion_length=None,
|
|
143
|
+
importance_sampling_level="token",
|
|
114
144
|
temperature=1.0,
|
|
115
145
|
compiled=True,
|
|
116
146
|
use_ref_model=True,
|
|
@@ -130,8 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
130
160
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
131
161
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
132
162
|
beta (float): Weight for the KL penalty
|
|
133
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
163
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
134
164
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
165
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
135
166
|
temperature (float): Temperature for the logits
|
|
136
167
|
compiled (bool): Whether to use torch compile
|
|
137
168
|
use_ref_model (bool): Whether to use a reference model
|
|
@@ -162,6 +193,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
162
193
|
compiled=compiled,
|
|
163
194
|
use_ref_model=use_ref_model,
|
|
164
195
|
chunk_size=chunk_size,
|
|
196
|
+
importance_sampling_level=importance_sampling_level,
|
|
165
197
|
)
|
|
166
198
|
|
|
167
199
|
@staticmethod
|
|
@@ -187,6 +219,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
187
219
|
None, # grad_epsilon_high
|
|
188
220
|
None, # grad_loss_type (string, not differentiable)
|
|
189
221
|
None, # grad_max_completion_length (int, not differentiable)
|
|
222
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
190
223
|
None, # grad_temperature
|
|
191
224
|
None, # grad_compiled
|
|
192
225
|
None, # grad_use_ref_model
|
|
@@ -205,8 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
205
238
|
chunk_size: int = 1,
|
|
206
239
|
epsilon_low: float = 0.2,
|
|
207
240
|
epsilon_high: float = 0.2,
|
|
208
|
-
loss_type: str = "
|
|
241
|
+
loss_type: str = "dapo",
|
|
209
242
|
max_completion_length: Optional[int] = None,
|
|
243
|
+
importance_sampling_level: str = "token",
|
|
210
244
|
temperature: float = 1.0,
|
|
211
245
|
):
|
|
212
246
|
"""
|
|
@@ -217,8 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
217
251
|
chunk_size (int): Size of chunks for processing.
|
|
218
252
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
219
253
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
220
|
-
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "
|
|
254
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
221
255
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
256
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
222
257
|
temperature (float): Temperature for the logits.
|
|
223
258
|
"""
|
|
224
259
|
super().__init__()
|
|
@@ -230,6 +265,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
230
265
|
self.epsilon_high = epsilon_high
|
|
231
266
|
self.loss_type = loss_type
|
|
232
267
|
self.max_completion_length = max_completion_length
|
|
268
|
+
self.importance_sampling_level = importance_sampling_level
|
|
233
269
|
self.temperature = temperature
|
|
234
270
|
|
|
235
271
|
def forward(
|
|
@@ -263,6 +299,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
263
299
|
self.epsilon_high,
|
|
264
300
|
self.loss_type,
|
|
265
301
|
self.max_completion_length,
|
|
302
|
+
self.importance_sampling_level,
|
|
266
303
|
self.temperature,
|
|
267
304
|
self.compiled,
|
|
268
305
|
self.use_ref_model,
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
1
6
|
import torch
|
|
2
7
|
import torch.nn.functional as F
|
|
3
8
|
|
|
@@ -25,8 +30,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
25
30
|
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
31
|
else:
|
|
27
32
|
# Compute probabilities (only required for mean calculation)
|
|
28
|
-
|
|
29
|
-
|
|
33
|
+
log_mean_probs = torch.logsumexp(
|
|
34
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
35
|
+
)
|
|
30
36
|
|
|
31
37
|
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
38
|
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
@@ -53,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
53
59
|
temperature: float = 1.0,
|
|
54
60
|
compiled: bool = True,
|
|
55
61
|
chunk_size: int = 1024,
|
|
62
|
+
return_soft_hard_loss: bool = False,
|
|
56
63
|
):
|
|
57
64
|
"""
|
|
58
65
|
Fused linear layer with JSD distillation loss.
|
|
@@ -69,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
69
76
|
temperature (float): Temperature for softening/sharpening distributions
|
|
70
77
|
compiled (bool): Whether to use torch compile
|
|
71
78
|
chunk_size (int): Size of chunks for processing.
|
|
79
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
72
80
|
Returns:
|
|
73
|
-
torch.Tensor: Computed loss
|
|
81
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
74
82
|
"""
|
|
75
83
|
return super().forward(
|
|
76
84
|
cls=cls,
|
|
@@ -89,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
89
97
|
ignore_index=ignore_index,
|
|
90
98
|
temperature=temperature,
|
|
91
99
|
compiled=compiled,
|
|
100
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
92
101
|
)
|
|
93
102
|
|
|
94
103
|
@staticmethod
|
|
95
|
-
def backward(ctx, grad_output):
|
|
96
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
104
|
+
def backward(ctx, grad_output, *args):
|
|
105
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
97
106
|
|
|
98
107
|
return (
|
|
99
108
|
*grads,
|
|
@@ -105,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
105
114
|
None, # temperature
|
|
106
115
|
None, # compiled
|
|
107
116
|
None, # chunk_size
|
|
117
|
+
None, # return_soft_hard_loss
|
|
108
118
|
)
|
|
109
119
|
|
|
110
120
|
|
|
@@ -122,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
122
132
|
temperature: float = 1.0,
|
|
123
133
|
compiled: bool = True,
|
|
124
134
|
chunk_size: int = 1024,
|
|
135
|
+
return_soft_hard_loss: bool = False,
|
|
125
136
|
):
|
|
126
137
|
"""
|
|
127
138
|
Args:
|
|
@@ -132,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
132
143
|
compiled (bool): Whether to use torch compile
|
|
133
144
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
134
145
|
chunk_size (int): Size of chunks for processing.
|
|
146
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
135
147
|
"""
|
|
136
148
|
super().__init__()
|
|
137
149
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -142,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
142
154
|
self.compiled = compiled
|
|
143
155
|
self.beta = beta
|
|
144
156
|
self.chunk_size = chunk_size
|
|
157
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
145
158
|
|
|
146
159
|
def forward(
|
|
147
160
|
self,
|
|
@@ -152,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
152
165
|
true_labels: torch.LongTensor,
|
|
153
166
|
student_bias: torch.Tensor = None,
|
|
154
167
|
teacher_bias: torch.Tensor = None,
|
|
155
|
-
) -> torch.Tensor:
|
|
168
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
156
169
|
"""
|
|
157
170
|
Compute the JSD distillation loss.
|
|
158
171
|
|
|
@@ -164,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
164
177
|
true_labels (torch.LongTensor): Target labels tensor
|
|
165
178
|
|
|
166
179
|
Returns:
|
|
167
|
-
torch.Tensor
|
|
180
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
182
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
168
183
|
"""
|
|
169
184
|
return LigerFusedLinearJSDFunction.apply(
|
|
170
185
|
student_input,
|
|
@@ -181,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
181
196
|
self.temperature,
|
|
182
197
|
self.compiled,
|
|
183
198
|
self.chunk_size,
|
|
199
|
+
self.return_soft_hard_loss,
|
|
184
200
|
)
|