liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
|
|
11
|
+
@staticmethod
|
|
12
|
+
def distillation_loss_fn(
|
|
13
|
+
student_logits,
|
|
14
|
+
teacher_logits,
|
|
15
|
+
target=None,
|
|
16
|
+
ignore_index=None,
|
|
17
|
+
beta=1.0,
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Compute Cosine loss (Cosine Similarity Loss).
|
|
21
|
+
Args:
|
|
22
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
23
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
24
|
+
beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
|
|
25
|
+
Returns:
|
|
26
|
+
torch.Tensor: cosine similarity loss
|
|
27
|
+
"""
|
|
28
|
+
student_norm = F.normalize(student_logits, p=2, dim=-1)
|
|
29
|
+
teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
|
|
30
|
+
|
|
31
|
+
cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
|
|
32
|
+
loss = beta * (1 - cosine_sim)
|
|
33
|
+
return loss.sum()
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def forward(
|
|
37
|
+
cls,
|
|
38
|
+
ctx,
|
|
39
|
+
student_input: torch.Tensor,
|
|
40
|
+
student_weight: torch.Tensor,
|
|
41
|
+
teacher_input: torch.Tensor,
|
|
42
|
+
teacher_weight: torch.Tensor,
|
|
43
|
+
true_labels: torch.LongTensor,
|
|
44
|
+
student_bias: torch.Tensor,
|
|
45
|
+
teacher_bias: torch.Tensor,
|
|
46
|
+
weight_hard_loss: float = 0.5,
|
|
47
|
+
weight_soft_loss: float = 0.5,
|
|
48
|
+
beta: float = 0.5,
|
|
49
|
+
ignore_index: int = -100,
|
|
50
|
+
temperature: float = 1.0,
|
|
51
|
+
compiled: bool = True,
|
|
52
|
+
chunk_size: int = 1024,
|
|
53
|
+
return_soft_hard_loss: bool = False,
|
|
54
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
55
|
+
return super().forward(
|
|
56
|
+
cls=cls,
|
|
57
|
+
ctx=ctx,
|
|
58
|
+
student_input=student_input,
|
|
59
|
+
student_weight=student_weight,
|
|
60
|
+
teacher_input=teacher_input,
|
|
61
|
+
teacher_weight=teacher_weight,
|
|
62
|
+
target=true_labels,
|
|
63
|
+
student_bias=student_bias,
|
|
64
|
+
teacher_bias=teacher_bias,
|
|
65
|
+
chunk_size=chunk_size,
|
|
66
|
+
weight_hard_loss=weight_hard_loss,
|
|
67
|
+
weight_soft_loss=weight_soft_loss,
|
|
68
|
+
beta=beta,
|
|
69
|
+
ignore_index=ignore_index,
|
|
70
|
+
temperature=temperature,
|
|
71
|
+
compiled=compiled,
|
|
72
|
+
return_soft_hard_loss=return_soft_hard_loss,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def backward(ctx, grad_output, *args):
|
|
77
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
|
|
78
|
+
|
|
79
|
+
return (
|
|
80
|
+
*grads,
|
|
81
|
+
None, # teacher_bias
|
|
82
|
+
None, # weight_hard_loss
|
|
83
|
+
None, # weight_soft_loss
|
|
84
|
+
None, # beta
|
|
85
|
+
None, # ignore_index
|
|
86
|
+
None, # temperature
|
|
87
|
+
None, # compiled
|
|
88
|
+
None, # chunk_size
|
|
89
|
+
None, # return_soft_hard_loss
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
weight_hard_loss: float = 0.5,
|
|
97
|
+
weight_soft_loss: float = 0.5,
|
|
98
|
+
beta: float = 0.5,
|
|
99
|
+
ignore_index: int = -100,
|
|
100
|
+
temperature: float = 1.0,
|
|
101
|
+
compiled: bool = True,
|
|
102
|
+
chunk_size: int = 1024,
|
|
103
|
+
return_soft_hard_loss: bool = False,
|
|
104
|
+
):
|
|
105
|
+
super().__init__()
|
|
106
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
107
|
+
self.weight_hard_loss = weight_hard_loss
|
|
108
|
+
self.weight_soft_loss = weight_soft_loss
|
|
109
|
+
self.ignore_index = ignore_index
|
|
110
|
+
self.temperature = temperature
|
|
111
|
+
self.compiled = compiled
|
|
112
|
+
self.beta = beta
|
|
113
|
+
self.chunk_size = chunk_size
|
|
114
|
+
self.return_soft_hard_loss = return_soft_hard_loss
|
|
115
|
+
|
|
116
|
+
def forward(
|
|
117
|
+
self,
|
|
118
|
+
student_input: torch.Tensor,
|
|
119
|
+
student_weight: torch.Tensor,
|
|
120
|
+
teacher_input: torch.Tensor,
|
|
121
|
+
teacher_weight: torch.Tensor,
|
|
122
|
+
true_labels: torch.LongTensor,
|
|
123
|
+
student_bias: torch.Tensor = None,
|
|
124
|
+
teacher_bias: torch.Tensor = None,
|
|
125
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
126
|
+
return LigerFusedLinearCosineSimilarityFunction.apply(
|
|
127
|
+
student_input,
|
|
128
|
+
student_weight,
|
|
129
|
+
teacher_input,
|
|
130
|
+
teacher_weight,
|
|
131
|
+
true_labels,
|
|
132
|
+
student_bias,
|
|
133
|
+
teacher_bias,
|
|
134
|
+
self.weight_hard_loss,
|
|
135
|
+
self.weight_soft_loss,
|
|
136
|
+
self.beta,
|
|
137
|
+
self.ignore_index,
|
|
138
|
+
self.temperature,
|
|
139
|
+
self.compiled,
|
|
140
|
+
self.chunk_size,
|
|
141
|
+
self.return_soft_hard_loss,
|
|
142
|
+
)
|
|
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
13
13
|
ref_chosen_logps=None,
|
|
14
14
|
ref_rejected_logps=None,
|
|
15
15
|
beta=0.1,
|
|
16
|
+
loss_type="sigmoid",
|
|
16
17
|
):
|
|
17
18
|
"""
|
|
18
19
|
Paper: https://arxiv.org/pdf/2305.18290
|
|
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
49
|
chosen_rewards = beta * chosen_logratios
|
|
49
50
|
rejected_rewards = beta * rejected_logratios
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
|
|
52
|
+
if loss_type == "sigmoid":
|
|
53
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
54
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
55
|
+
|
|
56
|
+
elif loss_type == "apo_zero":
|
|
57
|
+
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
58
|
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
|
59
|
+
losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
|
|
60
|
+
losses_rejected = F.sigmoid(beta * rejected_logratios)
|
|
61
|
+
losses = losses_chosen + losses_rejected
|
|
62
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
63
|
+
|
|
64
|
+
elif loss_type == "apo_down":
|
|
65
|
+
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
|
66
|
+
# Use this loss when you believe the chosen outputs are worse than your model's default output.
|
|
67
|
+
# Decrease chosen likelihood and decrease rejected likelihood more
|
|
68
|
+
losses_chosen = F.sigmoid(beta * chosen_logratios)
|
|
69
|
+
losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
|
|
70
|
+
losses = losses_chosen + losses_rejected
|
|
71
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
72
|
+
|
|
73
|
+
elif loss_type == "sppo_hard":
|
|
74
|
+
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
|
|
75
|
+
# estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
|
|
76
|
+
# The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
|
|
77
|
+
# set to 1 for the winner and 0 for the loser.
|
|
78
|
+
a = chosen_logps - ref_chosen_logps
|
|
79
|
+
b = rejected_logps - ref_rejected_logps
|
|
80
|
+
losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
|
|
81
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
82
|
+
|
|
83
|
+
elif loss_type == "nca_pair":
|
|
84
|
+
losses = (
|
|
85
|
+
-F.logsigmoid(chosen_rewards)
|
|
86
|
+
- 0.5 * F.logsigmoid(-chosen_rewards)
|
|
87
|
+
- 0.5 * F.logsigmoid(-rejected_rewards)
|
|
88
|
+
)
|
|
89
|
+
loss = losses.sum() / (full_target.shape[0] // 2)
|
|
90
|
+
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
|
|
94
|
+
)
|
|
95
|
+
|
|
53
96
|
return loss, chosen_rewards, rejected_rewards
|
|
54
97
|
|
|
55
98
|
@classmethod
|
|
@@ -68,7 +111,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
111
|
compute_nll_loss=False,
|
|
69
112
|
compiled=True,
|
|
70
113
|
use_ref_model=True,
|
|
114
|
+
average_log_prob=False,
|
|
71
115
|
chunk_size=1,
|
|
116
|
+
loss_type="sigmoid",
|
|
72
117
|
):
|
|
73
118
|
"""
|
|
74
119
|
Fused linear layer with DPO loss.
|
|
@@ -85,6 +130,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
85
130
|
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
86
131
|
compiled (bool): Whether to use torch compile
|
|
87
132
|
use_ref_model (bool): Whether to use a reference model
|
|
133
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
88
134
|
chunk_size (int): Size of chunks for processing.
|
|
89
135
|
Returns:
|
|
90
136
|
torch.Tensor: Computed loss
|
|
@@ -104,13 +150,15 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
104
150
|
ref_input=ref_input,
|
|
105
151
|
ref_weight=ref_weight,
|
|
106
152
|
ref_bias=ref_bias,
|
|
153
|
+
average_log_prob=average_log_prob,
|
|
107
154
|
chunk_size=chunk_size,
|
|
155
|
+
loss_type=loss_type,
|
|
108
156
|
)
|
|
109
157
|
|
|
110
158
|
@staticmethod
|
|
111
159
|
def backward(ctx, *grad_output):
|
|
112
160
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
113
|
-
return *grads, None, None, None, None, None, None, None, None, None
|
|
161
|
+
return *grads, None, None, None, None, None, None, None, None, None, None, None
|
|
114
162
|
|
|
115
163
|
|
|
116
164
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -125,7 +173,9 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
125
173
|
compute_nll_loss: bool = False,
|
|
126
174
|
compiled: bool = True,
|
|
127
175
|
use_ref_model: bool = True,
|
|
176
|
+
average_log_prob: bool = False,
|
|
128
177
|
chunk_size: int = 1,
|
|
178
|
+
loss_type: str = "sigmoid",
|
|
129
179
|
):
|
|
130
180
|
"""
|
|
131
181
|
Args:
|
|
@@ -134,6 +184,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
134
184
|
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
135
185
|
compiled (bool): Whether to use the torch compiled kernel.
|
|
136
186
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
187
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token.
|
|
137
188
|
chunk_size (int): Size of chunks for processing.
|
|
138
189
|
"""
|
|
139
190
|
super().__init__()
|
|
@@ -142,7 +193,12 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
142
193
|
self.compute_nll_loss = compute_nll_loss
|
|
143
194
|
self.compiled = compiled
|
|
144
195
|
self.use_ref_model = use_ref_model
|
|
196
|
+
self.average_log_prob = average_log_prob
|
|
145
197
|
self.chunk_size = chunk_size
|
|
198
|
+
self.loss_type = loss_type
|
|
199
|
+
supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
|
|
200
|
+
if self.loss_type not in supported_loss_types:
|
|
201
|
+
raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
|
|
146
202
|
|
|
147
203
|
def forward(
|
|
148
204
|
self,
|
|
@@ -167,5 +223,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
167
223
|
self.compute_nll_loss,
|
|
168
224
|
self.compiled,
|
|
169
225
|
self.use_ref_model,
|
|
226
|
+
self.average_log_prob,
|
|
170
227
|
self.chunk_size,
|
|
228
|
+
self.loss_type,
|
|
171
229
|
)
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
|
|
1
2
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
3
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
4
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
|
|
@@ -9,6 +10,7 @@ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
|
9
10
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
10
11
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
11
12
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
13
|
+
liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
|
|
12
14
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
13
15
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
14
16
|
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
|
|
@@ -11,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
11
13
|
def distillation_loss_fn(
|
|
12
14
|
student_logits,
|
|
13
15
|
teacher_logits,
|
|
16
|
+
target=None,
|
|
17
|
+
ignore_index=None,
|
|
14
18
|
):
|
|
15
19
|
"""
|
|
16
20
|
Compute distillation loss.
|
|
@@ -130,10 +134,15 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
130
134
|
)
|
|
131
135
|
student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
|
|
132
136
|
|
|
133
|
-
|
|
137
|
+
num_valid_tokens = (full_target != ignore_index).sum()
|
|
138
|
+
num_valid_tokens = num_valid_tokens.clamp_min(1) # to avoid division by zero
|
|
134
139
|
|
|
135
|
-
|
|
136
|
-
|
|
140
|
+
hard_loss /= num_valid_tokens
|
|
141
|
+
|
|
142
|
+
soft_loss = distillation_loss_fn(
|
|
143
|
+
student_logits_chunk, teacher_logits_chunk, target=target_chunk, ignore_index=ignore_index, **loss_kwargs
|
|
144
|
+
)
|
|
145
|
+
soft_loss /= num_valid_tokens
|
|
137
146
|
|
|
138
147
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
139
148
|
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
|
|
@@ -157,8 +166,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
157
166
|
compute_ce_loss=True,
|
|
158
167
|
temperature=1.0,
|
|
159
168
|
compiled=True,
|
|
169
|
+
return_soft_hard_loss=False,
|
|
160
170
|
**loss_kwargs,
|
|
161
|
-
):
|
|
171
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
162
172
|
"""
|
|
163
173
|
Base class for fused linear layer with distillation loss.
|
|
164
174
|
Only need to compute gradients for student model.
|
|
@@ -180,6 +190,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
180
190
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
181
191
|
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
182
192
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
193
|
+
return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
|
|
183
194
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
184
195
|
"""
|
|
185
196
|
CHUNK_SIZE = chunk_size
|
|
@@ -187,6 +198,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
187
198
|
grad_inputs = []
|
|
188
199
|
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
189
200
|
loss_acc = torch.zeros((), device=student_input.device)
|
|
201
|
+
soft_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
202
|
+
hard_loss_acc = torch.zeros((), device=student_input.device) if return_soft_hard_loss else None
|
|
190
203
|
|
|
191
204
|
loss_func_to_call = partial(
|
|
192
205
|
LigerFusedLinearDistillationBase._compute_loss,
|
|
@@ -247,6 +260,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
247
260
|
)
|
|
248
261
|
grad_weight.add_(chunk_grad_weight)
|
|
249
262
|
loss_acc.add_(chunk_loss)
|
|
263
|
+
if return_soft_hard_loss:
|
|
264
|
+
soft_loss_acc.add_(chunk_soft_loss)
|
|
265
|
+
hard_loss_acc.add_(chunk_hard_loss)
|
|
250
266
|
return chunk_grad_input
|
|
251
267
|
|
|
252
268
|
if compiled:
|
|
@@ -268,10 +284,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
268
284
|
grad_weight,
|
|
269
285
|
grad_bias,
|
|
270
286
|
)
|
|
287
|
+
if return_soft_hard_loss:
|
|
288
|
+
return loss_acc, soft_loss_acc, hard_loss_acc
|
|
271
289
|
return loss_acc
|
|
272
290
|
|
|
273
291
|
@staticmethod
|
|
274
|
-
def backward(ctx, grad_output):
|
|
292
|
+
def backward(ctx, grad_output, *args):
|
|
275
293
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
276
294
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
277
295
|
grad_input = grad_input * grad_output
|
|
@@ -32,11 +32,15 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
32
32
|
epsilon_low=0.2,
|
|
33
33
|
epsilon_high=0.2,
|
|
34
34
|
beta=0.04,
|
|
35
|
+
loss_type="dapo",
|
|
36
|
+
max_completion_length=None,
|
|
37
|
+
importance_sampling_level="token",
|
|
35
38
|
temperature=1.0,
|
|
36
39
|
compiled=True,
|
|
37
40
|
use_ref_model=False,
|
|
38
41
|
chunk_size=1,
|
|
39
42
|
):
|
|
43
|
+
# TODO: check torch compile matmul
|
|
40
44
|
"""Chunked forward pass for PPO loss computation.
|
|
41
45
|
|
|
42
46
|
Args:
|
|
@@ -56,6 +60,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
56
60
|
epsilon_low: Lower bound for clipping the importance sampling ratio
|
|
57
61
|
epsilon_high: Upper bound for clipping the importance sampling ratio
|
|
58
62
|
beta: Weight for the KL penalty
|
|
63
|
+
loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo")
|
|
64
|
+
max_completion_length: Maximum completion length required for "dr_grpo"
|
|
59
65
|
temperature: Temperature for the logits
|
|
60
66
|
compiled: Whether to use torch compile
|
|
61
67
|
use_ref_model: Whether to use a reference model
|
|
@@ -67,6 +73,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
67
73
|
)
|
|
68
74
|
if ref_per_token_logps is not None and ref_input is not None:
|
|
69
75
|
raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
|
|
76
|
+
if loss_type == "dr_grpo":
|
|
77
|
+
assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
|
|
70
78
|
# Initialize accumulators
|
|
71
79
|
loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
|
|
72
80
|
grad_weight = torch.zeros_like(weight) # [V, H]
|
|
@@ -83,6 +91,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
83
91
|
epsilon_low=epsilon_low,
|
|
84
92
|
epsilon_high=epsilon_high,
|
|
85
93
|
beta=beta,
|
|
94
|
+
loss_type=loss_type,
|
|
95
|
+
max_completion_length=max_completion_length,
|
|
96
|
+
importance_sampling_level=importance_sampling_level,
|
|
86
97
|
temperature=temperature,
|
|
87
98
|
use_ref_model=use_ref_model,
|
|
88
99
|
ppo_loss_fn=cls.ppo_loss_fn,
|
|
@@ -233,6 +244,21 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
233
244
|
|
|
234
245
|
return loss_acc, tuple(final_metrics)
|
|
235
246
|
|
|
247
|
+
@staticmethod
|
|
248
|
+
def _compute_dapo_normalizer(attention_mask):
|
|
249
|
+
"""Global active tokens averaged per process."""
|
|
250
|
+
normalizer = attention_mask.to(torch.float32).sum()
|
|
251
|
+
world_size = 1
|
|
252
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
253
|
+
import torch.distributed as dist
|
|
254
|
+
|
|
255
|
+
normalizer = normalizer.clone()
|
|
256
|
+
dist.all_reduce(normalizer, op=dist.ReduceOp.SUM)
|
|
257
|
+
world_size = dist.get_world_size()
|
|
258
|
+
|
|
259
|
+
normalizer = normalizer / world_size
|
|
260
|
+
return torch.clamp(normalizer, min=1.0)
|
|
261
|
+
|
|
236
262
|
@staticmethod
|
|
237
263
|
def _compute_chunk_loss(
|
|
238
264
|
input_chunk,
|
|
@@ -250,6 +276,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
250
276
|
epsilon_low=0.2,
|
|
251
277
|
epsilon_high=0.2,
|
|
252
278
|
beta=0.04,
|
|
279
|
+
loss_type="dapo",
|
|
280
|
+
max_completion_length=None,
|
|
281
|
+
importance_sampling_level="token",
|
|
253
282
|
temperature=1.0,
|
|
254
283
|
use_ref_model=False,
|
|
255
284
|
ppo_loss_fn=None,
|
|
@@ -279,6 +308,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
279
308
|
epsilon_low=epsilon_low,
|
|
280
309
|
epsilon_high=epsilon_high,
|
|
281
310
|
beta=beta,
|
|
311
|
+
loss_type=loss_type,
|
|
312
|
+
max_completion_length=max_completion_length,
|
|
313
|
+
importance_sampling_level=importance_sampling_level,
|
|
282
314
|
)
|
|
283
315
|
|
|
284
316
|
return chunk_loss, chunk_metrics
|
|
@@ -302,6 +334,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
302
334
|
def backward(ctx, grad_output, *grad_metrics):
|
|
303
335
|
"""Backward pass for PPO loss."""
|
|
304
336
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
337
|
+
|
|
305
338
|
if grad_output != 1.0:
|
|
306
339
|
grad_input = grad_input * grad_output
|
|
307
340
|
grad_weight = grad_weight * grad_output
|
|
@@ -323,6 +356,9 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
|
|
|
323
356
|
None, # grad_epsilon_low
|
|
324
357
|
None, # grad_epsilon_high
|
|
325
358
|
None, # grad_beta
|
|
359
|
+
None, # grad_loss_type
|
|
360
|
+
None, # grad_max_completion_length
|
|
361
|
+
None, # grad_importance_sampling_level
|
|
326
362
|
None, # grad_temperature
|
|
327
363
|
None, # grad_compiled
|
|
328
364
|
None, # grad_use_ref_model
|
|
@@ -222,7 +222,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
|
222
222
|
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
|
223
223
|
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
|
224
224
|
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
|
|
225
|
-
strict=False,
|
|
226
225
|
):
|
|
227
226
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
228
227
|
ref_input_chunk = (
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
|
|
3
5
|
from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
|
|
@@ -27,6 +29,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
27
29
|
epsilon_low=0.2,
|
|
28
30
|
epsilon_high=0.2,
|
|
29
31
|
beta=0.04,
|
|
32
|
+
loss_type="dapo", # ["grpo", "bnpo", "dr_grpo", "dapo"]
|
|
33
|
+
max_completion_length=None, # Required for dr_grpo
|
|
34
|
+
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
|
|
30
35
|
**kwargs,
|
|
31
36
|
):
|
|
32
37
|
"""GRPO Loss Function matching GRPOTrainer implementation."""
|
|
@@ -46,7 +51,22 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
46
51
|
|
|
47
52
|
# Compute policy gradient loss with importance sampling ratio
|
|
48
53
|
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
|
|
49
|
-
|
|
54
|
+
log_ratio = per_token_logps - old_per_token_logps
|
|
55
|
+
|
|
56
|
+
if importance_sampling_level == "token":
|
|
57
|
+
log_importance_weights = log_ratio
|
|
58
|
+
elif importance_sampling_level == "sequence":
|
|
59
|
+
log_importance_weights = (log_ratio * attention_mask).sum(-1) / attention_mask.sum(-1).clamp(min=1.0)
|
|
60
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"Unknown importance sampling level: {importance_sampling_level}. Possible values are 'token' "
|
|
64
|
+
"and 'sequence'."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
|
68
|
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
|
69
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
50
70
|
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
|
|
51
71
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
|
52
72
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
|
@@ -61,15 +81,42 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
61
81
|
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
|
|
62
82
|
# and TRL GRPO implementation
|
|
63
83
|
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
|
|
64
|
-
|
|
84
|
+
if loss_type == "grpo":
|
|
85
|
+
# Average per-sequence loss
|
|
86
|
+
loss = (
|
|
87
|
+
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
|
|
88
|
+
).sum() / full_attention_mask.shape[0]
|
|
89
|
+
elif loss_type == "bnpo":
|
|
90
|
+
# Batch Normalized Per-token loss (original implementation)
|
|
91
|
+
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
|
|
92
|
+
elif loss_type == "dr_grpo":
|
|
93
|
+
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
|
|
94
|
+
if max_completion_length is None:
|
|
95
|
+
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
|
|
96
|
+
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
|
|
97
|
+
elif loss_type == "dapo":
|
|
98
|
+
loss_normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(full_attention_mask)
|
|
99
|
+
loss = (per_token_loss * attention_mask).sum() / loss_normalizer
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
65
102
|
|
|
66
103
|
# Calculate metrics
|
|
67
104
|
metrics = []
|
|
68
105
|
if beta != 0.0:
|
|
69
106
|
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
107
|
+
|
|
108
|
+
# Adjust clipping metric calculation based on importance sampling level
|
|
109
|
+
if importance_sampling_level == "token":
|
|
110
|
+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
|
111
|
+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
|
|
112
|
+
)
|
|
113
|
+
else: # sequence level
|
|
114
|
+
# For sequence level, coef_1 is shape (B, 1), advantages is shape (B,)
|
|
115
|
+
is_clipped = ((coef_1.squeeze(-1) < 1 - epsilon_low) & (advantages < 0)) | (
|
|
116
|
+
(coef_1.squeeze(-1) > 1 + epsilon_high) & (advantages > 0)
|
|
117
|
+
)
|
|
118
|
+
is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask)
|
|
119
|
+
|
|
73
120
|
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
|
|
74
121
|
return loss, metrics
|
|
75
122
|
|
|
@@ -91,6 +138,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
91
138
|
beta=0.04,
|
|
92
139
|
epsilon_low=0.2,
|
|
93
140
|
epsilon_high=0.2,
|
|
141
|
+
loss_type="dapo",
|
|
142
|
+
max_completion_length=None,
|
|
143
|
+
importance_sampling_level="token",
|
|
94
144
|
temperature=1.0,
|
|
95
145
|
compiled=True,
|
|
96
146
|
use_ref_model=True,
|
|
@@ -110,6 +160,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
110
160
|
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
111
161
|
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
112
162
|
beta (float): Weight for the KL penalty
|
|
163
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
164
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
165
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
113
166
|
temperature (float): Temperature for the logits
|
|
114
167
|
compiled (bool): Whether to use torch compile
|
|
115
168
|
use_ref_model (bool): Whether to use a reference model
|
|
@@ -134,10 +187,13 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
134
187
|
beta=beta,
|
|
135
188
|
epsilon_low=epsilon_low,
|
|
136
189
|
epsilon_high=epsilon_high,
|
|
190
|
+
loss_type=loss_type,
|
|
191
|
+
max_completion_length=max_completion_length,
|
|
137
192
|
temperature=temperature,
|
|
138
193
|
compiled=compiled,
|
|
139
194
|
use_ref_model=use_ref_model,
|
|
140
195
|
chunk_size=chunk_size,
|
|
196
|
+
importance_sampling_level=importance_sampling_level,
|
|
141
197
|
)
|
|
142
198
|
|
|
143
199
|
@staticmethod
|
|
@@ -161,6 +217,9 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
|
|
|
161
217
|
None, # grad_beta
|
|
162
218
|
None, # grad_epsilon_low
|
|
163
219
|
None, # grad_epsilon_high
|
|
220
|
+
None, # grad_loss_type (string, not differentiable)
|
|
221
|
+
None, # grad_max_completion_length (int, not differentiable)
|
|
222
|
+
None, # grad_importance_sampling_level (string, not differentiable)
|
|
164
223
|
None, # grad_temperature
|
|
165
224
|
None, # grad_compiled
|
|
166
225
|
None, # grad_use_ref_model
|
|
@@ -179,6 +238,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
179
238
|
chunk_size: int = 1,
|
|
180
239
|
epsilon_low: float = 0.2,
|
|
181
240
|
epsilon_high: float = 0.2,
|
|
241
|
+
loss_type: str = "dapo",
|
|
242
|
+
max_completion_length: Optional[int] = None,
|
|
243
|
+
importance_sampling_level: str = "token",
|
|
182
244
|
temperature: float = 1.0,
|
|
183
245
|
):
|
|
184
246
|
"""
|
|
@@ -189,6 +251,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
189
251
|
chunk_size (int): Size of chunks for processing.
|
|
190
252
|
epsilon_low (float): Lower bound for the importance sampling ratio.
|
|
191
253
|
epsilon_high (float): Upper bound for the importance sampling ratio.
|
|
254
|
+
loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "dapo".
|
|
255
|
+
max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
|
|
256
|
+
importance_sampling_level (str): Level of importance sampling ("token" or "sequence"). Defaults to "token".
|
|
192
257
|
temperature (float): Temperature for the logits.
|
|
193
258
|
"""
|
|
194
259
|
super().__init__()
|
|
@@ -198,6 +263,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
198
263
|
self.chunk_size = chunk_size
|
|
199
264
|
self.epsilon_low = epsilon_low
|
|
200
265
|
self.epsilon_high = epsilon_high
|
|
266
|
+
self.loss_type = loss_type
|
|
267
|
+
self.max_completion_length = max_completion_length
|
|
268
|
+
self.importance_sampling_level = importance_sampling_level
|
|
201
269
|
self.temperature = temperature
|
|
202
270
|
|
|
203
271
|
def forward(
|
|
@@ -229,6 +297,9 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
|
|
|
229
297
|
self.beta,
|
|
230
298
|
self.epsilon_low,
|
|
231
299
|
self.epsilon_high,
|
|
300
|
+
self.loss_type,
|
|
301
|
+
self.max_completion_length,
|
|
302
|
+
self.importance_sampling_level,
|
|
232
303
|
self.temperature,
|
|
233
304
|
self.compiled,
|
|
234
305
|
self.use_ref_model,
|