liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241228022953__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +5 -12
- liger_kernel/chunked_loss/dpo_loss.py +1 -4
- liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
- liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
- liger_kernel/chunked_loss/orpo_loss.py +2 -6
- liger_kernel/chunked_loss/simpo_loss.py +4 -8
- liger_kernel/env_report.py +4 -11
- liger_kernel/ops/cross_entropy.py +7 -10
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +12 -17
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +4 -7
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +8 -25
- liger_kernel/ops/rms_norm.py +11 -29
- liger_kernel/ops/rope.py +8 -24
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +1 -3
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +2 -6
- liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +23 -53
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +16 -30
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/METADATA +1 -1
- liger_kernel_nightly-0.5.2.dev20241228022953.dist-info/RECORD +66 -0
- liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241228022953.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,12 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
|
-
def preference_loss_fn(
|
13
|
-
chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
|
14
|
-
):
|
9
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
|
15
10
|
"""
|
16
11
|
Paper: https://arxiv.org/pdf/2401.08417
|
17
12
|
|
@@ -35,11 +30,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
35
30
|
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
36
31
|
"""
|
37
32
|
logits = beta * (chosen_logps - rejected_logps)
|
38
|
-
loss = (
|
39
|
-
|
40
|
-
|
41
|
-
).sum() / (full_target.shape[0] // 2)
|
42
|
-
|
33
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
34
|
+
full_target.shape[0] // 2
|
35
|
+
)
|
43
36
|
return loss
|
44
37
|
|
45
38
|
@staticmethod
|
@@ -1,13 +1,10 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
9
|
def preference_loss_fn(
|
13
10
|
chosen_logps,
|
@@ -2,11 +2,11 @@ from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
3
3
|
|
4
4
|
import torch
|
5
|
+
|
5
6
|
from torch.nn import functional as F
|
6
7
|
|
7
8
|
|
8
9
|
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
9
|
-
|
10
10
|
@abstractmethod
|
11
11
|
def distillation_loss_fn(student_logits, teacher_logits, temperature):
|
12
12
|
"""
|
@@ -89,25 +89,25 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
89
89
|
compute_ce_loss (bool): Whether to compute CE loss.
|
90
90
|
loss_kwargs (dict): Additional arguments for the loss function.
|
91
91
|
"""
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
92
|
+
(
|
93
|
+
student_logits_chunk,
|
94
|
+
teacher_logits_chunk,
|
95
|
+
hard_loss,
|
96
|
+
) = LigerFusedLinearDistillationBase.chunk_forward(
|
97
|
+
student_input_chunk,
|
98
|
+
student_weight,
|
99
|
+
teacher_input_chunk,
|
100
|
+
teacher_weight,
|
101
|
+
target_chunk,
|
102
|
+
student_bias=student_bias,
|
103
|
+
teacher_bias=teacher_bias,
|
104
|
+
ignore_index=ignore_index,
|
105
|
+
compute_ce_loss=compute_ce_loss,
|
104
106
|
)
|
105
107
|
|
106
108
|
hard_loss /= full_target.shape[0]
|
107
109
|
|
108
|
-
soft_loss = distillation_loss_fn(
|
109
|
-
student_logits_chunk, teacher_logits_chunk, temperature
|
110
|
-
)
|
110
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
|
111
111
|
soft_loss /= full_target.shape[0]
|
112
112
|
|
113
113
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
@@ -174,17 +174,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
174
174
|
|
175
175
|
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
|
176
176
|
if student_bias is not None:
|
177
|
-
(
|
178
|
-
|
177
|
+
(
|
178
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
179
179
|
(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
180
|
+
chunk_loss,
|
181
|
+
(
|
182
|
+
chunk_soft_loss,
|
183
|
+
chunk_hard_loss,
|
184
|
+
chunk_student_logits,
|
185
|
+
chunk_teacher_logits,
|
186
|
+
),
|
184
187
|
),
|
185
|
-
) = torch.func.grad_and_value(
|
186
|
-
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
|
187
|
-
)(
|
188
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
|
188
189
|
student_input_chunk,
|
189
190
|
student_weight,
|
190
191
|
teacher_input_chunk,
|
@@ -195,17 +196,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
195
196
|
)
|
196
197
|
grad_bias.add_(chunk_grad_bias)
|
197
198
|
else:
|
198
|
-
(
|
199
|
-
|
199
|
+
(
|
200
|
+
(chunk_grad_input, chunk_grad_weight),
|
200
201
|
(
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
202
|
+
chunk_loss,
|
203
|
+
(
|
204
|
+
chunk_soft_loss,
|
205
|
+
chunk_hard_loss,
|
206
|
+
chunk_student_logits,
|
207
|
+
chunk_teacher_logits,
|
208
|
+
),
|
205
209
|
),
|
206
|
-
) = torch.func.grad_and_value(
|
207
|
-
loss_func_to_call, argnums=(0, 1), has_aux=True
|
208
|
-
)(
|
210
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
|
209
211
|
student_input_chunk,
|
210
212
|
student_weight,
|
211
213
|
teacher_input_chunk,
|
@@ -229,9 +231,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
229
231
|
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
|
230
232
|
_student_input_chunks, _teacher_input_chunks, _target_chunks
|
231
233
|
):
|
232
|
-
grad_input = accumulate_chunk(
|
233
|
-
student_input_chunk, teacher_input_chunk, target_chunk
|
234
|
-
)
|
234
|
+
grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
|
235
235
|
grad_inputs.append(grad_input)
|
236
236
|
|
237
237
|
ctx.save_for_backward(
|
@@ -2,11 +2,11 @@ from abc import abstractmethod
|
|
2
2
|
from functools import partial
|
3
3
|
|
4
4
|
import torch
|
5
|
+
|
5
6
|
from torch.nn import functional as F
|
6
7
|
|
7
8
|
|
8
9
|
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
9
|
-
|
10
10
|
@abstractmethod
|
11
11
|
def preference_loss_fn(*args, **kwargs):
|
12
12
|
"""
|
@@ -102,9 +102,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
102
102
|
Fused forward and backward pass for a chunk of input and target.
|
103
103
|
"""
|
104
104
|
if bias is not None:
|
105
|
-
return torch.func.grad_and_value(
|
106
|
-
compute_loss, argnums=(0, 1, 3), has_aux=True
|
107
|
-
)(
|
105
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
|
108
106
|
input_chunk,
|
109
107
|
weight,
|
110
108
|
target_chunk,
|
@@ -112,43 +110,47 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
112
110
|
ref_input_chunk=ref_input_chunk,
|
113
111
|
)
|
114
112
|
else:
|
115
|
-
return torch.func.grad_and_value(
|
116
|
-
|
117
|
-
)
|
113
|
+
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
|
114
|
+
input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
|
115
|
+
)
|
118
116
|
|
119
117
|
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
|
120
118
|
if bias is not None:
|
121
|
-
(
|
122
|
-
|
119
|
+
(
|
120
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
123
121
|
(
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
122
|
+
chunk_loss,
|
123
|
+
(
|
124
|
+
chunk_chosen_logps,
|
125
|
+
chunk_rejected_logps,
|
126
|
+
chunk_chosen_logits_mean,
|
127
|
+
chunk_rejected_logits_mean,
|
128
|
+
chunk_nll_loss,
|
129
|
+
*aux_outputs,
|
130
|
+
),
|
130
131
|
),
|
131
132
|
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
132
133
|
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
|
133
134
|
else:
|
134
|
-
(
|
135
|
-
|
135
|
+
(
|
136
|
+
(chunk_grad_input, chunk_grad_weight),
|
136
137
|
(
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
138
|
+
chunk_loss,
|
139
|
+
(
|
140
|
+
chunk_chosen_logps,
|
141
|
+
chunk_rejected_logps,
|
142
|
+
chunk_chosen_logits_mean,
|
143
|
+
chunk_rejected_logits_mean,
|
144
|
+
chunk_nll_loss,
|
145
|
+
*aux_outputs,
|
146
|
+
),
|
143
147
|
),
|
144
148
|
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
|
145
149
|
|
146
150
|
# Accumulate gradients
|
147
151
|
grad_weight.add_(chunk_grad_weight)
|
148
152
|
grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
|
149
|
-
grad_rejected_inputs.append(
|
150
|
-
chunk_grad_input[chosen_target_chunk.shape[0] :]
|
151
|
-
)
|
153
|
+
grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
|
152
154
|
|
153
155
|
# Accumulate loss
|
154
156
|
loss_acc.add_(chunk_loss)
|
@@ -165,9 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
165
167
|
if len(aggregated_aux_outputs) == 0:
|
166
168
|
for aux in aux_outputs:
|
167
169
|
if aux.ndim == 0:
|
168
|
-
aggregated_aux_outputs.append(
|
169
|
-
torch.zeros((), device=aux.device)
|
170
|
-
)
|
170
|
+
aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
|
171
171
|
else:
|
172
172
|
aggregated_aux_outputs.append([])
|
173
173
|
|
@@ -189,12 +189,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
189
189
|
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
190
190
|
|
191
191
|
if use_ref_model:
|
192
|
-
_ref_chosen_input_chunks = torch.chunk(
|
193
|
-
|
194
|
-
)
|
195
|
-
_ref_rejected_input_chunks = torch.chunk(
|
196
|
-
ref_input[len_chosen:], chunks=chunks, dim=0
|
197
|
-
)
|
192
|
+
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
|
193
|
+
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
|
198
194
|
|
199
195
|
for (
|
200
196
|
chosen_input_chunk,
|
@@ -208,26 +204,15 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
208
204
|
_rejected_input_chunks,
|
209
205
|
_chosen_target_chunks,
|
210
206
|
_rejected_target_chunks,
|
211
|
-
(
|
212
|
-
|
213
|
-
|
214
|
-
else [None] * len(_chosen_input_chunks)
|
215
|
-
),
|
216
|
-
(
|
217
|
-
_ref_rejected_input_chunks
|
218
|
-
if use_ref_model
|
219
|
-
else [None] * len(_rejected_input_chunks)
|
220
|
-
),
|
207
|
+
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
|
208
|
+
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
|
209
|
+
strict=False,
|
221
210
|
):
|
222
211
|
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
223
212
|
ref_input_chunk = (
|
224
|
-
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
|
225
|
-
if use_ref_model
|
226
|
-
else None
|
227
|
-
)
|
228
|
-
target_chunk = torch.cat(
|
229
|
-
[chosen_target_chunk, rejected_target_chunk], dim=0
|
213
|
+
torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
|
230
214
|
)
|
215
|
+
target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
|
231
216
|
|
232
217
|
# mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
|
233
218
|
torch._dynamo.mark_dynamic(input_chunk, 1)
|
@@ -265,9 +250,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
265
250
|
@staticmethod
|
266
251
|
def backward(ctx, *grad_output):
|
267
252
|
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
268
|
-
if torch.ne(
|
269
|
-
grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
|
270
|
-
):
|
253
|
+
if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
|
271
254
|
grad_input = grad_input * grad_output[0][0]
|
272
255
|
grad_weight = grad_weight * grad_output[0][0]
|
273
256
|
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
|
@@ -301,9 +284,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
301
284
|
loss_mask = target_chunk != ignore_index
|
302
285
|
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
303
286
|
|
304
|
-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
305
|
-
-1
|
306
|
-
)
|
287
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
|
307
288
|
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
308
289
|
|
309
290
|
chosen_logps = average_log_prob[:len_chosen_chunk]
|
@@ -370,13 +351,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
370
351
|
ignore_index=ignore_index,
|
371
352
|
compute_nll_loss=compute_nll_loss,
|
372
353
|
)
|
373
|
-
chosen_nll_loss = (
|
374
|
-
|
375
|
-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
376
|
-
)
|
377
|
-
chosen_logits_mean = chosen_logits.sum() / (
|
378
|
-
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
379
|
-
)
|
354
|
+
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
355
|
+
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
|
380
356
|
rejected_logits_mean = rejected_logits.sum() / (
|
381
357
|
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
|
382
358
|
)
|
@@ -1,13 +1,10 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
9
|
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
13
10
|
"""
|
@@ -32,8 +29,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
32
29
|
beta (float): Weight for the odds ratio loss.
|
33
30
|
"""
|
34
31
|
log_odds = (chosen_logps - rejected_logps) - (
|
35
|
-
torch.log1p(-torch.exp(chosen_logps))
|
36
|
-
- torch.log1p(-torch.exp(rejected_logps))
|
32
|
+
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
|
37
33
|
)
|
38
34
|
ratio = F.logsigmoid(log_odds)
|
39
35
|
loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
|
@@ -1,13 +1,10 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn.functional as F
|
3
3
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
5
|
-
LigerFusedLinearPreferenceBase,
|
6
|
-
)
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
7
5
|
|
8
6
|
|
9
7
|
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
10
|
-
|
11
8
|
@staticmethod
|
12
9
|
def preference_loss_fn(
|
13
10
|
chosen_logps,
|
@@ -41,10 +38,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
41
38
|
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
42
39
|
"""
|
43
40
|
logits = beta * (chosen_logps - rejected_logps) - gamma
|
44
|
-
loss = (
|
45
|
-
|
46
|
-
|
47
|
-
).sum() / (full_target.shape[0] // 2)
|
41
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
42
|
+
full_target.shape[0] // 2
|
43
|
+
)
|
48
44
|
|
49
45
|
return loss
|
50
46
|
|
liger_kernel/env_report.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import platform
|
2
2
|
import sys
|
3
|
+
|
3
4
|
from importlib.metadata import version
|
4
5
|
|
5
6
|
|
@@ -27,15 +28,9 @@ def print_env_report():
|
|
27
28
|
import torch
|
28
29
|
|
29
30
|
print(f"PyTorch version: {torch.__version__}")
|
30
|
-
cuda_version = (
|
31
|
-
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
32
|
-
)
|
31
|
+
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
|
33
32
|
print(f"CUDA version: {cuda_version}")
|
34
|
-
hip_version = (
|
35
|
-
torch.version.hip
|
36
|
-
if torch.cuda.is_available() and torch.version.hip
|
37
|
-
else "Not available"
|
38
|
-
)
|
33
|
+
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
|
39
34
|
print(f"HIP(ROCm) version: {hip_version}")
|
40
35
|
|
41
36
|
except ImportError:
|
@@ -58,9 +53,7 @@ def print_env_report():
|
|
58
53
|
print("Transformers: Not installed")
|
59
54
|
|
60
55
|
try:
|
61
|
-
xpu_version = (
|
62
|
-
torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
63
|
-
)
|
56
|
+
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
64
57
|
print(f"XPU version: {xpu_version}")
|
65
58
|
except ImportError:
|
66
59
|
print("XPU version: Unable to query")
|
@@ -1,11 +1,14 @@
|
|
1
1
|
import operator
|
2
|
+
|
2
3
|
from typing import Optional
|
3
4
|
|
4
5
|
import torch
|
5
6
|
import triton
|
6
7
|
import triton.language as tl
|
7
8
|
|
8
|
-
from liger_kernel.ops.utils import compare_version
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
10
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
11
|
+
from liger_kernel.ops.utils import is_hip
|
9
12
|
|
10
13
|
if compare_version("triton", operator.ge, "3.0.0"):
|
11
14
|
try:
|
@@ -92,9 +95,7 @@ def liger_cross_entropy_kernel(
|
|
92
95
|
# 3. [Online softmax] first pass: find max + sum
|
93
96
|
m = float("-inf") # m is the max value. use the notation from the paper
|
94
97
|
d = 0.0 # d is the sum. use the notation from the paper
|
95
|
-
ori_X_y = tl.load(X_ptr + y).cast(
|
96
|
-
tl.float32
|
97
|
-
) # we need to store the original value of X_y for the loss calculation
|
98
|
+
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
98
99
|
if HAS_SOFTCAPPING:
|
99
100
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
100
101
|
|
@@ -232,14 +233,10 @@ def cross_entropy_forward(
|
|
232
233
|
return_z_loss,
|
233
234
|
):
|
234
235
|
if not isinstance(return_z_loss, int):
|
235
|
-
assert
|
236
|
-
return_z_loss in _bool_to_return_z_loss
|
237
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
236
|
+
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
|
238
237
|
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
239
238
|
else:
|
240
|
-
assert
|
241
|
-
return_z_loss in _bool_to_return_z_loss
|
242
|
-
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
239
|
+
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
|
243
240
|
|
244
241
|
BT, V = _input.shape
|
245
242
|
n_rows = BT
|
@@ -34,9 +34,7 @@ def embedding_forward_kernel(
|
|
34
34
|
)
|
35
35
|
|
36
36
|
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
37
|
-
tl.store(
|
38
|
-
output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
|
39
|
-
)
|
37
|
+
tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
|
40
38
|
|
41
39
|
|
42
40
|
@triton.jit
|
@@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
|
|
37
37
|
else:
|
38
38
|
packed_tensor_shape = (row_dim, *original_shape[1:])
|
39
39
|
|
40
|
-
packed = torch.zeros(
|
41
|
-
packed_tensor_shape, device=intweights.device, dtype=torch.uint8
|
42
|
-
)
|
40
|
+
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
|
43
41
|
unpacked = intweights.to(torch.uint8)
|
44
42
|
|
45
43
|
def lshift(t: torch.Tensor, bits: int):
|
@@ -327,17 +325,13 @@ def matmul_kernel(
|
|
327
325
|
|
328
326
|
|
329
327
|
def matmul(a, b):
|
330
|
-
assert
|
331
|
-
a.shape[1] == b.shape[0] * 4
|
332
|
-
), "Incompatible dimensions, the weight matrix need to be packed"
|
328
|
+
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
|
333
329
|
assert a.is_contiguous(), "Matrix A must be contiguous"
|
334
330
|
M, K = a.shape
|
335
331
|
_, N = b.shape
|
336
332
|
# c is in int32 to avoid any overflows or underflows
|
337
333
|
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
|
338
|
-
grid = lambda META: (
|
339
|
-
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
340
|
-
)
|
334
|
+
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
|
341
335
|
matmul_kernel[grid](
|
342
336
|
a,
|
343
337
|
b,
|
@@ -2,12 +2,10 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
|
4
4
|
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
|
5
|
-
from liger_kernel.ops.utils import
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
is_hip,
|
10
|
-
)
|
5
|
+
from liger_kernel.ops.utils import amp_custom_bwd
|
6
|
+
from liger_kernel.ops.utils import amp_custom_fwd
|
7
|
+
from liger_kernel.ops.utils import element_mul_kernel
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
11
9
|
|
12
10
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
13
11
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
@@ -40,14 +38,10 @@ def fused_linear_cross_entropy_forward(
|
|
40
38
|
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
41
39
|
|
42
40
|
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
43
|
-
chunk_size = triton.next_power_of_2(
|
44
|
-
triton.cdiv(BT, inc_factor)
|
45
|
-
) # (BT + inc_factor - 1) // inc_factor
|
41
|
+
chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
|
46
42
|
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
47
43
|
|
48
|
-
grad_weight = (
|
49
|
-
torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
50
|
-
)
|
44
|
+
grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
|
51
45
|
grad_input = torch.zeros_like(_input, device=device)
|
52
46
|
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
|
53
47
|
# we use fp32 for loss accumulator
|
@@ -133,15 +127,16 @@ def fused_linear_cross_entropy_forward(
|
|
133
127
|
alpha=alpha,
|
134
128
|
)
|
135
129
|
|
136
|
-
|
130
|
+
if reduction == "none":
|
131
|
+
loss = loss_1d
|
132
|
+
else:
|
133
|
+
loss = torch.sum(loss_1d)
|
137
134
|
return loss, grad_input, grad_weight, grad_bias
|
138
135
|
|
139
136
|
|
140
|
-
def fused_linear_cross_entropy_backward(
|
141
|
-
grad_output, grad_input, grad_weight, grad_bias
|
142
|
-
):
|
137
|
+
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
|
143
138
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
|
144
|
-
if torch.
|
139
|
+
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
145
140
|
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
146
141
|
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
147
142
|
BT, H = grad_input.shape
|