liger-kernel-nightly 0.5.10.dev20250624183504__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +4 -0
- liger_kernel/chunked_loss/grpo_loss.py +38 -4
- liger_kernel/chunked_loss/jsd_loss.py +23 -7
- liger_kernel/ops/cross_entropy.py +118 -62
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +113 -21
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/layer_norm.py +124 -89
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/rms_norm.py +2 -2
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/transformers/__init__.py +50 -0
- liger_kernel/transformers/cross_entropy.py +8 -3
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/functional.py +38 -6
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -4
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +31 -8
- liger_kernel/transformers/model/gemma3.py +100 -110
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +41 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1090 -116
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/rms_norm.py +7 -0
- liger_kernel/transformers/rope.py +43 -0
- liger_kernel/transformers/tiled_mlp.py +133 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +26 -24
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- liger_kernel_nightly-0.5.10.dev20250624183504.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250624183504.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
|
|
13
|
+
"""
|
|
14
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
|
15
|
+
Args:
|
|
16
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
17
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
18
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
|
19
|
+
Returns:
|
|
20
|
+
torch.Tensor: cosine similarity loss
|
|
21
|
+
"""
|
|
22
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
|
23
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
|
24
|
+
|
|
25
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
|
26
|
+
loss = beta * (1 - cosine_sim)
|
|
27
|
+
return loss.sum()
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def forward(
|
|
31
|
+
cls,
|
|
32
|
+
ctx,
|
|
33
|
+
student_input: torch.Tensor,
|
|
34
|
+
student_weight: torch.Tensor,
|
|
35
|
+
teacher_input: torch.Tensor,
|
|
36
|
+
teacher_weight: torch.Tensor,
|
|
37
|
+
true_labels: torch.LongTensor,
|
|
38
|
+
student_bias: torch.Tensor,
|
|
39
|
+
teacher_bias: torch.Tensor,
|
|
40
|
+
weight_hard_loss: float = 0.5,
|
|
41
|
+
weight_soft_loss: float = 0.5,
|
|
42
|
+
beta: float = 0.5,
|
|
43
|
+
ignore_index: int = -100,
|
|
44
|
+
temperature: float = 1.0,
|
|
45
|
+
compiled: bool = True,
|
|
46
|
+
chunk_size: int = 1024,
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
49
|
+
return super().forward(
|
|
50
|
+
cls=cls,
|
|
51
|
+
ctx=ctx,
|
|
52
|
+
student_input=student_input,
|
|
53
|
+
student_weight=student_weight,
|
|
54
|
+
teacher_input=teacher_input,
|
|
55
|
+
teacher_weight=teacher_weight,
|
|
56
|
+
target=true_labels,
|
|
57
|
+
student_bias=student_bias,
|
|
58
|
+
teacher_bias=teacher_bias,
|
|
59
|
+
chunk_size=chunk_size,
|
|
60
|
+
weight_hard_loss=weight_hard_loss,
|
|
61
|
+
weight_soft_loss=weight_soft_loss,
|
|
62
|
+
beta=beta,
|
|
63
|
+
ignore_index=ignore_index,
|
|
64
|
+
temperature=temperature,
|
|
65
|
+
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
72
|
+
|
|
73
|
+
return (
|
|
74
|
+
*grads,
|
|
75
|
+
None, # teacher_bias
|
|
76
|
+
None, # weight_hard_loss
|
|
77
|
+
None, # weight_soft_loss
|
|
78
|
+
None, # beta
|
|
79
|
+
None, # ignore_index
|
|
80
|
+
None, # temperature
|
|
81
|
+
None, # compiled
|
|
82
|
+
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
weight_hard_loss: float = 0.5,
|
|
91
|
+
weight_soft_loss: float = 0.5,
|
|
92
|
+
beta: float = 0.5,
|
|
93
|
+
ignore_index: int = -100,
|
|
94
|
+
temperature: float = 1.0,
|
|
95
|
+
compiled: bool = True,
|
|
96
|
+
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
101
|
+
self.weight_hard_loss = weight_hard_loss
|
|
102
|
+
self.weight_soft_loss = weight_soft_loss
|
|
103
|
+
self.ignore_index = ignore_index
|
|
104
|
+
self.temperature = temperature
|
|
105
|
+
self.compiled = compiled
|
|
106
|
+
self.beta = beta
|
|
107
|
+
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
109
|
+
|
|
110
|
+
def forward(
|
|
111
|
+
self,
|
|
112
|
+
student_input: torch.Tensor,
|
|
113
|
+
student_weight: torch.Tensor,
|
|
114
|
+
teacher_input: torch.Tensor,
|
|
115
|
+
teacher_weight: torch.Tensor,
|
|
116
|
+
true_labels: torch.LongTensor,
|
|
117
|
+
student_bias: torch.Tensor = None,
|
|
118
|
+
teacher_bias: torch.Tensor = None,
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
120
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
121
|
+
student_input,
|
|
122
|
+
student_weight,
|
|
123
|
+
teacher_input,
|
|
124
|
+
teacher_weight,
|
|
125
|
+
true_labels,
|
|
126
|
+
student_bias,
|
|
127
|
+
teacher_bias,
|
|
128
|
+
self.weight_hard_loss,
|
|
129
|
+
self.weight_soft_loss,
|
|
130
|
+
self.beta,
|
|
131
|
+
self.ignore_index,
|
|
132
|
+
self.temperature,
|
|
133
|
+
self.compiled,
|
|
134
|
+
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
136
|
+
)
|
|
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
13
13
|
ref_chosen_logps=None,
|
|
14
14
|
ref_rejected_logps=None,
|
|
15
15
|
beta=0.1,
|
|
16
|
+
loss_type="sigmoid",
|
|
16
17
|
):
|
|
17
18
|
"""
|
|
18
19
|
Paper: https://arxiv.org/pdf/2305.18290
|
|
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
49
|
chosen_rewards = beta * chosen_logratios
|
|
49
50
|
rejected_rewards = beta * rejected_logratios
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
if loss_type == "sigmoid":
|
|
53
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
54
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
55
|
+
|
|
56
|
+
elif loss_type == "apo_zero":
|
|
57
|
+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
58
|
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
|
59
|
+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
|
|
60
|
+
losses_rejected = F.sigmoid(beta * rejected_logratios)
|
|
61
|
+
losses = losses_chosen + losses_rejected
|
|
62
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
63
|
+
|
|
64
|
+
elif loss_type == "apo_down":
|
|
65
|
+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
66
|
+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
|
|
67
|
+
# Decrease chosen likelihood and decrease rejected likelihood more
|
|
68
|
+
losses_chosen = F.sigmoid(beta * chosen_logratios)
|
|
69
|
+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
|
|
70
|
+
losses = losses_chosen + losses_rejected
|
|
71
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
72
|
+
|
|
73
|
+
elif loss_type == "sppo_hard":
|
|
74
|
+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
|
|
75
|
+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
|
|
76
|
+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
|
|
77
|
+
# set to 1 for the winner and 0 for the loser.
|
|
78
|
+
a = chosen_logps - ref_chosen_logps
|
|
79
|
+
b = rejected_logps - ref_rejected_logps
|
|
80
|
+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
|
|
81
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
82
|
+
|
|
83
|
+
elif loss_type == "nca_pair":
|
|
84
|
+
losses = (
|
|
85
|
+
-F.logsigmoid(chosen_rewards)
|
|
86
|
+
- 0.5 * F.logsigmoid(-chosen_rewards)
|
|
87
|
+
- 0.5 * F.logsigmoid(-rejected_rewards)
|
|
88
|
+
)
|
|
89
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
90
|
+
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
|
|
94
|
+
)
|
|
95
|
+
|
|
53
96
|
return loss, chosen_rewards, rejected_rewards
|
|
54
97
|
|
|
55
98
|
@classmethod
|
|
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
70
113
|
use_ref_model=True,
|
|
71
114
|
average_log_prob=False,
|
|
72
115
|
chunk_size=1,
|
|
116
|
+
loss_type="sigmoid",
|
|
73
117
|
):
|
|
74
118
|
"""
|
|
75
119
|
Fused linear layer with DPO loss.
|
|
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
108
152
|
ref_bias=ref_bias,
|
|
109
153
|
average_log_prob=average_log_prob,
|
|
110
154
|
chunk_size=chunk_size,
|
|
155
|
+
loss_type=loss_type,
|
|
111
156
|
)
|
|
112
157
|
|
|
113
158
|
@staticmethod
|
|
114
159
|
def backward(ctx, *grad_output):
|
|
115
160
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
116
|
-
return *grads, None, None, None, None, None, None, None, None, None, None
|
|
161
|
+
return *grads, None, None, None, None, None, None, None, None, None, None, None
|
|
117
162
|
|
|
118
163
|
|
|
119
164
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
130
175
|
use_ref_model: bool = True,
|
|
131
176
|
average_log_prob: bool = False,
|
|
132
177
|
chunk_size: int = 1,
|
|
178
|
+
loss_type: str = "sigmoid",
|
|
133
179
|
):
|
|
134
180
|
"""
|
|
135
181
|
Args:
|
|
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
149
195
|
self.use_ref_model = use_ref_model
|
|
150
196
|
self.average_log_prob = average_log_prob
|
|
151
197
|
self.chunk_size = chunk_size
|
|
198
|
+
self.loss_type = loss_type
|
|
199
|
+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
|
|
200
|
+
if self.loss_type not in supported_loss_types:
|
|
201
|
+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
|
|
152
202
|
|
|
153
203
|
def forward(
|
|
154
204
|
self,
|
|
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
175
225
|
self.use_ref_model,
|
|
176
226
|
self.average_log_prob,
|
|
177
227
|
self.chunk_size,
|
|
228
|
+
self.loss_type,
|
|
178
229
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
|
9
10
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
10
11
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
11
12
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
|
12
14
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
13
15
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
14
16
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
159
|
compute_ce_loss=True,
|
|
158
160
|
temperature=1.0,
|
|
159
161
|
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
160
163
|
**loss_kwargs,
|
|
161
|
-
):
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
165
|
"""
|
|
163
166
|
Base class for fused linear layer with distillation loss.
|
|
164
167
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
183
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
184
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
185
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
187
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
188
|
"""
|
|
185
189
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
191
|
grad_inputs = []
|
|
188
192
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
193
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
196
|
|
|
191
197
|
loss_func_to_call = partial(
|
|
192
198
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
253
|
)
|
|
248
254
|
grad_weight.add_(chunk_grad_weight)
|
|
249
255
|
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
259
|
return chunk_grad_input
|
|
251
260
|
|
|
252
261
|
if compiled:
|
|
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
277
|
grad_weight,
|
|
269
278
|
grad_bias,
|
|
270
279
|
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
282
|
return loss_acc
|
|
272
283
|
|
|
273
284
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
275
286
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
287
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
288
|
grad_input = grad_input * grad_output
|
|
@@ -34,6 +34,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
34
34
|
beta=0.04,
|
|
35
35
|
loss_type="bnpo",
|
|
36
36
|
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
37
38
|
temperature=1.0,
|
|
38
39
|
compiled=True,
|
|
39
40
|
use_ref_model=False,
|
|
@@ -92,6 +93,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
92
93
|
beta=beta,
|
|
93
94
|
loss_type=loss_type,
|
|
94
95
|
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
95
97
|
temperature=temperature,
|
|
96
98
|
use_ref_model=use_ref_model,
|
|
97
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
|
@@ -261,6 +263,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
261
263
|
beta=0.04,
|
|
262
264
|
loss_type="bnpo",
|
|
263
265
|
max_completion_length=None,
|
|
266
|
+
importance_sampling_level="token",
|
|
264
267
|
temperature=1.0,
|
|
265
268
|
use_ref_model=False,
|
|
266
269
|
ppo_loss_fn=None,
|
|
@@ -292,6 +295,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
292
295
|
beta=beta,
|
|
293
296
|
loss_type=loss_type,
|
|
294
297
|
max_completion_length=max_completion_length,
|
|
298
|
+
importance_sampling_level=importance_sampling_level,
|
|
295
299
|
)
|
|
296
300
|
|
|
297
301
|
return chunk_loss, chunk_metrics
|
|
@@ -31,6 +31,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
31
31
|
beta=0.04,
|
|
32
32
|
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
|
|
33
33
|
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
34
35
|
**kwargs,
|
|
35
36
|
):
|
|
36
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
@@ -50,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
50
51
|
|
|
51
52
|
# Compute policy gradient loss with importance sampling ratio
|
|
52
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
53
|
-
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
54
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
55
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
56
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
@@ -85,9 +101,19 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
85
101
|
metrics = []
|
|
86
102
|
if beta != 0.0:
|
|
87
103
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
104
|
+
|
|
105
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
106
|
+
if importance_sampling_level == "token":
|
|
107
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
108
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
109
|
+
)
|
|
110
|
+
else: # sequence level
|
|
111
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
112
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
113
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
114
|
+
)
|
|
115
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
116
|
+
|
|
91
117
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
92
118
|
return loss, metrics
|
|
93
119
|
|
|
@@ -111,6 +137,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
111
137
|
epsilon_high=0.2,
|
|
112
138
|
loss_type="bnpo",
|
|
113
139
|
max_completion_length=None,
|
|
140
|
+
importance_sampling_level="token",
|
|
114
141
|
temperature=1.0,
|
|
115
142
|
compiled=True,
|
|
116
143
|
use_ref_model=True,
|
|
@@ -132,6 +159,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
132
159
|
beta (float): Weight for the KL penalty
|
|
133
160
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
134
161
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
162
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
135
163
|
temperature (float): Temperature for the logits
|
|
136
164
|
compiled (bool): Whether to use torch compile
|
|
137
165
|
use_ref_model (bool): Whether to use a reference model
|
|
@@ -162,6 +190,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
162
190
|
compiled=compiled,
|
|
163
191
|
use_ref_model=use_ref_model,
|
|
164
192
|
chunk_size=chunk_size,
|
|
193
|
+
importance_sampling_level=importance_sampling_level,
|
|
165
194
|
)
|
|
166
195
|
|
|
167
196
|
@staticmethod
|
|
@@ -187,6 +216,7 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
187
216
|
None, # grad_epsilon_high
|
|
188
217
|
None, # grad_loss_type (string, not differentiable)
|
|
189
218
|
None, # grad_max_completion_length (int, not differentiable)
|
|
219
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
190
220
|
None, # grad_temperature
|
|
191
221
|
None, # grad_compiled
|
|
192
222
|
None, # grad_use_ref_model
|
|
@@ -207,6 +237,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
207
237
|
epsilon_high: float = 0.2,
|
|
208
238
|
loss_type: str = "bnpo",
|
|
209
239
|
max_completion_length: Optional[int] = None,
|
|
240
|
+
importance_sampling_level: str = "token",
|
|
210
241
|
temperature: float = 1.0,
|
|
211
242
|
):
|
|
212
243
|
"""
|
|
@@ -219,6 +250,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
219
250
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
220
251
|
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
|
|
221
252
|
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
253
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
222
254
|
temperature (float): Temperature for the logits.
|
|
223
255
|
"""
|
|
224
256
|
super().__init__()
|
|
@@ -230,6 +262,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
230
262
|
self.epsilon_high = epsilon_high
|
|
231
263
|
self.loss_type = loss_type
|
|
232
264
|
self.max_completion_length = max_completion_length
|
|
265
|
+
self.importance_sampling_level = importance_sampling_level
|
|
233
266
|
self.temperature = temperature
|
|
234
267
|
|
|
235
268
|
def forward(
|
|
@@ -263,6 +296,7 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
263
296
|
self.epsilon_high,
|
|
264
297
|
self.loss_type,
|
|
265
298
|
self.max_completion_length,
|
|
299
|
+
self.importance_sampling_level,
|
|
266
300
|
self.temperature,
|
|
267
301
|
self.compiled,
|
|
268
302
|
self.use_ref_model,
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
1
6
|
import torch
|
|
2
7
|
import torch.nn.functional as F
|
|
3
8
|
|
|
@@ -25,8 +30,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
25
30
|
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
|
|
26
31
|
else:
|
|
27
32
|
# Compute probabilities (only required for mean calculation)
|
|
28
|
-
|
|
29
|
-
|
|
33
|
+
log_mean_probs = torch.logsumexp(
|
|
34
|
+
torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
|
|
35
|
+
)
|
|
30
36
|
|
|
31
37
|
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
32
38
|
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
@@ -53,6 +59,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
53
59
|
temperature: float = 1.0,
|
|
54
60
|
compiled: bool = True,
|
|
55
61
|
chunk_size: int = 1024,
|
|
62
|
+
return_soft_hard_loss: bool = False,
|
|
56
63
|
):
|
|
57
64
|
"""
|
|
58
65
|
Fused linear layer with JSD distillation loss.
|
|
@@ -69,8 +76,9 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
69
76
|
temperature (float): Temperature for softening/sharpening distributions
|
|
70
77
|
compiled (bool): Whether to use torch compile
|
|
71
78
|
chunk_size (int): Size of chunks for processing.
|
|
79
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
72
80
|
Returns:
|
|
73
|
-
torch.Tensor: Computed loss
|
|
81
|
+
torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
|
|
74
82
|
"""
|
|
75
83
|
return super().forward(
|
|
76
84
|
cls=cls,
|
|
@@ -89,11 +97,12 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
89
97
|
ignore_index=ignore_index,
|
|
90
98
|
temperature=temperature,
|
|
91
99
|
compiled=compiled,
|
|
100
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
92
101
|
)
|
|
93
102
|
|
|
94
103
|
@staticmethod
|
|
95
|
-
def backward(ctx, grad_output):
|
|
96
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
104
|
+
def backward(ctx, grad_output, *args):
|
|
105
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
97
106
|
|
|
98
107
|
return (
|
|
99
108
|
*grads,
|
|
@@ -105,6 +114,7 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
105
114
|
None, # temperature
|
|
106
115
|
None, # compiled
|
|
107
116
|
None, # chunk_size
|
|
117
|
+
None, # return_soft_hard_loss
|
|
108
118
|
)
|
|
109
119
|
|
|
110
120
|
|
|
@@ -122,6 +132,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
122
132
|
temperature: float = 1.0,
|
|
123
133
|
compiled: bool = True,
|
|
124
134
|
chunk_size: int = 1024,
|
|
135
|
+
return_soft_hard_loss: bool = False,
|
|
125
136
|
):
|
|
126
137
|
"""
|
|
127
138
|
Args:
|
|
@@ -132,6 +143,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
132
143
|
compiled (bool): Whether to use torch compile
|
|
133
144
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
134
145
|
chunk_size (int): Size of chunks for processing.
|
|
146
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
135
147
|
"""
|
|
136
148
|
super().__init__()
|
|
137
149
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -142,6 +154,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
142
154
|
self.compiled = compiled
|
|
143
155
|
self.beta = beta
|
|
144
156
|
self.chunk_size = chunk_size
|
|
157
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
145
158
|
|
|
146
159
|
def forward(
|
|
147
160
|
self,
|
|
@@ -152,7 +165,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
152
165
|
true_labels: torch.LongTensor,
|
|
153
166
|
student_bias: torch.Tensor = None,
|
|
154
167
|
teacher_bias: torch.Tensor = None,
|
|
155
|
-
) -> torch.Tensor:
|
|
168
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
156
169
|
"""
|
|
157
170
|
Compute the JSD distillation loss.
|
|
158
171
|
|
|
@@ -164,7 +177,9 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
164
177
|
true_labels (torch.LongTensor): Target labels tensor
|
|
165
178
|
|
|
166
179
|
Returns:
|
|
167
|
-
torch.Tensor
|
|
180
|
+
torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
|
+
If return_soft_hard_loss is False: Computed combined loss
|
|
182
|
+
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
|
|
168
183
|
"""
|
|
169
184
|
return LigerFusedLinearJSDFunction.apply(
|
|
170
185
|
student_input,
|
|
@@ -181,4 +196,5 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
181
196
|
self.temperature,
|
|
182
197
|
self.compiled,
|
|
183
198
|
self.chunk_size,
|
|
199
|
+
self.return_soft_hard_loss,
|
|
184
200
|
)
|