liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- liger_kernel/chunked_loss/cpo_loss.py +5 -11
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,12 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
|
-
def preference_loss_fn(
|
13
|
-
chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
|
14
|
-
):
|
9
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
|
15
10
|
"""
|
16
11
|
Paper: https://arxiv.org/pdf/2401.08417
|
17
12
|
|
@@ -35,10 +30,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
35
30
|
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
36
31
|
"""
|
37
32
|
logits = beta * (chosen_logps - rejected_logps)
|
38
|
-
loss = (
|
39
|
-
|
40
|
-
|
41
|
-
).sum() / (full_target.shape[0] // 2)
|
33
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
34
|
+
full_target.shape[0] // 2
|
35
|
+
)
|
42
36
|
|
43
37
|
return loss
|
44
38
|
|
@@ -1,13 +1,10 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
9
|
def preference_loss_fn(
|
13
10
|
chosen_logps,
|
@@ -2,11 +2,11 @@ from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
3
3
|
|
4
4
|
import torch
|
5
|
+
|
5
6
|
from torch.nn import functional as F
|
6
7
|
|
7
8
|
|
8
9
|
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
9
|
-
|
10
10
|
@abstractmethod
|
11
11
|
def distillation_loss_fn(student_logits, teacher_logits, temperature):
|
12
12
|
"""
|
@@ -89,25 +89,25 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
89
89
|
compute_ce_loss (bool): Whether to compute CE loss.
|
90
90
|
loss_kwargs (dict): Additional arguments for the loss function.
|
91
91
|
"""
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
92
|
+
(
|
93
|
+
student_logits_chunk,
|
94
|
+
teacher_logits_chunk,
|
95
|
+
hard_loss,
|
96
|
+
) = LigerFusedLinearDistillationBase.chunk_forward(
|
97
|
+
student_input_chunk,
|
98
|
+
student_weight,
|
99
|
+
teacher_input_chunk,
|
100
|
+
teacher_weight,
|
101
|
+
target_chunk,
|
102
|
+
student_bias=student_bias,
|
103
|
+
teacher_bias=teacher_bias,
|
104
|
+
ignore_index=ignore_index,
|
105
|
+
compute_ce_loss=compute_ce_loss,
|
104
106
|
)
|
105
107
|
|
106
108
|
hard_loss /= full_target.shape[0]
|
107
109
|
|
108
|
-
soft_loss = distillation_loss_fn(
|
109
|
-
student_logits_chunk, teacher_logits_chunk, temperature
|
110
|
-
)
|
110
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
|
111
111
|
soft_loss /= full_target.shape[0]
|
112
112
|
|
113
113
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
@@ -174,17 +174,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
174
174
|
|
175
175
|
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
|
176
176
|
if student_bias is not None:
|
177
|
-
(
|
178
|
-
|
177
|
+
(
|
178
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
179
179
|
(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
180
|
+
chunk_loss,
|
181
|
+
(
|
182
|
+
chunk_soft_loss,
|
183
|
+
chunk_hard_loss,
|
184
|
+
chunk_student_logits,
|
185
|
+
chunk_teacher_logits,
|
186
|
+
),
|
184
187
|
),
|
185
|
-
) = torch.func.grad_and_value(
|
186
|
-
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
|
187
|
-
)(
|
188
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
|
188
189
|
student_input_chunk,
|
189
190
|
student_weight,
|
190
191
|
teacher_input_chunk,
|
@@ -195,17 +196,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
195
196
|
)
|
196
197
|
grad_bias.add_(chunk_grad_bias)
|
197
198
|
else:
|
198
|
-
(
|
199
|
-
|
199
|
+
(
|
200
|
+
(chunk_grad_input, chunk_grad_weight),
|
200
201
|
(
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
202
|
+
chunk_loss,
|
203
|
+
(
|
204
|
+
chunk_soft_loss,
|
205
|
+
chunk_hard_loss,
|
206
|
+
chunk_student_logits,
|
207
|
+
chunk_teacher_logits,
|
208
|
+
),
|
205
209
|
),
|
206
|
-
) = torch.func.grad_and_value(
|
207
|
-
loss_func_to_call, argnums=(0, 1), has_aux=True
|
208
|
-
)(
|
210
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
|
209
211
|
student_input_chunk,
|
210
212
|
student_weight,
|
211
213
|
teacher_input_chunk,
|
@@ -229,9 +231,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
229
231
|
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
|
230
232
|
_student_input_chunks, _teacher_input_chunks, _target_chunks
|
231
233
|
):
|
232
|
-
grad_input = accumulate_chunk(
|
233
|
-
student_input_chunk, teacher_input_chunk, target_chunk
|
234
|
-
)
|
234
|
+
grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
|
235
235
|
grad_inputs.append(grad_input)
|
236
236
|
|
237
237
|
ctx.save_for_backward(
|
@@ -2,11 +2,11 @@ from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
3
3
|
|
4
4
|
import torch
|
5
|
+
|
5
6
|
from torch.nn import functional as F
|
6
7
|
|
7
8
|
|
8
9
|
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
9
|
-
|
10
10
|
@abstractmethod
|
11
11
|
def preference_loss_fn(*args, **kwargs):
|
12
12
|
"""
|
@@ -102,9 +102,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
102
102
|
Fused forward and backward pass for a chunk of input and target.
|
103
103
|
"""
|
104
104
|
if bias is not None:
|
105
|
-
return torch.func.grad_and_value(
|
106
|
-
compute_loss, argnums=(0, 1, 3), has_aux=True
|
107
|
-
)(
|
105
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
|
108
106
|
input_chunk,
|
109
107
|
weight,
|
110
108
|
target_chunk,
|
@@ -112,43 +110,47 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
112
110
|
ref_input_chunk=ref_input_chunk,
|
113
111
|
)
|
114
112
|
else:
|
115
|
-
return torch.func.grad_and_value(
|
116
|
-
|
117
|
-
)
|
113
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
114
|
+
input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
|
115
|
+
)
|
118
116
|
|
119
117
|
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
|
120
118
|
if bias is not None:
|
121
|
-
(
|
122
|
-
|
119
|
+
(
|
120
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
123
121
|
(
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
122
|
+
chunk_loss,
|
123
|
+
(
|
124
|
+
chunk_chosen_logps,
|
125
|
+
chunk_rejected_logps,
|
126
|
+
chunk_chosen_logits_mean,
|
127
|
+
chunk_rejected_logits_mean,
|
128
|
+
chunk_nll_loss,
|
129
|
+
*aux_outputs,
|
130
|
+
),
|
130
131
|
),
|
131
132
|
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
132
133
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
133
134
|
else:
|
134
|
-
(
|
135
|
-
|
135
|
+
(
|
136
|
+
(chunk_grad_input, chunk_grad_weight),
|
136
137
|
(
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
138
|
+
chunk_loss,
|
139
|
+
(
|
140
|
+
chunk_chosen_logps,
|
141
|
+
chunk_rejected_logps,
|
142
|
+
chunk_chosen_logits_mean,
|
143
|
+
chunk_rejected_logits_mean,
|
144
|
+
chunk_nll_loss,
|
145
|
+
*aux_outputs,
|
146
|
+
),
|
143
147
|
),
|
144
148
|
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
145
149
|
|
146
150
|
# Accumulate gradients
|
147
151
|
grad_weight.add_(chunk_grad_weight)
|
148
152
|
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
|
149
|
-
grad_rejected_inputs.append(
|
150
|
-
chunk_grad_input[chosen_target_chunk.shape[0] :]
|
151
|
-
)
|
153
|
+
grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
|
152
154
|
|
153
155
|
# Accumulate loss
|
154
156
|
loss_acc.add_(chunk_loss)
|
@@ -165,9 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
165
167
|
if len(aggregated_aux_outputs) == 0:
|
166
168
|
for aux in aux_outputs:
|
167
169
|
if aux.ndim == 0:
|
168
|
-
aggregated_aux_outputs.append(
|
169
|
-
torch.zeros((), device=aux.device)
|
170
|
-
)
|
170
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
171
171
|
else:
|
172
172
|
aggregated_aux_outputs.append([])
|
173
173
|
|
@@ -189,12 +189,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
189
189
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
190
190
|
|
191
191
|
if use_ref_model:
|
192
|
-
_ref_chosen_input_chunks = torch.chunk(
|
193
|
-
|
194
|
-
)
|
195
|
-
_ref_rejected_input_chunks = torch.chunk(
|
196
|
-
ref_input[len_chosen:], chunks=chunks, dim=0
|
197
|
-
)
|
192
|
+
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
193
|
+
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
198
194
|
|
199
195
|
for (
|
200
196
|
chosen_input_chunk,
|
@@ -208,26 +204,15 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
208
204
|
_rejected_input_chunks,
|
209
205
|
_chosen_target_chunks,
|
210
206
|
_rejected_target_chunks,
|
211
|
-
(
|
212
|
-
|
213
|
-
|
214
|
-
else [None] * len(_chosen_input_chunks)
|
215
|
-
),
|
216
|
-
(
|
217
|
-
_ref_rejected_input_chunks
|
218
|
-
if use_ref_model
|
219
|
-
else [None] * len(_rejected_input_chunks)
|
220
|
-
),
|
207
|
+
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
208
|
+
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
209
|
+
strict=False,
|
221
210
|
):
|
222
211
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
223
212
|
ref_input_chunk = (
|
224
|
-
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
|
225
|
-
if use_ref_model
|
226
|
-
else None
|
227
|
-
)
|
228
|
-
target_chunk = torch.cat(
|
229
|
-
[chosen_target_chunk, rejected_target_chunk], dim=0
|
213
|
+
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
|
230
214
|
)
|
215
|
+
target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
|
231
216
|
|
232
217
|
# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
|
233
218
|
torch._dynamo.mark_dynamic(input_chunk, 1)
|
@@ -265,9 +250,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
265
250
|
@staticmethod
|
266
251
|
def backward(ctx, *grad_output):
|
267
252
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
268
|
-
if torch.ne(
|
269
|
-
grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
|
270
|
-
):
|
253
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
271
254
|
grad_input = grad_input * grad_output[0][0]
|
272
255
|
grad_weight = grad_weight * grad_output[0][0]
|
273
256
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
@@ -301,9 +284,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
301
284
|
loss_mask = target_chunk != ignore_index
|
302
285
|
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
303
286
|
|
304
|
-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
305
|
-
-1
|
306
|
-
)
|
287
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
307
288
|
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
308
289
|
|
309
290
|
chosen_logps = average_log_prob[:len_chosen_chunk]
|
@@ -370,13 +351,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
370
351
|
ignore_index=ignore_index,
|
371
352
|
compute_nll_loss=compute_nll_loss,
|
372
353
|
)
|
373
|
-
chosen_nll_loss = (
|
374
|
-
|
375
|
-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
376
|
-
)
|
377
|
-
chosen_logits_mean = chosen_logits.sum() / (
|
378
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
379
|
-
)
|
354
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
355
|
+
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
380
356
|
rejected_logits_mean = rejected_logits.sum() / (
|
381
357
|
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
382
358
|
)
|
@@ -1,13 +1,10 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
9
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
10
|
"""
|
@@ -32,8 +29,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
32
29
|
beta (float): Weight for the odds ratio loss.
|
33
30
|
"""
|
34
31
|
log_odds = (chosen_logps - rejected_logps) - (
|
35
|
-
torch.log1p(-torch.exp(chosen_logps))
|
36
|
-
- torch.log1p(-torch.exp(rejected_logps))
|
32
|
+
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
37
33
|
)
|
38
34
|
ratio = F.logsigmoid(log_odds)
|
39
35
|
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
|
@@ -1,13 +1,10 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
9
|
def preference_loss_fn(
|
13
10
|
chosen_logps,
|
@@ -41,10 +38,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
41
38
|
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
42
39
|
"""
|
43
40
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
44
|
-
loss = (
|
45
|
-
|
46
|
-
|
47
|
-
).sum() / (full_target.shape[0] // 2)
|
41
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
42
|
+
full_target.shape[0] // 2
|
43
|
+
)
|
48
44
|
|
49
45
|
return loss
|
50
46
|
|
liger_kernel/env_report.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import platform
|
2
2
|
import sys
|
3
|
+
|
3
4
|
from importlib.metadata import version
|
4
5
|
|
5
6
|
|
@@ -27,15 +28,9 @@ def print_env_report():
|
|
27
28
|
import torch
|
28
29
|
|
29
30
|
print(f"PyTorch version: {torch.__version__}")
|
30
|
-
cuda_version = (
|
31
|
-
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
32
|
-
)
|
31
|
+
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
|
33
32
|
print(f"CUDA version: {cuda_version}")
|
34
|
-
hip_version = (
|
35
|
-
torch.version.hip
|
36
|
-
if torch.cuda.is_available() and torch.version.hip
|
37
|
-
else "Not available"
|
38
|
-
)
|
33
|
+
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
|
39
34
|
print(f"HIP(ROCm) version: {hip_version}")
|
40
35
|
|
41
36
|
except ImportError:
|
@@ -58,9 +53,7 @@ def print_env_report():
|
|
58
53
|
print("Transformers: Not installed")
|
59
54
|
|
60
55
|
try:
|
61
|
-
xpu_version = (
|
62
|
-
torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
63
|
-
)
|
56
|
+
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
64
57
|
print(f"XPU version: {xpu_version}")
|
65
58
|
except ImportError:
|
66
59
|
print("XPU version: Unable to query")
|
@@ -1,11 +1,14 @@
|
|
1
1
|
import operator
|
2
|
+
|
2
3
|
from typing import Optional
|
3
4
|
|
4
5
|
import torch
|
5
6
|
import triton
|
6
7
|
import triton.language as tl
|
7
8
|
|
8
|
-
from liger_kernel.ops.utils import compare_version
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
10
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
11
|
+
from liger_kernel.ops.utils import is_hip
|
9
12
|
|
10
13
|
if compare_version("triton", operator.ge, "3.0.0"):
|
11
14
|
try:
|
@@ -92,9 +95,7 @@ def liger_cross_entropy_kernel(
|
|
92
95
|
# 3. [Online softmax] first pass: find max + sum
|
93
96
|
m = float("-inf") # m is the max value. use the notation from the paper
|
94
97
|
d = 0.0 # d is the sum. use the notation from the paper
|
95
|
-
ori_X_y = tl.load(X_ptr + y).cast(
|
96
|
-
tl.float32
|
97
|
-
) # we need to store the original value of X_y for the loss calculation
|
98
|
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
98
99
|
if HAS_SOFTCAPPING:
|
99
100
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
100
101
|
|
@@ -232,14 +233,10 @@ def cross_entropy_forward(
|
|
232
233
|
return_z_loss,
|
233
234
|
):
|
234
235
|
if not isinstance(return_z_loss, int):
|
235
|
-
assert
|
236
|
-
return_z_loss in _bool_to_return_z_loss
|
237
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
236
|
+
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
|
238
237
|
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
239
238
|
else:
|
240
|
-
assert
|
241
|
-
return_z_loss in _bool_to_return_z_loss
|
242
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
239
|
+
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
|
243
240
|
|
244
241
|
BT, V = _input.shape
|
245
242
|
n_rows = BT
|
@@ -34,9 +34,7 @@ def embedding_forward_kernel(
|
|
34
34
|
)
|
35
35
|
|
36
36
|
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
37
|
-
tl.store(
|
38
|
-
output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
|
39
|
-
)
|
37
|
+
tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
|
40
38
|
|
41
39
|
|
42
40
|
@triton.jit
|
@@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
|
|
37
37
|
else:
|
38
38
|
packed_tensor_shape = (row_dim, *original_shape[1:])
|
39
39
|
|
40
|
-
packed = torch.zeros(
|
41
|
-
packed_tensor_shape, device=intweights.device, dtype=torch.uint8
|
42
|
-
)
|
40
|
+
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
|
43
41
|
unpacked = intweights.to(torch.uint8)
|
44
42
|
|
45
43
|
def lshift(t: torch.Tensor, bits: int):
|
@@ -327,17 +325,13 @@ def matmul_kernel(
|
|
327
325
|
|
328
326
|
|
329
327
|
def matmul(a, b):
|
330
|
-
assert
|
331
|
-
a.shape[1] == b.shape[0] * 4
|
332
|
-
), "Incompatible dimensions, the weight matrix need to be packed"
|
328
|
+
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
|
333
329
|
assert a.is_contiguous(), "Matrix A must be contiguous"
|
334
330
|
M, K = a.shape
|
335
331
|
_, N = b.shape
|
336
332
|
# c is in int32 to avoid any overflows or underflows
|
337
333
|
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
|
338
|
-
grid = lambda META: (
|
339
|
-
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
340
|
-
)
|
334
|
+
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
341
335
|
matmul_kernel[grid](
|
342
336
|
a,
|
343
337
|
b,
|
@@ -2,12 +2,10 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
|
4
4
|
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
5
|
-
from liger_kernel.ops.utils import
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
is_hip,
|
10
|
-
)
|
5
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
6
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
7
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
11
9
|
|
12
10
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
13
11
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
@@ -40,14 +38,10 @@ def fused_linear_cross_entropy_forward(
|
|
40
38
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
41
39
|
|
42
40
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
43
|
-
chunk_size = triton.next_power_of_2(
|
44
|
-
triton.cdiv(BT, inc_factor)
|
45
|
-
) # (BT + inc_factor - 1) // inc_factor
|
41
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
46
42
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
47
43
|
|
48
|
-
grad_weight = (
|
49
|
-
torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
50
|
-
)
|
44
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
51
45
|
grad_input = torch.zeros_like(_input, device=device)
|
52
46
|
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
53
47
|
# we use fp32 for loss accumulator
|
@@ -137,9 +131,7 @@ def fused_linear_cross_entropy_forward(
|
|
137
131
|
return loss, grad_input, grad_weight, grad_bias
|
138
132
|
|
139
133
|
|
140
|
-
def fused_linear_cross_entropy_backward(
|
141
|
-
grad_output, grad_input, grad_weight, grad_bias
|
142
|
-
):
|
134
|
+
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
143
135
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
144
136
|
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
145
137
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
@@ -4,12 +4,10 @@ import torch
|
|
4
4
|
import triton
|
5
5
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
7
|
-
from liger_kernel.ops.utils import
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
is_hip,
|
12
|
-
)
|
7
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
8
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
9
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
10
|
+
from liger_kernel.ops.utils import is_hip
|
13
11
|
|
14
12
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
15
13
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
|
|
43
41
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
44
42
|
|
45
43
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
46
|
-
chunk_size = triton.next_power_of_2(
|
47
|
-
triton.cdiv(BT, inc_factor)
|
48
|
-
) # (BT + inc_factor - 1) // inc_factor
|
44
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
49
45
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
50
46
|
|
51
|
-
grad_weight = (
|
52
|
-
torch.zeros_like(student_weight, device=device)
|
53
|
-
if student_weight.requires_grad
|
54
|
-
else None
|
55
|
-
)
|
47
|
+
grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
|
56
48
|
grad_input = torch.zeros_like(student_input)
|
57
49
|
# we use fp32 for loss accumulator
|
58
50
|
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
|
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
|
|
73
65
|
# shape: chunk_size x V
|
74
66
|
# For anything starting from logits to the final JSD loss, we do computation
|
75
67
|
# in FP32 to avoid losing numerical stability.
|
76
|
-
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
77
|
-
|
78
|
-
)
|
79
|
-
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
80
|
-
torch.float32
|
81
|
-
)
|
68
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
|
69
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
|
82
70
|
chunk_n_rows = student_logits_chunk.shape[0]
|
83
71
|
|
84
72
|
# unreduced loss
|
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
|
|
104
92
|
dX_ptr=student_prob_chunk,
|
105
93
|
dX_stride=student_prob_chunk.stride(-2),
|
106
94
|
label_ptr=(
|
107
|
-
shift_labels[start_idx:end_idx]
|
108
|
-
if has_label
|
109
|
-
else torch.empty(1, device=device)
|
95
|
+
shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
|
110
96
|
), # dummy ptr if no label
|
111
97
|
beta=jsd_beta,
|
112
98
|
n_non_ignore=n_non_ignore,
|
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
|
|
121
107
|
student_logits_chunk = (
|
122
108
|
student_prob_chunk
|
123
109
|
- torch.softmax(student_logits_chunk, dim=-1)
|
124
|
-
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
|
125
|
-
student_prob_chunk.shape
|
126
|
-
)
|
110
|
+
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
|
127
111
|
) / temperature
|
128
112
|
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
129
113
|
# w.r.t. `lm_head` which should be computed in original dtype
|
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
239
223
|
@amp_custom_bwd
|
240
224
|
def backward(ctx, grad_output):
|
241
225
|
(grad_input, grad_weight) = ctx.saved_tensors
|
242
|
-
grad_input, grad_weight = fused_linear_jsd_backward(
|
243
|
-
grad_output, grad_input, grad_weight
|
244
|
-
)
|
226
|
+
grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
|
245
227
|
return (grad_input, grad_weight, None, None, None, None, None, None)
|