liger-kernel 0.5.10__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/ops/dyt.py +0 -2
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/layer_norm.py +126 -89
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +267 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +62 -50
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/functional.py +67 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +25 -8
- liger_kernel/transformers/model/gemma2.py +27 -8
- liger_kernel/transformers/model/gemma3.py +63 -99
- liger_kernel/transformers/model/glm4.py +16 -7
- liger_kernel/transformers/model/llama.py +25 -7
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -124
- liger_kernel/transformers/model/mistral.py +13 -8
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +16 -7
- liger_kernel/transformers/model/olmo2.py +16 -7
- liger_kernel/transformers/model/paligemma.py +8 -1
- liger_kernel/transformers/model/phi3.py +25 -8
- liger_kernel/transformers/model/qwen2.py +24 -7
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
- liger_kernel/transformers/model/qwen2_vl.py +38 -100
- liger_kernel/transformers/model/qwen3.py +11 -3
- liger_kernel/transformers/model/qwen3_moe.py +10 -6
- liger_kernel/transformers/model/smollm3.py +189 -0
- liger_kernel/transformers/monkey_patch.py +389 -82
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/METADATA +18 -14
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/RECORD +47 -37
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.1.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -63,7 +63,7 @@ def _rms_norm_forward_kernel(
|
|
|
63
63
|
3. https://arxiv.org/pdf/1910.07467
|
|
64
64
|
"""
|
|
65
65
|
|
|
66
|
-
row_idx = tl.program_id(0)
|
|
66
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
67
67
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
68
68
|
mask = col_offsets < n_cols
|
|
69
69
|
|
|
@@ -137,7 +137,7 @@ def _rms_norm_backward_kernel(
|
|
|
137
137
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
138
138
|
"""
|
|
139
139
|
|
|
140
|
-
row_block_id = tl.program_id(0)
|
|
140
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
141
141
|
row_start = row_block_id * rows_per_program
|
|
142
142
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
143
143
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
@@ -194,6 +194,175 @@ def _rms_norm_backward_kernel(
|
|
|
194
194
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
195
195
|
|
|
196
196
|
|
|
197
|
+
@triton.jit
|
|
198
|
+
def _block_rms_norm_forward_kernel(
|
|
199
|
+
Y_ptr,
|
|
200
|
+
Y_row_stride,
|
|
201
|
+
X_ptr,
|
|
202
|
+
X_row_stride,
|
|
203
|
+
W_ptr,
|
|
204
|
+
W_row_stride,
|
|
205
|
+
RSTD_ptr,
|
|
206
|
+
RSTD_row_stride,
|
|
207
|
+
n_rows,
|
|
208
|
+
n_cols,
|
|
209
|
+
eps,
|
|
210
|
+
offset,
|
|
211
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
212
|
+
BLOCK_SIZE: tl.constexpr,
|
|
213
|
+
BLOCK_ROW: tl.constexpr,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
217
|
+
|
|
218
|
+
Reference:
|
|
219
|
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
220
|
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
221
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
|
|
225
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
226
|
+
row_mask = row_idx < n_rows
|
|
227
|
+
col_mask = col_offsets < n_cols
|
|
228
|
+
|
|
229
|
+
X_row = tl.load(
|
|
230
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
231
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
232
|
+
other=0,
|
|
233
|
+
)
|
|
234
|
+
X_row_dtype = X_row.dtype
|
|
235
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
236
|
+
|
|
237
|
+
# On Llama, only rstd is computed on fp32
|
|
238
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
239
|
+
X_row = X_row.to(tl.float32)
|
|
240
|
+
|
|
241
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
242
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
243
|
+
W_row = W_row.to(tl.float32)
|
|
244
|
+
X_row = X_row.to(tl.float32)
|
|
245
|
+
|
|
246
|
+
if casting_mode == _CASTING_MODE_NONE:
|
|
247
|
+
eps = eps.to(X_row_dtype)
|
|
248
|
+
offset = offset.to(X_row_dtype)
|
|
249
|
+
|
|
250
|
+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
|
|
251
|
+
rstd = rsqrt(mean_square + eps)
|
|
252
|
+
|
|
253
|
+
# We can save time by caching rms with minimal memory overhead
|
|
254
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
255
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
256
|
+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
|
|
257
|
+
|
|
258
|
+
X_row = X_row * rstd[:, None]
|
|
259
|
+
|
|
260
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
261
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
262
|
+
X_row = X_row.to(X_row_dtype)
|
|
263
|
+
|
|
264
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
265
|
+
|
|
266
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
267
|
+
Y_row = Y_row.to(X_row_dtype)
|
|
268
|
+
|
|
269
|
+
tl.store(
|
|
270
|
+
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
|
|
271
|
+
Y_row,
|
|
272
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@triton.jit
|
|
277
|
+
def _block_rms_norm_backward_kernel(
|
|
278
|
+
dY_ptr,
|
|
279
|
+
dY_row_stride,
|
|
280
|
+
dX_ptr,
|
|
281
|
+
dX_row_stride,
|
|
282
|
+
X_ptr,
|
|
283
|
+
X_row_stride,
|
|
284
|
+
X_dtype: tl.constexpr,
|
|
285
|
+
W_ptr,
|
|
286
|
+
W_row_stride,
|
|
287
|
+
RSTD_ptr,
|
|
288
|
+
RSTD_row_stride,
|
|
289
|
+
dW_ptr,
|
|
290
|
+
dW_row_stride,
|
|
291
|
+
n_rows,
|
|
292
|
+
n_cols,
|
|
293
|
+
offset,
|
|
294
|
+
rows_per_program: tl.constexpr,
|
|
295
|
+
casting_mode: tl.constexpr,
|
|
296
|
+
BLOCK_SIZE: tl.constexpr,
|
|
297
|
+
BLOCK_ROW: tl.constexpr,
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
301
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
pid = tl.program_id(0).cast(tl.int64)
|
|
305
|
+
NUM_SMS = tl.num_programs(0)
|
|
306
|
+
|
|
307
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
308
|
+
col_mask = col_offsets < n_cols
|
|
309
|
+
|
|
310
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
311
|
+
|
|
312
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
313
|
+
W_row = W_row + offset
|
|
314
|
+
|
|
315
|
+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
316
|
+
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
317
|
+
row_mask = row_idx < n_rows
|
|
318
|
+
dY_row = tl.load(
|
|
319
|
+
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
|
|
320
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
321
|
+
other=0.0,
|
|
322
|
+
)
|
|
323
|
+
X_row = tl.load(
|
|
324
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
325
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
326
|
+
other=0.0,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Get cached rms
|
|
330
|
+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
|
|
331
|
+
|
|
332
|
+
X_row = X_row.to(tl.float32)
|
|
333
|
+
|
|
334
|
+
# Different bacward graphs for different casting modes
|
|
335
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
336
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
337
|
+
|
|
338
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
339
|
+
dY_row = dY_row.to(tl.float32)
|
|
340
|
+
m = dY_row * W_row[None, :]
|
|
341
|
+
else:
|
|
342
|
+
m = dY_row * W_row[None, :]
|
|
343
|
+
|
|
344
|
+
dX_row = rstd_row[:, None] * m
|
|
345
|
+
|
|
346
|
+
dX_row += (rstd_row[:, None]) * (
|
|
347
|
+
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# calculate the gradient of W
|
|
351
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
352
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
|
|
353
|
+
else:
|
|
354
|
+
# here X_row is already in fp32 (see previous if block)
|
|
355
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
356
|
+
|
|
357
|
+
tl.store(
|
|
358
|
+
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
359
|
+
dX_row,
|
|
360
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
364
|
+
|
|
365
|
+
|
|
197
366
|
_str_to_casting_mode = {
|
|
198
367
|
"llama": _CASTING_MODE_LLAMA.value,
|
|
199
368
|
"gemma": _CASTING_MODE_GEMMA.value,
|
|
@@ -201,7 +370,7 @@ _str_to_casting_mode = {
|
|
|
201
370
|
}
|
|
202
371
|
|
|
203
372
|
|
|
204
|
-
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
373
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
205
374
|
if not isinstance(casting_mode, int):
|
|
206
375
|
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
|
207
376
|
casting_mode = _str_to_casting_mode[casting_mode]
|
|
@@ -227,27 +396,49 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
227
396
|
kernel_args = {}
|
|
228
397
|
if X.device.type == "xpu":
|
|
229
398
|
kernel_args["grf_mode"] = "large"
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
399
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
400
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
|
401
|
+
Y,
|
|
402
|
+
Y.stride(0),
|
|
403
|
+
X,
|
|
404
|
+
X.stride(0),
|
|
405
|
+
W,
|
|
406
|
+
W.stride(0),
|
|
407
|
+
RSTD,
|
|
408
|
+
RSTD.stride(0),
|
|
409
|
+
n_cols,
|
|
410
|
+
eps,
|
|
411
|
+
offset,
|
|
412
|
+
casting_mode,
|
|
413
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
414
|
+
num_warps=num_warps,
|
|
415
|
+
**kernel_args, # XPU-specific optimization
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
BLOCK_ROW = 16
|
|
419
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
420
|
+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
|
|
421
|
+
Y,
|
|
422
|
+
Y.stride(0),
|
|
423
|
+
X,
|
|
424
|
+
X.stride(0),
|
|
425
|
+
W,
|
|
426
|
+
W.stride(0),
|
|
427
|
+
RSTD,
|
|
428
|
+
RSTD.stride(0),
|
|
429
|
+
n_rows,
|
|
430
|
+
n_cols,
|
|
431
|
+
eps,
|
|
432
|
+
offset,
|
|
433
|
+
casting_mode,
|
|
434
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
435
|
+
num_warps=num_warps,
|
|
436
|
+
**kernel_args, # XPU-specific optimization
|
|
437
|
+
)
|
|
247
438
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
248
439
|
|
|
249
440
|
|
|
250
|
-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
|
441
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
|
|
251
442
|
shape = dY.shape
|
|
252
443
|
dim = shape[-1]
|
|
253
444
|
dY = dY.view(-1, dim)
|
|
@@ -277,29 +468,56 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
277
468
|
if X.device.type == "xpu":
|
|
278
469
|
kernel_args["grf_mode"] = "large"
|
|
279
470
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
471
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
472
|
+
_rms_norm_backward_kernel[grid](
|
|
473
|
+
dY,
|
|
474
|
+
dY.stride(0),
|
|
475
|
+
dX,
|
|
476
|
+
dX.stride(0),
|
|
477
|
+
X,
|
|
478
|
+
X.stride(0),
|
|
479
|
+
torch_to_triton_dtype[X.dtype],
|
|
480
|
+
W,
|
|
481
|
+
W.stride(0),
|
|
482
|
+
RSTD,
|
|
483
|
+
RSTD.stride(0),
|
|
484
|
+
_dW,
|
|
485
|
+
_dW.stride(0),
|
|
486
|
+
n_rows,
|
|
487
|
+
n_cols,
|
|
488
|
+
offset,
|
|
489
|
+
rows_per_program,
|
|
490
|
+
casting_mode,
|
|
491
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
492
|
+
num_warps=num_warps,
|
|
493
|
+
**kernel_args, # XPU-specific optimization
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
BLOCK_ROW = 16
|
|
497
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
498
|
+
_block_rms_norm_backward_kernel[grid](
|
|
499
|
+
dY,
|
|
500
|
+
dY.stride(0),
|
|
501
|
+
dX,
|
|
502
|
+
dX.stride(0),
|
|
503
|
+
X,
|
|
504
|
+
X.stride(0),
|
|
505
|
+
torch_to_triton_dtype[X.dtype],
|
|
506
|
+
W,
|
|
507
|
+
W.stride(0),
|
|
508
|
+
RSTD,
|
|
509
|
+
RSTD.stride(0),
|
|
510
|
+
_dW,
|
|
511
|
+
_dW.stride(0),
|
|
512
|
+
n_rows,
|
|
513
|
+
n_cols,
|
|
514
|
+
offset,
|
|
515
|
+
rows_per_program,
|
|
516
|
+
casting_mode,
|
|
517
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
518
|
+
num_warps=num_warps,
|
|
519
|
+
**kernel_args, # XPU-specific optimization
|
|
520
|
+
)
|
|
303
521
|
dX = dX.view(*shape)
|
|
304
522
|
dW = _dW.sum(dim=0).to(W.dtype)
|
|
305
523
|
|
|
@@ -330,15 +548,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
330
548
|
|
|
331
549
|
@staticmethod
|
|
332
550
|
@ensure_contiguous
|
|
333
|
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
|
551
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
|
|
334
552
|
"""
|
|
335
553
|
X: (B, T, H) or (BxT, H)
|
|
336
554
|
W: (H,)
|
|
337
555
|
"""
|
|
338
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
|
556
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
339
557
|
ctx.offset = offset
|
|
340
558
|
ctx.casting_mode = casting_mode
|
|
341
559
|
ctx.in_place = in_place
|
|
560
|
+
ctx.row_mode = row_mode
|
|
342
561
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
343
562
|
ctx.num_warps = num_warps
|
|
344
563
|
ctx.save_for_backward(X, W, RSTD)
|
|
@@ -352,14 +571,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
352
571
|
"""
|
|
353
572
|
X, W, RSTD = ctx.saved_tensors
|
|
354
573
|
dX, dW = rms_norm_backward(
|
|
355
|
-
dY,
|
|
356
|
-
X,
|
|
357
|
-
W,
|
|
358
|
-
RSTD,
|
|
359
|
-
ctx.offset,
|
|
360
|
-
ctx.casting_mode,
|
|
361
|
-
ctx.BLOCK_SIZE,
|
|
362
|
-
ctx.num_warps,
|
|
363
|
-
ctx.in_place,
|
|
574
|
+
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
364
575
|
)
|
|
365
|
-
return dX, dW, None, None, None, None
|
|
576
|
+
return dX, dW, None, None, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -32,7 +32,7 @@ def _triton_rope(
|
|
|
32
32
|
|
|
33
33
|
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
34
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
35
|
-
pid = tl.program_id(0)
|
|
35
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
36
36
|
|
|
37
37
|
# locate start address
|
|
38
38
|
q_ptr = q_ptr + pid * q_row_stride
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.jit
|
|
12
|
+
def _softmax_single_block_forward_kernel(
|
|
13
|
+
Y_ptr,
|
|
14
|
+
Y_row_stride,
|
|
15
|
+
X_ptr,
|
|
16
|
+
X_row_stride,
|
|
17
|
+
n_cols,
|
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
|
19
|
+
):
|
|
20
|
+
row_id = tl.program_id(0)
|
|
21
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
22
|
+
mask = offs < n_cols
|
|
23
|
+
|
|
24
|
+
x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
25
|
+
m = tl.max(x, axis=0)
|
|
26
|
+
e = tl.exp(x - m)
|
|
27
|
+
d = tl.sum(e, axis=0)
|
|
28
|
+
y = e / d
|
|
29
|
+
tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@triton.jit
|
|
33
|
+
def _softmax_multi_block_forward_kernel(
|
|
34
|
+
Y_ptr,
|
|
35
|
+
Y_row_stride,
|
|
36
|
+
X_ptr,
|
|
37
|
+
X_row_stride,
|
|
38
|
+
n_cols,
|
|
39
|
+
BLOCK_SIZE: tl.constexpr,
|
|
40
|
+
):
|
|
41
|
+
row_id = tl.program_id(0)
|
|
42
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
43
|
+
|
|
44
|
+
m = tl.float32(-float("inf"))
|
|
45
|
+
d = tl.float32(0.0)
|
|
46
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
47
|
+
idx = start + offs
|
|
48
|
+
mask = idx < n_cols
|
|
49
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
50
|
+
blk_max = tl.max(xblk, axis=0)
|
|
51
|
+
new_m = tl.max(m, blk_max)
|
|
52
|
+
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
|
|
53
|
+
m = new_m
|
|
54
|
+
|
|
55
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
56
|
+
idx = start + offs
|
|
57
|
+
mask = idx < n_cols
|
|
58
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
59
|
+
yblk = tl.exp(xblk - m) / d
|
|
60
|
+
tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@triton.jit
|
|
64
|
+
def _softmax_single_block_backward_kernel(
|
|
65
|
+
dy_ptr,
|
|
66
|
+
dy_stride,
|
|
67
|
+
y_ptr,
|
|
68
|
+
y_stride,
|
|
69
|
+
dx_ptr,
|
|
70
|
+
dx_stride,
|
|
71
|
+
n_cols,
|
|
72
|
+
BLOCK_SIZE: tl.constexpr,
|
|
73
|
+
):
|
|
74
|
+
row_id = tl.program_id(0)
|
|
75
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
76
|
+
mask = offs < n_cols
|
|
77
|
+
|
|
78
|
+
dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
|
|
79
|
+
y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
|
|
80
|
+
dot = tl.sum(dy * y, axis=0)
|
|
81
|
+
dx = y * (dy - dot)
|
|
82
|
+
tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@triton.jit
|
|
86
|
+
def _softmax_multi_block_backward_kernel(
|
|
87
|
+
dy_ptr,
|
|
88
|
+
dy_stride,
|
|
89
|
+
y_ptr,
|
|
90
|
+
y_stride,
|
|
91
|
+
dx_ptr,
|
|
92
|
+
dx_stride,
|
|
93
|
+
n_cols,
|
|
94
|
+
BLOCK_SIZE: tl.constexpr,
|
|
95
|
+
):
|
|
96
|
+
row_id = tl.program_id(0)
|
|
97
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
98
|
+
acc = tl.float32(0.0)
|
|
99
|
+
|
|
100
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
101
|
+
idx = start + offs
|
|
102
|
+
mask = idx < n_cols
|
|
103
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
104
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
105
|
+
acc += tl.sum(dy_blk * y_blk, axis=0)
|
|
106
|
+
|
|
107
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
108
|
+
idx = start + offs
|
|
109
|
+
mask = idx < n_cols
|
|
110
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
111
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
112
|
+
dx_blk = y_blk * (dy_blk - acc)
|
|
113
|
+
tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
|
|
117
|
+
*batch, n_cols = x.shape
|
|
118
|
+
x2d = x.contiguous().view(-1, n_cols)
|
|
119
|
+
n_rows = x2d.shape[0]
|
|
120
|
+
|
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
+
y2d = torch.empty_like(x2d)
|
|
123
|
+
|
|
124
|
+
if n_cols <= BLOCK_SIZE:
|
|
125
|
+
_softmax_single_block_forward_kernel[(n_rows,)](
|
|
126
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
127
|
+
)
|
|
128
|
+
multi_block_launch = False
|
|
129
|
+
else:
|
|
130
|
+
_softmax_multi_block_forward_kernel[(n_rows,)](
|
|
131
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
132
|
+
)
|
|
133
|
+
multi_block_launch = True
|
|
134
|
+
|
|
135
|
+
return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _softmax_backward(
|
|
139
|
+
dy: torch.Tensor,
|
|
140
|
+
y: torch.Tensor,
|
|
141
|
+
BLOCK_SIZE: int,
|
|
142
|
+
num_warps: int,
|
|
143
|
+
multi_block_launch: bool,
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
*batch, n_cols = dy.shape
|
|
146
|
+
dy2d = dy.contiguous().view(-1, n_cols)
|
|
147
|
+
y2d = y.contiguous().view(-1, n_cols)
|
|
148
|
+
n_rows = dy2d.shape[0]
|
|
149
|
+
dx2d = torch.empty_like(dy2d)
|
|
150
|
+
|
|
151
|
+
if not multi_block_launch and n_cols <= BLOCK_SIZE:
|
|
152
|
+
_softmax_single_block_backward_kernel[(n_rows,)](
|
|
153
|
+
dy2d,
|
|
154
|
+
dy2d.stride(0),
|
|
155
|
+
y2d,
|
|
156
|
+
y2d.stride(0),
|
|
157
|
+
dx2d,
|
|
158
|
+
dx2d.stride(0),
|
|
159
|
+
n_cols,
|
|
160
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
161
|
+
num_warps=num_warps,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
_softmax_multi_block_backward_kernel[(n_rows,)](
|
|
165
|
+
dy2d,
|
|
166
|
+
dy2d.stride(0),
|
|
167
|
+
y2d,
|
|
168
|
+
y2d.stride(0),
|
|
169
|
+
dx2d,
|
|
170
|
+
dx2d.stride(0),
|
|
171
|
+
n_cols,
|
|
172
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
173
|
+
num_warps=num_warps,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return dx2d.view(*batch, n_cols)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class LigerSoftmaxFunction(torch.autograd.Function):
|
|
180
|
+
@staticmethod
|
|
181
|
+
@ensure_contiguous
|
|
182
|
+
def forward(ctx, input_: torch.Tensor):
|
|
183
|
+
y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
|
|
184
|
+
ctx.save_for_backward(y)
|
|
185
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
186
|
+
ctx.num_warps = num_warps
|
|
187
|
+
ctx.multi_block_launch = multi_block_launch
|
|
188
|
+
return y
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
@ensure_contiguous
|
|
192
|
+
def backward(ctx, grad_output):
|
|
193
|
+
(y,) = ctx.saved_tensors
|
|
194
|
+
dx = _softmax_backward(
|
|
195
|
+
grad_output,
|
|
196
|
+
y,
|
|
197
|
+
ctx.BLOCK_SIZE,
|
|
198
|
+
ctx.num_warps,
|
|
199
|
+
ctx.multi_block_launch,
|
|
200
|
+
)
|
|
201
|
+
return dx
|