liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +25 -9
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +124 -64
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +13 -6
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +283 -56
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +205 -19
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +50 -25
- liger_kernel/transformers/model/gemma2.py +55 -23
- liger_kernel/transformers/model/gemma3.py +117 -120
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +102 -25
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +36 -23
- liger_kernel/transformers/model/mixtral.py +45 -25
- liger_kernel/transformers/model/mllama.py +39 -22
- liger_kernel/transformers/model/olmo2.py +40 -20
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -177
- liger_kernel/transformers/model/qwen2.py +48 -21
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1678 -160
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +48 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +36 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
23
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
24
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
25
|
+
from liger_kernel.utils import is_npu_available
|
|
24
26
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
27
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
28
|
try:
|
|
27
29
|
# typical import path with dispatch available
|
|
28
30
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -63,7 +65,7 @@ def _rms_norm_forward_kernel(
|
|
|
63
65
|
3. https://arxiv.org/pdf/1910.07467
|
|
64
66
|
"""
|
|
65
67
|
|
|
66
|
-
row_idx = tl.program_id(0)
|
|
68
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
67
69
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
68
70
|
mask = col_offsets < n_cols
|
|
69
71
|
|
|
@@ -137,7 +139,7 @@ def _rms_norm_backward_kernel(
|
|
|
137
139
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
138
140
|
"""
|
|
139
141
|
|
|
140
|
-
row_block_id = tl.program_id(0)
|
|
142
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
141
143
|
row_start = row_block_id * rows_per_program
|
|
142
144
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
143
145
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
@@ -194,6 +196,176 @@ def _rms_norm_backward_kernel(
|
|
|
194
196
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
195
197
|
|
|
196
198
|
|
|
199
|
+
@triton.jit
|
|
200
|
+
def _block_rms_norm_forward_kernel(
|
|
201
|
+
Y_ptr,
|
|
202
|
+
Y_row_stride,
|
|
203
|
+
X_ptr,
|
|
204
|
+
X_row_stride,
|
|
205
|
+
W_ptr,
|
|
206
|
+
W_row_stride,
|
|
207
|
+
RSTD_ptr,
|
|
208
|
+
RSTD_row_stride,
|
|
209
|
+
n_rows,
|
|
210
|
+
n_cols,
|
|
211
|
+
eps,
|
|
212
|
+
offset,
|
|
213
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
214
|
+
BLOCK_SIZE: tl.constexpr,
|
|
215
|
+
BLOCK_ROW: tl.constexpr,
|
|
216
|
+
):
|
|
217
|
+
"""
|
|
218
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
219
|
+
|
|
220
|
+
Reference:
|
|
221
|
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
222
|
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
223
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
|
|
227
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
228
|
+
row_mask = row_idx < n_rows
|
|
229
|
+
col_mask = col_offsets < n_cols
|
|
230
|
+
|
|
231
|
+
X_row = tl.load(
|
|
232
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
233
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
234
|
+
other=0,
|
|
235
|
+
)
|
|
236
|
+
X_row_dtype = X_row.dtype
|
|
237
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
238
|
+
|
|
239
|
+
# On Llama, only rstd is computed on fp32
|
|
240
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
241
|
+
X_row = X_row.to(tl.float32)
|
|
242
|
+
|
|
243
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
244
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
245
|
+
W_row = W_row.to(tl.float32)
|
|
246
|
+
X_row = X_row.to(tl.float32)
|
|
247
|
+
|
|
248
|
+
if casting_mode == _CASTING_MODE_NONE:
|
|
249
|
+
eps = eps.to(X_row_dtype)
|
|
250
|
+
offset = offset.to(X_row_dtype)
|
|
251
|
+
|
|
252
|
+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
|
|
253
|
+
rstd = rsqrt(mean_square + eps)
|
|
254
|
+
|
|
255
|
+
# We can save time by caching rms with minimal memory overhead
|
|
256
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
257
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
258
|
+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
|
|
259
|
+
|
|
260
|
+
X_row = X_row * rstd[:, None]
|
|
261
|
+
|
|
262
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
263
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
264
|
+
X_row = X_row.to(X_row_dtype)
|
|
265
|
+
|
|
266
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
267
|
+
|
|
268
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
269
|
+
Y_row = Y_row.to(X_row_dtype)
|
|
270
|
+
|
|
271
|
+
tl.store(
|
|
272
|
+
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
|
|
273
|
+
Y_row,
|
|
274
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@triton.jit
|
|
279
|
+
def _block_rms_norm_backward_kernel(
|
|
280
|
+
dY_ptr,
|
|
281
|
+
dY_row_stride,
|
|
282
|
+
dX_ptr,
|
|
283
|
+
dX_row_stride,
|
|
284
|
+
X_ptr,
|
|
285
|
+
X_row_stride,
|
|
286
|
+
X_dtype: tl.constexpr,
|
|
287
|
+
W_ptr,
|
|
288
|
+
W_row_stride,
|
|
289
|
+
RSTD_ptr,
|
|
290
|
+
RSTD_row_stride,
|
|
291
|
+
dW_ptr,
|
|
292
|
+
dW_row_stride,
|
|
293
|
+
n_rows,
|
|
294
|
+
n_cols,
|
|
295
|
+
offset,
|
|
296
|
+
rows_per_program: tl.constexpr,
|
|
297
|
+
casting_mode: tl.constexpr,
|
|
298
|
+
BLOCK_SIZE: tl.constexpr,
|
|
299
|
+
BLOCK_ROW: tl.constexpr,
|
|
300
|
+
):
|
|
301
|
+
"""
|
|
302
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
303
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
pid = tl.program_id(0).cast(tl.int64)
|
|
307
|
+
NUM_SMS = tl.num_programs(0)
|
|
308
|
+
|
|
309
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
310
|
+
col_mask = col_offsets < n_cols
|
|
311
|
+
|
|
312
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
313
|
+
|
|
314
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
315
|
+
W_row = W_row + offset
|
|
316
|
+
|
|
317
|
+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
318
|
+
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
319
|
+
row_mask = row_idx < n_rows
|
|
320
|
+
dY_row = tl.load(
|
|
321
|
+
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
|
|
322
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
323
|
+
other=0.0,
|
|
324
|
+
)
|
|
325
|
+
X_row = tl.load(
|
|
326
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
327
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
328
|
+
other=0.0,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Get cached rms
|
|
332
|
+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
|
|
333
|
+
|
|
334
|
+
X_row = X_row.to(tl.float32)
|
|
335
|
+
|
|
336
|
+
# Different bacward graphs for different casting modes
|
|
337
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
338
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
339
|
+
|
|
340
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
341
|
+
dY_row = dY_row.to(tl.float32)
|
|
342
|
+
m = dY_row * W_row[None, :]
|
|
343
|
+
else:
|
|
344
|
+
m = dY_row * W_row[None, :]
|
|
345
|
+
|
|
346
|
+
dX_row = rstd_row[:, None] * m
|
|
347
|
+
|
|
348
|
+
dX_row += (rstd_row[:, None]) * (
|
|
349
|
+
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# calculate the gradient of W
|
|
353
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
354
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
355
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
356
|
+
else:
|
|
357
|
+
# here X_row is already in fp32 (see previous if block)
|
|
358
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
359
|
+
|
|
360
|
+
tl.store(
|
|
361
|
+
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
362
|
+
dX_row,
|
|
363
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
367
|
+
|
|
368
|
+
|
|
197
369
|
_str_to_casting_mode = {
|
|
198
370
|
"llama": _CASTING_MODE_LLAMA.value,
|
|
199
371
|
"gemma": _CASTING_MODE_GEMMA.value,
|
|
@@ -201,7 +373,7 @@ _str_to_casting_mode = {
|
|
|
201
373
|
}
|
|
202
374
|
|
|
203
375
|
|
|
204
|
-
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
376
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
205
377
|
if not isinstance(casting_mode, int):
|
|
206
378
|
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
|
207
379
|
casting_mode = _str_to_casting_mode[casting_mode]
|
|
@@ -223,26 +395,53 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
223
395
|
# Check constraints.
|
|
224
396
|
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
225
397
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
398
|
+
# XPU-specific optimization
|
|
399
|
+
kernel_args = {}
|
|
400
|
+
if X.device.type == "xpu":
|
|
401
|
+
kernel_args["grf_mode"] = "large"
|
|
402
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
403
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
|
404
|
+
Y,
|
|
405
|
+
Y.stride(0),
|
|
406
|
+
X,
|
|
407
|
+
X.stride(0),
|
|
408
|
+
W,
|
|
409
|
+
W.stride(0),
|
|
410
|
+
RSTD,
|
|
411
|
+
RSTD.stride(0),
|
|
412
|
+
n_cols,
|
|
413
|
+
eps,
|
|
414
|
+
offset,
|
|
415
|
+
casting_mode,
|
|
416
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
417
|
+
num_warps=num_warps,
|
|
418
|
+
**kernel_args, # XPU-specific optimization
|
|
419
|
+
)
|
|
420
|
+
else:
|
|
421
|
+
BLOCK_ROW = 16
|
|
422
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
423
|
+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
|
|
424
|
+
Y,
|
|
425
|
+
Y.stride(0),
|
|
426
|
+
X,
|
|
427
|
+
X.stride(0),
|
|
428
|
+
W,
|
|
429
|
+
W.stride(0),
|
|
430
|
+
RSTD,
|
|
431
|
+
RSTD.stride(0),
|
|
432
|
+
n_rows,
|
|
433
|
+
n_cols,
|
|
434
|
+
eps,
|
|
435
|
+
offset,
|
|
436
|
+
casting_mode,
|
|
437
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
438
|
+
num_warps=num_warps,
|
|
439
|
+
**kernel_args, # XPU-specific optimization
|
|
440
|
+
)
|
|
242
441
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
243
442
|
|
|
244
443
|
|
|
245
|
-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
|
444
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
|
|
246
445
|
shape = dY.shape
|
|
247
446
|
dim = shape[-1]
|
|
248
447
|
dY = dY.view(-1, dim)
|
|
@@ -252,7 +451,9 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
252
451
|
if X.device.type == "cuda":
|
|
253
452
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
254
453
|
elif X.device.type == "xpu":
|
|
255
|
-
sm_count = torch.xpu.get_device_properties(X.device).
|
|
454
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
455
|
+
elif X.device.type == "npu":
|
|
456
|
+
sm_count = get_npu_multi_processor_count()
|
|
256
457
|
|
|
257
458
|
# fp32 for numerical stability especially.
|
|
258
459
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
@@ -267,28 +468,61 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
267
468
|
else:
|
|
268
469
|
dX = torch.zeros_like(dY)
|
|
269
470
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
471
|
+
# XPU-specific optimization
|
|
472
|
+
kernel_args = {}
|
|
473
|
+
if X.device.type == "xpu":
|
|
474
|
+
kernel_args["grf_mode"] = "large"
|
|
475
|
+
|
|
476
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
477
|
+
_rms_norm_backward_kernel[grid](
|
|
478
|
+
dY,
|
|
479
|
+
dY.stride(0),
|
|
480
|
+
dX,
|
|
481
|
+
dX.stride(0),
|
|
482
|
+
X,
|
|
483
|
+
X.stride(0),
|
|
484
|
+
torch_to_triton_dtype[X.dtype],
|
|
485
|
+
W,
|
|
486
|
+
W.stride(0),
|
|
487
|
+
RSTD,
|
|
488
|
+
RSTD.stride(0),
|
|
489
|
+
_dW,
|
|
490
|
+
_dW.stride(0),
|
|
491
|
+
n_rows,
|
|
492
|
+
n_cols,
|
|
493
|
+
offset,
|
|
494
|
+
rows_per_program,
|
|
495
|
+
casting_mode,
|
|
496
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
497
|
+
num_warps=num_warps,
|
|
498
|
+
**kernel_args, # XPU-specific optimization
|
|
499
|
+
)
|
|
500
|
+
else:
|
|
501
|
+
BLOCK_ROW = 16
|
|
502
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
503
|
+
_block_rms_norm_backward_kernel[grid](
|
|
504
|
+
dY,
|
|
505
|
+
dY.stride(0),
|
|
506
|
+
dX,
|
|
507
|
+
dX.stride(0),
|
|
508
|
+
X,
|
|
509
|
+
X.stride(0),
|
|
510
|
+
torch_to_triton_dtype[X.dtype],
|
|
511
|
+
W,
|
|
512
|
+
W.stride(0),
|
|
513
|
+
RSTD,
|
|
514
|
+
RSTD.stride(0),
|
|
515
|
+
_dW,
|
|
516
|
+
_dW.stride(0),
|
|
517
|
+
n_rows,
|
|
518
|
+
n_cols,
|
|
519
|
+
offset,
|
|
520
|
+
rows_per_program,
|
|
521
|
+
casting_mode,
|
|
522
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
523
|
+
num_warps=num_warps,
|
|
524
|
+
**kernel_args, # XPU-specific optimization
|
|
525
|
+
)
|
|
292
526
|
dX = dX.view(*shape)
|
|
293
527
|
dW = _dW.sum(dim=0).to(W.dtype)
|
|
294
528
|
|
|
@@ -319,15 +553,16 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
319
553
|
|
|
320
554
|
@staticmethod
|
|
321
555
|
@ensure_contiguous
|
|
322
|
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
|
556
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
|
|
323
557
|
"""
|
|
324
558
|
X: (B, T, H) or (BxT, H)
|
|
325
559
|
W: (H,)
|
|
326
560
|
"""
|
|
327
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode)
|
|
561
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
328
562
|
ctx.offset = offset
|
|
329
563
|
ctx.casting_mode = casting_mode
|
|
330
564
|
ctx.in_place = in_place
|
|
565
|
+
ctx.row_mode = row_mode
|
|
331
566
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
332
567
|
ctx.num_warps = num_warps
|
|
333
568
|
ctx.save_for_backward(X, W, RSTD)
|
|
@@ -341,14 +576,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
341
576
|
"""
|
|
342
577
|
X, W, RSTD = ctx.saved_tensors
|
|
343
578
|
dX, dW = rms_norm_backward(
|
|
344
|
-
dY,
|
|
345
|
-
X,
|
|
346
|
-
W,
|
|
347
|
-
RSTD,
|
|
348
|
-
ctx.offset,
|
|
349
|
-
ctx.casting_mode,
|
|
350
|
-
ctx.BLOCK_SIZE,
|
|
351
|
-
ctx.num_warps,
|
|
352
|
-
ctx.in_place,
|
|
579
|
+
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
353
580
|
)
|
|
354
|
-
return dX, dW, None, None, None, None
|
|
581
|
+
return dX, dW, None, None, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -32,7 +32,7 @@ def _triton_rope(
|
|
|
32
32
|
|
|
33
33
|
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
34
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
35
|
-
pid = tl.program_id(0)
|
|
35
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
36
36
|
|
|
37
37
|
# locate start address
|
|
38
38
|
q_ptr = q_ptr + pid * q_row_stride
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.jit
|
|
12
|
+
def _softmax_single_block_forward_kernel(
|
|
13
|
+
Y_ptr,
|
|
14
|
+
Y_row_stride,
|
|
15
|
+
X_ptr,
|
|
16
|
+
X_row_stride,
|
|
17
|
+
n_cols,
|
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
|
19
|
+
):
|
|
20
|
+
row_id = tl.program_id(0)
|
|
21
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
22
|
+
mask = offs < n_cols
|
|
23
|
+
|
|
24
|
+
x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
25
|
+
m = tl.max(x, axis=0)
|
|
26
|
+
e = tl.exp(x - m)
|
|
27
|
+
d = tl.sum(e, axis=0)
|
|
28
|
+
y = e / d
|
|
29
|
+
tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@triton.jit
|
|
33
|
+
def _softmax_multi_block_forward_kernel(
|
|
34
|
+
Y_ptr,
|
|
35
|
+
Y_row_stride,
|
|
36
|
+
X_ptr,
|
|
37
|
+
X_row_stride,
|
|
38
|
+
n_cols,
|
|
39
|
+
BLOCK_SIZE: tl.constexpr,
|
|
40
|
+
):
|
|
41
|
+
row_id = tl.program_id(0)
|
|
42
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
43
|
+
|
|
44
|
+
m = tl.float32(-float("inf"))
|
|
45
|
+
d = tl.float32(0.0)
|
|
46
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
47
|
+
idx = start + offs
|
|
48
|
+
mask = idx < n_cols
|
|
49
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
50
|
+
blk_max = tl.max(xblk, axis=0)
|
|
51
|
+
new_m = tl.max(m, blk_max)
|
|
52
|
+
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
|
|
53
|
+
m = new_m
|
|
54
|
+
|
|
55
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
56
|
+
idx = start + offs
|
|
57
|
+
mask = idx < n_cols
|
|
58
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
59
|
+
yblk = tl.exp(xblk - m) / d
|
|
60
|
+
tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@triton.jit
|
|
64
|
+
def _softmax_single_block_backward_kernel(
|
|
65
|
+
dy_ptr,
|
|
66
|
+
dy_stride,
|
|
67
|
+
y_ptr,
|
|
68
|
+
y_stride,
|
|
69
|
+
dx_ptr,
|
|
70
|
+
dx_stride,
|
|
71
|
+
n_cols,
|
|
72
|
+
BLOCK_SIZE: tl.constexpr,
|
|
73
|
+
):
|
|
74
|
+
row_id = tl.program_id(0)
|
|
75
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
76
|
+
mask = offs < n_cols
|
|
77
|
+
|
|
78
|
+
dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
|
|
79
|
+
y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
|
|
80
|
+
dot = tl.sum(dy * y, axis=0)
|
|
81
|
+
dx = y * (dy - dot)
|
|
82
|
+
tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@triton.jit
|
|
86
|
+
def _softmax_multi_block_backward_kernel(
|
|
87
|
+
dy_ptr,
|
|
88
|
+
dy_stride,
|
|
89
|
+
y_ptr,
|
|
90
|
+
y_stride,
|
|
91
|
+
dx_ptr,
|
|
92
|
+
dx_stride,
|
|
93
|
+
n_cols,
|
|
94
|
+
BLOCK_SIZE: tl.constexpr,
|
|
95
|
+
):
|
|
96
|
+
row_id = tl.program_id(0)
|
|
97
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
98
|
+
acc = tl.float32(0.0)
|
|
99
|
+
|
|
100
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
101
|
+
idx = start + offs
|
|
102
|
+
mask = idx < n_cols
|
|
103
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
104
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
105
|
+
acc += tl.sum(dy_blk * y_blk, axis=0)
|
|
106
|
+
|
|
107
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
108
|
+
idx = start + offs
|
|
109
|
+
mask = idx < n_cols
|
|
110
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
111
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
112
|
+
dx_blk = y_blk * (dy_blk - acc)
|
|
113
|
+
tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
|
|
117
|
+
*batch, n_cols = x.shape
|
|
118
|
+
x2d = x.contiguous().view(-1, n_cols)
|
|
119
|
+
n_rows = x2d.shape[0]
|
|
120
|
+
|
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
+
y2d = torch.empty_like(x2d)
|
|
123
|
+
|
|
124
|
+
if n_cols <= BLOCK_SIZE:
|
|
125
|
+
_softmax_single_block_forward_kernel[(n_rows,)](
|
|
126
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
127
|
+
)
|
|
128
|
+
multi_block_launch = False
|
|
129
|
+
else:
|
|
130
|
+
_softmax_multi_block_forward_kernel[(n_rows,)](
|
|
131
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
132
|
+
)
|
|
133
|
+
multi_block_launch = True
|
|
134
|
+
|
|
135
|
+
return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _softmax_backward(
|
|
139
|
+
dy: torch.Tensor,
|
|
140
|
+
y: torch.Tensor,
|
|
141
|
+
BLOCK_SIZE: int,
|
|
142
|
+
num_warps: int,
|
|
143
|
+
multi_block_launch: bool,
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
*batch, n_cols = dy.shape
|
|
146
|
+
dy2d = dy.contiguous().view(-1, n_cols)
|
|
147
|
+
y2d = y.contiguous().view(-1, n_cols)
|
|
148
|
+
n_rows = dy2d.shape[0]
|
|
149
|
+
dx2d = torch.empty_like(dy2d)
|
|
150
|
+
|
|
151
|
+
if not multi_block_launch and n_cols <= BLOCK_SIZE:
|
|
152
|
+
_softmax_single_block_backward_kernel[(n_rows,)](
|
|
153
|
+
dy2d,
|
|
154
|
+
dy2d.stride(0),
|
|
155
|
+
y2d,
|
|
156
|
+
y2d.stride(0),
|
|
157
|
+
dx2d,
|
|
158
|
+
dx2d.stride(0),
|
|
159
|
+
n_cols,
|
|
160
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
161
|
+
num_warps=num_warps,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
_softmax_multi_block_backward_kernel[(n_rows,)](
|
|
165
|
+
dy2d,
|
|
166
|
+
dy2d.stride(0),
|
|
167
|
+
y2d,
|
|
168
|
+
y2d.stride(0),
|
|
169
|
+
dx2d,
|
|
170
|
+
dx2d.stride(0),
|
|
171
|
+
n_cols,
|
|
172
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
173
|
+
num_warps=num_warps,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return dx2d.view(*batch, n_cols)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class LigerSoftmaxFunction(torch.autograd.Function):
|
|
180
|
+
@staticmethod
|
|
181
|
+
@ensure_contiguous
|
|
182
|
+
def forward(ctx, input_: torch.Tensor):
|
|
183
|
+
y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
|
|
184
|
+
ctx.save_for_backward(y)
|
|
185
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
186
|
+
ctx.num_warps = num_warps
|
|
187
|
+
ctx.multi_block_launch = multi_block_launch
|
|
188
|
+
return y
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
@ensure_contiguous
|
|
192
|
+
def backward(ctx, grad_output):
|
|
193
|
+
(y,) = ctx.saved_tensors
|
|
194
|
+
dx = _softmax_backward(
|
|
195
|
+
grad_output,
|
|
196
|
+
y,
|
|
197
|
+
ctx.BLOCK_SIZE,
|
|
198
|
+
ctx.num_warps,
|
|
199
|
+
ctx.multi_block_launch,
|
|
200
|
+
)
|
|
201
|
+
return dx
|