liger-kernel 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/__init__.py +0 -0
- liger_kernel/ops/cross_entropy.py +277 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +161 -0
- liger_kernel/ops/geglu.py +129 -0
- liger_kernel/ops/rms_norm.py +167 -0
- liger_kernel/ops/rope.py +234 -0
- liger_kernel/ops/swiglu.py +113 -0
- liger_kernel/ops/utils.py +38 -0
- liger_kernel/transformers/__init__.py +5 -0
- liger_kernel/transformers/cross_entropy.py +11 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +15 -0
- liger_kernel/transformers/geglu.py +23 -0
- liger_kernel/transformers/model/__init__.py +0 -0
- liger_kernel/transformers/model/llama.py +143 -0
- liger_kernel/transformers/monkey_patch.py +103 -0
- liger_kernel/transformers/rms_norm.py +16 -0
- liger_kernel/transformers/rope.py +20 -0
- liger_kernel/transformers/swiglu.py +40 -0
- liger_kernel/triton/__init__.py +3 -0
- liger_kernel/triton/monkey_patch.py +44 -0
- liger_kernel-0.0.0.dist-info/METADATA +14 -0
- liger_kernel-0.0.0.dist-info/RECORD +24 -0
- liger_kernel-0.0.0.dist-info/WHEEL +5 -0
- liger_kernel-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def _rms_norm_forward(
|
|
10
|
+
Y_ptr,
|
|
11
|
+
Y_row_stride,
|
|
12
|
+
X_ptr,
|
|
13
|
+
X_row_stride,
|
|
14
|
+
W_ptr,
|
|
15
|
+
W_row_stride,
|
|
16
|
+
r_ptr,
|
|
17
|
+
r_row_stride,
|
|
18
|
+
n_cols,
|
|
19
|
+
eps,
|
|
20
|
+
BLOCK_SIZE: tl.constexpr,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Reference:
|
|
24
|
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
25
|
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
row_idx = tl.program_id(0)
|
|
29
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
30
|
+
mask = col_offsets < n_cols
|
|
31
|
+
|
|
32
|
+
Y_ptr += row_idx * Y_row_stride
|
|
33
|
+
X_ptr += row_idx * X_row_stride
|
|
34
|
+
r_ptr += row_idx * r_row_stride
|
|
35
|
+
|
|
36
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
37
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
38
|
+
|
|
39
|
+
row_var = tl.sum(X_row * X_row, axis=0) / n_cols
|
|
40
|
+
inv_var = tl.math.rsqrt(row_var + eps)
|
|
41
|
+
|
|
42
|
+
# trick: row_var is tiny compared to X_row because it just has one per row we can save 4 ops (*, sum, /, rqrt) if we cache it
|
|
43
|
+
tl.store(r_ptr, inv_var)
|
|
44
|
+
|
|
45
|
+
normed = X_row * inv_var
|
|
46
|
+
|
|
47
|
+
output = normed * W_row
|
|
48
|
+
tl.store(Y_ptr + col_offsets, output, mask=mask)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@triton.jit
|
|
52
|
+
def _rms_norm_backward(
|
|
53
|
+
dY_ptr,
|
|
54
|
+
dY_row_stride,
|
|
55
|
+
X_ptr,
|
|
56
|
+
X_row_stride,
|
|
57
|
+
W_ptr,
|
|
58
|
+
W_row_stride,
|
|
59
|
+
r_ptr,
|
|
60
|
+
r_row_stride,
|
|
61
|
+
dW_ptr,
|
|
62
|
+
dW_row_stride,
|
|
63
|
+
n_cols,
|
|
64
|
+
eps,
|
|
65
|
+
BLOCK_SIZE: tl.constexpr,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
dx = (1 / var(x)) * (dy * w - (1/N) * (dy * w) dot x) * x
|
|
69
|
+
dw = sum(dy * (x / var(x)))
|
|
70
|
+
"""
|
|
71
|
+
row_idx = tl.program_id(0)
|
|
72
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
73
|
+
mask = col_offsets < n_cols
|
|
74
|
+
|
|
75
|
+
dY_ptr += row_idx * dY_row_stride
|
|
76
|
+
X_ptr += row_idx * X_row_stride
|
|
77
|
+
r_ptr += row_idx * r_row_stride
|
|
78
|
+
dW_ptr += row_idx * dW_row_stride
|
|
79
|
+
|
|
80
|
+
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
|
|
81
|
+
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
82
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
83
|
+
|
|
84
|
+
# Get saved row variance
|
|
85
|
+
inv_var = tl.load(r_ptr)
|
|
86
|
+
|
|
87
|
+
normed = X_row * inv_var
|
|
88
|
+
|
|
89
|
+
dY_W = dY_row * W_row
|
|
90
|
+
dY_normed = dY_row * normed
|
|
91
|
+
|
|
92
|
+
rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
|
|
93
|
+
output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
|
|
94
|
+
tl.store(dY_ptr + col_offsets, output, mask=mask)
|
|
95
|
+
|
|
96
|
+
# calculate the gradient of W
|
|
97
|
+
tl.store(dW_ptr + col_offsets, dY_normed, mask=mask)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class LigerRMSNormFunction(torch.autograd.Function):
|
|
101
|
+
@staticmethod
|
|
102
|
+
@ensure_contiguous
|
|
103
|
+
def forward(ctx, X, W, eps):
|
|
104
|
+
shape = X.shape
|
|
105
|
+
dim = shape[-1]
|
|
106
|
+
X = X.view(-1, dim)
|
|
107
|
+
n_rows, n_cols = X.shape
|
|
108
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
109
|
+
|
|
110
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device="cuda")
|
|
111
|
+
r = torch.empty(n_rows, dtype=X.dtype, device="cuda")
|
|
112
|
+
|
|
113
|
+
# Check constraints.
|
|
114
|
+
assert (
|
|
115
|
+
X.shape[1] == W.shape[0]
|
|
116
|
+
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
117
|
+
|
|
118
|
+
_rms_norm_forward[(n_rows,)](
|
|
119
|
+
Y,
|
|
120
|
+
Y.stride(0),
|
|
121
|
+
X,
|
|
122
|
+
X.stride(0),
|
|
123
|
+
W,
|
|
124
|
+
W.stride(0),
|
|
125
|
+
r,
|
|
126
|
+
r.stride(0),
|
|
127
|
+
n_cols,
|
|
128
|
+
eps,
|
|
129
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
130
|
+
num_warps=num_warps,
|
|
131
|
+
)
|
|
132
|
+
ctx.eps = eps
|
|
133
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
134
|
+
ctx.num_warps = num_warps
|
|
135
|
+
|
|
136
|
+
ctx.save_for_backward(X, W, r)
|
|
137
|
+
return Y.view(*shape)
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
@ensure_contiguous
|
|
141
|
+
def backward(ctx, dY):
|
|
142
|
+
shape = dY.shape
|
|
143
|
+
dim = shape[-1]
|
|
144
|
+
dY = dY.view(-1, dim)
|
|
145
|
+
X, W, r = ctx.saved_tensors
|
|
146
|
+
n_rows, n_cols = dY.shape
|
|
147
|
+
dW = torch.zeros_like(X)
|
|
148
|
+
|
|
149
|
+
_rms_norm_backward[(n_rows,)](
|
|
150
|
+
dY,
|
|
151
|
+
dY.stride(0),
|
|
152
|
+
X,
|
|
153
|
+
X.stride(0),
|
|
154
|
+
W,
|
|
155
|
+
W.stride(0),
|
|
156
|
+
r,
|
|
157
|
+
r.stride(0),
|
|
158
|
+
dW,
|
|
159
|
+
dW.stride(0),
|
|
160
|
+
n_cols,
|
|
161
|
+
ctx.eps,
|
|
162
|
+
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
|
163
|
+
num_warps=ctx.num_warps,
|
|
164
|
+
)
|
|
165
|
+
dX = dY.view(*shape)
|
|
166
|
+
dW = torch.sum(dW, dim=0)
|
|
167
|
+
return dX, dW, None
|
liger_kernel/ops/rope.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _triton_rope(
|
|
8
|
+
q_ptr,
|
|
9
|
+
q_row_stride,
|
|
10
|
+
k_ptr,
|
|
11
|
+
k_row_stride,
|
|
12
|
+
cos,
|
|
13
|
+
cos_row_stride,
|
|
14
|
+
sin,
|
|
15
|
+
sin_row_stride,
|
|
16
|
+
bs: tl.constexpr,
|
|
17
|
+
sl: tl.constexpr,
|
|
18
|
+
n_qh: tl.constexpr,
|
|
19
|
+
n_kh: tl.constexpr,
|
|
20
|
+
hd: tl.constexpr,
|
|
21
|
+
pad_n_qh: tl.constexpr,
|
|
22
|
+
pad_n_kh: tl.constexpr,
|
|
23
|
+
pad_hd: tl.constexpr,
|
|
24
|
+
BLOCK_SIZE: tl.constexpr,
|
|
25
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
26
|
+
):
|
|
27
|
+
# q size: (bsz, seq_len, num_q_heads, head_dim)
|
|
28
|
+
# q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
|
|
29
|
+
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
|
30
|
+
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
|
31
|
+
|
|
32
|
+
# cos size: (1, seq_len, head_dim)
|
|
33
|
+
# stride: (seq_len * head_dim, head_dim, 1)
|
|
34
|
+
pid = tl.program_id(0)
|
|
35
|
+
|
|
36
|
+
# locate start address
|
|
37
|
+
q_ptr = q_ptr + pid * q_row_stride
|
|
38
|
+
k_ptr = k_ptr + pid * k_row_stride
|
|
39
|
+
|
|
40
|
+
# ####################################################################
|
|
41
|
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
|
42
|
+
# m of this program instance
|
|
43
|
+
# ####################################################################
|
|
44
|
+
|
|
45
|
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
|
46
|
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
|
47
|
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
|
48
|
+
# and pid % sl to get the sequence index.
|
|
49
|
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
|
50
|
+
# a clone of the left half.
|
|
51
|
+
cos_row_idx = pid % (sl)
|
|
52
|
+
cos = cos + cos_row_idx * cos_row_stride
|
|
53
|
+
sin = sin + cos_row_idx * sin_row_stride
|
|
54
|
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
55
|
+
cos_mask = cos_offsets < hd // 2
|
|
56
|
+
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
|
57
|
+
sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
|
|
58
|
+
|
|
59
|
+
# ####################################################################
|
|
60
|
+
# Load the left and right half of q and k for the current
|
|
61
|
+
# program instance (i.e. for the current token) separately
|
|
62
|
+
# ####################################################################
|
|
63
|
+
# left half of the head
|
|
64
|
+
first_half_q_offsets = (
|
|
65
|
+
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
66
|
+
)
|
|
67
|
+
first_half_k_offsets = (
|
|
68
|
+
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
69
|
+
)
|
|
70
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
71
|
+
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
72
|
+
)
|
|
73
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
74
|
+
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
75
|
+
)
|
|
76
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
77
|
+
sin_row.dtype
|
|
78
|
+
)
|
|
79
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
80
|
+
sin_row.dtype
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# right half of the head
|
|
84
|
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
85
|
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
86
|
+
second_q_mask = first_q_mask
|
|
87
|
+
second_k_mask = first_k_mask
|
|
88
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
89
|
+
sin_row.dtype
|
|
90
|
+
)
|
|
91
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
92
|
+
sin_row.dtype
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if not BACKWARD_PASS:
|
|
96
|
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
97
|
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
|
98
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
99
|
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
|
100
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
101
|
+
|
|
102
|
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
|
103
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
104
|
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
|
105
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
106
|
+
else:
|
|
107
|
+
# with some math, we can get:
|
|
108
|
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
|
109
|
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
|
110
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
111
|
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
|
112
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
113
|
+
|
|
114
|
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
|
115
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
116
|
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
|
117
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class LigerRopeFunction(torch.autograd.Function):
|
|
121
|
+
"""
|
|
122
|
+
Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
|
|
123
|
+
this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
|
|
124
|
+
than the original RoPE paper.
|
|
125
|
+
|
|
126
|
+
Please find the corresponding HuggingFace implementation here:
|
|
127
|
+
https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
|
|
128
|
+
|
|
129
|
+
For more details about the rotation matrix used here, please refer to:
|
|
130
|
+
https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
135
|
+
"""
|
|
136
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
137
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
138
|
+
cos size: (1, seq_len, head_dim)
|
|
139
|
+
sin size: (1, seq_len, head_dim)
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
143
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
144
|
+
q = q.transpose(1, 2)
|
|
145
|
+
k = k.transpose(1, 2)
|
|
146
|
+
|
|
147
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
148
|
+
n_kv_head = k.shape[2]
|
|
149
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
150
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
151
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
152
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
153
|
+
|
|
154
|
+
n_row = batch_size * seq_len
|
|
155
|
+
|
|
156
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
157
|
+
q = q.contiguous()
|
|
158
|
+
k = k.contiguous()
|
|
159
|
+
cos = cos.contiguous()
|
|
160
|
+
sin = sin.contiguous()
|
|
161
|
+
|
|
162
|
+
_triton_rope[(n_row,)](
|
|
163
|
+
q,
|
|
164
|
+
q.stride(1),
|
|
165
|
+
k,
|
|
166
|
+
k.stride(1),
|
|
167
|
+
cos,
|
|
168
|
+
cos.stride(-2),
|
|
169
|
+
sin,
|
|
170
|
+
sin.stride(-2),
|
|
171
|
+
batch_size,
|
|
172
|
+
seq_len,
|
|
173
|
+
n_q_head,
|
|
174
|
+
n_kv_head,
|
|
175
|
+
head_dim,
|
|
176
|
+
pad_n_q_head,
|
|
177
|
+
pad_n_kv_head,
|
|
178
|
+
pad_hd,
|
|
179
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
180
|
+
BACKWARD_PASS=False,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
ctx.save_for_backward(cos, sin)
|
|
184
|
+
return q.transpose(1, 2), k.transpose(1, 2)
|
|
185
|
+
|
|
186
|
+
def backward(ctx, dq, dk):
|
|
187
|
+
"""
|
|
188
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
189
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
190
|
+
cos size: (1, seq_len, head_dim)
|
|
191
|
+
sin size: (1, seq_len, head_dim)
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
cos, sin = ctx.saved_tensors
|
|
195
|
+
|
|
196
|
+
dq = dq.transpose(1, 2)
|
|
197
|
+
dk = dk.transpose(1, 2)
|
|
198
|
+
|
|
199
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
200
|
+
n_kv_head = dk.shape[2]
|
|
201
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
202
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
203
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
204
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
205
|
+
|
|
206
|
+
n_row = batch_size * seq_len
|
|
207
|
+
|
|
208
|
+
# ensure dq and dk are contiguous
|
|
209
|
+
dq = dq.contiguous()
|
|
210
|
+
dk = dk.contiguous()
|
|
211
|
+
|
|
212
|
+
# backward is similar to forward except swapping few ops
|
|
213
|
+
_triton_rope[(n_row,)](
|
|
214
|
+
dq,
|
|
215
|
+
dq.stride(1),
|
|
216
|
+
dk,
|
|
217
|
+
dk.stride(1),
|
|
218
|
+
cos,
|
|
219
|
+
cos.stride(-2),
|
|
220
|
+
sin,
|
|
221
|
+
sin.stride(-2),
|
|
222
|
+
batch_size,
|
|
223
|
+
seq_len,
|
|
224
|
+
n_q_head,
|
|
225
|
+
n_kv_head,
|
|
226
|
+
head_dim,
|
|
227
|
+
pad_n_q_head,
|
|
228
|
+
pad_n_kv_head,
|
|
229
|
+
pad_hd,
|
|
230
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
231
|
+
BACKWARD_PASS=True,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@triton.jit
|
|
9
|
+
def silu(x):
|
|
10
|
+
return x * tl.sigmoid(x)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@triton.jit
|
|
14
|
+
def _swiglu_forward_kernel(
|
|
15
|
+
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
16
|
+
):
|
|
17
|
+
program_id = tl.program_id(0)
|
|
18
|
+
|
|
19
|
+
# locate start index
|
|
20
|
+
a += program_id * stride
|
|
21
|
+
b += program_id * stride
|
|
22
|
+
c += program_id * stride
|
|
23
|
+
|
|
24
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
25
|
+
mask = col_offsets < n_cols
|
|
26
|
+
|
|
27
|
+
# sigmoid requires type float32
|
|
28
|
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
29
|
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
|
30
|
+
c_row = silu(a_row) * b_row
|
|
31
|
+
tl.store(c + col_offsets, c_row, mask=mask)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@triton.jit
|
|
35
|
+
def _swiglu_backward_kernel(
|
|
36
|
+
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
37
|
+
):
|
|
38
|
+
program_id = tl.program_id(0)
|
|
39
|
+
|
|
40
|
+
# locate start index
|
|
41
|
+
dc += program_id * stride
|
|
42
|
+
a += program_id * stride
|
|
43
|
+
b += program_id * stride
|
|
44
|
+
|
|
45
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
46
|
+
mask = col_offsets < n_cols
|
|
47
|
+
|
|
48
|
+
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
|
|
49
|
+
# sigmoid requires type float32
|
|
50
|
+
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
51
|
+
b_row = tl.load(b + col_offsets, mask=mask, other=0)
|
|
52
|
+
|
|
53
|
+
# recomputation to save memory
|
|
54
|
+
sig_a = tl.sigmoid(a_row)
|
|
55
|
+
silu_a = a_row * sig_a
|
|
56
|
+
db_row = dc_row * silu_a
|
|
57
|
+
da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row
|
|
58
|
+
|
|
59
|
+
tl.store(a + col_offsets, da_row, mask=mask)
|
|
60
|
+
tl.store(b + col_offsets, db_row, mask=mask)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LigerSiLUMulFunction(torch.autograd.Function):
|
|
64
|
+
@staticmethod
|
|
65
|
+
@ensure_contiguous
|
|
66
|
+
def forward(ctx, a, b):
|
|
67
|
+
ori_shape = a.shape
|
|
68
|
+
|
|
69
|
+
n_cols = ori_shape[-1]
|
|
70
|
+
a = a.view(-1, n_cols)
|
|
71
|
+
b = b.view(-1, n_cols)
|
|
72
|
+
c = torch.zeros_like(a)
|
|
73
|
+
n_rows = a.shape[0]
|
|
74
|
+
|
|
75
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
76
|
+
|
|
77
|
+
_swiglu_forward_kernel[(n_rows,)](
|
|
78
|
+
a,
|
|
79
|
+
b,
|
|
80
|
+
c,
|
|
81
|
+
c.stride(-2),
|
|
82
|
+
n_cols=n_cols,
|
|
83
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
84
|
+
num_warps=num_warps,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
ctx.save_for_backward(a, b)
|
|
88
|
+
|
|
89
|
+
return c.view(*ori_shape)
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
@ensure_contiguous
|
|
93
|
+
def backward(ctx, dc):
|
|
94
|
+
|
|
95
|
+
ori_shape = dc.shape
|
|
96
|
+
n_cols = ori_shape[-1]
|
|
97
|
+
dc = dc.view(-1, n_cols)
|
|
98
|
+
a, b = ctx.saved_tensors
|
|
99
|
+
n_rows = dc.shape[0]
|
|
100
|
+
|
|
101
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
102
|
+
|
|
103
|
+
_swiglu_backward_kernel[(n_rows,)](
|
|
104
|
+
dc,
|
|
105
|
+
a,
|
|
106
|
+
b,
|
|
107
|
+
dc.stride(-2),
|
|
108
|
+
n_cols=n_cols,
|
|
109
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
110
|
+
num_warps=num_warps,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return a.view(*ori_shape), b.view(*ori_shape)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def ensure_contiguous(fn):
|
|
8
|
+
@functools.wraps(fn)
|
|
9
|
+
def wrapper(ctx, *args, **kwargs):
|
|
10
|
+
def maybe_to_contiguous(x):
|
|
11
|
+
return x.contiguous() if isinstance(x, torch.Tensor) else x
|
|
12
|
+
|
|
13
|
+
args = [maybe_to_contiguous(arg) for arg in args]
|
|
14
|
+
kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
|
|
15
|
+
return fn(ctx, *args, **kwargs)
|
|
16
|
+
|
|
17
|
+
return wrapper
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def calculate_settings(n):
|
|
21
|
+
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
|
|
22
|
+
|
|
23
|
+
MAX_FUSED_SIZE = 65536
|
|
24
|
+
BLOCK_SIZE = triton.next_power_of_2(n)
|
|
25
|
+
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
|
26
|
+
raise RuntimeError(
|
|
27
|
+
f"Cannot launch Triton kernel since n = {n} exceeds "
|
|
28
|
+
f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
num_warps = 4
|
|
32
|
+
if BLOCK_SIZE >= 32768:
|
|
33
|
+
num_warps = 32
|
|
34
|
+
elif BLOCK_SIZE >= 8192:
|
|
35
|
+
num_warps = 16
|
|
36
|
+
elif BLOCK_SIZE >= 2048:
|
|
37
|
+
num_warps = 8
|
|
38
|
+
return BLOCK_SIZE, num_warps
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from torch.nn import CrossEntropyLoss
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerCrossEntropyLoss(CrossEntropyLoss):
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
9
|
+
|
|
10
|
+
def forward(self, _input, target):
|
|
11
|
+
return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from torch.nn import CrossEntropyLoss
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
4
|
+
LigerFusedLinearCrossEntropyFunction,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def forward(self, lin_weight, _input, target):
|
|
13
|
+
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
14
|
+
_input, lin_weight, target, self.ignore_index
|
|
15
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerGEGLUMLP(nn.Module):
|
|
7
|
+
def __init__(self, config):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.config = config
|
|
10
|
+
self.hidden_size = config.hidden_size
|
|
11
|
+
self.intermediate_size = config.intermediate_size
|
|
12
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
13
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
14
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
15
|
+
# TODO: support exact GELU
|
|
16
|
+
if config.hidden_act not in ["gelu_pytorch_tanh"]:
|
|
17
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
18
|
+
|
|
19
|
+
def forward(self, x):
|
|
20
|
+
|
|
21
|
+
return self.down_proj(
|
|
22
|
+
LigerGELUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
|
23
|
+
)
|
|
File without changes
|