liger-kernel 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +5 -39
- liger_kernel/ops/experimental/mm_int8int2.py +355 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +13 -10
- liger_kernel/ops/fused_linear_jsd.py +245 -0
- liger_kernel/ops/geglu.py +2 -2
- liger_kernel/ops/jsd.py +176 -0
- liger_kernel/ops/kl_div.py +45 -34
- liger_kernel/ops/rms_norm.py +67 -42
- liger_kernel/ops/swiglu.py +2 -2
- liger_kernel/ops/utils.py +62 -1
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/auto_model.py +18 -6
- liger_kernel/transformers/functional.py +4 -0
- liger_kernel/transformers/fused_linear_jsd.py +98 -0
- liger_kernel/transformers/jsd.py +75 -0
- liger_kernel/transformers/kl_div.py +3 -2
- liger_kernel/transformers/model/gemma.py +124 -1
- liger_kernel/transformers/model/llama.py +135 -4
- liger_kernel/transformers/model/mistral.py +3 -0
- liger_kernel/transformers/model/mixtral.py +153 -2
- liger_kernel/transformers/model/mllama.py +274 -0
- liger_kernel/transformers/model/phi3.py +140 -2
- liger_kernel/transformers/model/qwen2.py +123 -2
- liger_kernel/transformers/model/qwen2_vl.py +8 -1
- liger_kernel/transformers/monkey_patch.py +254 -129
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +74 -35
- liger_kernel-0.4.0.dist-info/NOTICE +58 -0
- liger_kernel-0.4.0.dist-info/RECORD +48 -0
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
- liger_kernel-0.3.0.dist-info/NOTICE +0 -4
- liger_kernel-0.3.0.dist-info/RECORD +0 -42
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
|
|
6
|
+
from liger_kernel.ops.jsd import _jsd_kernel
|
|
7
|
+
from liger_kernel.ops.utils import (
|
|
8
|
+
amp_custom_bwd,
|
|
9
|
+
amp_custom_fwd,
|
|
10
|
+
element_mul_kernel,
|
|
11
|
+
is_hip,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
15
|
+
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
16
|
+
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
17
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def fused_linear_jsd_forward(
|
|
21
|
+
student_input,
|
|
22
|
+
student_weight,
|
|
23
|
+
teacher_input,
|
|
24
|
+
teacher_weight,
|
|
25
|
+
shift_labels,
|
|
26
|
+
jsd_beta,
|
|
27
|
+
ignore_index,
|
|
28
|
+
has_label,
|
|
29
|
+
temperature,
|
|
30
|
+
):
|
|
31
|
+
device = student_input.device
|
|
32
|
+
dtype = student_input.dtype
|
|
33
|
+
|
|
34
|
+
# inputs have shape: BT x H
|
|
35
|
+
# materialized activations will have shape: BT x V
|
|
36
|
+
# the increase in memory = BT x V
|
|
37
|
+
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
|
|
38
|
+
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
|
|
39
|
+
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
|
|
40
|
+
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
|
|
41
|
+
BT, H = student_input.shape
|
|
42
|
+
V = student_weight.shape[0]
|
|
43
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
44
|
+
|
|
45
|
+
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
|
|
46
|
+
chunk_size = triton.next_power_of_2(
|
|
47
|
+
triton.cdiv(BT, inc_factor)
|
|
48
|
+
) # (BT + inc_factor - 1) // inc_factor
|
|
49
|
+
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
|
|
50
|
+
|
|
51
|
+
grad_weight = (
|
|
52
|
+
torch.zeros_like(student_weight, device=device)
|
|
53
|
+
if student_weight.requires_grad
|
|
54
|
+
else None
|
|
55
|
+
)
|
|
56
|
+
grad_input = torch.zeros_like(student_input)
|
|
57
|
+
# we use fp32 for loss accumulator
|
|
58
|
+
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
|
|
59
|
+
|
|
60
|
+
if has_label:
|
|
61
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item()
|
|
62
|
+
else:
|
|
63
|
+
n_non_ignore = BT
|
|
64
|
+
|
|
65
|
+
for chunk_id in range(num_chunks):
|
|
66
|
+
start_idx = chunk_id * chunk_size
|
|
67
|
+
end_idx = min((chunk_id + 1) * chunk_size, BT)
|
|
68
|
+
|
|
69
|
+
# chunk both inputs, shape: chunk_size x H
|
|
70
|
+
student_input_chunk = student_input[start_idx:end_idx]
|
|
71
|
+
teacher_input_chunk = teacher_input[start_idx:end_idx]
|
|
72
|
+
|
|
73
|
+
# shape: chunk_size x V
|
|
74
|
+
# For anything starting from logits to the final JSD loss, we do computation
|
|
75
|
+
# in FP32 to avoid losing numerical stability.
|
|
76
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
|
77
|
+
torch.float32
|
|
78
|
+
)
|
|
79
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
|
80
|
+
torch.float32
|
|
81
|
+
)
|
|
82
|
+
chunk_n_rows = student_logits_chunk.shape[0]
|
|
83
|
+
|
|
84
|
+
# unreduced loss
|
|
85
|
+
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size
|
|
86
|
+
# log-softmax with temperature
|
|
87
|
+
student_logits_chunk = student_logits_chunk / temperature
|
|
88
|
+
teacher_logits_chunk = teacher_logits_chunk / temperature
|
|
89
|
+
student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1)
|
|
90
|
+
teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1)
|
|
91
|
+
|
|
92
|
+
# ensure _input and target are contiguous
|
|
93
|
+
student_prob_chunk = student_prob_chunk.contiguous()
|
|
94
|
+
teacher_prob_chunk = teacher_prob_chunk.contiguous()
|
|
95
|
+
|
|
96
|
+
# Here we calculate the gradient of prob_chunk in place so we can save memory.
|
|
97
|
+
_jsd_kernel[(chunk_n_rows,)](
|
|
98
|
+
X_ptr=student_prob_chunk,
|
|
99
|
+
X_stride=student_prob_chunk.stride(-2),
|
|
100
|
+
Y_ptr=teacher_prob_chunk,
|
|
101
|
+
Y_stride=teacher_prob_chunk.stride(-2),
|
|
102
|
+
loss_ptr=loss_1d_slice,
|
|
103
|
+
loss_stride=loss_1d_slice.stride(-2),
|
|
104
|
+
dX_ptr=student_prob_chunk,
|
|
105
|
+
dX_stride=student_prob_chunk.stride(-2),
|
|
106
|
+
label_ptr=(
|
|
107
|
+
shift_labels[start_idx:end_idx]
|
|
108
|
+
if has_label
|
|
109
|
+
else torch.empty(1, device=device)
|
|
110
|
+
), # dummy ptr if no label
|
|
111
|
+
beta=jsd_beta,
|
|
112
|
+
n_non_ignore=n_non_ignore,
|
|
113
|
+
ignore_index=ignore_index,
|
|
114
|
+
n_cols=V,
|
|
115
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
116
|
+
HAS_LABEL=has_label,
|
|
117
|
+
)
|
|
118
|
+
loss_1d[start_idx:end_idx] = loss_1d_slice
|
|
119
|
+
# gradients of prob_chunk in place, shape: chunk_size x V
|
|
120
|
+
# gradients of logits_chunk in place, shape: chunk_size x V
|
|
121
|
+
student_logits_chunk = (
|
|
122
|
+
student_prob_chunk
|
|
123
|
+
- torch.softmax(student_logits_chunk, dim=-1)
|
|
124
|
+
* student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
|
|
125
|
+
student_prob_chunk.shape
|
|
126
|
+
)
|
|
127
|
+
) / temperature
|
|
128
|
+
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
|
129
|
+
# w.r.t. `lm_head` which should be computed in original dtype
|
|
130
|
+
student_logits_chunk = student_logits_chunk.to(dtype)
|
|
131
|
+
grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
|
|
132
|
+
|
|
133
|
+
if grad_weight is not None:
|
|
134
|
+
grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
|
|
135
|
+
|
|
136
|
+
loss = torch.sum(loss_1d)
|
|
137
|
+
return loss, grad_input, grad_weight
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
|
|
141
|
+
# If JSD is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
142
|
+
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
143
|
+
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
|
|
144
|
+
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
|
|
145
|
+
BT, H = grad_input.shape
|
|
146
|
+
n_rows = BT
|
|
147
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
|
|
148
|
+
|
|
149
|
+
element_mul_kernel[(n_rows,)](
|
|
150
|
+
grad_input,
|
|
151
|
+
grad_input.stride(-2),
|
|
152
|
+
grad_output,
|
|
153
|
+
H,
|
|
154
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
155
|
+
num_warps=32 if not is_hip() else 16,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# handle grad_weight
|
|
159
|
+
if grad_weight is not None:
|
|
160
|
+
V, H = grad_weight.shape
|
|
161
|
+
n_rows = V
|
|
162
|
+
|
|
163
|
+
element_mul_kernel[(n_rows,)](
|
|
164
|
+
grad_weight,
|
|
165
|
+
grad_weight.stride(-2),
|
|
166
|
+
grad_output,
|
|
167
|
+
H,
|
|
168
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
169
|
+
num_warps=32 if not is_hip() else 16,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return grad_input, grad_weight
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
176
|
+
"""
|
|
177
|
+
Fusing the last linear layer with generalized JSD
|
|
178
|
+
|
|
179
|
+
Handle the forward and backward pass of the final linear layer via JSD by avoiding
|
|
180
|
+
the materialization of the large logits tensor. Since JSD is the last layer, we can
|
|
181
|
+
compute the gradient at the forward pass.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
@amp_custom_fwd
|
|
186
|
+
def forward(
|
|
187
|
+
ctx,
|
|
188
|
+
student_input: torch.Tensor,
|
|
189
|
+
student_weight: torch.Tensor,
|
|
190
|
+
teacher_input: torch.Tensor,
|
|
191
|
+
teacher_weight: torch.Tensor,
|
|
192
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
193
|
+
jsd_beta: float = 0.5,
|
|
194
|
+
ignore_index: int = -100,
|
|
195
|
+
temperature: float = 1.0,
|
|
196
|
+
):
|
|
197
|
+
"""
|
|
198
|
+
Args:
|
|
199
|
+
|
|
200
|
+
student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
|
|
201
|
+
student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
|
|
202
|
+
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
|
|
203
|
+
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
|
|
204
|
+
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
205
|
+
jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
|
|
206
|
+
ignore_index (int): the index to ignore. Default: -100
|
|
207
|
+
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
loss (torch.Tensor): generalized JSD
|
|
211
|
+
"""
|
|
212
|
+
has_label = False
|
|
213
|
+
if shift_labels is not None:
|
|
214
|
+
assert shift_labels.shape == (
|
|
215
|
+
teacher_input.shape[0],
|
|
216
|
+
), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
217
|
+
shift_labels = shift_labels.contiguous()
|
|
218
|
+
has_label = True
|
|
219
|
+
|
|
220
|
+
loss, grad_input, grad_weight = fused_linear_jsd_forward(
|
|
221
|
+
student_input,
|
|
222
|
+
student_weight,
|
|
223
|
+
teacher_input,
|
|
224
|
+
teacher_weight,
|
|
225
|
+
shift_labels,
|
|
226
|
+
jsd_beta,
|
|
227
|
+
ignore_index,
|
|
228
|
+
has_label,
|
|
229
|
+
temperature,
|
|
230
|
+
)
|
|
231
|
+
# downcast to dtype and store for backward
|
|
232
|
+
ctx.save_for_backward(
|
|
233
|
+
grad_input.detach(),
|
|
234
|
+
grad_weight.detach() if grad_weight is not None else None,
|
|
235
|
+
)
|
|
236
|
+
return loss
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
@amp_custom_bwd
|
|
240
|
+
def backward(ctx, grad_output):
|
|
241
|
+
(grad_input, grad_weight) = ctx.saved_tensors
|
|
242
|
+
grad_input, grad_weight = fused_linear_jsd_backward(
|
|
243
|
+
grad_output, grad_input, grad_weight
|
|
244
|
+
)
|
|
245
|
+
return (grad_input, grad_weight, None, None, None, None, None, None)
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -25,7 +25,7 @@ else:
|
|
|
25
25
|
def _geglu_tanh_forward_kernel(
|
|
26
26
|
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
27
27
|
):
|
|
28
|
-
program_id = tl.program_id(0)
|
|
28
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
29
29
|
|
|
30
30
|
# locate start index
|
|
31
31
|
a += program_id * stride
|
|
@@ -52,7 +52,7 @@ def _geglu_tanh_forward_kernel(
|
|
|
52
52
|
def _geglu_tanh_backward_kernel(
|
|
53
53
|
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
54
54
|
):
|
|
55
|
-
program_id = tl.program_id(0)
|
|
55
|
+
program_id = tl.program_id(0).to(tl.int64)
|
|
56
56
|
|
|
57
57
|
# locate start index
|
|
58
58
|
dc += program_id * stride
|
liger_kernel/ops/jsd.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@triton.jit
|
|
11
|
+
def _jsd_kernel(
|
|
12
|
+
X_ptr, # input in logspace, X = log Q
|
|
13
|
+
X_stride,
|
|
14
|
+
Y_ptr, # ground truth in logspace, Y = log P
|
|
15
|
+
Y_stride,
|
|
16
|
+
loss_ptr,
|
|
17
|
+
loss_stride,
|
|
18
|
+
dX_ptr,
|
|
19
|
+
dX_stride,
|
|
20
|
+
label_ptr,
|
|
21
|
+
beta,
|
|
22
|
+
n_non_ignore: int,
|
|
23
|
+
ignore_index: tl.constexpr,
|
|
24
|
+
n_cols,
|
|
25
|
+
BLOCK_SIZE: tl.constexpr,
|
|
26
|
+
HAS_LABEL: tl.constexpr,
|
|
27
|
+
):
|
|
28
|
+
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
|
|
29
|
+
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
|
|
30
|
+
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
|
|
31
|
+
# grad_x_i = 0.5 * Q * (X - log_M)
|
|
32
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
33
|
+
X_ptr += pid * X_stride
|
|
34
|
+
dX_ptr += pid * dX_stride
|
|
35
|
+
Y_ptr += pid * Y_stride
|
|
36
|
+
loss_ptr += pid * loss_stride
|
|
37
|
+
label_ptr += pid
|
|
38
|
+
|
|
39
|
+
if HAS_LABEL:
|
|
40
|
+
label = tl.load(label_ptr)
|
|
41
|
+
if label == ignore_index:
|
|
42
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
43
|
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
44
|
+
tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
48
|
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
49
|
+
mask = offsets < n_cols
|
|
50
|
+
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
51
|
+
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
|
+
|
|
53
|
+
Q = tl.exp(X)
|
|
54
|
+
P = tl.exp(Y)
|
|
55
|
+
M = beta * P + (1 - beta) * Q
|
|
56
|
+
log_M = tl.log(M)
|
|
57
|
+
|
|
58
|
+
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
|
|
59
|
+
# reduction == "batchmean"
|
|
60
|
+
loss = loss / n_non_ignore
|
|
61
|
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
62
|
+
|
|
63
|
+
dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
|
|
64
|
+
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
MAX_FUSED_SIZE = 65536
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
71
|
+
BT, V = _input.shape
|
|
72
|
+
n_rows = BT
|
|
73
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
74
|
+
# non reduction loss
|
|
75
|
+
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
|
|
76
|
+
dX = torch.empty_like(_input)
|
|
77
|
+
|
|
78
|
+
if has_label:
|
|
79
|
+
n_non_ignore = (shift_labels != ignore_index).sum().item()
|
|
80
|
+
else:
|
|
81
|
+
n_non_ignore = BT
|
|
82
|
+
|
|
83
|
+
_jsd_kernel[(n_rows,)](
|
|
84
|
+
X_ptr=_input, # input in logspace, X = log Q
|
|
85
|
+
X_stride=_input.stride(-2),
|
|
86
|
+
Y_ptr=target, # ground truth in logspace, Y = log P
|
|
87
|
+
Y_stride=target.stride(-2),
|
|
88
|
+
loss_ptr=loss,
|
|
89
|
+
loss_stride=loss.stride(-2),
|
|
90
|
+
dX_ptr=dX,
|
|
91
|
+
dX_stride=dX.stride(-2),
|
|
92
|
+
label_ptr=(
|
|
93
|
+
shift_labels if has_label else torch.empty(1, device=_input.device)
|
|
94
|
+
), # dummy ptr if no label
|
|
95
|
+
beta=beta,
|
|
96
|
+
n_non_ignore=n_non_ignore,
|
|
97
|
+
ignore_index=ignore_index,
|
|
98
|
+
n_cols=V,
|
|
99
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
100
|
+
HAS_LABEL=has_label,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
loss = torch.sum(loss)
|
|
104
|
+
return loss.to(_input.dtype), dX
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def jsd_backward(dX, grad_output):
|
|
108
|
+
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
|
|
109
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
110
|
+
return dX
|
|
111
|
+
else:
|
|
112
|
+
return grad_output * dX
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class LigerJSDFunction(torch.autograd.Function):
|
|
116
|
+
r"""
|
|
117
|
+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
|
|
118
|
+
.. math::
|
|
119
|
+
JSD(\beta)(P || Q)
|
|
120
|
+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
|
|
121
|
+
|
|
122
|
+
.. note::
|
|
123
|
+
As all the other losses in PyTorch, this function expects the first argument,
|
|
124
|
+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
|
|
125
|
+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
|
|
126
|
+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
|
|
127
|
+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
@ensure_contiguous
|
|
132
|
+
def forward(
|
|
133
|
+
ctx,
|
|
134
|
+
_input: torch.Tensor,
|
|
135
|
+
target: torch.Tensor,
|
|
136
|
+
shift_labels: Optional[torch.Tensor] = None,
|
|
137
|
+
beta: float = 0.5,
|
|
138
|
+
ignore_index: int = -100,
|
|
139
|
+
) -> torch.Tensor:
|
|
140
|
+
"""
|
|
141
|
+
Args:
|
|
142
|
+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
|
|
143
|
+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
|
|
144
|
+
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
145
|
+
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
|
|
146
|
+
ignore_index (int): the index to ignore. Default: -100
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
loss (torch.Tensor): generalized JSD
|
|
150
|
+
"""
|
|
151
|
+
has_label = False
|
|
152
|
+
if shift_labels is not None:
|
|
153
|
+
assert shift_labels.shape == (
|
|
154
|
+
_input.shape[0],
|
|
155
|
+
), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
156
|
+
shift_labels = shift_labels.contiguous()
|
|
157
|
+
has_label = True
|
|
158
|
+
|
|
159
|
+
loss, dX = jsd_forward(
|
|
160
|
+
_input, target, shift_labels, beta, ignore_index, has_label
|
|
161
|
+
)
|
|
162
|
+
ctx.save_for_backward(dX)
|
|
163
|
+
return loss
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
@ensure_contiguous
|
|
167
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
(dX,) = ctx.saved_tensors
|
|
169
|
+
dX = jsd_backward(dX, grad_output)
|
|
170
|
+
return (
|
|
171
|
+
dX,
|
|
172
|
+
None,
|
|
173
|
+
None,
|
|
174
|
+
None,
|
|
175
|
+
None,
|
|
176
|
+
)
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -4,13 +4,13 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous, is_hip
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def get_num_warps(BLOCK_SIZE):
|
|
11
11
|
num_warps = 4
|
|
12
12
|
if BLOCK_SIZE >= 32768:
|
|
13
|
-
num_warps = 32
|
|
13
|
+
num_warps = 32 if not is_hip() else 16
|
|
14
14
|
elif BLOCK_SIZE >= 8192:
|
|
15
15
|
num_warps = 16
|
|
16
16
|
elif BLOCK_SIZE >= 2048:
|
|
@@ -45,6 +45,7 @@ def _kldiv_kernel_forward(
|
|
|
45
45
|
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
|
|
46
46
|
loss_stride, # int, output stride
|
|
47
47
|
n_cols, # int, number of columns in the input tensor
|
|
48
|
+
eps,
|
|
48
49
|
BLOCK_SIZE: tl.constexpr,
|
|
49
50
|
log_target: tl.constexpr = False,
|
|
50
51
|
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
@@ -56,6 +57,7 @@ def _kldiv_kernel_forward(
|
|
|
56
57
|
|
|
57
58
|
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
58
59
|
|
|
60
|
+
loss_sum = 0.0
|
|
59
61
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
60
62
|
offsets = i + base_offsets
|
|
61
63
|
mask = offsets < n_cols
|
|
@@ -65,32 +67,33 @@ def _kldiv_kernel_forward(
|
|
|
65
67
|
# KL(y_true || y) = y_true * (log(y_true) - log(y))
|
|
66
68
|
# We compute KL(y_true || y) with y in the log-space
|
|
67
69
|
if not log_target:
|
|
68
|
-
loss = y_true * (tl.log(y_true) - y)
|
|
70
|
+
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
|
|
69
71
|
else:
|
|
70
72
|
loss = tl.exp(y_true) * (y_true - y)
|
|
71
73
|
|
|
72
74
|
if reduction == _REDUCTION_MODE_NONE:
|
|
73
75
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
74
76
|
else:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
77
|
+
loss_sum += tl.sum(loss, axis=0)
|
|
78
|
+
|
|
79
|
+
if reduction != _REDUCTION_MODE_NONE:
|
|
80
|
+
tl.store(loss_ptr, loss_sum)
|
|
78
81
|
|
|
79
82
|
|
|
80
83
|
@triton.jit
|
|
81
84
|
def _kldiv_kernel_backward(
|
|
82
|
-
input_ptr,
|
|
83
|
-
input_stride,
|
|
84
85
|
target_ptr,
|
|
85
86
|
target_stride,
|
|
87
|
+
new_grads_ptr,
|
|
88
|
+
new_grads_stride,
|
|
86
89
|
n_cols,
|
|
87
90
|
BLOCK_SIZE: tl.constexpr,
|
|
88
91
|
log_target: tl.constexpr = False,
|
|
89
92
|
):
|
|
90
93
|
pid = tl.program_id(0).to(tl.int64)
|
|
91
94
|
|
|
92
|
-
input_ptr += pid * input_stride
|
|
93
95
|
target_ptr += pid * target_stride
|
|
96
|
+
new_grads_ptr += pid * new_grads_stride
|
|
94
97
|
|
|
95
98
|
offsets = tl.arange(0, BLOCK_SIZE)
|
|
96
99
|
mask = offsets < n_cols
|
|
@@ -106,19 +109,19 @@ def _kldiv_kernel_backward(
|
|
|
106
109
|
else:
|
|
107
110
|
res = -tl.exp(target)
|
|
108
111
|
|
|
109
|
-
tl.store(
|
|
112
|
+
tl.store(new_grads_ptr + offsets, res, mask=mask)
|
|
110
113
|
|
|
111
114
|
|
|
112
|
-
def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [
|
|
113
|
-
|
|
115
|
+
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
116
|
+
BT, V = y_pred.shape
|
|
114
117
|
|
|
115
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(
|
|
118
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
116
119
|
num_warps = get_num_warps(BLOCK_SIZE)
|
|
117
120
|
|
|
118
|
-
grid = (
|
|
121
|
+
grid = (BT,)
|
|
119
122
|
reduction = _str_to_reduction_mode[reduction]
|
|
120
123
|
|
|
121
|
-
out_size = (
|
|
124
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
|
122
125
|
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
|
|
123
126
|
|
|
124
127
|
_kldiv_kernel_forward[grid](
|
|
@@ -128,7 +131,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
|
|
|
128
131
|
y_true.stride(0),
|
|
129
132
|
output_tensor,
|
|
130
133
|
output_tensor.stride(0),
|
|
131
|
-
|
|
134
|
+
V,
|
|
135
|
+
eps=eps,
|
|
132
136
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
133
137
|
num_warps=num_warps,
|
|
134
138
|
log_target=log_target,
|
|
@@ -139,30 +143,30 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B
|
|
|
139
143
|
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
140
144
|
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
|
|
141
145
|
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
142
|
-
return output_tensor.sum() /
|
|
146
|
+
return output_tensor.sum() / BT
|
|
143
147
|
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
144
148
|
return output_tensor.sum(dim=0)
|
|
145
149
|
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
146
|
-
return output_tensor.
|
|
150
|
+
return output_tensor.sum() / (BT * V)
|
|
147
151
|
else:
|
|
148
152
|
return output_tensor
|
|
149
153
|
|
|
150
154
|
|
|
151
|
-
def kldiv_backward_triton(
|
|
152
|
-
|
|
155
|
+
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
156
|
+
BT, V = target.shape
|
|
153
157
|
|
|
154
|
-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(
|
|
158
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
155
159
|
num_warps = get_num_warps(BLOCK_SIZE)
|
|
156
160
|
|
|
157
|
-
grid = (
|
|
161
|
+
grid = (BT,)
|
|
158
162
|
|
|
159
163
|
# We store the gradients in-place in the input tensor
|
|
160
164
|
_kldiv_kernel_backward[grid](
|
|
161
|
-
input,
|
|
162
|
-
input.stride(0),
|
|
163
165
|
target,
|
|
164
166
|
target.stride(0),
|
|
165
|
-
|
|
167
|
+
new_grads,
|
|
168
|
+
new_grads.stride(0),
|
|
169
|
+
V,
|
|
166
170
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
167
171
|
num_warps=num_warps,
|
|
168
172
|
log_target=log_target,
|
|
@@ -170,9 +174,9 @@ def kldiv_backward_triton(input, target, grad_output, log_target):
|
|
|
170
174
|
|
|
171
175
|
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
172
176
|
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
173
|
-
return
|
|
177
|
+
return new_grads
|
|
174
178
|
|
|
175
|
-
return
|
|
179
|
+
return new_grads * grad_output
|
|
176
180
|
|
|
177
181
|
|
|
178
182
|
class LigerKLDivLossFunction(torch.autograd.Function):
|
|
@@ -196,6 +200,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
196
200
|
y_true: torch.Tensor,
|
|
197
201
|
reduction: REDUCTION_LITERAL = "batchmean",
|
|
198
202
|
log_target: bool = False,
|
|
203
|
+
eps: float = 1e-10,
|
|
199
204
|
) -> torch.Tensor:
|
|
200
205
|
"""A forward pass for the KL Divergence Loss.
|
|
201
206
|
|
|
@@ -205,15 +210,16 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
205
210
|
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
|
|
206
211
|
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
|
|
207
212
|
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
|
|
213
|
+
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
|
|
208
214
|
|
|
209
215
|
Returns:
|
|
210
216
|
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
|
|
211
217
|
"""
|
|
212
|
-
ctx.save_for_backward(
|
|
218
|
+
ctx.save_for_backward(y_true)
|
|
213
219
|
ctx.reduction = reduction
|
|
214
220
|
ctx.log_target = log_target
|
|
215
221
|
return kldiv_forward_triton(
|
|
216
|
-
y_pred, y_true, log_target=log_target, reduction=reduction
|
|
222
|
+
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
|
217
223
|
)
|
|
218
224
|
|
|
219
225
|
@staticmethod
|
|
@@ -226,22 +232,27 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
226
232
|
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
227
233
|
|
|
228
234
|
Returns:
|
|
229
|
-
tuple[torch.Tensor, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
|
235
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
|
230
236
|
"""
|
|
231
|
-
|
|
237
|
+
(y_true,) = ctx.saved_tensors
|
|
238
|
+
|
|
239
|
+
new_grads = torch.empty_like(y_true)
|
|
232
240
|
|
|
233
|
-
derivative = kldiv_backward_triton(
|
|
241
|
+
derivative = kldiv_backward_triton(
|
|
242
|
+
y_true, grad_output, new_grads, ctx.log_target
|
|
243
|
+
)
|
|
234
244
|
|
|
235
245
|
if ctx.reduction == "batchmean":
|
|
236
|
-
derivative = derivative /
|
|
246
|
+
derivative = derivative / y_true.shape[0]
|
|
237
247
|
elif ctx.reduction == "sum" or ctx.reduction == "none":
|
|
238
248
|
pass
|
|
239
249
|
elif ctx.reduction == "mean":
|
|
240
|
-
derivative = derivative / (
|
|
250
|
+
derivative = derivative / (y_true.shape[0] * y_true.shape[1])
|
|
241
251
|
|
|
242
252
|
return (
|
|
243
253
|
derivative,
|
|
244
254
|
None,
|
|
245
255
|
None,
|
|
246
256
|
None,
|
|
257
|
+
None,
|
|
247
258
|
)
|