liger-kernel 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/env_report.py +2 -0
- liger_kernel/ops/cross_entropy.py +143 -30
- liger_kernel/ops/fused_linear_cross_entropy.py +19 -2
- liger_kernel/ops/group_norm.py +322 -0
- liger_kernel/ops/rms_norm.py +27 -6
- liger_kernel/transformers/cross_entropy.py +44 -12
- liger_kernel/transformers/functional.py +34 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
- liger_kernel/transformers/group_norm.py +56 -0
- liger_kernel/transformers/model/gemma2.py +277 -0
- liger_kernel/transformers/monkey_patch.py +101 -62
- liger_kernel/transformers/rms_norm.py +11 -3
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +5 -3
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/RECORD +18 -15
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.0.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import compare_version, ensure_contiguous
|
|
8
|
+
|
|
9
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
10
|
+
try:
|
|
11
|
+
# typical import path with dispatch available
|
|
12
|
+
from triton.language.extra.libdevice import rsqrt
|
|
13
|
+
except ModuleNotFoundError:
|
|
14
|
+
# for working with NGC containers
|
|
15
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
16
|
+
else:
|
|
17
|
+
from triton.language.math import rsqrt
|
|
18
|
+
|
|
19
|
+
MAX_FUSED_SIZE = 65536
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@triton.jit
|
|
23
|
+
def _group_norm_forward_kernel(
|
|
24
|
+
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
|
|
25
|
+
Y_row_stride, # stride of each row in output
|
|
26
|
+
Y_col_stride, # stride of each column in output
|
|
27
|
+
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
|
|
28
|
+
X_row_stride, # stride of each row in input
|
|
29
|
+
X_col_stride, # stride of each column in input
|
|
30
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
31
|
+
Mean_row_stride, # stride of each row in mean
|
|
32
|
+
Mean_col_stride, # stride of each column in mean
|
|
33
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
34
|
+
RSTD_row_stride, # stride of each row in rstd
|
|
35
|
+
RSTD_col_stride, # stride of each column in rstd
|
|
36
|
+
W_ptr, # pointer to W
|
|
37
|
+
B_ptr, # pointer to B
|
|
38
|
+
hidden_size, # hidden size of X
|
|
39
|
+
channels_per_group, # the number of channels per group
|
|
40
|
+
eps,
|
|
41
|
+
BLOCK_SIZE: tl.constexpr,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
References:
|
|
45
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
46
|
+
"""
|
|
47
|
+
batch_idx = tl.program_id(0)
|
|
48
|
+
group_idx = tl.program_id(1)
|
|
49
|
+
|
|
50
|
+
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
|
|
51
|
+
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
|
|
52
|
+
|
|
53
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
54
|
+
|
|
55
|
+
# Compute mean and variance using the online algorithm
|
|
56
|
+
s = 0.0
|
|
57
|
+
squared_sum = 0.0
|
|
58
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
59
|
+
hidden_size_offsets = i + block_range
|
|
60
|
+
mask = hidden_size_offsets < hidden_size
|
|
61
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
|
|
62
|
+
s += tl.sum(X)
|
|
63
|
+
# X**2
|
|
64
|
+
squared_sum += tl.sum(X * X)
|
|
65
|
+
|
|
66
|
+
m = s / hidden_size
|
|
67
|
+
|
|
68
|
+
# variance = E[X**2] - E[X]**2
|
|
69
|
+
variance = (squared_sum / hidden_size) - (m * m)
|
|
70
|
+
|
|
71
|
+
# 1/std
|
|
72
|
+
rstd = rsqrt(variance + eps)
|
|
73
|
+
|
|
74
|
+
# Normalize
|
|
75
|
+
hidden_size_per_channel = hidden_size // channels_per_group
|
|
76
|
+
for channel_idx in tl.range(
|
|
77
|
+
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
78
|
+
):
|
|
79
|
+
W = tl.load(W_ptr + channel_idx)
|
|
80
|
+
B = tl.load(B_ptr + channel_idx)
|
|
81
|
+
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
82
|
+
hidden_size_offsets = i + block_range
|
|
83
|
+
mask = hidden_size_offsets < hidden_size_per_channel
|
|
84
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
85
|
+
Y = (X - m) * rstd * W + B
|
|
86
|
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
87
|
+
|
|
88
|
+
X_ptr += hidden_size_per_channel
|
|
89
|
+
Y_ptr += hidden_size_per_channel
|
|
90
|
+
|
|
91
|
+
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
92
|
+
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@triton.jit
|
|
96
|
+
def _group_norm_backward_kernel(
|
|
97
|
+
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
|
|
98
|
+
X_row_stride, # stride of each row in input
|
|
99
|
+
X_col_stride, # stride of each column in input
|
|
100
|
+
W_ptr, # pointer to weights, shape (n_channels)
|
|
101
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
102
|
+
Mean_ptr_row_stride, # stride of each column in mean
|
|
103
|
+
Mean_ptr_col_stride, # stride of each column in mean
|
|
104
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
105
|
+
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
|
|
106
|
+
DW_ptr, # pointer to weights grad, shape (n_channels)
|
|
107
|
+
DB_ptr, # pointer to bias grad, shape (n_channels)
|
|
108
|
+
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
|
|
109
|
+
hidden_size: tl.constexpr, # hidden size
|
|
110
|
+
channels_per_group: tl.constexpr, # number of groups in group norm
|
|
111
|
+
BLOCK_SIZE: tl.constexpr,
|
|
112
|
+
dtype: tl.constexpr,
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
References:
|
|
116
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
117
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
118
|
+
|
|
119
|
+
The backprop equations are the same for group_norm and layer_norm
|
|
120
|
+
the only difference here is that we load the Mean, Rstd corresponding to the
|
|
121
|
+
group we're computing gradients for and the mean and rstd are computed over n-channels
|
|
122
|
+
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
|
|
123
|
+
|
|
124
|
+
We also need to load the Weights corresponding to the current channel to compute the gradients.
|
|
125
|
+
"""
|
|
126
|
+
batch_idx = tl.program_id(0)
|
|
127
|
+
group_idx = tl.program_id(1)
|
|
128
|
+
|
|
129
|
+
# Move the pointers to the correct batch
|
|
130
|
+
X_ptr += batch_idx * X_row_stride
|
|
131
|
+
DX_ptr += batch_idx * X_row_stride
|
|
132
|
+
UPSTREAM_ptr += batch_idx * X_row_stride
|
|
133
|
+
|
|
134
|
+
# Mean and rstd are the same shape so have the same strides
|
|
135
|
+
mean = tl.load(
|
|
136
|
+
Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
|
137
|
+
)
|
|
138
|
+
rstd = tl.load(
|
|
139
|
+
RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
c1 = 0.0
|
|
143
|
+
c2 = 0.0
|
|
144
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
145
|
+
|
|
146
|
+
# We need to compute the sum terms of the backprop equations across all channels in the group
|
|
147
|
+
for channel_idx in range(
|
|
148
|
+
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
149
|
+
):
|
|
150
|
+
dW = 0.0
|
|
151
|
+
dB = 0.0
|
|
152
|
+
# Move the pointers to the correct channel
|
|
153
|
+
W = tl.load(W_ptr + channel_idx)
|
|
154
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
155
|
+
hidden_size_offsets = i + block_range
|
|
156
|
+
mask = hidden_size_offsets < hidden_size
|
|
157
|
+
X = tl.load(
|
|
158
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
159
|
+
mask=mask,
|
|
160
|
+
other=0.0,
|
|
161
|
+
)
|
|
162
|
+
UPSTREAM_grad = tl.load(
|
|
163
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
164
|
+
mask=mask,
|
|
165
|
+
other=0.0,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
x_hat = (X - mean) * rstd
|
|
169
|
+
dW += tl.sum(UPSTREAM_grad * x_hat)
|
|
170
|
+
dB += tl.sum(UPSTREAM_grad)
|
|
171
|
+
|
|
172
|
+
wdy = W * UPSTREAM_grad
|
|
173
|
+
c1 += tl.sum(x_hat * wdy)
|
|
174
|
+
c2 += tl.sum(wdy)
|
|
175
|
+
|
|
176
|
+
# Need to ensure additions to the same channel are atomic
|
|
177
|
+
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
|
|
178
|
+
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
|
|
179
|
+
|
|
180
|
+
N = hidden_size * channels_per_group
|
|
181
|
+
c1 = c1 / N
|
|
182
|
+
c2 = c2 / N
|
|
183
|
+
|
|
184
|
+
for channel_idx in tl.range(
|
|
185
|
+
group_idx * channels_per_group, (group_idx + 1) * channels_per_group
|
|
186
|
+
):
|
|
187
|
+
# Move the pointers to the correct channel
|
|
188
|
+
W = tl.load(W_ptr + channel_idx)
|
|
189
|
+
for i in range(0, hidden_size, BLOCK_SIZE):
|
|
190
|
+
hidden_size_offsets = i + block_range
|
|
191
|
+
mask = hidden_size_offsets < hidden_size
|
|
192
|
+
X = tl.load(
|
|
193
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
194
|
+
mask=mask,
|
|
195
|
+
other=0.0,
|
|
196
|
+
)
|
|
197
|
+
UPSTREAM_grad = tl.load(
|
|
198
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
199
|
+
mask=mask,
|
|
200
|
+
other=0.0,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
x_hat = (X - mean) * rstd
|
|
204
|
+
wdy = W * UPSTREAM_grad
|
|
205
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
|
206
|
+
tl.store(
|
|
207
|
+
DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
212
|
+
shape = X.shape
|
|
213
|
+
batch_size = shape[0]
|
|
214
|
+
channels_per_group = num_channels // num_groups
|
|
215
|
+
# Reshape X so that the mean and std are computed across the groups
|
|
216
|
+
X = X.view(batch_size, num_groups, -1).contiguous()
|
|
217
|
+
hidden_size = X.shape[-1]
|
|
218
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
219
|
+
Y = torch.empty(
|
|
220
|
+
(batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device
|
|
221
|
+
)
|
|
222
|
+
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
223
|
+
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
224
|
+
|
|
225
|
+
_group_norm_forward_kernel[(batch_size, num_groups)](
|
|
226
|
+
Y,
|
|
227
|
+
Y.stride(0),
|
|
228
|
+
Y.stride(1),
|
|
229
|
+
X,
|
|
230
|
+
X.stride(0),
|
|
231
|
+
X.stride(1),
|
|
232
|
+
Mean,
|
|
233
|
+
Mean.stride(0),
|
|
234
|
+
Mean.stride(1),
|
|
235
|
+
RSTD,
|
|
236
|
+
RSTD.stride(0),
|
|
237
|
+
RSTD.stride(1),
|
|
238
|
+
W,
|
|
239
|
+
B,
|
|
240
|
+
hidden_size,
|
|
241
|
+
channels_per_group,
|
|
242
|
+
eps,
|
|
243
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
244
|
+
)
|
|
245
|
+
# Return tensors in the original shape
|
|
246
|
+
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
|
|
250
|
+
shape = dY.shape
|
|
251
|
+
batch_size = shape[0]
|
|
252
|
+
hidden_size = dY.shape[-1]
|
|
253
|
+
channels_per_group = num_channels // num_groups
|
|
254
|
+
dY = dY.view(batch_size, num_groups, -1)
|
|
255
|
+
DX = torch.empty(
|
|
256
|
+
(batch_size, num_groups, hidden_size * channels_per_group),
|
|
257
|
+
dtype=X.dtype,
|
|
258
|
+
device=X.device,
|
|
259
|
+
)
|
|
260
|
+
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
|
|
261
|
+
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
|
|
262
|
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
|
263
|
+
|
|
264
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
265
|
+
_group_norm_backward_kernel[(batch_size, num_groups)](
|
|
266
|
+
X,
|
|
267
|
+
X.stride(0),
|
|
268
|
+
X.stride(1),
|
|
269
|
+
W,
|
|
270
|
+
Mean,
|
|
271
|
+
Mean.stride(0),
|
|
272
|
+
Mean.stride(1),
|
|
273
|
+
RSTD,
|
|
274
|
+
DX,
|
|
275
|
+
DW,
|
|
276
|
+
DB,
|
|
277
|
+
dY,
|
|
278
|
+
hidden_size,
|
|
279
|
+
channels_per_group,
|
|
280
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
281
|
+
dtype=triton_dtype,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Return tensors in the original shape
|
|
285
|
+
return DX.view(*shape), DW, DB
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class LigerGroupNormFunction(torch.autograd.Function):
|
|
289
|
+
@staticmethod
|
|
290
|
+
@ensure_contiguous
|
|
291
|
+
def forward(
|
|
292
|
+
ctx,
|
|
293
|
+
X,
|
|
294
|
+
affine_scaling_weight,
|
|
295
|
+
affine_shifting_bias,
|
|
296
|
+
num_channels,
|
|
297
|
+
num_groups,
|
|
298
|
+
eps,
|
|
299
|
+
):
|
|
300
|
+
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
|
|
301
|
+
X,
|
|
302
|
+
num_channels,
|
|
303
|
+
num_groups,
|
|
304
|
+
affine_scaling_weight,
|
|
305
|
+
affine_shifting_bias,
|
|
306
|
+
eps,
|
|
307
|
+
)
|
|
308
|
+
ctx.num_channels = num_channels
|
|
309
|
+
ctx.num_groups = num_groups
|
|
310
|
+
ctx.save_for_backward(
|
|
311
|
+
X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD
|
|
312
|
+
)
|
|
313
|
+
return Y
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
@ensure_contiguous
|
|
317
|
+
def backward(ctx, dY):
|
|
318
|
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
319
|
+
DX, DW, DB = group_norm_backward(
|
|
320
|
+
dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups
|
|
321
|
+
)
|
|
322
|
+
return DX, DW, DB, None, None, None
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -116,6 +116,8 @@ def _rms_norm_forward_kernel(
|
|
|
116
116
|
def _rms_norm_backward_kernel(
|
|
117
117
|
dY_ptr,
|
|
118
118
|
dY_row_stride,
|
|
119
|
+
dX_ptr,
|
|
120
|
+
dX_row_stride,
|
|
119
121
|
X_ptr,
|
|
120
122
|
X_row_stride,
|
|
121
123
|
X_dtype: tl.constexpr,
|
|
@@ -146,6 +148,8 @@ def _rms_norm_backward_kernel(
|
|
|
146
148
|
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
149
|
|
|
148
150
|
dY_ptr += row_start * dY_row_stride
|
|
151
|
+
dX_ptr += row_start * dX_row_stride
|
|
152
|
+
|
|
149
153
|
X_ptr += row_start * X_row_stride
|
|
150
154
|
RSTD_ptr += row_start
|
|
151
155
|
|
|
@@ -184,9 +188,10 @@ def _rms_norm_backward_kernel(
|
|
|
184
188
|
# here X_row is already in fp32 (see previous if block)
|
|
185
189
|
dW_row += dY_row * (X_row * rstd_row)
|
|
186
190
|
|
|
187
|
-
tl.store(
|
|
191
|
+
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
188
192
|
|
|
189
193
|
dY_ptr += dY_row_stride
|
|
194
|
+
dX_ptr += dX_row_stride
|
|
190
195
|
X_ptr += X_row_stride
|
|
191
196
|
RSTD_ptr += RSTD_row_stride
|
|
192
197
|
|
|
@@ -251,7 +256,9 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
251
256
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
252
257
|
|
|
253
258
|
|
|
254
|
-
def rms_norm_backward(
|
|
259
|
+
def rms_norm_backward(
|
|
260
|
+
dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
|
|
261
|
+
):
|
|
255
262
|
shape = dY.shape
|
|
256
263
|
dim = shape[-1]
|
|
257
264
|
dY = dY.view(-1, dim)
|
|
@@ -265,10 +272,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
265
272
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
266
273
|
rows_per_program = math.ceil(n_rows / sm_count)
|
|
267
274
|
grid = (sm_count,)
|
|
268
|
-
|
|
275
|
+
|
|
276
|
+
if in_place is True:
|
|
277
|
+
dX = dY
|
|
278
|
+
else:
|
|
279
|
+
dX = torch.zeros_like(dY)
|
|
280
|
+
|
|
269
281
|
_rms_norm_backward_kernel[grid](
|
|
270
282
|
dY,
|
|
271
283
|
dY.stride(0),
|
|
284
|
+
dX,
|
|
285
|
+
dX.stride(0),
|
|
272
286
|
X,
|
|
273
287
|
X.stride(0),
|
|
274
288
|
torch_to_triton_dtype[X.dtype],
|
|
@@ -286,8 +300,9 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
286
300
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
287
301
|
num_warps=num_warps,
|
|
288
302
|
)
|
|
289
|
-
dX =
|
|
303
|
+
dX = dX.view(*shape)
|
|
290
304
|
dW = _dW.sum(dim=0).to(W.dtype)
|
|
305
|
+
|
|
291
306
|
return dX, dW
|
|
292
307
|
|
|
293
308
|
|
|
@@ -307,11 +322,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
307
322
|
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
|
308
323
|
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
|
309
324
|
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
|
325
|
+
|
|
326
|
+
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
|
|
327
|
+
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
|
|
328
|
+
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
|
|
310
329
|
"""
|
|
311
330
|
|
|
312
331
|
@staticmethod
|
|
313
332
|
@ensure_contiguous
|
|
314
|
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
|
|
333
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
|
315
334
|
"""
|
|
316
335
|
X: (B, T, H) or (BxT, H)
|
|
317
336
|
W: (H,)
|
|
@@ -321,6 +340,7 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
321
340
|
)
|
|
322
341
|
ctx.offset = offset
|
|
323
342
|
ctx.casting_mode = casting_mode
|
|
343
|
+
ctx.in_place = in_place
|
|
324
344
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
325
345
|
ctx.num_warps = num_warps
|
|
326
346
|
ctx.save_for_backward(X, W, RSTD)
|
|
@@ -342,5 +362,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
342
362
|
ctx.casting_mode,
|
|
343
363
|
ctx.BLOCK_SIZE,
|
|
344
364
|
ctx.num_warps,
|
|
365
|
+
ctx.in_place,
|
|
345
366
|
)
|
|
346
|
-
return dX, dW, None, None, None
|
|
367
|
+
return dX, dW, None, None, None, None
|
|
@@ -1,21 +1,53 @@
|
|
|
1
|
-
from
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
2
4
|
|
|
3
5
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
class LigerCrossEntropyLoss(
|
|
7
|
-
def __init__(
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
class LigerCrossEntropyLoss(torch.nn.Module):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
ignore_index: int = -100,
|
|
12
|
+
lse_square_scale: float = 0.0,
|
|
13
|
+
label_smoothing: float = 0.0,
|
|
14
|
+
reduction: str = "mean",
|
|
15
|
+
softcap: Optional[float] = None,
|
|
16
|
+
return_z_loss: bool = False,
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
assert (label_smoothing >= 0) and (
|
|
20
|
+
label_smoothing <= 1
|
|
21
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
22
|
+
assert (label_smoothing >= 0) and (
|
|
23
|
+
label_smoothing <= 1
|
|
24
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
25
|
+
assert reduction in {
|
|
13
26
|
"mean",
|
|
14
27
|
"sum",
|
|
15
28
|
"none",
|
|
16
|
-
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {
|
|
29
|
+
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
30
|
+
assert (
|
|
31
|
+
softcap is None or softcap > 0
|
|
32
|
+
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
33
|
+
self.ignore_index = ignore_index
|
|
34
|
+
self.lse_square_scale = lse_square_scale
|
|
35
|
+
self.label_smoothing = label_smoothing
|
|
36
|
+
self.reduction = reduction
|
|
37
|
+
self.softcap = softcap
|
|
38
|
+
self.return_z_loss = return_z_loss
|
|
17
39
|
|
|
18
|
-
def forward(self, _input, target):
|
|
19
|
-
|
|
20
|
-
_input,
|
|
40
|
+
def forward(self, _input: torch.Tensor, target: torch.Tensor):
|
|
41
|
+
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
42
|
+
_input,
|
|
43
|
+
target,
|
|
44
|
+
self.ignore_index,
|
|
45
|
+
self.lse_square_scale,
|
|
46
|
+
self.label_smoothing,
|
|
47
|
+
self.reduction,
|
|
48
|
+
self.softcap,
|
|
49
|
+
self.return_z_loss,
|
|
21
50
|
)
|
|
51
|
+
if not self.return_z_loss:
|
|
52
|
+
return loss
|
|
53
|
+
return loss, z_loss
|
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
1
3
|
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
|
|
2
4
|
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
3
5
|
LigerFusedLinearCrossEntropyFunction,
|
|
4
6
|
)
|
|
5
7
|
from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
|
|
6
8
|
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
9
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
7
10
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
|
8
11
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
9
12
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
@@ -12,7 +15,6 @@ from liger_kernel.ops.rope import LigerRopeFunction
|
|
|
12
15
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
13
16
|
|
|
14
17
|
liger_swiglu = LigerSiLUMulFunction.apply
|
|
15
|
-
liger_cross_entropy = LigerCrossEntropyFunction.apply
|
|
16
18
|
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
|
|
17
19
|
liger_geglu = LigerGELUMulFunction.apply
|
|
18
20
|
liger_rms_norm = LigerRMSNormFunction.apply
|
|
@@ -21,3 +23,34 @@ liger_layer_norm = LigerLayerNormFunction.apply
|
|
|
21
23
|
liger_kl_div = LigerKLDivLossFunction.apply
|
|
22
24
|
liger_jsd = LigerJSDFunction.apply
|
|
23
25
|
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
26
|
+
liger_group_norm = LigerGroupNormFunction.apply
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
30
|
+
# `weight` and `size_average` are placeholders and not implemented yet
|
|
31
|
+
def liger_cross_entropy(
|
|
32
|
+
input,
|
|
33
|
+
target,
|
|
34
|
+
weight=None,
|
|
35
|
+
size_average=None,
|
|
36
|
+
ignore_index: int = -100,
|
|
37
|
+
reduce=None,
|
|
38
|
+
reduction: str = "mean",
|
|
39
|
+
label_smoothing: float = 0.0,
|
|
40
|
+
lse_square_scale: float = 0.0,
|
|
41
|
+
softcap: Optional[float] = None,
|
|
42
|
+
return_z_loss: bool = False,
|
|
43
|
+
):
|
|
44
|
+
loss, z_loss = LigerCrossEntropyFunction.apply(
|
|
45
|
+
input,
|
|
46
|
+
target,
|
|
47
|
+
ignore_index,
|
|
48
|
+
lse_square_scale,
|
|
49
|
+
label_smoothing,
|
|
50
|
+
reduction,
|
|
51
|
+
softcap,
|
|
52
|
+
return_z_loss,
|
|
53
|
+
)
|
|
54
|
+
if not return_z_loss:
|
|
55
|
+
return loss
|
|
56
|
+
return loss, z_loss
|
|
@@ -1,13 +1,38 @@
|
|
|
1
|
-
from
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
2
4
|
|
|
3
5
|
from liger_kernel.ops.fused_linear_cross_entropy import (
|
|
4
6
|
LigerFusedLinearCrossEntropyFunction,
|
|
5
7
|
)
|
|
6
8
|
|
|
7
9
|
|
|
8
|
-
class LigerFusedLinearCrossEntropyLoss(
|
|
9
|
-
def __init__(
|
|
10
|
-
|
|
10
|
+
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
ignore_index: int = -100,
|
|
14
|
+
lse_square_scale: float = 0.0,
|
|
15
|
+
label_smoothing: float = 0.0,
|
|
16
|
+
reduction: str = "mean",
|
|
17
|
+
softcap: Optional[float] = None,
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
assert (label_smoothing >= 0) and (
|
|
21
|
+
label_smoothing <= 1
|
|
22
|
+
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
|
|
23
|
+
assert reduction in {
|
|
24
|
+
"mean",
|
|
25
|
+
"sum",
|
|
26
|
+
"none",
|
|
27
|
+
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
|
|
28
|
+
assert (
|
|
29
|
+
softcap is None or softcap > 0
|
|
30
|
+
), f"softcap must greater than 0.0 or None. Got: {softcap}"
|
|
31
|
+
self.ignore_index = ignore_index
|
|
32
|
+
self.lse_square_scale = lse_square_scale
|
|
33
|
+
self.label_smoothing = label_smoothing
|
|
34
|
+
self.reduction = reduction
|
|
35
|
+
self.softcap = softcap
|
|
11
36
|
|
|
12
37
|
def forward(self, lin_weight, _input, target, bias=None):
|
|
13
38
|
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
@@ -16,6 +41,8 @@ class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
|
|
|
16
41
|
target,
|
|
17
42
|
bias,
|
|
18
43
|
self.ignore_index,
|
|
44
|
+
self.lse_square_scale,
|
|
19
45
|
self.label_smoothing,
|
|
20
46
|
self.reduction,
|
|
47
|
+
self.softcap,
|
|
21
48
|
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerGroupNorm(nn.Module):
|
|
8
|
+
def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"):
|
|
9
|
+
"""
|
|
10
|
+
A Group Normalization layer.
|
|
11
|
+
Args:
|
|
12
|
+
num_channels (int): Number of channels in the input tensor.
|
|
13
|
+
num_groups (int): Number of groups to divide the channels into.
|
|
14
|
+
eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6.
|
|
15
|
+
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``.
|
|
16
|
+
init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones".
|
|
17
|
+
"""
|
|
18
|
+
super().__init__()
|
|
19
|
+
assert init_fn in [
|
|
20
|
+
"ones",
|
|
21
|
+
"zeros",
|
|
22
|
+
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
23
|
+
|
|
24
|
+
assert (
|
|
25
|
+
num_channels % num_groups == 0
|
|
26
|
+
), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}"
|
|
27
|
+
self.num_channels = num_channels
|
|
28
|
+
self.num_groups = num_groups
|
|
29
|
+
self.eps = eps
|
|
30
|
+
self.weight = nn.Parameter(
|
|
31
|
+
torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels)
|
|
32
|
+
)
|
|
33
|
+
self.bias = nn.Parameter(
|
|
34
|
+
torch.randn(num_channels) if bias else torch.zeros(num_channels)
|
|
35
|
+
)
|
|
36
|
+
self.variance_epsilon = eps
|
|
37
|
+
|
|
38
|
+
def forward(self, hidden_states):
|
|
39
|
+
# hidden_states: (batch_size, num_channels, *)
|
|
40
|
+
assert (
|
|
41
|
+
hidden_states.dim() >= 3
|
|
42
|
+
), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}"
|
|
43
|
+
assert (
|
|
44
|
+
hidden_states.size(1) == self.num_channels
|
|
45
|
+
), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}"
|
|
46
|
+
return LigerGroupNormFunction.apply(
|
|
47
|
+
hidden_states,
|
|
48
|
+
self.weight,
|
|
49
|
+
self.bias,
|
|
50
|
+
self.num_channels,
|
|
51
|
+
self.num_groups,
|
|
52
|
+
self.variance_epsilon,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def extra_repr(self):
|
|
56
|
+
return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}"
|