liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +25 -9
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +124 -64
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +13 -6
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +283 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +205 -19
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +50 -25
- liger_kernel/transformers/model/gemma2.py +55 -23
- liger_kernel/transformers/model/gemma3.py +117 -120
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +102 -25
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +36 -23
- liger_kernel/transformers/model/mixtral.py +45 -25
- liger_kernel/transformers/model/mllama.py +39 -22
- liger_kernel/transformers/model/olmo2.py +40 -20
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -177
- liger_kernel/transformers/model/qwen2.py +48 -21
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1678 -160
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +48 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +36 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
|
|
13
|
+
"""
|
|
14
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
|
15
|
+
Args:
|
|
16
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
17
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
18
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
|
19
|
+
Returns:
|
|
20
|
+
torch.Tensor: cosine similarity loss
|
|
21
|
+
"""
|
|
22
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
|
23
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
|
24
|
+
|
|
25
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
|
26
|
+
loss = beta * (1 - cosine_sim)
|
|
27
|
+
return loss.sum()
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def forward(
|
|
31
|
+
cls,
|
|
32
|
+
ctx,
|
|
33
|
+
student_input: torch.Tensor,
|
|
34
|
+
student_weight: torch.Tensor,
|
|
35
|
+
teacher_input: torch.Tensor,
|
|
36
|
+
teacher_weight: torch.Tensor,
|
|
37
|
+
true_labels: torch.LongTensor,
|
|
38
|
+
student_bias: torch.Tensor,
|
|
39
|
+
teacher_bias: torch.Tensor,
|
|
40
|
+
weight_hard_loss: float = 0.5,
|
|
41
|
+
weight_soft_loss: float = 0.5,
|
|
42
|
+
beta: float = 0.5,
|
|
43
|
+
ignore_index: int = -100,
|
|
44
|
+
temperature: float = 1.0,
|
|
45
|
+
compiled: bool = True,
|
|
46
|
+
chunk_size: int = 1024,
|
|
47
|
+
return_soft_hard_loss: bool = False,
|
|
48
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
49
|
+
return super().forward(
|
|
50
|
+
cls=cls,
|
|
51
|
+
ctx=ctx,
|
|
52
|
+
student_input=student_input,
|
|
53
|
+
student_weight=student_weight,
|
|
54
|
+
teacher_input=teacher_input,
|
|
55
|
+
teacher_weight=teacher_weight,
|
|
56
|
+
target=true_labels,
|
|
57
|
+
student_bias=student_bias,
|
|
58
|
+
teacher_bias=teacher_bias,
|
|
59
|
+
chunk_size=chunk_size,
|
|
60
|
+
weight_hard_loss=weight_hard_loss,
|
|
61
|
+
weight_soft_loss=weight_soft_loss,
|
|
62
|
+
beta=beta,
|
|
63
|
+
ignore_index=ignore_index,
|
|
64
|
+
temperature=temperature,
|
|
65
|
+
compiled=compiled,
|
|
66
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def backward(ctx, grad_output, *args):
|
|
71
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
72
|
+
|
|
73
|
+
return (
|
|
74
|
+
*grads,
|
|
75
|
+
None, # teacher_bias
|
|
76
|
+
None, # weight_hard_loss
|
|
77
|
+
None, # weight_soft_loss
|
|
78
|
+
None, # beta
|
|
79
|
+
None, # ignore_index
|
|
80
|
+
None, # temperature
|
|
81
|
+
None, # compiled
|
|
82
|
+
None, # chunk_size
|
|
83
|
+
None, # return_soft_hard_loss
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
weight_hard_loss: float = 0.5,
|
|
91
|
+
weight_soft_loss: float = 0.5,
|
|
92
|
+
beta: float = 0.5,
|
|
93
|
+
ignore_index: int = -100,
|
|
94
|
+
temperature: float = 1.0,
|
|
95
|
+
compiled: bool = True,
|
|
96
|
+
chunk_size: int = 1024,
|
|
97
|
+
return_soft_hard_loss: bool = False,
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
101
|
+
self.weight_hard_loss = weight_hard_loss
|
|
102
|
+
self.weight_soft_loss = weight_soft_loss
|
|
103
|
+
self.ignore_index = ignore_index
|
|
104
|
+
self.temperature = temperature
|
|
105
|
+
self.compiled = compiled
|
|
106
|
+
self.beta = beta
|
|
107
|
+
self.chunk_size = chunk_size
|
|
108
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
109
|
+
|
|
110
|
+
def forward(
|
|
111
|
+
self,
|
|
112
|
+
student_input: torch.Tensor,
|
|
113
|
+
student_weight: torch.Tensor,
|
|
114
|
+
teacher_input: torch.Tensor,
|
|
115
|
+
teacher_weight: torch.Tensor,
|
|
116
|
+
true_labels: torch.LongTensor,
|
|
117
|
+
student_bias: torch.Tensor = None,
|
|
118
|
+
teacher_bias: torch.Tensor = None,
|
|
119
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
120
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
121
|
+
student_input,
|
|
122
|
+
student_weight,
|
|
123
|
+
teacher_input,
|
|
124
|
+
teacher_weight,
|
|
125
|
+
true_labels,
|
|
126
|
+
student_bias,
|
|
127
|
+
teacher_bias,
|
|
128
|
+
self.weight_hard_loss,
|
|
129
|
+
self.weight_soft_loss,
|
|
130
|
+
self.beta,
|
|
131
|
+
self.ignore_index,
|
|
132
|
+
self.temperature,
|
|
133
|
+
self.compiled,
|
|
134
|
+
self.chunk_size,
|
|
135
|
+
self.return_soft_hard_loss,
|
|
136
|
+
)
|
|
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
13
13
|
ref_chosen_logps=None,
|
|
14
14
|
ref_rejected_logps=None,
|
|
15
15
|
beta=0.1,
|
|
16
|
+
loss_type="sigmoid",
|
|
16
17
|
):
|
|
17
18
|
"""
|
|
18
19
|
Paper: https://arxiv.org/pdf/2305.18290
|
|
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
49
|
chosen_rewards = beta * chosen_logratios
|
|
49
50
|
rejected_rewards = beta * rejected_logratios
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
if loss_type == "sigmoid":
|
|
53
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
54
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
55
|
+
|
|
56
|
+
elif loss_type == "apo_zero":
|
|
57
|
+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
58
|
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
|
59
|
+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
|
|
60
|
+
losses_rejected = F.sigmoid(beta * rejected_logratios)
|
|
61
|
+
losses = losses_chosen + losses_rejected
|
|
62
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
63
|
+
|
|
64
|
+
elif loss_type == "apo_down":
|
|
65
|
+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
66
|
+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
|
|
67
|
+
# Decrease chosen likelihood and decrease rejected likelihood more
|
|
68
|
+
losses_chosen = F.sigmoid(beta * chosen_logratios)
|
|
69
|
+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
|
|
70
|
+
losses = losses_chosen + losses_rejected
|
|
71
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
72
|
+
|
|
73
|
+
elif loss_type == "sppo_hard":
|
|
74
|
+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
|
|
75
|
+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
|
|
76
|
+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
|
|
77
|
+
# set to 1 for the winner and 0 for the loser.
|
|
78
|
+
a = chosen_logps - ref_chosen_logps
|
|
79
|
+
b = rejected_logps - ref_rejected_logps
|
|
80
|
+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
|
|
81
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
82
|
+
|
|
83
|
+
elif loss_type == "nca_pair":
|
|
84
|
+
losses = (
|
|
85
|
+
-F.logsigmoid(chosen_rewards)
|
|
86
|
+
- 0.5 * F.logsigmoid(-chosen_rewards)
|
|
87
|
+
- 0.5 * F.logsigmoid(-rejected_rewards)
|
|
88
|
+
)
|
|
89
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
90
|
+
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
|
|
94
|
+
)
|
|
95
|
+
|
|
53
96
|
return loss, chosen_rewards, rejected_rewards
|
|
54
97
|
|
|
55
98
|
@classmethod
|
|
@@ -68,7 +111,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
111
|
compute_nll_loss=False,
|
|
69
112
|
compiled=True,
|
|
70
113
|
use_ref_model=True,
|
|
114
|
+
average_log_prob=False,
|
|
71
115
|
chunk_size=1,
|
|
116
|
+
loss_type="sigmoid",
|
|
72
117
|
):
|
|
73
118
|
"""
|
|
74
119
|
Fused linear layer with DPO loss.
|
|
@@ -85,6 +130,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
85
130
|
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
86
131
|
compiled (bool): Whether to use torch compile
|
|
87
132
|
use_ref_model (bool): Whether to use a reference model
|
|
133
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
88
134
|
chunk_size (int): Size of chunks for processing.
|
|
89
135
|
Returns:
|
|
90
136
|
torch.Tensor: Computed loss
|
|
@@ -104,13 +150,15 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
104
150
|
ref_input=ref_input,
|
|
105
151
|
ref_weight=ref_weight,
|
|
106
152
|
ref_bias=ref_bias,
|
|
153
|
+
average_log_prob=average_log_prob,
|
|
107
154
|
chunk_size=chunk_size,
|
|
155
|
+
loss_type=loss_type,
|
|
108
156
|
)
|
|
109
157
|
|
|
110
158
|
@staticmethod
|
|
111
159
|
def backward(ctx, *grad_output):
|
|
112
160
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
113
|
-
return *grads, None, None, None, None, None, None, None, None, None
|
|
161
|
+
return *grads, None, None, None, None, None, None, None, None, None, None, None
|
|
114
162
|
|
|
115
163
|
|
|
116
164
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -125,7 +173,9 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
125
173
|
compute_nll_loss: bool = False,
|
|
126
174
|
compiled: bool = True,
|
|
127
175
|
use_ref_model: bool = True,
|
|
176
|
+
average_log_prob: bool = False,
|
|
128
177
|
chunk_size: int = 1,
|
|
178
|
+
loss_type: str = "sigmoid",
|
|
129
179
|
):
|
|
130
180
|
"""
|
|
131
181
|
Args:
|
|
@@ -134,6 +184,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
134
184
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
135
185
|
compiled (bool): Whether to use the torch compiled kernel.
|
|
136
186
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
187
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
137
188
|
chunk_size (int): Size of chunks for processing.
|
|
138
189
|
"""
|
|
139
190
|
super().__init__()
|
|
@@ -142,7 +193,12 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
142
193
|
self.compute_nll_loss = compute_nll_loss
|
|
143
194
|
self.compiled = compiled
|
|
144
195
|
self.use_ref_model = use_ref_model
|
|
196
|
+
self.average_log_prob = average_log_prob
|
|
145
197
|
self.chunk_size = chunk_size
|
|
198
|
+
self.loss_type = loss_type
|
|
199
|
+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
|
|
200
|
+
if self.loss_type not in supported_loss_types:
|
|
201
|
+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
|
|
146
202
|
|
|
147
203
|
def forward(
|
|
148
204
|
self,
|
|
@@ -167,5 +223,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
167
223
|
self.compute_nll_loss,
|
|
168
224
|
self.compiled,
|
|
169
225
|
self.use_ref_model,
|
|
226
|
+
self.average_log_prob,
|
|
170
227
|
self.chunk_size,
|
|
228
|
+
self.loss_type,
|
|
171
229
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
|
9
10
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
10
11
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
11
12
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
|
12
14
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
13
15
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
14
16
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -157,8 +159,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
159
|
compute_ce_loss=True,
|
|
158
160
|
temperature=1.0,
|
|
159
161
|
compiled=True,
|
|
162
|
+
return_soft_hard_loss=False,
|
|
160
163
|
**loss_kwargs,
|
|
161
|
-
):
|
|
164
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
165
|
"""
|
|
163
166
|
Base class for fused linear layer with distillation loss.
|
|
164
167
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +183,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
183
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
184
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
185
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
186
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
187
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
188
|
"""
|
|
185
189
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +191,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
191
|
grad_inputs = []
|
|
188
192
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
193
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
194
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
195
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
196
|
|
|
191
197
|
loss_func_to_call = partial(
|
|
192
198
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +253,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
253
|
)
|
|
248
254
|
grad_weight.add_(chunk_grad_weight)
|
|
249
255
|
loss_acc.add_(chunk_loss)
|
|
256
|
+
if return_soft_hard_loss:
|
|
257
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
258
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
259
|
return chunk_grad_input
|
|
251
260
|
|
|
252
261
|
if compiled:
|
|
@@ -268,10 +277,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
277
|
grad_weight,
|
|
269
278
|
grad_bias,
|
|
270
279
|
)
|
|
280
|
+
if return_soft_hard_loss:
|
|
281
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
282
|
return loss_acc
|
|
272
283
|
|
|
273
284
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
285
|
+
def backward(ctx, grad_output, *args):
|
|
275
286
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
287
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
288
|
grad_input = grad_input * grad_output
|
|
@@ -32,6 +32,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
32
32
|
epsilon_low=0.2,
|
|
33
33
|
epsilon_high=0.2,
|
|
34
34
|
beta=0.04,
|
|
35
|
+
loss_type="dapo",
|
|
36
|
+
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
35
38
|
temperature=1.0,
|
|
36
39
|
compiled=True,
|
|
37
40
|
use_ref_model=False,
|
|
@@ -57,6 +60,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
57
60
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
58
61
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
59
62
|
beta: Weight for the KL penalty
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
64
|
+
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
60
65
|
temperature: Temperature for the logits
|
|
61
66
|
compiled: Whether to use torch compile
|
|
62
67
|
use_ref_model: Whether to use a reference model
|
|
@@ -68,6 +73,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
68
73
|
)
|
|
69
74
|
if ref_per_token_logps is not None and ref_input is not None:
|
|
70
75
|
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
|
76
|
+
if loss_type == "dr_grpo":
|
|
77
|
+
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
|
|
71
78
|
# Initialize accumulators
|
|
72
79
|
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
|
73
80
|
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
@@ -84,6 +91,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
84
91
|
epsilon_low=epsilon_low,
|
|
85
92
|
epsilon_high=epsilon_high,
|
|
86
93
|
beta=beta,
|
|
94
|
+
loss_type=loss_type,
|
|
95
|
+
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
87
97
|
temperature=temperature,
|
|
88
98
|
use_ref_model=use_ref_model,
|
|
89
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
|
@@ -234,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
234
244
|
|
|
235
245
|
return loss_acc, tuple(final_metrics)
|
|
236
246
|
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
237
262
|
@staticmethod
|
|
238
263
|
def _compute_chunk_loss(
|
|
239
264
|
input_chunk,
|
|
@@ -251,6 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
251
276
|
epsilon_low=0.2,
|
|
252
277
|
epsilon_high=0.2,
|
|
253
278
|
beta=0.04,
|
|
279
|
+
loss_type="dapo",
|
|
280
|
+
max_completion_length=None,
|
|
281
|
+
importance_sampling_level="token",
|
|
254
282
|
temperature=1.0,
|
|
255
283
|
use_ref_model=False,
|
|
256
284
|
ppo_loss_fn=None,
|
|
@@ -280,6 +308,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
280
308
|
epsilon_low=epsilon_low,
|
|
281
309
|
epsilon_high=epsilon_high,
|
|
282
310
|
beta=beta,
|
|
311
|
+
loss_type=loss_type,
|
|
312
|
+
max_completion_length=max_completion_length,
|
|
313
|
+
importance_sampling_level=importance_sampling_level,
|
|
283
314
|
)
|
|
284
315
|
|
|
285
316
|
return chunk_loss, chunk_metrics
|
|
@@ -303,6 +334,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
303
334
|
def backward(ctx, grad_output, *grad_metrics):
|
|
304
335
|
"""Backward pass for PPO loss."""
|
|
305
336
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
337
|
+
|
|
306
338
|
if grad_output != 1.0:
|
|
307
339
|
grad_input = grad_input * grad_output
|
|
308
340
|
grad_weight = grad_weight * grad_output
|
|
@@ -324,6 +356,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
324
356
|
None, # grad_epsilon_low
|
|
325
357
|
None, # grad_epsilon_high
|
|
326
358
|
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
327
362
|
None, # grad_temperature
|
|
328
363
|
None, # grad_compiled
|
|
329
364
|
None, # grad_use_ref_model
|
|
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
222
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
223
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
224
|
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
225
|
-
strict=False,
|
|
226
225
|
):
|
|
227
226
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
228
227
|
ref_input_chunk = (
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
|
|
3
5
|
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
|
@@ -27,6 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
27
29
|
epsilon_low=0.2,
|
|
28
30
|
epsilon_high=0.2,
|
|
29
31
|
beta=0.04,
|
|
32
|
+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
|
|
33
|
+
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
30
35
|
**kwargs,
|
|
31
36
|
):
|
|
32
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
@@ -46,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
46
51
|
|
|
47
52
|
# Compute policy gradient loss with importance sampling ratio
|
|
48
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
49
|
-
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
50
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
51
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
52
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
@@ -61,15 +81,42 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
61
81
|
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
|
62
82
|
# and TRL GRPO implementation
|
|
63
83
|
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
|
64
|
-
|
|
84
|
+
if loss_type == "grpo":
|
|
85
|
+
# Average per-sequence loss
|
|
86
|
+
loss = (
|
|
87
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
|
88
|
+
).sum() / full_attention_mask.shape[0]
|
|
89
|
+
elif loss_type == "bnpo":
|
|
90
|
+
# Batch Normalized Per-token loss (original implementation)
|
|
91
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
|
92
|
+
elif loss_type == "dr_grpo":
|
|
93
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
|
94
|
+
if max_completion_length is None:
|
|
95
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
96
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
elif loss_type == "dapo":
|
|
98
|
+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
|
|
99
|
+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
65
102
|
|
|
66
103
|
# Calculate metrics
|
|
67
104
|
metrics = []
|
|
68
105
|
if beta != 0.0:
|
|
69
106
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
107
|
+
|
|
108
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
109
|
+
if importance_sampling_level == "token":
|
|
110
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
111
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
112
|
+
)
|
|
113
|
+
else: # sequence level
|
|
114
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
115
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
116
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
117
|
+
)
|
|
118
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
119
|
+
|
|
73
120
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
74
121
|
return loss, metrics
|
|
75
122
|
|
|
@@ -91,6 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
91
138
|
beta=0.04,
|
|
92
139
|
epsilon_low=0.2,
|
|
93
140
|
epsilon_high=0.2,
|
|
141
|
+
loss_type="dapo",
|
|
142
|
+
max_completion_length=None,
|
|
143
|
+
importance_sampling_level="token",
|
|
94
144
|
temperature=1.0,
|
|
95
145
|
compiled=True,
|
|
96
146
|
use_ref_model=True,
|
|
@@ -110,6 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
110
160
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
111
161
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
112
162
|
beta (float): Weight for the KL penalty
|
|
163
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
164
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
165
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
113
166
|
temperature (float): Temperature for the logits
|
|
114
167
|
compiled (bool): Whether to use torch compile
|
|
115
168
|
use_ref_model (bool): Whether to use a reference model
|
|
@@ -134,10 +187,13 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
134
187
|
beta=beta,
|
|
135
188
|
epsilon_low=epsilon_low,
|
|
136
189
|
epsilon_high=epsilon_high,
|
|
190
|
+
loss_type=loss_type,
|
|
191
|
+
max_completion_length=max_completion_length,
|
|
137
192
|
temperature=temperature,
|
|
138
193
|
compiled=compiled,
|
|
139
194
|
use_ref_model=use_ref_model,
|
|
140
195
|
chunk_size=chunk_size,
|
|
196
|
+
importance_sampling_level=importance_sampling_level,
|
|
141
197
|
)
|
|
142
198
|
|
|
143
199
|
@staticmethod
|
|
@@ -161,6 +217,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
161
217
|
None, # grad_beta
|
|
162
218
|
None, # grad_epsilon_low
|
|
163
219
|
None, # grad_epsilon_high
|
|
220
|
+
None, # grad_loss_type (string, not differentiable)
|
|
221
|
+
None, # grad_max_completion_length (int, not differentiable)
|
|
222
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
164
223
|
None, # grad_temperature
|
|
165
224
|
None, # grad_compiled
|
|
166
225
|
None, # grad_use_ref_model
|
|
@@ -179,6 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
179
238
|
chunk_size: int = 1,
|
|
180
239
|
epsilon_low: float = 0.2,
|
|
181
240
|
epsilon_high: float = 0.2,
|
|
241
|
+
loss_type: str = "dapo",
|
|
242
|
+
max_completion_length: Optional[int] = None,
|
|
243
|
+
importance_sampling_level: str = "token",
|
|
182
244
|
temperature: float = 1.0,
|
|
183
245
|
):
|
|
184
246
|
"""
|
|
@@ -189,6 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
189
251
|
chunk_size (int): Size of chunks for processing.
|
|
190
252
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
191
253
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
254
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
255
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
256
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
192
257
|
temperature (float): Temperature for the logits.
|
|
193
258
|
"""
|
|
194
259
|
super().__init__()
|
|
@@ -198,6 +263,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
198
263
|
self.chunk_size = chunk_size
|
|
199
264
|
self.epsilon_low = epsilon_low
|
|
200
265
|
self.epsilon_high = epsilon_high
|
|
266
|
+
self.loss_type = loss_type
|
|
267
|
+
self.max_completion_length = max_completion_length
|
|
268
|
+
self.importance_sampling_level = importance_sampling_level
|
|
201
269
|
self.temperature = temperature
|
|
202
270
|
|
|
203
271
|
def forward(
|
|
@@ -229,6 +297,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
229
297
|
self.beta,
|
|
230
298
|
self.epsilon_low,
|
|
231
299
|
self.epsilon_high,
|
|
300
|
+
self.loss_type,
|
|
301
|
+
self.max_completion_length,
|
|
302
|
+
self.importance_sampling_level,
|
|
232
303
|
self.temperature,
|
|
233
304
|
self.compiled,
|
|
234
305
|
self.use_ref_model,
|