liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +46 -0
- liger_kernel/ops/cross_entropy.py +130 -63
- liger_kernel/ops/experimental/embedding.py +143 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
- liger_kernel/ops/geglu.py +54 -42
- liger_kernel/ops/kl_div.py +247 -0
- liger_kernel/ops/layer_norm.py +236 -0
- liger_kernel/ops/rms_norm.py +220 -84
- liger_kernel/ops/rope.py +91 -84
- liger_kernel/ops/swiglu.py +48 -41
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +22 -0
- liger_kernel/transformers/auto_model.py +33 -0
- liger_kernel/transformers/cross_entropy.py +11 -1
- liger_kernel/transformers/experimental/embedding.py +28 -0
- liger_kernel/transformers/functional.py +19 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
- liger_kernel/transformers/geglu.py +4 -2
- liger_kernel/transformers/kl_div.py +13 -0
- liger_kernel/transformers/layer_norm.py +30 -0
- liger_kernel/transformers/model/gemma.py +138 -0
- liger_kernel/transformers/model/llama.py +1 -1
- liger_kernel/transformers/model/mistral.py +138 -0
- liger_kernel/transformers/model/mixtral.py +158 -0
- liger_kernel/transformers/model/phi3.py +136 -0
- liger_kernel/transformers/model/qwen2.py +135 -0
- liger_kernel/transformers/model/qwen2_vl.py +172 -0
- liger_kernel/transformers/monkey_patch.py +605 -14
- liger_kernel/transformers/rms_norm.py +23 -4
- liger_kernel/transformers/swiglu.py +24 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel-0.3.0.dist-info/METADATA +388 -0
- liger_kernel-0.3.0.dist-info/RECORD +42 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
- liger_kernel-0.1.0.dist-info/METADATA +0 -16
- liger_kernel-0.1.0.dist-info/RECORD +0 -27
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_num_warps(BLOCK_SIZE):
|
|
11
|
+
num_warps = 4
|
|
12
|
+
if BLOCK_SIZE >= 32768:
|
|
13
|
+
num_warps = 32
|
|
14
|
+
elif BLOCK_SIZE >= 8192:
|
|
15
|
+
num_warps = 16
|
|
16
|
+
elif BLOCK_SIZE >= 2048:
|
|
17
|
+
num_warps = 8
|
|
18
|
+
|
|
19
|
+
return num_warps
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
23
|
+
|
|
24
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
25
|
+
|
|
26
|
+
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
|
27
|
+
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
|
28
|
+
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
|
29
|
+
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
|
30
|
+
|
|
31
|
+
_str_to_reduction_mode = {
|
|
32
|
+
"none": _REDUCTION_MODE_NONE.value,
|
|
33
|
+
"sum": _REDUCTION_MODE_SUM.value,
|
|
34
|
+
"mean": _REDUCTION_MODE_MEAN.value,
|
|
35
|
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@triton.jit
|
|
40
|
+
def _kldiv_kernel_forward(
|
|
41
|
+
y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
|
|
42
|
+
y_stride, # int, prediction stride
|
|
43
|
+
gt_ptr, # [B, S], ground truth ptr
|
|
44
|
+
gt_stride, # int, ground truth stride
|
|
45
|
+
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
|
|
46
|
+
loss_stride, # int, output stride
|
|
47
|
+
n_cols, # int, number of columns in the input tensor
|
|
48
|
+
BLOCK_SIZE: tl.constexpr,
|
|
49
|
+
log_target: tl.constexpr = False,
|
|
50
|
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
51
|
+
):
|
|
52
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
53
|
+
y_ptr += pid * y_stride
|
|
54
|
+
gt_ptr += pid * gt_stride
|
|
55
|
+
loss_ptr += pid * loss_stride
|
|
56
|
+
|
|
57
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
58
|
+
|
|
59
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
60
|
+
offsets = i + base_offsets
|
|
61
|
+
mask = offsets < n_cols
|
|
62
|
+
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
|
63
|
+
y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
|
|
64
|
+
|
|
65
|
+
# KL(y_true || y) = y_true * (log(y_true) - log(y))
|
|
66
|
+
# We compute KL(y_true || y) with y in the log-space
|
|
67
|
+
if not log_target:
|
|
68
|
+
loss = y_true * (tl.log(y_true) - y)
|
|
69
|
+
else:
|
|
70
|
+
loss = tl.exp(y_true) * (y_true - y)
|
|
71
|
+
|
|
72
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
73
|
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
74
|
+
else:
|
|
75
|
+
loss = tl.sum(loss, axis=0)
|
|
76
|
+
tl.store(loss_ptr, loss)
|
|
77
|
+
loss_ptr += 1 # in case of reduction, the output tensor has dimensions [B,], therefore stride is always 1
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@triton.jit
|
|
81
|
+
def _kldiv_kernel_backward(
|
|
82
|
+
input_ptr,
|
|
83
|
+
input_stride,
|
|
84
|
+
target_ptr,
|
|
85
|
+
target_stride,
|
|
86
|
+
n_cols,
|
|
87
|
+
BLOCK_SIZE: tl.constexpr,
|
|
88
|
+
log_target: tl.constexpr = False,
|
|
89
|
+
):
|
|
90
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
91
|
+
|
|
92
|
+
input_ptr += pid * input_stride
|
|
93
|
+
target_ptr += pid * target_stride
|
|
94
|
+
|
|
95
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
96
|
+
mask = offsets < n_cols
|
|
97
|
+
|
|
98
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
99
|
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
100
|
+
mask = offsets < n_cols
|
|
101
|
+
|
|
102
|
+
target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
|
|
103
|
+
|
|
104
|
+
if not log_target:
|
|
105
|
+
res = target * -1
|
|
106
|
+
else:
|
|
107
|
+
res = -tl.exp(target)
|
|
108
|
+
|
|
109
|
+
tl.store(input_ptr + offsets, res, mask=mask)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B, S]
|
|
113
|
+
B, S = y_pred.shape
|
|
114
|
+
|
|
115
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))
|
|
116
|
+
num_warps = get_num_warps(BLOCK_SIZE)
|
|
117
|
+
|
|
118
|
+
grid = (B,)
|
|
119
|
+
reduction = _str_to_reduction_mode[reduction]
|
|
120
|
+
|
|
121
|
+
out_size = (B, S) if reduction == _REDUCTION_MODE_NONE.value else (B,)
|
|
122
|
+
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
|
|
123
|
+
|
|
124
|
+
_kldiv_kernel_forward[grid](
|
|
125
|
+
y_pred,
|
|
126
|
+
y_pred.stride(0),
|
|
127
|
+
y_true,
|
|
128
|
+
y_true.stride(0),
|
|
129
|
+
output_tensor,
|
|
130
|
+
output_tensor.stride(0),
|
|
131
|
+
S,
|
|
132
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
133
|
+
num_warps=num_warps,
|
|
134
|
+
log_target=log_target,
|
|
135
|
+
reduction=reduction,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
|
|
139
|
+
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
140
|
+
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
|
|
141
|
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
142
|
+
return output_tensor.sum() / B
|
|
143
|
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
144
|
+
return output_tensor.sum(dim=0)
|
|
145
|
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
146
|
+
return output_tensor.mean(dim=0)
|
|
147
|
+
else:
|
|
148
|
+
return output_tensor
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def kldiv_backward_triton(input, target, grad_output, log_target):
|
|
152
|
+
B, S = input.shape
|
|
153
|
+
|
|
154
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))
|
|
155
|
+
num_warps = get_num_warps(BLOCK_SIZE)
|
|
156
|
+
|
|
157
|
+
grid = (B,)
|
|
158
|
+
|
|
159
|
+
# We store the gradients in-place in the input tensor
|
|
160
|
+
_kldiv_kernel_backward[grid](
|
|
161
|
+
input,
|
|
162
|
+
input.stride(0),
|
|
163
|
+
target,
|
|
164
|
+
target.stride(0),
|
|
165
|
+
S,
|
|
166
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
167
|
+
num_warps=num_warps,
|
|
168
|
+
log_target=log_target,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
172
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
173
|
+
return input
|
|
174
|
+
|
|
175
|
+
return input * grad_output
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class LigerKLDivLossFunction(torch.autograd.Function):
|
|
179
|
+
"""
|
|
180
|
+
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
|
181
|
+
```python
|
|
182
|
+
if log_target:
|
|
183
|
+
loss = target * (target.log() - input)
|
|
184
|
+
else:
|
|
185
|
+
loss = target.exp() * (target - input)
|
|
186
|
+
```,
|
|
187
|
+
then the loss is reduced according to the `reduction` parameter.
|
|
188
|
+
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
@ensure_contiguous
|
|
193
|
+
def forward(
|
|
194
|
+
ctx,
|
|
195
|
+
y_pred: torch.Tensor,
|
|
196
|
+
y_true: torch.Tensor,
|
|
197
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
198
|
+
log_target: bool = False,
|
|
199
|
+
) -> torch.Tensor:
|
|
200
|
+
"""A forward pass for the KL Divergence Loss.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
ctx: Torch autograd context
|
|
204
|
+
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
|
|
205
|
+
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
|
|
206
|
+
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
|
|
207
|
+
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
|
|
211
|
+
"""
|
|
212
|
+
ctx.save_for_backward(y_pred, y_true)
|
|
213
|
+
ctx.reduction = reduction
|
|
214
|
+
ctx.log_target = log_target
|
|
215
|
+
return kldiv_forward_triton(
|
|
216
|
+
y_pred, y_true, log_target=log_target, reduction=reduction
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
@ensure_contiguous
|
|
221
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
222
|
+
"""A backward pass for the KL Divergence Loss.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
ctx: Torch autograd context
|
|
226
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
tuple[torch.Tensor, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
|
230
|
+
"""
|
|
231
|
+
y_pred, y_true = ctx.saved_tensors
|
|
232
|
+
|
|
233
|
+
derivative = kldiv_backward_triton(y_pred, y_true, grad_output, ctx.log_target)
|
|
234
|
+
|
|
235
|
+
if ctx.reduction == "batchmean":
|
|
236
|
+
derivative = derivative / y_pred.shape[0]
|
|
237
|
+
elif ctx.reduction == "sum" or ctx.reduction == "none":
|
|
238
|
+
pass
|
|
239
|
+
elif ctx.reduction == "mean":
|
|
240
|
+
derivative = derivative / (y_pred.shape[0] * y_pred.shape[1])
|
|
241
|
+
|
|
242
|
+
return (
|
|
243
|
+
derivative,
|
|
244
|
+
None,
|
|
245
|
+
None,
|
|
246
|
+
None,
|
|
247
|
+
)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import operator
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.utils import (
|
|
9
|
+
calculate_settings,
|
|
10
|
+
compare_version,
|
|
11
|
+
ensure_contiguous,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
try:
|
|
16
|
+
# typical import path with dispatch available
|
|
17
|
+
from triton.language.extra.libdevice import rsqrt
|
|
18
|
+
except ModuleNotFoundError:
|
|
19
|
+
# for working with NGC containers
|
|
20
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
21
|
+
else:
|
|
22
|
+
from triton.language.math import rsqrt
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@triton.jit
|
|
26
|
+
def _layer_norm_forward_kernel(
|
|
27
|
+
Y_ptr, # pointer to output, shape (n_rows, n_cols)
|
|
28
|
+
Y_row_stride, # stride of each row in output
|
|
29
|
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
30
|
+
X_row_stride, # stride of each row in input
|
|
31
|
+
W_ptr, # pointer to weights, shape (n_cols,)
|
|
32
|
+
W_row_stride, # stride of each row in weights
|
|
33
|
+
B_ptr, # pointer to bias, shape (n_cols,)
|
|
34
|
+
B_row_stride, # stride of each row in bias
|
|
35
|
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
36
|
+
Mean_row_stride, # stride of each row in mean
|
|
37
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
38
|
+
RSTD_row_stride, # stride of each row in rstd
|
|
39
|
+
n_cols,
|
|
40
|
+
eps,
|
|
41
|
+
BLOCK_SIZE: tl.constexpr,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
References:
|
|
45
|
+
https://arxiv.org/abs/1607.06450
|
|
46
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
47
|
+
"""
|
|
48
|
+
row_idx = tl.program_id(0)
|
|
49
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
50
|
+
mask = col_offsets < n_cols
|
|
51
|
+
|
|
52
|
+
Y_ptr += row_idx * Y_row_stride
|
|
53
|
+
X_ptr += row_idx * X_row_stride
|
|
54
|
+
Mean_ptr += row_idx * Mean_row_stride
|
|
55
|
+
RSTD_ptr += row_idx * RSTD_row_stride
|
|
56
|
+
|
|
57
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
58
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
59
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
|
|
60
|
+
|
|
61
|
+
mean = tl.sum(X_row, axis=0) / n_cols
|
|
62
|
+
var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
|
|
63
|
+
rstd = rsqrt(var + eps)
|
|
64
|
+
|
|
65
|
+
tl.store(Mean_ptr, mean)
|
|
66
|
+
tl.store(RSTD_ptr, rstd)
|
|
67
|
+
|
|
68
|
+
Y_row = (X_row - mean) * rstd * W_row + B_row
|
|
69
|
+
|
|
70
|
+
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@triton.jit
|
|
74
|
+
def _layer_norm_backward_kernel(
|
|
75
|
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
76
|
+
W_ptr, # pointer to weights, shape (n_cols,)
|
|
77
|
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
78
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
79
|
+
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
80
|
+
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
81
|
+
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
82
|
+
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
83
|
+
stride_x, # stride of each row in input
|
|
84
|
+
stride_dx, # stride of each row in input grad
|
|
85
|
+
stride_dw, # stride of each row in weights grad
|
|
86
|
+
stride_db, # stride of each row in bias grad
|
|
87
|
+
stride_dy, # stride of each row in output grad
|
|
88
|
+
n_rows,
|
|
89
|
+
n_cols,
|
|
90
|
+
rows_per_program: tl.constexpr,
|
|
91
|
+
BLOCK_SIZE: tl.constexpr,
|
|
92
|
+
dtype: tl.constexpr,
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
References:
|
|
96
|
+
https://arxiv.org/abs/1607.06450
|
|
97
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
98
|
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
99
|
+
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
|
100
|
+
"""
|
|
101
|
+
row_block_id = tl.program_id(0)
|
|
102
|
+
row_start = row_block_id * rows_per_program
|
|
103
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
104
|
+
cols = tl.arange(0, BLOCK_SIZE)
|
|
105
|
+
mask = cols < n_cols
|
|
106
|
+
|
|
107
|
+
dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
108
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
109
|
+
|
|
110
|
+
X_ptr += row_start * stride_x
|
|
111
|
+
Mean_ptr += row_start
|
|
112
|
+
RSTD_ptr += row_start
|
|
113
|
+
DX_ptr += row_start * stride_dx
|
|
114
|
+
DY_ptr += row_start * stride_dy
|
|
115
|
+
|
|
116
|
+
for _ in range(row_start, row_end):
|
|
117
|
+
x = tl.load(X_ptr + cols, mask=mask, other=0.0)
|
|
118
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
119
|
+
dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
|
|
120
|
+
mean = tl.load(Mean_ptr)
|
|
121
|
+
rstd = tl.load(RSTD_ptr)
|
|
122
|
+
|
|
123
|
+
x_hat = (x - mean) * rstd
|
|
124
|
+
wdy = w * dy
|
|
125
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
126
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
127
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
|
128
|
+
tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
|
|
129
|
+
|
|
130
|
+
dw_row += dy * x_hat
|
|
131
|
+
db_row += dy
|
|
132
|
+
|
|
133
|
+
X_ptr += stride_x
|
|
134
|
+
Mean_ptr += 1
|
|
135
|
+
RSTD_ptr += 1
|
|
136
|
+
DX_ptr += stride_dx
|
|
137
|
+
DY_ptr += stride_dy
|
|
138
|
+
|
|
139
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
|
|
140
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def layer_norm_forward(X, W, B, eps):
|
|
144
|
+
shape = X.shape
|
|
145
|
+
dim = shape[-1]
|
|
146
|
+
X = X.view(-1, dim)
|
|
147
|
+
n_rows, n_cols = X.shape
|
|
148
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
149
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
150
|
+
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
151
|
+
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
152
|
+
assert (
|
|
153
|
+
X.shape[1] == W.shape[0]
|
|
154
|
+
), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
|
|
155
|
+
|
|
156
|
+
_layer_norm_forward_kernel[(n_rows,)](
|
|
157
|
+
Y,
|
|
158
|
+
Y.stride(0),
|
|
159
|
+
X,
|
|
160
|
+
X.stride(0),
|
|
161
|
+
W,
|
|
162
|
+
W.stride(0),
|
|
163
|
+
B,
|
|
164
|
+
B.stride(0),
|
|
165
|
+
Mean,
|
|
166
|
+
Mean.stride(0),
|
|
167
|
+
RSTD,
|
|
168
|
+
RSTD.stride(0),
|
|
169
|
+
n_cols,
|
|
170
|
+
eps,
|
|
171
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
172
|
+
num_warps=num_warps,
|
|
173
|
+
)
|
|
174
|
+
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
178
|
+
shape = dY.shape
|
|
179
|
+
dim = shape[-1]
|
|
180
|
+
dY = dY.view(-1, dim)
|
|
181
|
+
n_rows, n_cols = dY.shape
|
|
182
|
+
|
|
183
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
184
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
185
|
+
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
|
186
|
+
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
|
187
|
+
|
|
188
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
189
|
+
if n_cols > BLOCK_SIZE:
|
|
190
|
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
191
|
+
|
|
192
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
193
|
+
grid = (sm_count,)
|
|
194
|
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
|
195
|
+
_layer_norm_backward_kernel[grid](
|
|
196
|
+
X,
|
|
197
|
+
W,
|
|
198
|
+
Mean,
|
|
199
|
+
RSTD,
|
|
200
|
+
DX,
|
|
201
|
+
_DW,
|
|
202
|
+
_DB,
|
|
203
|
+
dY,
|
|
204
|
+
X.stride(0),
|
|
205
|
+
DX.stride(0),
|
|
206
|
+
_DW.stride(0),
|
|
207
|
+
_DB.stride(0),
|
|
208
|
+
dY.stride(0),
|
|
209
|
+
n_rows,
|
|
210
|
+
n_cols,
|
|
211
|
+
rows_per_program,
|
|
212
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
213
|
+
dtype=triton_dtype,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
217
|
+
DB = _DB.sum(dim=0).to(W.dtype)
|
|
218
|
+
|
|
219
|
+
DX = DX.view(*shape)
|
|
220
|
+
return DX, DW, DB
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class LigerLayerNormFunction(torch.autograd.Function):
|
|
224
|
+
@staticmethod
|
|
225
|
+
@ensure_contiguous
|
|
226
|
+
def forward(ctx, X, W, B, eps):
|
|
227
|
+
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
|
|
228
|
+
ctx.save_for_backward(X, W, B, Mean, RSTD)
|
|
229
|
+
return Y
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
@ensure_contiguous
|
|
233
|
+
def backward(ctx, dY):
|
|
234
|
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
235
|
+
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
|
|
236
|
+
return DX, DW, DB, None
|