liger-kernel 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +0 -0
- liger_kernel/chunked_loss/dpo_loss.py +57 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +206 -0
- liger_kernel/chunked_loss/orpo_loss.py +63 -0
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +143 -30
- liger_kernel/ops/fused_linear_cross_entropy.py +20 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/rms_norm.py +27 -6
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +34 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- liger_kernel/transformers/monkey_patch.py +106 -64
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/METADATA +18 -82
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/RECORD +23 -16
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.2.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch.nn.functional as F
|
|
2
|
+
|
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
4
|
+
LigerFusedLinearPreferenceBase,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
9
|
+
|
|
10
|
+
@staticmethod
|
|
11
|
+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
|
|
12
|
+
"""
|
|
13
|
+
Compute DPO loss (Direct Preference Optimization).
|
|
14
|
+
Args:
|
|
15
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
16
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
17
|
+
beta (float): Weight for the direct preference loss.
|
|
18
|
+
"""
|
|
19
|
+
logits_diff = beta * (chosen_logps - rejected_logps)
|
|
20
|
+
losses = -F.logsigmoid(logits_diff)
|
|
21
|
+
return losses.sum()
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def forward(
|
|
25
|
+
ctx,
|
|
26
|
+
_input,
|
|
27
|
+
weight,
|
|
28
|
+
target,
|
|
29
|
+
bias=None,
|
|
30
|
+
ignore_index=-100,
|
|
31
|
+
beta=0.1,
|
|
32
|
+
compute_nll_loss=True,
|
|
33
|
+
compiled=True,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Fused linear layer with DPO (Direct Preference Optimization) loss.
|
|
37
|
+
Handles both the forward and backward pass of the final linear layer with DPO loss.
|
|
38
|
+
"""
|
|
39
|
+
return LigerFusedLinearPreferenceBase.forward(
|
|
40
|
+
ctx=ctx,
|
|
41
|
+
_input=_input,
|
|
42
|
+
weight=weight,
|
|
43
|
+
target=target,
|
|
44
|
+
bias=bias,
|
|
45
|
+
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
|
|
46
|
+
compute_nll_loss=compute_nll_loss,
|
|
47
|
+
ignore_index=ignore_index,
|
|
48
|
+
beta=beta,
|
|
49
|
+
compiled=compiled,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def backward(ctx, grad_output):
|
|
54
|
+
# Get gradients for _input, weight, bias, and target from the base class
|
|
55
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
56
|
+
# Return these gradients, followed by None for the remaining inputs
|
|
57
|
+
return *grads, None, None, None, None
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
|
|
12
|
+
"""
|
|
13
|
+
Compute preference loss.
|
|
14
|
+
Args:
|
|
15
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
16
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
17
|
+
beta (float): Weight for the odds ratio loss.
|
|
18
|
+
"""
|
|
19
|
+
raise NotImplementedError("Preference loss function must be implemented.")
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def forward(
|
|
23
|
+
ctx,
|
|
24
|
+
_input,
|
|
25
|
+
weight,
|
|
26
|
+
target,
|
|
27
|
+
bias=None,
|
|
28
|
+
loss_fn=None,
|
|
29
|
+
chunk_size=1,
|
|
30
|
+
compute_nll_loss=True,
|
|
31
|
+
ignore_index=-100,
|
|
32
|
+
beta=0.1,
|
|
33
|
+
compiled=True,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Base class for fused linear layer with preference loss.
|
|
37
|
+
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
|
|
41
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
42
|
+
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
|
|
43
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
44
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
45
|
+
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
|
46
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
|
47
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
48
|
+
beta (float): Weight for the odds ratio loss.
|
|
49
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
50
|
+
"""
|
|
51
|
+
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
|
52
|
+
CHUNK_SIZE = chunk_size
|
|
53
|
+
|
|
54
|
+
grad_weight = torch.zeros_like(weight)
|
|
55
|
+
grad_chosen_inputs = []
|
|
56
|
+
grad_rejected_inputs = []
|
|
57
|
+
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
|
58
|
+
loss_acc = torch.zeros((), device=_input.device)
|
|
59
|
+
|
|
60
|
+
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
|
61
|
+
loss_func_to_call = partial(
|
|
62
|
+
LigerFusedLinearPreferenceBase._compute_loss,
|
|
63
|
+
preference_loss_fn=loss_fn,
|
|
64
|
+
ignore_index=ignore_index,
|
|
65
|
+
beta=beta,
|
|
66
|
+
compute_nll_loss=compute_nll_loss,
|
|
67
|
+
full_target=target,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def accumulate_chunk(input_chunk, target_chunk):
|
|
71
|
+
if bias is not None:
|
|
72
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
|
|
73
|
+
chunk_loss,
|
|
74
|
+
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
|
|
75
|
+
) = torch.func.grad_and_value(
|
|
76
|
+
loss_func_to_call, argnums=(0, 1, 3), has_aux=True
|
|
77
|
+
)(
|
|
78
|
+
input_chunk, weight, target_chunk, bias
|
|
79
|
+
)
|
|
80
|
+
grad_bias.add_(chunk_grad_bias)
|
|
81
|
+
else:
|
|
82
|
+
(chunk_grad_input, chunk_grad_weight), (
|
|
83
|
+
chunk_loss,
|
|
84
|
+
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
|
|
85
|
+
) = torch.func.grad_and_value(
|
|
86
|
+
loss_func_to_call, argnums=(0, 1), has_aux=True
|
|
87
|
+
)(
|
|
88
|
+
input_chunk, weight, target_chunk
|
|
89
|
+
)
|
|
90
|
+
grad_weight.add_(chunk_grad_weight)
|
|
91
|
+
loss_acc.add_(chunk_loss)
|
|
92
|
+
return chunk_grad_input
|
|
93
|
+
|
|
94
|
+
len_chosen = target.shape[0] // 2
|
|
95
|
+
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
|
|
96
|
+
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
|
|
97
|
+
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
|
98
|
+
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
|
|
99
|
+
|
|
100
|
+
for (
|
|
101
|
+
chosen_input_chunk,
|
|
102
|
+
rejected_input_chunk,
|
|
103
|
+
chosen_target_chunk,
|
|
104
|
+
rejected_target_chunk,
|
|
105
|
+
) in zip(
|
|
106
|
+
_chosen_input_chunks,
|
|
107
|
+
_rejected_input_chunks,
|
|
108
|
+
_chosen_target_chunks,
|
|
109
|
+
_rejected_target_chunks,
|
|
110
|
+
):
|
|
111
|
+
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
|
|
112
|
+
target_chunk = torch.cat(
|
|
113
|
+
[chosen_target_chunk, rejected_target_chunk], dim=0
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if compiled:
|
|
117
|
+
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
118
|
+
grad_input = accumulate_chunk(input_chunk, target_chunk)
|
|
119
|
+
|
|
120
|
+
grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
|
|
121
|
+
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
|
|
122
|
+
|
|
123
|
+
# combine grad_chosen_inputs and grad_rejected_inputs
|
|
124
|
+
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
|
|
125
|
+
|
|
126
|
+
ctx.save_for_backward(
|
|
127
|
+
torch.cat(grad_inputs, dim=0),
|
|
128
|
+
grad_weight,
|
|
129
|
+
grad_bias,
|
|
130
|
+
)
|
|
131
|
+
return loss_acc
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def backward(ctx, grad_output):
|
|
135
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
136
|
+
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
137
|
+
grad_input = grad_input * grad_output
|
|
138
|
+
grad_weight = grad_weight * grad_output
|
|
139
|
+
grad_bias = grad_bias * grad_output if grad_bias is not None else None
|
|
140
|
+
|
|
141
|
+
return grad_input, grad_weight, None, grad_bias, None, None, None
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _compute_loss(
|
|
145
|
+
input_chunk,
|
|
146
|
+
weight,
|
|
147
|
+
target_chunk,
|
|
148
|
+
bias=None,
|
|
149
|
+
preference_loss_fn=None,
|
|
150
|
+
full_target=None,
|
|
151
|
+
ignore_index=-100,
|
|
152
|
+
beta=0.1,
|
|
153
|
+
compute_nll_loss=True,
|
|
154
|
+
**loss_kwargs,
|
|
155
|
+
):
|
|
156
|
+
"""
|
|
157
|
+
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
|
|
158
|
+
Args:
|
|
159
|
+
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
160
|
+
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
|
161
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
|
|
162
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
|
|
163
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
164
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
|
165
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
166
|
+
beta (float): Weight for the odds ratio loss.
|
|
167
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
168
|
+
"""
|
|
169
|
+
len_chosen_chunk = target_chunk.shape[0] // 2
|
|
170
|
+
|
|
171
|
+
logits_chunk = input_chunk @ weight.t() # chunk_size x V
|
|
172
|
+
if bias is not None:
|
|
173
|
+
logits_chunk = logits_chunk + bias
|
|
174
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
|
175
|
+
|
|
176
|
+
chosen_nll_loss = 0.0
|
|
177
|
+
if compute_nll_loss:
|
|
178
|
+
chosen_nll_loss = F.nll_loss(
|
|
179
|
+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
|
180
|
+
target_chunk[:len_chosen_chunk].view(-1),
|
|
181
|
+
reduction="sum",
|
|
182
|
+
ignore_index=ignore_index,
|
|
183
|
+
)
|
|
184
|
+
chosen_nll_loss = (
|
|
185
|
+
chosen_nll_loss
|
|
186
|
+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
loss_mask = target_chunk != ignore_index
|
|
190
|
+
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
|
191
|
+
|
|
192
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
|
193
|
+
-1
|
|
194
|
+
)
|
|
195
|
+
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
|
196
|
+
|
|
197
|
+
chosen_logps = average_log_prob[:len_chosen_chunk]
|
|
198
|
+
rejected_logps = average_log_prob[len_chosen_chunk:]
|
|
199
|
+
|
|
200
|
+
alignment_loss = preference_loss_fn(
|
|
201
|
+
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
|
|
202
|
+
)
|
|
203
|
+
alignment_loss = alignment_loss / (full_target.shape[0] // 2)
|
|
204
|
+
|
|
205
|
+
loss = chosen_nll_loss - alignment_loss
|
|
206
|
+
return loss, (alignment_loss, chosen_logps, rejected_logps)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
5
|
+
LigerFusedLinearPreferenceBase,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
|
|
13
|
+
"""
|
|
14
|
+
Compute odds-ratio loss.
|
|
15
|
+
Args:
|
|
16
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
17
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
18
|
+
beta (float): Weight for the odds ratio loss.
|
|
19
|
+
"""
|
|
20
|
+
log_odds = (chosen_logps - rejected_logps) - (
|
|
21
|
+
torch.log1p(-torch.exp(chosen_logps))
|
|
22
|
+
- torch.log1p(-torch.exp(rejected_logps))
|
|
23
|
+
)
|
|
24
|
+
ratio = F.logsigmoid(log_odds)
|
|
25
|
+
return beta * ratio.sum()
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def forward(
|
|
29
|
+
ctx,
|
|
30
|
+
_input,
|
|
31
|
+
weight,
|
|
32
|
+
target,
|
|
33
|
+
bias=None,
|
|
34
|
+
ignore_index=-100,
|
|
35
|
+
beta=0.1,
|
|
36
|
+
compute_nll_loss=True,
|
|
37
|
+
compiled=True,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
|
|
41
|
+
Handles both the forward and backward pass of the final linear layer with ORPO loss.
|
|
42
|
+
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
return LigerFusedLinearPreferenceBase.forward(
|
|
46
|
+
ctx=ctx,
|
|
47
|
+
_input=_input,
|
|
48
|
+
weight=weight,
|
|
49
|
+
target=target,
|
|
50
|
+
bias=bias,
|
|
51
|
+
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
|
|
52
|
+
compute_nll_loss=compute_nll_loss,
|
|
53
|
+
ignore_index=ignore_index,
|
|
54
|
+
beta=beta,
|
|
55
|
+
compiled=compiled,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def backward(ctx, grad_output):
|
|
60
|
+
# Get gradients for _input, weight, bias, and target from the base class
|
|
61
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
62
|
+
# Return these gradients, followed by None for the remaining inputs
|
|
63
|
+
return *grads, None, None, None, None
|
liger_kernel/env_report.py
CHANGED
|
@@ -4,11 +4,13 @@ import sys
|
|
|
4
4
|
|
|
5
5
|
def print_env_report():
|
|
6
6
|
"""
|
|
7
|
+
|
|
7
8
|
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
8
9
|
Usage:
|
|
9
10
|
```
|
|
10
11
|
python -m liger_kernel.env_report
|
|
11
12
|
```
|
|
13
|
+
|
|
12
14
|
"""
|
|
13
15
|
print("Environment Report:")
|
|
14
16
|
print("-------------------")
|
|
@@ -1,8 +1,24 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
1
4
|
import torch
|
|
2
5
|
import triton
|
|
3
6
|
import triton.language as tl
|
|
4
7
|
|
|
5
|
-
from liger_kernel.ops.utils import element_mul_kernel, is_hip
|
|
8
|
+
from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
|
|
9
|
+
|
|
10
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
try:
|
|
12
|
+
# typical import path with dispatch available
|
|
13
|
+
from triton.language.extra.libdevice import tanh
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
# for working with NGC containers
|
|
16
|
+
from triton.language.extra.cuda.libdevice import tanh
|
|
17
|
+
else:
|
|
18
|
+
from triton.language.math import tanh
|
|
19
|
+
|
|
20
|
+
_TRUE = tl.constexpr(1)
|
|
21
|
+
_FALSE = tl.constexpr(0)
|
|
6
22
|
|
|
7
23
|
|
|
8
24
|
@triton.jit
|
|
@@ -12,13 +28,18 @@ def liger_cross_entropy_kernel(
|
|
|
12
28
|
Y_ptr,
|
|
13
29
|
Y_stride,
|
|
14
30
|
loss_ptr,
|
|
31
|
+
z_loss_ptr,
|
|
15
32
|
loss_stride,
|
|
16
33
|
n_cols,
|
|
17
34
|
n_non_ignore,
|
|
18
35
|
ignore_index,
|
|
36
|
+
lse_square_scale: tl.constexpr,
|
|
19
37
|
label_smoothing: tl.constexpr,
|
|
20
38
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
39
|
+
softcap,
|
|
40
|
+
RETURN_Z_LOSS: tl.constexpr,
|
|
21
41
|
BLOCK_SIZE: tl.constexpr,
|
|
42
|
+
HAS_SOFTCAPPING: tl.constexpr,
|
|
22
43
|
):
|
|
23
44
|
"""
|
|
24
45
|
This kernel computes both cross entropy loss and the gradient of the input.
|
|
@@ -30,13 +51,18 @@ def liger_cross_entropy_kernel(
|
|
|
30
51
|
Y_ptr: Pointer to target tensor.
|
|
31
52
|
Y_stride (int): The stride of the target tensor.
|
|
32
53
|
loss_ptr: Pointer to tensor to store the loss.
|
|
54
|
+
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
33
55
|
loss_stride (int): The stride of the loss tensor.
|
|
34
56
|
n_cols (int): The number of columns in the input tensor.
|
|
35
57
|
n_non_ignore (int): The number of non-ignored elements in the batch.
|
|
36
58
|
ignore_index (int): The index to ignore in the target.
|
|
37
59
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
60
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
61
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
38
62
|
reduction (str): The string for the reduction to apply
|
|
63
|
+
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
39
64
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
65
|
+
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
40
66
|
"""
|
|
41
67
|
|
|
42
68
|
# https://github.com/triton-lang/triton/issues/1058
|
|
@@ -58,6 +84,7 @@ def liger_cross_entropy_kernel(
|
|
|
58
84
|
return
|
|
59
85
|
|
|
60
86
|
loss_ptr += program_id * loss_stride
|
|
87
|
+
z_loss_ptr += program_id * loss_stride
|
|
61
88
|
|
|
62
89
|
# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
|
|
63
90
|
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
|
|
@@ -68,6 +95,8 @@ def liger_cross_entropy_kernel(
|
|
|
68
95
|
ori_X_y = tl.load(
|
|
69
96
|
X_ptr + y
|
|
70
97
|
) # we need to store the original value of X_y for the loss calculation
|
|
98
|
+
if HAS_SOFTCAPPING:
|
|
99
|
+
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
71
100
|
|
|
72
101
|
# Label smoothing is a general case of normal cross entropy
|
|
73
102
|
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
|
|
@@ -79,6 +108,8 @@ def liger_cross_entropy_kernel(
|
|
|
79
108
|
X_block = tl.load(
|
|
80
109
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
81
110
|
)
|
|
111
|
+
if HAS_SOFTCAPPING:
|
|
112
|
+
X_block = softcap * tanh(X_block / softcap)
|
|
82
113
|
block_max = tl.max(X_block)
|
|
83
114
|
if label_smoothing > 0:
|
|
84
115
|
# scale X beforehand to avoid overflow
|
|
@@ -87,32 +118,49 @@ def liger_cross_entropy_kernel(
|
|
|
87
118
|
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
|
|
88
119
|
m = m_new
|
|
89
120
|
|
|
121
|
+
# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
|
|
122
|
+
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
|
|
123
|
+
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
|
|
124
|
+
lse = m + tl.log(d)
|
|
125
|
+
|
|
90
126
|
# 4. [Online Softmax] Second pass: compute gradients
|
|
91
127
|
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
|
|
92
128
|
# dx_y = (softmax(x_y) - 1) / N
|
|
93
129
|
# dx_i = softmax(x_i) / N, i != y
|
|
94
130
|
# For label smoothing:
|
|
95
|
-
# dx_i = (softmax(
|
|
131
|
+
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
|
|
96
132
|
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
|
|
97
133
|
# = dx_i - (1 - label_smoothing) / N
|
|
98
|
-
#
|
|
134
|
+
# With Z loss:
|
|
135
|
+
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
|
|
136
|
+
# dx_y = dx_i - (1 - label_smoothing) / N
|
|
99
137
|
# For 'sum' reduction, no normalization is applied:
|
|
100
138
|
# dx_y = softmax(x_y) - 1
|
|
101
139
|
# dx_i = softmax(x_i), for i ≠ y
|
|
102
|
-
# For label smoothing:
|
|
103
|
-
# dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
|
|
104
|
-
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
|
|
105
|
-
# = dx_i - (1 - label_smoothing)
|
|
106
140
|
|
|
107
141
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
108
142
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
109
143
|
X_block = tl.load(
|
|
110
144
|
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
|
|
111
145
|
)
|
|
146
|
+
if HAS_SOFTCAPPING:
|
|
147
|
+
intermediate = tanh(X_block / softcap)
|
|
148
|
+
X_block = softcap * intermediate
|
|
149
|
+
# softmax(x_i)
|
|
150
|
+
X_block = tl.exp(X_block - m) / d
|
|
151
|
+
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
|
|
152
|
+
X_block += 2 * lse_square_scale * lse * X_block
|
|
153
|
+
# smoothing term
|
|
154
|
+
X_block += -eps
|
|
155
|
+
# special handle dx_y
|
|
156
|
+
X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
|
|
157
|
+
# reduction scale
|
|
112
158
|
if reduction == "mean":
|
|
113
|
-
X_block =
|
|
114
|
-
|
|
115
|
-
|
|
159
|
+
X_block = X_block / (n_non_ignore)
|
|
160
|
+
# chain rule
|
|
161
|
+
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
|
|
162
|
+
if HAS_SOFTCAPPING:
|
|
163
|
+
X_block = X_block * (1 - intermediate * intermediate)
|
|
116
164
|
|
|
117
165
|
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
|
|
118
166
|
|
|
@@ -124,35 +172,35 @@ def liger_cross_entropy_kernel(
|
|
|
124
172
|
|
|
125
173
|
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
|
|
126
174
|
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
|
|
175
|
+
# = X_y - m - log d = X_y - lse
|
|
127
176
|
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
|
|
128
177
|
# So we can safely calculate log (softmax(X_y)) without overflow
|
|
129
|
-
loss =
|
|
178
|
+
loss = lse - ori_X_y
|
|
130
179
|
|
|
131
180
|
# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
|
|
132
181
|
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
|
|
133
182
|
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
|
|
134
183
|
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
|
|
135
|
-
# = (1 - label_smoothing) * H(q, p) + (
|
|
184
|
+
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
|
|
136
185
|
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
|
|
137
186
|
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
|
|
138
187
|
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
|
|
139
188
|
if label_smoothing > 0:
|
|
140
|
-
smooth_loss = scaled_x_sum + label_smoothing *
|
|
189
|
+
smooth_loss = scaled_x_sum + label_smoothing * lse
|
|
141
190
|
loss = loss * (1 - label_smoothing) + smooth_loss
|
|
142
191
|
|
|
192
|
+
# An auxiliary loss, z_loss
|
|
193
|
+
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
|
|
194
|
+
z_loss = lse_square_scale * lse * lse
|
|
195
|
+
loss += z_loss
|
|
143
196
|
# Normalize the loss by the number of non-ignored elements if reduction is "mean"
|
|
144
197
|
if reduction == "mean":
|
|
198
|
+
z_loss = z_loss / n_non_ignore
|
|
145
199
|
loss = loss / n_non_ignore
|
|
146
200
|
|
|
147
|
-
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
|
|
148
|
-
X_y = tl.load(X_ptr + y)
|
|
149
|
-
if reduction == "mean":
|
|
150
|
-
X_y += -(1 - label_smoothing) / (n_non_ignore)
|
|
151
|
-
else:
|
|
152
|
-
X_y += -(1 - label_smoothing)
|
|
153
|
-
|
|
154
201
|
tl.store(loss_ptr, loss)
|
|
155
|
-
|
|
202
|
+
if RETURN_Z_LOSS == _TRUE:
|
|
203
|
+
tl.store(z_loss_ptr, z_loss)
|
|
156
204
|
|
|
157
205
|
|
|
158
206
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
@@ -161,7 +209,32 @@ def liger_cross_entropy_kernel(
|
|
|
161
209
|
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
|
162
210
|
|
|
163
211
|
|
|
164
|
-
|
|
212
|
+
_bool_to_return_z_loss = {
|
|
213
|
+
True: _TRUE.value,
|
|
214
|
+
False: _FALSE.value,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def cross_entropy_forward(
|
|
219
|
+
_input,
|
|
220
|
+
target,
|
|
221
|
+
ignore_index,
|
|
222
|
+
lse_square_scale,
|
|
223
|
+
label_smoothing,
|
|
224
|
+
reduction,
|
|
225
|
+
softcap,
|
|
226
|
+
return_z_loss,
|
|
227
|
+
):
|
|
228
|
+
if not isinstance(return_z_loss, int):
|
|
229
|
+
assert (
|
|
230
|
+
return_z_loss in _bool_to_return_z_loss
|
|
231
|
+
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
232
|
+
return_z_loss = _bool_to_return_z_loss[return_z_loss]
|
|
233
|
+
else:
|
|
234
|
+
assert (
|
|
235
|
+
return_z_loss in _bool_to_return_z_loss
|
|
236
|
+
), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
237
|
+
|
|
165
238
|
BT, V = _input.shape
|
|
166
239
|
n_rows = BT
|
|
167
240
|
|
|
@@ -169,6 +242,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
169
242
|
|
|
170
243
|
# unreduced loss
|
|
171
244
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
245
|
+
if return_z_loss == _TRUE.value:
|
|
246
|
+
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
247
|
+
else:
|
|
248
|
+
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
|
|
172
249
|
|
|
173
250
|
n_non_ignore = (target != ignore_index).sum().item()
|
|
174
251
|
|
|
@@ -185,20 +262,30 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
|
|
|
185
262
|
Y_ptr=target,
|
|
186
263
|
Y_stride=target.stride(-1), # always 1
|
|
187
264
|
loss_ptr=loss_1d,
|
|
265
|
+
z_loss_ptr=z_loss_1d,
|
|
188
266
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
189
267
|
n_cols=V,
|
|
190
268
|
n_non_ignore=n_non_ignore,
|
|
191
269
|
ignore_index=ignore_index,
|
|
270
|
+
lse_square_scale=lse_square_scale,
|
|
192
271
|
label_smoothing=label_smoothing,
|
|
193
272
|
reduction=reduction,
|
|
273
|
+
softcap=softcap if softcap is not None else 0.0,
|
|
274
|
+
RETURN_Z_LOSS=return_z_loss,
|
|
194
275
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
276
|
+
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
195
277
|
# TODO: 32 seems to give the best performance
|
|
196
278
|
# Performance is quite sensitive to num_warps
|
|
197
279
|
num_warps=32 if not is_hip() else 16,
|
|
198
280
|
)
|
|
199
281
|
|
|
200
282
|
loss = torch.sum(loss_1d)
|
|
201
|
-
|
|
283
|
+
if return_z_loss == _TRUE.value:
|
|
284
|
+
z_loss = torch.sum(z_loss_1d)
|
|
285
|
+
else:
|
|
286
|
+
z_loss = None
|
|
287
|
+
|
|
288
|
+
return loss, z_loss, _input
|
|
202
289
|
|
|
203
290
|
|
|
204
291
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -233,7 +320,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
233
320
|
|
|
234
321
|
@staticmethod
|
|
235
322
|
def forward(
|
|
236
|
-
ctx,
|
|
323
|
+
ctx,
|
|
324
|
+
_input: torch.Tensor,
|
|
325
|
+
target: torch.Tensor,
|
|
326
|
+
ignore_index: int = -100,
|
|
327
|
+
lse_square_scale: float = 0.0,
|
|
328
|
+
label_smoothing: float = 0.0,
|
|
329
|
+
reduction: str = "mean",
|
|
330
|
+
softcap: Optional[float] = None,
|
|
331
|
+
return_z_loss: bool = False,
|
|
237
332
|
):
|
|
238
333
|
"""
|
|
239
334
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -243,33 +338,48 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
243
338
|
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
|
|
244
339
|
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
|
|
245
340
|
ignore_index (int): The index to ignore in the target.
|
|
341
|
+
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
246
342
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
247
343
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
344
|
+
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
345
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
248
346
|
|
|
249
347
|
Returns:
|
|
250
|
-
|
|
348
|
+
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
|
|
251
349
|
"""
|
|
252
|
-
loss, _input = cross_entropy_forward(
|
|
253
|
-
_input,
|
|
350
|
+
loss, z_loss, _input = cross_entropy_forward(
|
|
351
|
+
_input,
|
|
352
|
+
target,
|
|
353
|
+
ignore_index,
|
|
354
|
+
lse_square_scale,
|
|
355
|
+
label_smoothing,
|
|
356
|
+
reduction,
|
|
357
|
+
softcap,
|
|
358
|
+
return_z_loss,
|
|
254
359
|
)
|
|
255
360
|
# TODO: investigation
|
|
256
361
|
# If we don't detach the _input tensor, the memory will double
|
|
257
362
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
258
363
|
ctx.save_for_backward(_input.detach())
|
|
259
|
-
|
|
364
|
+
ctx.return_z_loss = return_z_loss
|
|
365
|
+
|
|
366
|
+
return loss, z_loss
|
|
260
367
|
|
|
261
368
|
@staticmethod
|
|
262
|
-
def backward(ctx, grad_output):
|
|
369
|
+
def backward(ctx, grad_output, grad_ouput2):
|
|
263
370
|
"""
|
|
264
371
|
The backward pass of the Liger Cross Entropy loss.
|
|
265
372
|
|
|
266
373
|
Parameters:
|
|
267
374
|
ctx : The context object with saved tensors.
|
|
268
375
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
269
|
-
|
|
376
|
+
grad_output2 (tenosr): No use.
|
|
270
377
|
Returns:
|
|
271
378
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
272
379
|
"""
|
|
380
|
+
if ctx.return_z_loss:
|
|
381
|
+
del grad_ouput2 # z_loss is only for logging
|
|
382
|
+
|
|
273
383
|
(_input,) = ctx.saved_tensors
|
|
274
384
|
_input = cross_entropy_backward(_input, grad_output)
|
|
275
385
|
return (
|
|
@@ -278,4 +388,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
278
388
|
None,
|
|
279
389
|
None,
|
|
280
390
|
None,
|
|
391
|
+
None,
|
|
392
|
+
None,
|
|
393
|
+
None,
|
|
281
394
|
)
|