liger-kernel 0.5.4__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_rlhf.py +33 -6
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +37 -3
- liger_kernel/chunked_loss/jsd_loss.py +31 -6
- liger_kernel/chunked_loss/kto_loss.py +50 -12
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +4 -0
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
- liger_kernel/transformers/monkey_patch.py +68 -0
- liger_kernel/utils.py +1 -3
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.5.dist-info}/METADATA +3 -2
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.5.dist-info}/RECORD +22 -21
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.5.dist-info}/WHEEL +1 -1
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.5.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.5.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -30,20 +30,24 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
30
30
|
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
|
31
31
|
return jsd_loss
|
|
32
32
|
|
|
33
|
-
@
|
|
33
|
+
@classmethod
|
|
34
34
|
def forward(
|
|
35
|
+
cls,
|
|
35
36
|
ctx,
|
|
36
37
|
student_input: torch.Tensor,
|
|
37
38
|
student_weight: torch.Tensor,
|
|
38
39
|
teacher_input: torch.Tensor,
|
|
39
40
|
teacher_weight: torch.Tensor,
|
|
40
41
|
true_labels: torch.LongTensor,
|
|
42
|
+
student_bias: torch.Tensor,
|
|
43
|
+
teacher_bias: torch.Tensor,
|
|
41
44
|
weight_hard_loss: float = 0.5,
|
|
42
45
|
weight_soft_loss: float = 0.5,
|
|
43
46
|
beta: float = 0.5,
|
|
44
47
|
ignore_index: int = -100,
|
|
45
48
|
temperature: float = 1.0,
|
|
46
49
|
compiled: bool = True,
|
|
50
|
+
chunk_size: int = 1024,
|
|
47
51
|
):
|
|
48
52
|
"""
|
|
49
53
|
Fused linear layer with JSD distillation loss.
|
|
@@ -59,18 +63,21 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
59
63
|
ignore_index (int): Index to ignore in loss computation
|
|
60
64
|
temperature (float): Temperature for softening/sharpening distributions
|
|
61
65
|
compiled (bool): Whether to use torch compile
|
|
66
|
+
chunk_size (int): Size of chunks for processing.
|
|
62
67
|
Returns:
|
|
63
68
|
torch.Tensor: Computed loss
|
|
64
69
|
"""
|
|
65
|
-
return
|
|
70
|
+
return super().forward(
|
|
71
|
+
cls=cls,
|
|
66
72
|
ctx=ctx,
|
|
67
73
|
student_input=student_input,
|
|
68
74
|
student_weight=student_weight,
|
|
69
75
|
teacher_input=teacher_input,
|
|
70
76
|
teacher_weight=teacher_weight,
|
|
71
77
|
target=true_labels,
|
|
72
|
-
|
|
73
|
-
|
|
78
|
+
student_bias=student_bias,
|
|
79
|
+
teacher_bias=teacher_bias,
|
|
80
|
+
chunk_size=chunk_size,
|
|
74
81
|
weight_hard_loss=weight_hard_loss,
|
|
75
82
|
weight_soft_loss=weight_soft_loss,
|
|
76
83
|
beta=beta,
|
|
@@ -81,9 +88,19 @@ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
|
|
|
81
88
|
|
|
82
89
|
@staticmethod
|
|
83
90
|
def backward(ctx, grad_output):
|
|
84
|
-
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:
|
|
91
|
+
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:6]
|
|
85
92
|
|
|
86
|
-
return (
|
|
93
|
+
return (
|
|
94
|
+
*grads,
|
|
95
|
+
None, # teacher_bias
|
|
96
|
+
None, # weight_hard_loss
|
|
97
|
+
None, # weight_soft_loss
|
|
98
|
+
None, # beta
|
|
99
|
+
None, # ignore_index
|
|
100
|
+
None, # temperature
|
|
101
|
+
None, # compiled
|
|
102
|
+
None, # chunk_size
|
|
103
|
+
)
|
|
87
104
|
|
|
88
105
|
|
|
89
106
|
class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
@@ -99,6 +116,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
99
116
|
ignore_index: int = -100,
|
|
100
117
|
temperature: float = 1.0,
|
|
101
118
|
compiled: bool = True,
|
|
119
|
+
chunk_size: int = 1024,
|
|
102
120
|
):
|
|
103
121
|
"""
|
|
104
122
|
Args:
|
|
@@ -108,6 +126,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
108
126
|
temperature (float): Temperature for softening distributions
|
|
109
127
|
compiled (bool): Whether to use torch compile
|
|
110
128
|
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
|
|
129
|
+
chunk_size (int): Size of chunks for processing.
|
|
111
130
|
"""
|
|
112
131
|
super().__init__()
|
|
113
132
|
assert temperature != 0, "Temperature cannot be 0."
|
|
@@ -117,6 +136,7 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
117
136
|
self.temperature = temperature
|
|
118
137
|
self.compiled = compiled
|
|
119
138
|
self.beta = beta
|
|
139
|
+
self.chunk_size = chunk_size
|
|
120
140
|
|
|
121
141
|
def forward(
|
|
122
142
|
self,
|
|
@@ -125,6 +145,8 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
125
145
|
teacher_input: torch.Tensor,
|
|
126
146
|
teacher_weight: torch.Tensor,
|
|
127
147
|
true_labels: torch.LongTensor,
|
|
148
|
+
student_bias: torch.Tensor,
|
|
149
|
+
teacher_bias: torch.Tensor,
|
|
128
150
|
) -> torch.Tensor:
|
|
129
151
|
"""
|
|
130
152
|
Compute the JSD distillation loss.
|
|
@@ -145,10 +167,13 @@ class LigerFusedLinearJSDLoss(torch.nn.Module):
|
|
|
145
167
|
teacher_input,
|
|
146
168
|
teacher_weight,
|
|
147
169
|
true_labels,
|
|
170
|
+
student_bias,
|
|
171
|
+
teacher_bias,
|
|
148
172
|
self.weight_hard_loss,
|
|
149
173
|
self.weight_soft_loss,
|
|
150
174
|
self.beta,
|
|
151
175
|
self.ignore_index,
|
|
152
176
|
self.temperature,
|
|
153
177
|
self.compiled,
|
|
178
|
+
self.chunk_size,
|
|
154
179
|
)
|
|
@@ -7,10 +7,10 @@ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFuse
|
|
|
7
7
|
class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
8
8
|
@staticmethod
|
|
9
9
|
def preference_loss_fn(
|
|
10
|
-
|
|
10
|
+
log_prob_chunk,
|
|
11
11
|
preference_labels_chunk,
|
|
12
12
|
full_target,
|
|
13
|
-
|
|
13
|
+
ref_log_prob_chunk=None,
|
|
14
14
|
beta=0.1,
|
|
15
15
|
kl=None,
|
|
16
16
|
):
|
|
@@ -43,30 +43,34 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
43
43
|
3. Maintain reasonable distance from the reference model
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
|
|
46
|
+
log_prob_chunk: Log probabilities for the chunk (batch_size,)
|
|
47
47
|
preference_labels_chunk: Preference labels for the chunk (batch_size,)
|
|
48
48
|
full_target: Non chunked full target tensor
|
|
49
|
-
|
|
49
|
+
ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
|
|
50
50
|
beta: Weight for the KTO loss
|
|
51
51
|
kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
|
|
52
52
|
Returns:
|
|
53
53
|
- loss: The KTO loss value
|
|
54
54
|
"""
|
|
55
|
-
if
|
|
56
|
-
logratios_chunk =
|
|
55
|
+
if ref_log_prob_chunk is not None:
|
|
56
|
+
logratios_chunk = log_prob_chunk - ref_log_prob_chunk
|
|
57
57
|
else:
|
|
58
|
-
logratios_chunk =
|
|
59
|
-
|
|
58
|
+
logratios_chunk = log_prob_chunk
|
|
60
59
|
multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
|
|
61
60
|
if kl is not None:
|
|
62
61
|
losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
|
|
63
62
|
else:
|
|
64
63
|
losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
|
|
65
64
|
|
|
66
|
-
|
|
65
|
+
rewards = beta * logratios_chunk
|
|
66
|
+
chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
|
|
67
|
+
rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
|
|
67
68
|
|
|
68
|
-
|
|
69
|
+
return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
69
72
|
def forward(
|
|
73
|
+
cls,
|
|
70
74
|
ctx,
|
|
71
75
|
_input,
|
|
72
76
|
weight,
|
|
@@ -81,15 +85,38 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
81
85
|
beta=0.1,
|
|
82
86
|
compiled=True,
|
|
83
87
|
use_ref_model=True,
|
|
88
|
+
average_log_prob=False,
|
|
89
|
+
chunk_size=1,
|
|
84
90
|
):
|
|
85
|
-
|
|
91
|
+
"""
|
|
92
|
+
Fused linear layer with KTO loss.
|
|
93
|
+
Args:
|
|
94
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
95
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
96
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
97
|
+
preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
|
|
98
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
99
|
+
ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
100
|
+
ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
|
|
101
|
+
ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
|
|
102
|
+
kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
|
|
103
|
+
ignore_index (int): Index to ignore in loss computation
|
|
104
|
+
beta (float): Temperature parameter for the KTO loss
|
|
105
|
+
compiled (bool): Whether to use torch compile
|
|
106
|
+
use_ref_model (bool): Whether to use a reference model
|
|
107
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
108
|
+
chunk_size (int): Size of chunks for processing
|
|
109
|
+
Returns:
|
|
110
|
+
torch.Tensor: Computed loss
|
|
111
|
+
"""
|
|
112
|
+
return super().forward(
|
|
113
|
+
cls=cls,
|
|
86
114
|
ctx=ctx,
|
|
87
115
|
_input=_input,
|
|
88
116
|
weight=weight,
|
|
89
117
|
target=target,
|
|
90
118
|
preference_labels=preference_labels,
|
|
91
119
|
bias=bias,
|
|
92
|
-
loss_fn=LigerFusedLinearKTOFunction.preference_loss_fn,
|
|
93
120
|
ignore_index=ignore_index,
|
|
94
121
|
beta=beta,
|
|
95
122
|
compiled=compiled,
|
|
@@ -97,7 +124,9 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
97
124
|
ref_input=ref_input,
|
|
98
125
|
ref_weight=ref_weight,
|
|
99
126
|
ref_bias=ref_bias,
|
|
127
|
+
average_log_prob=average_log_prob,
|
|
100
128
|
kl=kl,
|
|
129
|
+
chunk_size=chunk_size,
|
|
101
130
|
)
|
|
102
131
|
|
|
103
132
|
@staticmethod
|
|
@@ -115,6 +144,7 @@ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
|
|
|
115
144
|
None,
|
|
116
145
|
None,
|
|
117
146
|
None,
|
|
147
|
+
None,
|
|
118
148
|
)
|
|
119
149
|
|
|
120
150
|
|
|
@@ -129,6 +159,8 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
129
159
|
beta: float = 0.1,
|
|
130
160
|
compiled: bool = True,
|
|
131
161
|
use_ref_model: bool = False,
|
|
162
|
+
average_log_prob: bool = False,
|
|
163
|
+
chunk_size: int = 1,
|
|
132
164
|
):
|
|
133
165
|
"""
|
|
134
166
|
Args:
|
|
@@ -136,12 +168,16 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
136
168
|
beta (float): Temperature parameter for the KTO loss
|
|
137
169
|
compiled (bool): Whether to use compiled operations
|
|
138
170
|
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
171
|
+
average_log_prob (bool): Whether to average the log probability per non-masked token
|
|
172
|
+
chunk_size (int): Size of chunks for processing
|
|
139
173
|
"""
|
|
140
174
|
super().__init__()
|
|
141
175
|
self.ignore_index = ignore_index
|
|
142
176
|
self.beta = beta
|
|
143
177
|
self.compiled = compiled
|
|
144
178
|
self.use_ref_model = use_ref_model
|
|
179
|
+
self.average_log_prob = average_log_prob
|
|
180
|
+
self.chunk_size = chunk_size
|
|
145
181
|
|
|
146
182
|
def forward(
|
|
147
183
|
self,
|
|
@@ -169,4 +205,6 @@ class LigerFusedLinearKTOLoss(torch.nn.Module):
|
|
|
169
205
|
self.beta,
|
|
170
206
|
self.compiled,
|
|
171
207
|
self.use_ref_model,
|
|
208
|
+
self.average_log_prob,
|
|
209
|
+
self.chunk_size,
|
|
172
210
|
)
|
|
@@ -42,8 +42,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
42
42
|
|
|
43
43
|
return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
|
|
44
44
|
|
|
45
|
-
@
|
|
45
|
+
@classmethod
|
|
46
46
|
def forward(
|
|
47
|
+
cls,
|
|
47
48
|
ctx,
|
|
48
49
|
_input,
|
|
49
50
|
weight,
|
|
@@ -54,25 +55,43 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
54
55
|
compute_nll_loss=True,
|
|
55
56
|
nll_target=None,
|
|
56
57
|
compiled=True,
|
|
58
|
+
chunk_size=1,
|
|
57
59
|
):
|
|
58
|
-
|
|
60
|
+
"""
|
|
61
|
+
Fused linear layer with ORPO loss.
|
|
62
|
+
Args:
|
|
63
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
64
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
65
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
66
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
67
|
+
ignore_index (int): Index to ignore in loss computation
|
|
68
|
+
beta (float): Weight for the odds ratio loss
|
|
69
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
70
|
+
nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
|
|
71
|
+
compiled (bool): Whether to use torch compile
|
|
72
|
+
chunk_size (int): Size of chunks for processing
|
|
73
|
+
Returns:
|
|
74
|
+
torch.Tensor: Computed loss
|
|
75
|
+
"""
|
|
76
|
+
return super().forward(
|
|
77
|
+
cls=cls,
|
|
59
78
|
ctx=ctx,
|
|
60
79
|
_input=_input,
|
|
61
80
|
weight=weight,
|
|
62
81
|
target=target,
|
|
63
82
|
bias=bias,
|
|
64
|
-
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
|
|
65
83
|
ignore_index=ignore_index,
|
|
66
84
|
beta=beta,
|
|
67
85
|
compute_nll_loss=compute_nll_loss,
|
|
68
86
|
nll_target=nll_target,
|
|
69
87
|
compiled=compiled,
|
|
88
|
+
chunk_size=chunk_size,
|
|
70
89
|
)
|
|
71
90
|
|
|
72
91
|
@staticmethod
|
|
73
92
|
def backward(ctx, *grad_output):
|
|
74
93
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
75
|
-
return *grads, None, None, None, None, None
|
|
94
|
+
return *grads, None, None, None, None, None, None
|
|
76
95
|
|
|
77
96
|
|
|
78
97
|
class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
@@ -86,19 +105,31 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
86
105
|
beta: float = 0.1,
|
|
87
106
|
compute_nll_loss: bool = True,
|
|
88
107
|
compiled: bool = True,
|
|
108
|
+
chunk_size: int = 1,
|
|
89
109
|
):
|
|
90
110
|
"""
|
|
91
111
|
Args:
|
|
92
112
|
ignore_index (int): Index to ignore in the loss.
|
|
93
113
|
beta (float): Weight for the odds ratio loss.
|
|
114
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
115
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
116
|
+
chunk_size (int): Size of chunks for processing.
|
|
94
117
|
"""
|
|
95
118
|
super().__init__()
|
|
96
119
|
self.ignore_index = ignore_index
|
|
97
120
|
self.beta = beta
|
|
98
121
|
self.compute_nll_loss = compute_nll_loss
|
|
99
122
|
self.compiled = compiled
|
|
123
|
+
self.chunk_size = chunk_size
|
|
100
124
|
|
|
101
|
-
def forward(
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
lin_weight,
|
|
128
|
+
_input,
|
|
129
|
+
target,
|
|
130
|
+
bias=None,
|
|
131
|
+
nll_target=None,
|
|
132
|
+
):
|
|
102
133
|
return LigerFusedLinearORPOFunction.apply(
|
|
103
134
|
_input,
|
|
104
135
|
lin_weight,
|
|
@@ -109,4 +140,5 @@ class LigerFusedLinearORPOLoss(torch.nn.Module):
|
|
|
109
140
|
self.compute_nll_loss,
|
|
110
141
|
nll_target,
|
|
111
142
|
self.compiled,
|
|
143
|
+
self.chunk_size,
|
|
112
144
|
)
|
|
@@ -47,8 +47,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
47
47
|
|
|
48
48
|
return loss, chosen_rewards, rejected_rewards
|
|
49
49
|
|
|
50
|
-
@
|
|
50
|
+
@classmethod
|
|
51
51
|
def forward(
|
|
52
|
+
cls,
|
|
52
53
|
ctx,
|
|
53
54
|
_input,
|
|
54
55
|
weight,
|
|
@@ -61,27 +62,47 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
61
62
|
compute_nll_loss=False,
|
|
62
63
|
compiled=True,
|
|
63
64
|
gamma=0.5,
|
|
65
|
+
chunk_size=1,
|
|
64
66
|
):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
67
|
+
"""
|
|
68
|
+
Fused linear layer with SimPO loss.
|
|
69
|
+
Args:
|
|
70
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
71
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
72
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
73
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
74
|
+
ignore_index (int): Index to ignore in loss computation
|
|
75
|
+
beta (float): Weight for the odds ratio loss
|
|
76
|
+
alpha (float): Weight for the alpha parameter
|
|
77
|
+
label_smoothing (float): Label smoothing factor
|
|
78
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
79
|
+
compiled (bool): Whether to use torch compile
|
|
80
|
+
gamma (float): Weight for the gamma parameter
|
|
81
|
+
chunk_size (int): Size of chunks for processing
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Computed loss
|
|
84
|
+
"""
|
|
85
|
+
return super().forward(
|
|
86
|
+
cls=cls,
|
|
87
|
+
ctx=ctx,
|
|
88
|
+
_input=_input,
|
|
89
|
+
weight=weight,
|
|
90
|
+
target=target,
|
|
91
|
+
bias=bias,
|
|
73
92
|
ignore_index=ignore_index,
|
|
74
93
|
alpha=alpha,
|
|
75
94
|
beta=beta,
|
|
76
95
|
label_smoothing=label_smoothing,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
77
97
|
compiled=compiled,
|
|
78
98
|
gamma=gamma,
|
|
99
|
+
chunk_size=chunk_size,
|
|
79
100
|
)
|
|
80
101
|
|
|
81
102
|
@staticmethod
|
|
82
103
|
def backward(ctx, *grad_output):
|
|
83
104
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
84
|
-
return *grads, None, None, None, None, None, None, None
|
|
105
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
85
106
|
|
|
86
107
|
|
|
87
108
|
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
@@ -98,11 +119,18 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
98
119
|
compute_nll_loss: bool = True,
|
|
99
120
|
compiled: bool = True,
|
|
100
121
|
gamma: float = 0.5,
|
|
122
|
+
chunk_size: int = 1,
|
|
101
123
|
):
|
|
102
124
|
"""
|
|
103
125
|
Args:
|
|
104
126
|
ignore_index (int): Index to ignore in the loss.
|
|
105
127
|
beta (float): Weight for the odds ratio loss.
|
|
128
|
+
alpha (float): Weight for the alpha parameter.
|
|
129
|
+
label_smoothing (float): Label smoothing factor.
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
131
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
132
|
+
gamma (float): Weight for the gamma parameter.
|
|
133
|
+
chunk_size (int): Size of chunks for processing.
|
|
106
134
|
"""
|
|
107
135
|
super().__init__()
|
|
108
136
|
self.ignore_index = ignore_index
|
|
@@ -112,8 +140,15 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
112
140
|
self.compute_nll_loss = compute_nll_loss
|
|
113
141
|
self.compiled = compiled
|
|
114
142
|
self.gamma = gamma
|
|
143
|
+
self.chunk_size = chunk_size
|
|
115
144
|
|
|
116
|
-
def forward(
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
lin_weight,
|
|
148
|
+
_input,
|
|
149
|
+
target,
|
|
150
|
+
bias=None,
|
|
151
|
+
):
|
|
117
152
|
return LigerFusedLinearSimPOFunction.apply(
|
|
118
153
|
_input,
|
|
119
154
|
lin_weight,
|
|
@@ -126,4 +161,5 @@ class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
|
126
161
|
self.compute_nll_loss,
|
|
127
162
|
self.compiled,
|
|
128
163
|
self.gamma,
|
|
164
|
+
self.chunk_size,
|
|
129
165
|
)
|
|
@@ -285,6 +285,10 @@ def cross_entropy_forward(
|
|
|
285
285
|
|
|
286
286
|
target_mask = target != ignore_index
|
|
287
287
|
n_non_ignore = target_mask.sum().item()
|
|
288
|
+
assert (target * target_mask).max() < _input.shape[-1], (
|
|
289
|
+
f"Target {target.max()} is out of bounds. Expected < {_input.shape[-1]}"
|
|
290
|
+
)
|
|
291
|
+
assert (target * target_mask).min() >= 0, f"Target {target.min()} is out of bounds. Expected >= 0"
|
|
288
292
|
sum_non_ignore_weight = n_non_ignore
|
|
289
293
|
weight_sum = 0.0
|
|
290
294
|
if weight is not None:
|
|
@@ -17,6 +17,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama
|
|
|
17
17
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
|
|
18
18
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
|
|
19
19
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
|
|
20
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
|
|
20
21
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
21
22
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
22
23
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|