liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +21 -37
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/qwen2vl_mrope.py +2 -2
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.1.dist-info/RECORD +0 -65
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
|
|
10
|
+
"""
|
|
11
|
+
Compute JSD loss (Jensen-Shannon Divergence Loss).
|
|
12
|
+
Args:
|
|
13
|
+
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
|
|
14
|
+
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
|
|
15
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
16
|
+
Returns:
|
|
17
|
+
torch.Tensor: Jensen-Shannon Divergence loss
|
|
18
|
+
"""
|
|
19
|
+
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
|
20
|
+
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
|
21
|
+
|
|
22
|
+
# Compute probabilities (only required for mean calculation)
|
|
23
|
+
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
|
|
24
|
+
log_mean_probs = mean_probs.log()
|
|
25
|
+
|
|
26
|
+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
|
|
27
|
+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
|
|
28
|
+
|
|
29
|
+
# JSD is the weighted average of the KL divergences
|
|
30
|
+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
|
+
return jsd_loss
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def forward(
|
|
35
|
+
ctx,
|
|
36
|
+
student_input: torch.Tensor,
|
|
37
|
+
student_weight: torch.Tensor,
|
|
38
|
+
teacher_input: torch.Tensor,
|
|
39
|
+
teacher_weight: torch.Tensor,
|
|
40
|
+
true_labels: torch.LongTensor,
|
|
41
|
+
weight_hard_loss: float = 0.5,
|
|
42
|
+
weight_soft_loss: float = 0.5,
|
|
43
|
+
beta: float = 0.5,
|
|
44
|
+
ignore_index: int = -100,
|
|
45
|
+
temperature: float = 1.0,
|
|
46
|
+
compiled: bool = True,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Fused linear layer with JSD distillation loss.
|
|
50
|
+
Args:
|
|
51
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
|
|
52
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
|
|
53
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
|
|
54
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
|
|
55
|
+
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
56
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
57
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
58
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
59
|
+
ignore_index (int): Index to ignore in loss computation
|
|
60
|
+
temperature (float): Temperature for softening/sharpening distributions
|
|
61
|
+
compiled (bool): Whether to use torch compile
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: Computed loss
|
|
64
|
+
"""
|
|
65
|
+
return LigerFusedLinearDistillationBase.forward(
|
|
66
|
+
ctx=ctx,
|
|
67
|
+
student_input=student_input,
|
|
68
|
+
student_weight=student_weight,
|
|
69
|
+
teacher_input=teacher_input,
|
|
70
|
+
teacher_weight=teacher_weight,
|
|
71
|
+
target=true_labels,
|
|
72
|
+
loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
|
|
73
|
+
chunk_size=1,
|
|
74
|
+
weight_hard_loss=weight_hard_loss,
|
|
75
|
+
weight_soft_loss=weight_soft_loss,
|
|
76
|
+
beta=beta,
|
|
77
|
+
ignore_index=ignore_index,
|
|
78
|
+
temperature=temperature,
|
|
79
|
+
compiled=compiled,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def backward(ctx, grad_output):
|
|
84
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
|
|
85
|
+
|
|
86
|
+
return (*grads, None, None, None, None, None, None, None)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
90
|
+
"""
|
|
91
|
+
Fused linear layer with JSD distillation loss.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
weight_hard_loss: float = 0.5,
|
|
97
|
+
weight_soft_loss: float = 0.5,
|
|
98
|
+
beta: float = 0.5,
|
|
99
|
+
ignore_index: int = -100,
|
|
100
|
+
temperature: float = 1.0,
|
|
101
|
+
compiled: bool = True,
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
Args:
|
|
105
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
106
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
107
|
+
ignore_index (int): Index to ignore in the loss
|
|
108
|
+
temperature (float): Temperature for softening distributions
|
|
109
|
+
compiled (bool): Whether to use torch compile
|
|
110
|
+
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
111
|
+
"""
|
|
112
|
+
super().__init__()
|
|
113
|
+
assert temperature != 0, "Temperature cannot be 0."
|
|
114
|
+
self.weight_hard_loss = weight_hard_loss
|
|
115
|
+
self.weight_soft_loss = weight_soft_loss
|
|
116
|
+
self.ignore_index = ignore_index
|
|
117
|
+
self.temperature = temperature
|
|
118
|
+
self.compiled = compiled
|
|
119
|
+
self.beta = beta
|
|
120
|
+
|
|
121
|
+
def forward(
|
|
122
|
+
self,
|
|
123
|
+
student_input: torch.Tensor,
|
|
124
|
+
student_weight: torch.Tensor,
|
|
125
|
+
teacher_input: torch.Tensor,
|
|
126
|
+
teacher_weight: torch.Tensor,
|
|
127
|
+
true_labels: torch.LongTensor,
|
|
128
|
+
) -> torch.Tensor:
|
|
129
|
+
"""
|
|
130
|
+
Compute the JSD distillation loss.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
student_input (torch.Tensor): Student input tensor
|
|
134
|
+
student_weight (torch.Tensor): Student weight tensor
|
|
135
|
+
teacher_input (torch.Tensor): Teacher input tensor
|
|
136
|
+
teacher_weight (torch.Tensor): Teacher weight tensor
|
|
137
|
+
true_labels (torch.LongTensor): Target labels tensor
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
torch.Tensor: Computed loss
|
|
141
|
+
"""
|
|
142
|
+
return LigerFusedLinearJSDFunction.apply(
|
|
143
|
+
student_input,
|
|
144
|
+
student_weight,
|
|
145
|
+
teacher_input,
|
|
146
|
+
teacher_weight,
|
|
147
|
+
true_labels,
|
|
148
|
+
self.weight_hard_loss,
|
|
149
|
+
self.weight_soft_loss,
|
|
150
|
+
self.beta,
|
|
151
|
+
self.ignore_index,
|
|
152
|
+
self.temperature,
|
|
153
|
+
self.compiled,
|
|
154
|
+
)
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(
|
|
10
|
+
average_log_prob_chunk,
|
|
11
|
+
preference_labels_chunk,
|
|
12
|
+
full_target,
|
|
13
|
+
ref_average_log_prob_chunk=None,
|
|
14
|
+
beta=0.1,
|
|
15
|
+
kl=None,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Implements the Kahneman-Tversky Optimization (KTO) loss function.
|
|
19
|
+
Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
|
|
20
|
+
https://arxiv.org/abs/2402.01306
|
|
21
|
+
|
|
22
|
+
KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
|
|
23
|
+
from behavioral economics, which models how humans make decisions under uncertainty.
|
|
24
|
+
The loss function is asymmetric, treating gains and losses differently, similar to
|
|
25
|
+
human decision-making patterns.
|
|
26
|
+
|
|
27
|
+
Formula:
|
|
28
|
+
When y is chosen:
|
|
29
|
+
L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
|
|
30
|
+
When y is rejected:
|
|
31
|
+
L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
|
|
32
|
+
|
|
33
|
+
Where:
|
|
34
|
+
- σ: Sigmoid function
|
|
35
|
+
- β: Temperature parameter controlling the strength of the preference signal
|
|
36
|
+
- π(x): Policy (current model)
|
|
37
|
+
- π₀(x): Reference policy (reference model)
|
|
38
|
+
- KL(π||π₀)_y: KL divergence estimated using the rejected response y
|
|
39
|
+
|
|
40
|
+
The loss encourages the model to:
|
|
41
|
+
1. Assign higher probability to chosen responses
|
|
42
|
+
2. Assign lower probability to rejected responses
|
|
43
|
+
3. Maintain reasonable distance from the reference model
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
chosen_logps: Log probabilities of chosen tokens (batch_size,)
|
|
47
|
+
rejected_logps: Log probabilities of rejected tokens (batch_size,)
|
|
48
|
+
full_target: Non chunked full target tensor
|
|
49
|
+
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
|
|
50
|
+
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
|
|
51
|
+
beta: Weight for the direct preference loss
|
|
52
|
+
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
53
|
+
Returns:
|
|
54
|
+
Tuple of (loss, chosen_rewards, rejected_rewards):
|
|
55
|
+
- loss: The KTO loss value
|
|
56
|
+
- chosen_rewards: Reward signals for chosen responses (detached)
|
|
57
|
+
- rejected_rewards: Reward signals for rejected responses (detached)
|
|
58
|
+
"""
|
|
59
|
+
logratios_chunk = average_log_prob_chunk - ref_average_log_prob_chunk
|
|
60
|
+
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
61
|
+
if kl is not None:
|
|
62
|
+
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
63
|
+
else:
|
|
64
|
+
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
|
65
|
+
|
|
66
|
+
return losses.sum() / (full_target.shape[0])
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def forward(
|
|
70
|
+
ctx,
|
|
71
|
+
_input,
|
|
72
|
+
weight,
|
|
73
|
+
target,
|
|
74
|
+
preference_labels,
|
|
75
|
+
bias=None,
|
|
76
|
+
ref_input=None,
|
|
77
|
+
ref_weight=None,
|
|
78
|
+
ref_bias=None,
|
|
79
|
+
kl=None,
|
|
80
|
+
ignore_index=-100,
|
|
81
|
+
beta=0.1,
|
|
82
|
+
compiled=True,
|
|
83
|
+
use_ref_model=True,
|
|
84
|
+
):
|
|
85
|
+
return LigerFusedLinearUnpairedPreferenceBase.forward(
|
|
86
|
+
ctx=ctx,
|
|
87
|
+
_input=_input,
|
|
88
|
+
weight=weight,
|
|
89
|
+
target=target,
|
|
90
|
+
preference_labels=preference_labels,
|
|
91
|
+
bias=bias,
|
|
92
|
+
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
|
|
93
|
+
ignore_index=ignore_index,
|
|
94
|
+
beta=beta,
|
|
95
|
+
compiled=compiled,
|
|
96
|
+
use_ref_model=use_ref_model,
|
|
97
|
+
ref_input=ref_input,
|
|
98
|
+
ref_weight=ref_weight,
|
|
99
|
+
ref_bias=ref_bias,
|
|
100
|
+
kl=kl,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def backward(ctx, *grad_output):
|
|
105
|
+
grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
|
|
106
|
+
return (
|
|
107
|
+
*grads,
|
|
108
|
+
None,
|
|
109
|
+
None,
|
|
110
|
+
None,
|
|
111
|
+
None,
|
|
112
|
+
None,
|
|
113
|
+
None,
|
|
114
|
+
None,
|
|
115
|
+
None,
|
|
116
|
+
None,
|
|
117
|
+
None,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
122
|
+
"""
|
|
123
|
+
Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
ignore_index: int = -100,
|
|
129
|
+
beta: float = 0.1,
|
|
130
|
+
compiled: bool = True,
|
|
131
|
+
use_ref_model: bool = False,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Args:
|
|
135
|
+
ignore_index (int): Index to ignore in the loss calculation
|
|
136
|
+
beta (float): Temperature parameter for the KTO loss
|
|
137
|
+
compiled (bool): Whether to use compiled operations
|
|
138
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
139
|
+
"""
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.ignore_index = ignore_index
|
|
142
|
+
self.beta = beta
|
|
143
|
+
self.compiled = compiled
|
|
144
|
+
self.use_ref_model = use_ref_model
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
_input,
|
|
149
|
+
lin_weight,
|
|
150
|
+
target,
|
|
151
|
+
bias=None,
|
|
152
|
+
preference_labels=None,
|
|
153
|
+
ref_input=None,
|
|
154
|
+
ref_weight=None,
|
|
155
|
+
ref_bias=None,
|
|
156
|
+
kl=None,
|
|
157
|
+
):
|
|
158
|
+
return LigerFusedLinearKTOFunction.apply(
|
|
159
|
+
_input,
|
|
160
|
+
lin_weight,
|
|
161
|
+
target,
|
|
162
|
+
preference_labels,
|
|
163
|
+
bias,
|
|
164
|
+
ref_input,
|
|
165
|
+
ref_weight,
|
|
166
|
+
ref_bias,
|
|
167
|
+
kl,
|
|
168
|
+
self.ignore_index,
|
|
169
|
+
self.beta,
|
|
170
|
+
self.compiled,
|
|
171
|
+
self.use_ref_model,
|
|
172
|
+
)
|
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
|
5
|
-
LigerFusedLinearPreferenceBase,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
-
|
|
11
8
|
@staticmethod
|
|
12
9
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
13
10
|
"""
|
|
@@ -32,11 +29,10 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
32
29
|
beta (float): Weight for the odds ratio loss.
|
|
33
30
|
"""
|
|
34
31
|
log_odds = (chosen_logps - rejected_logps) - (
|
|
35
|
-
torch.log1p(-torch.exp(chosen_logps))
|
|
36
|
-
- torch.log1p(-torch.exp(rejected_logps))
|
|
32
|
+
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
|
37
33
|
)
|
|
38
34
|
ratio = F.logsigmoid(log_odds)
|
|
39
|
-
loss = beta * ratio.sum() / (full_target.shape[0] // 2)
|
|
35
|
+
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
|
|
40
36
|
|
|
41
37
|
chosen_rewards = beta * chosen_logps
|
|
42
38
|
rejected_rewards = beta * rejected_logps
|
|
@@ -56,6 +52,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
56
52
|
ignore_index=-100,
|
|
57
53
|
beta=0.1,
|
|
58
54
|
compute_nll_loss=True,
|
|
55
|
+
nll_target=None,
|
|
59
56
|
compiled=True,
|
|
60
57
|
):
|
|
61
58
|
return LigerFusedLinearPreferenceBase.forward(
|
|
@@ -68,13 +65,14 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
68
65
|
ignore_index=ignore_index,
|
|
69
66
|
beta=beta,
|
|
70
67
|
compute_nll_loss=compute_nll_loss,
|
|
68
|
+
nll_target=nll_target,
|
|
71
69
|
compiled=compiled,
|
|
72
70
|
)
|
|
73
71
|
|
|
74
72
|
@staticmethod
|
|
75
73
|
def backward(ctx, *grad_output):
|
|
76
74
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
77
|
-
return *grads, None, None, None, None
|
|
75
|
+
return *grads, None, None, None, None, None
|
|
78
76
|
|
|
79
77
|
|
|
80
78
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -100,7 +98,7 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
100
98
|
self.compute_nll_loss = compute_nll_loss
|
|
101
99
|
self.compiled = compiled
|
|
102
100
|
|
|
103
|
-
def forward(self, lin_weight, _input, target, bias=None):
|
|
101
|
+
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
|
|
104
102
|
return LigerFusedLinearORPOFunction.apply(
|
|
105
103
|
_input,
|
|
106
104
|
lin_weight,
|
|
@@ -109,5 +107,6 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
109
107
|
self.ignore_index,
|
|
110
108
|
self.beta,
|
|
111
109
|
self.compute_nll_loss,
|
|
110
|
+
nll_target,
|
|
112
111
|
self.compiled,
|
|
113
112
|
)
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
|
5
|
-
LigerFusedLinearPreferenceBase,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
-
|
|
11
8
|
@staticmethod
|
|
12
9
|
def preference_loss_fn(
|
|
13
|
-
chosen_logps,
|
|
10
|
+
chosen_logps,
|
|
11
|
+
rejected_logps,
|
|
12
|
+
full_target,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
gamma=0.5,
|
|
15
|
+
label_smoothing=0.0,
|
|
14
16
|
):
|
|
15
17
|
"""
|
|
16
18
|
Paper: https://arxiv.org/pdf/2405.14734
|
|
@@ -33,10 +35,17 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
33
35
|
full_target: Non chunked full target tensor
|
|
34
36
|
beta (float): beta weight
|
|
35
37
|
gamma (float): gemma margin term
|
|
38
|
+
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
|
36
39
|
"""
|
|
37
40
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
|
38
|
-
loss = F.logsigmoid(logits).sum() / (
|
|
39
|
-
|
|
41
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
|
42
|
+
full_target.shape[0] // 2
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
chosen_rewards = beta * chosen_logps
|
|
46
|
+
rejected_rewards = beta * rejected_logps
|
|
47
|
+
|
|
48
|
+
return loss, chosen_rewards, rejected_rewards
|
|
40
49
|
|
|
41
50
|
@staticmethod
|
|
42
51
|
def forward(
|
|
@@ -48,6 +57,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
57
|
ignore_index=-100,
|
|
49
58
|
beta=0.1,
|
|
50
59
|
alpha=1.0,
|
|
60
|
+
label_smoothing=0.0,
|
|
51
61
|
compute_nll_loss=False,
|
|
52
62
|
compiled=True,
|
|
53
63
|
gamma=0.5,
|
|
@@ -63,6 +73,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
63
73
|
ignore_index=ignore_index,
|
|
64
74
|
alpha=alpha,
|
|
65
75
|
beta=beta,
|
|
76
|
+
label_smoothing=label_smoothing,
|
|
66
77
|
compiled=compiled,
|
|
67
78
|
gamma=gamma,
|
|
68
79
|
)
|
|
@@ -70,7 +81,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
70
81
|
@staticmethod
|
|
71
82
|
def backward(ctx, *grad_output):
|
|
72
83
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
73
|
-
return *grads, None, None, None, None, None, None
|
|
84
|
+
return *grads, None, None, None, None, None, None, None
|
|
74
85
|
|
|
75
86
|
|
|
76
87
|
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
@@ -83,6 +94,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
83
94
|
ignore_index: int = -100,
|
|
84
95
|
beta: float = 0.1,
|
|
85
96
|
alpha: float = 1.0,
|
|
97
|
+
label_smoothing: float = 0.0,
|
|
86
98
|
compute_nll_loss: bool = True,
|
|
87
99
|
compiled: bool = True,
|
|
88
100
|
gamma: float = 0.5,
|
|
@@ -96,6 +108,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
96
108
|
self.ignore_index = ignore_index
|
|
97
109
|
self.beta = beta
|
|
98
110
|
self.alpha = alpha
|
|
111
|
+
self.label_smoothing = label_smoothing
|
|
99
112
|
self.compute_nll_loss = compute_nll_loss
|
|
100
113
|
self.compiled = compiled
|
|
101
114
|
self.gamma = gamma
|
|
@@ -109,6 +122,7 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
109
122
|
self.ignore_index,
|
|
110
123
|
self.beta,
|
|
111
124
|
self.alpha,
|
|
125
|
+
self.label_smoothing,
|
|
112
126
|
self.compute_nll_loss,
|
|
113
127
|
self.compiled,
|
|
114
128
|
self.gamma,
|
liger_kernel/env_report.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import platform
|
|
2
2
|
import sys
|
|
3
|
+
|
|
3
4
|
from importlib.metadata import version
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def print_env_report():
|
|
7
8
|
"""
|
|
8
9
|
|
|
9
|
-
Prints a report of the environment.
|
|
10
|
+
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
10
11
|
Usage:
|
|
11
12
|
```
|
|
12
13
|
python -m liger_kernel.env_report
|
|
@@ -27,15 +28,9 @@ def print_env_report():
|
|
|
27
28
|
import torch
|
|
28
29
|
|
|
29
30
|
print(f"PyTorch version: {torch.__version__}")
|
|
30
|
-
cuda_version = (
|
|
31
|
-
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
32
|
-
)
|
|
31
|
+
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
33
32
|
print(f"CUDA version: {cuda_version}")
|
|
34
|
-
hip_version = (
|
|
35
|
-
torch.version.hip
|
|
36
|
-
if torch.cuda.is_available() and torch.version.hip
|
|
37
|
-
else "Not available"
|
|
38
|
-
)
|
|
33
|
+
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
|
|
39
34
|
print(f"HIP(ROCm) version: {hip_version}")
|
|
40
35
|
|
|
41
36
|
except ImportError:
|
|
@@ -58,9 +53,7 @@ def print_env_report():
|
|
|
58
53
|
print("Transformers: Not installed")
|
|
59
54
|
|
|
60
55
|
try:
|
|
61
|
-
xpu_version = (
|
|
62
|
-
torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
|
63
|
-
)
|
|
56
|
+
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
|
64
57
|
print(f"XPU version: {xpu_version}")
|
|
65
58
|
except ImportError:
|
|
66
59
|
print("XPU version: Unable to query")
|