liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +75 -12
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +131 -49
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +30 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +48 -25
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +26 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +702 -48
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
23
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
24
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
25
|
+
from liger_kernel.utils import is_npu_available
|
|
24
26
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
27
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
28
|
try:
|
|
27
29
|
# typical import path with dispatch available
|
|
28
30
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -52,6 +54,7 @@ def _rms_norm_forward_kernel(
|
|
|
52
54
|
eps,
|
|
53
55
|
offset,
|
|
54
56
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
57
|
+
elementwise_affine: tl.constexpr,
|
|
55
58
|
BLOCK_SIZE: tl.constexpr,
|
|
56
59
|
):
|
|
57
60
|
"""
|
|
@@ -73,7 +76,8 @@ def _rms_norm_forward_kernel(
|
|
|
73
76
|
|
|
74
77
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
75
78
|
X_row_dtype = X_row.dtype
|
|
76
|
-
|
|
79
|
+
if elementwise_affine:
|
|
80
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
77
81
|
|
|
78
82
|
# On Llama, only rstd is computed on fp32
|
|
79
83
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -81,7 +85,8 @@ def _rms_norm_forward_kernel(
|
|
|
81
85
|
|
|
82
86
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
87
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
-
|
|
88
|
+
if elementwise_affine:
|
|
89
|
+
W_row = W_row.to(tl.float32)
|
|
85
90
|
X_row = X_row.to(tl.float32)
|
|
86
91
|
|
|
87
92
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -102,7 +107,10 @@ def _rms_norm_forward_kernel(
|
|
|
102
107
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
103
108
|
X_row = X_row.to(X_row_dtype)
|
|
104
109
|
|
|
105
|
-
|
|
110
|
+
if elementwise_affine:
|
|
111
|
+
Y_row = X_row * (offset + W_row)
|
|
112
|
+
else:
|
|
113
|
+
Y_row = X_row
|
|
106
114
|
|
|
107
115
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
108
116
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -128,8 +136,9 @@ def _rms_norm_backward_kernel(
|
|
|
128
136
|
n_rows,
|
|
129
137
|
n_cols,
|
|
130
138
|
offset,
|
|
131
|
-
rows_per_program
|
|
139
|
+
rows_per_program,
|
|
132
140
|
casting_mode: tl.constexpr,
|
|
141
|
+
elementwise_affine: tl.constexpr,
|
|
133
142
|
BLOCK_SIZE: tl.constexpr,
|
|
134
143
|
):
|
|
135
144
|
"""
|
|
@@ -143,7 +152,8 @@ def _rms_norm_backward_kernel(
|
|
|
143
152
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
144
153
|
mask = col_offsets < n_cols
|
|
145
154
|
|
|
146
|
-
|
|
155
|
+
if elementwise_affine:
|
|
156
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
157
|
|
|
148
158
|
dY_ptr += row_start * dY_row_stride
|
|
149
159
|
dX_ptr += row_start * dX_row_stride
|
|
@@ -151,8 +161,9 @@ def _rms_norm_backward_kernel(
|
|
|
151
161
|
X_ptr += row_start * X_row_stride
|
|
152
162
|
RSTD_ptr += row_start
|
|
153
163
|
|
|
154
|
-
|
|
155
|
-
|
|
164
|
+
if elementwise_affine:
|
|
165
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
166
|
+
W_row = W_row + offset
|
|
156
167
|
|
|
157
168
|
for _ in range(row_start, row_end):
|
|
158
169
|
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
|
|
@@ -165,24 +176,34 @@ def _rms_norm_backward_kernel(
|
|
|
165
176
|
|
|
166
177
|
# Different bacward graphs for different casting modes
|
|
167
178
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
168
|
-
|
|
179
|
+
if elementwise_affine:
|
|
180
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
181
|
+
else:
|
|
182
|
+
m = dY_row.to(tl.float32)
|
|
169
183
|
|
|
170
184
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
171
185
|
dY_row = dY_row.to(tl.float32)
|
|
172
|
-
|
|
186
|
+
if elementwise_affine:
|
|
187
|
+
m = dY_row * W_row
|
|
188
|
+
else:
|
|
189
|
+
m = dY_row
|
|
173
190
|
else:
|
|
174
|
-
|
|
191
|
+
if elementwise_affine:
|
|
192
|
+
m = dY_row * W_row
|
|
193
|
+
else:
|
|
194
|
+
m = dY_row
|
|
175
195
|
|
|
176
196
|
dX_row = rstd_row * m
|
|
177
197
|
|
|
178
198
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
179
199
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
200
|
+
if elementwise_affine:
|
|
201
|
+
# calculate the gradient of W
|
|
202
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
203
|
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
|
204
|
+
else:
|
|
205
|
+
# here X_row is already in fp32 (see previous if block)
|
|
206
|
+
dW_row += dY_row * (X_row * rstd_row)
|
|
186
207
|
|
|
187
208
|
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
188
209
|
|
|
@@ -191,7 +212,8 @@ def _rms_norm_backward_kernel(
|
|
|
191
212
|
X_ptr += X_row_stride
|
|
192
213
|
RSTD_ptr += RSTD_row_stride
|
|
193
214
|
|
|
194
|
-
|
|
215
|
+
if elementwise_affine:
|
|
216
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
195
217
|
|
|
196
218
|
|
|
197
219
|
@triton.jit
|
|
@@ -209,6 +231,7 @@ def _block_rms_norm_forward_kernel(
|
|
|
209
231
|
eps,
|
|
210
232
|
offset,
|
|
211
233
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
234
|
+
elementwise_affine: tl.constexpr,
|
|
212
235
|
BLOCK_SIZE: tl.constexpr,
|
|
213
236
|
BLOCK_ROW: tl.constexpr,
|
|
214
237
|
):
|
|
@@ -232,7 +255,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
232
255
|
other=0,
|
|
233
256
|
)
|
|
234
257
|
X_row_dtype = X_row.dtype
|
|
235
|
-
|
|
258
|
+
if elementwise_affine:
|
|
259
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
236
260
|
|
|
237
261
|
# On Llama, only rstd is computed on fp32
|
|
238
262
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -240,7 +264,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
240
264
|
|
|
241
265
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
242
266
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
243
|
-
|
|
267
|
+
if elementwise_affine:
|
|
268
|
+
W_row = W_row.to(tl.float32)
|
|
244
269
|
X_row = X_row.to(tl.float32)
|
|
245
270
|
|
|
246
271
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -261,7 +286,10 @@ def _block_rms_norm_forward_kernel(
|
|
|
261
286
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
262
287
|
X_row = X_row.to(X_row_dtype)
|
|
263
288
|
|
|
264
|
-
|
|
289
|
+
if elementwise_affine:
|
|
290
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
291
|
+
else:
|
|
292
|
+
Y_row = X_row
|
|
265
293
|
|
|
266
294
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
267
295
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -291,8 +319,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
291
319
|
n_rows,
|
|
292
320
|
n_cols,
|
|
293
321
|
offset,
|
|
294
|
-
rows_per_program: tl.constexpr,
|
|
295
322
|
casting_mode: tl.constexpr,
|
|
323
|
+
elementwise_affine: tl.constexpr,
|
|
296
324
|
BLOCK_SIZE: tl.constexpr,
|
|
297
325
|
BLOCK_ROW: tl.constexpr,
|
|
298
326
|
):
|
|
@@ -307,10 +335,11 @@ def _block_rms_norm_backward_kernel(
|
|
|
307
335
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
308
336
|
col_mask = col_offsets < n_cols
|
|
309
337
|
|
|
310
|
-
|
|
338
|
+
if elementwise_affine:
|
|
339
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
311
340
|
|
|
312
|
-
|
|
313
|
-
|
|
341
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
342
|
+
W_row = W_row + offset
|
|
314
343
|
|
|
315
344
|
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
316
345
|
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
@@ -333,13 +362,22 @@ def _block_rms_norm_backward_kernel(
|
|
|
333
362
|
|
|
334
363
|
# Different bacward graphs for different casting modes
|
|
335
364
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
336
|
-
|
|
365
|
+
if elementwise_affine:
|
|
366
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
367
|
+
else:
|
|
368
|
+
m = dY_row.to(tl.float32)
|
|
337
369
|
|
|
338
370
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
339
371
|
dY_row = dY_row.to(tl.float32)
|
|
340
|
-
|
|
372
|
+
if elementwise_affine:
|
|
373
|
+
m = dY_row * W_row[None, :]
|
|
374
|
+
else:
|
|
375
|
+
m = dY_row
|
|
341
376
|
else:
|
|
342
|
-
|
|
377
|
+
if elementwise_affine:
|
|
378
|
+
m = dY_row * W_row[None, :]
|
|
379
|
+
else:
|
|
380
|
+
m = dY_row
|
|
343
381
|
|
|
344
382
|
dX_row = rstd_row[:, None] * m
|
|
345
383
|
|
|
@@ -347,12 +385,13 @@ def _block_rms_norm_backward_kernel(
|
|
|
347
385
|
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
348
386
|
)
|
|
349
387
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
388
|
+
if elementwise_affine:
|
|
389
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
390
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
391
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
392
|
+
else:
|
|
393
|
+
# here X_row is already in fp32 (see previous if block)
|
|
394
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
356
395
|
|
|
357
396
|
tl.store(
|
|
358
397
|
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
@@ -360,7 +399,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
360
399
|
mask=row_mask[:, None] & col_mask[None, :],
|
|
361
400
|
)
|
|
362
401
|
|
|
363
|
-
|
|
402
|
+
if elementwise_affine:
|
|
403
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
364
404
|
|
|
365
405
|
|
|
366
406
|
_str_to_casting_mode = {
|
|
@@ -389,8 +429,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
389
429
|
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
390
430
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
391
431
|
|
|
392
|
-
|
|
393
|
-
|
|
432
|
+
if W is not None:
|
|
433
|
+
# Check constraints.
|
|
434
|
+
assert X.shape[1] == W.shape[0], (
|
|
435
|
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
436
|
+
)
|
|
437
|
+
elementwise_affine = True
|
|
438
|
+
else:
|
|
439
|
+
elementwise_affine = False
|
|
394
440
|
|
|
395
441
|
# XPU-specific optimization
|
|
396
442
|
kernel_args = {}
|
|
@@ -403,13 +449,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
403
449
|
X,
|
|
404
450
|
X.stride(0),
|
|
405
451
|
W,
|
|
406
|
-
W.stride(0),
|
|
452
|
+
W.stride(0) if elementwise_affine else 0,
|
|
407
453
|
RSTD,
|
|
408
454
|
RSTD.stride(0),
|
|
409
455
|
n_cols,
|
|
410
456
|
eps,
|
|
411
457
|
offset,
|
|
412
458
|
casting_mode,
|
|
459
|
+
elementwise_affine=elementwise_affine,
|
|
413
460
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
414
461
|
num_warps=num_warps,
|
|
415
462
|
**kernel_args, # XPU-specific optimization
|
|
@@ -423,7 +470,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
423
470
|
X,
|
|
424
471
|
X.stride(0),
|
|
425
472
|
W,
|
|
426
|
-
W.stride(0),
|
|
473
|
+
W.stride(0) if elementwise_affine else 0,
|
|
427
474
|
RSTD,
|
|
428
475
|
RSTD.stride(0),
|
|
429
476
|
n_rows,
|
|
@@ -431,6 +478,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
431
478
|
eps,
|
|
432
479
|
offset,
|
|
433
480
|
casting_mode,
|
|
481
|
+
elementwise_affine=elementwise_affine,
|
|
434
482
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
435
483
|
num_warps=num_warps,
|
|
436
484
|
**kernel_args, # XPU-specific optimization
|
|
@@ -449,9 +497,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
449
497
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
450
498
|
elif X.device.type == "xpu":
|
|
451
499
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
500
|
+
elif X.device.type == "npu":
|
|
501
|
+
sm_count = get_npu_multi_processor_count()
|
|
452
502
|
|
|
453
|
-
|
|
454
|
-
|
|
503
|
+
if W is not None:
|
|
504
|
+
# fp32 for numerical stability especially.
|
|
505
|
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
506
|
+
elementwise_affine = True
|
|
507
|
+
else:
|
|
508
|
+
_dW = None
|
|
509
|
+
elementwise_affine = False
|
|
455
510
|
|
|
456
511
|
if n_cols > BLOCK_SIZE:
|
|
457
512
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -478,16 +533,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
478
533
|
X.stride(0),
|
|
479
534
|
torch_to_triton_dtype[X.dtype],
|
|
480
535
|
W,
|
|
481
|
-
W.stride(0),
|
|
536
|
+
W.stride(0) if elementwise_affine else 0,
|
|
482
537
|
RSTD,
|
|
483
538
|
RSTD.stride(0),
|
|
484
539
|
_dW,
|
|
485
|
-
_dW.stride(0),
|
|
540
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
486
541
|
n_rows,
|
|
487
542
|
n_cols,
|
|
488
543
|
offset,
|
|
489
544
|
rows_per_program,
|
|
490
545
|
casting_mode,
|
|
546
|
+
elementwise_affine=elementwise_affine,
|
|
491
547
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
492
548
|
num_warps=num_warps,
|
|
493
549
|
**kernel_args, # XPU-specific optimization
|
|
@@ -504,22 +560,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
504
560
|
X.stride(0),
|
|
505
561
|
torch_to_triton_dtype[X.dtype],
|
|
506
562
|
W,
|
|
507
|
-
W.stride(0),
|
|
563
|
+
W.stride(0) if elementwise_affine else 0,
|
|
508
564
|
RSTD,
|
|
509
565
|
RSTD.stride(0),
|
|
510
566
|
_dW,
|
|
511
|
-
_dW.stride(0),
|
|
567
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
512
568
|
n_rows,
|
|
513
569
|
n_cols,
|
|
514
570
|
offset,
|
|
515
|
-
rows_per_program,
|
|
516
571
|
casting_mode,
|
|
572
|
+
elementwise_affine=elementwise_affine,
|
|
517
573
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
518
574
|
num_warps=num_warps,
|
|
519
575
|
**kernel_args, # XPU-specific optimization
|
|
520
576
|
)
|
|
521
577
|
dX = dX.view(*shape)
|
|
522
|
-
|
|
578
|
+
|
|
579
|
+
if elementwise_affine:
|
|
580
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
581
|
+
else:
|
|
582
|
+
dW = None
|
|
523
583
|
|
|
524
584
|
return dX, dW
|
|
525
585
|
|
|
@@ -553,6 +613,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
553
613
|
X: (B, T, H) or (BxT, H)
|
|
554
614
|
W: (H,)
|
|
555
615
|
"""
|
|
616
|
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
|
617
|
+
# Input tensor is output of a tensor parallel module and
|
|
618
|
+
# needs to be gathered to a local tensor to compute
|
|
619
|
+
# RMSE layer norm on each TP worker.
|
|
620
|
+
# TODO: support CP.
|
|
621
|
+
X = X.full_tensor()
|
|
622
|
+
|
|
556
623
|
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
557
624
|
ctx.offset = offset
|
|
558
625
|
ctx.casting_mode = casting_mode
|
|
@@ -560,7 +627,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
560
627
|
ctx.row_mode = row_mode
|
|
561
628
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
562
629
|
ctx.num_warps = num_warps
|
|
563
|
-
ctx.
|
|
630
|
+
ctx.elementwise_affine = W is not None
|
|
631
|
+
if W is not None:
|
|
632
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
633
|
+
else:
|
|
634
|
+
ctx.save_for_backward(X, RSTD)
|
|
564
635
|
return Y
|
|
565
636
|
|
|
566
637
|
@staticmethod
|
|
@@ -569,7 +640,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
569
640
|
"""
|
|
570
641
|
Y: (B, T, H) or (BxT, H)
|
|
571
642
|
"""
|
|
572
|
-
|
|
643
|
+
if ctx.elementwise_affine:
|
|
644
|
+
X, W, RSTD = ctx.saved_tensors
|
|
645
|
+
else:
|
|
646
|
+
X, RSTD = ctx.saved_tensors
|
|
647
|
+
W = None
|
|
648
|
+
|
|
649
|
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
|
650
|
+
# Gradients are output of a tensor parallel module and
|
|
651
|
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
|
652
|
+
# TODO: support CP.
|
|
653
|
+
dY = dY.full_tensor()
|
|
654
|
+
|
|
573
655
|
dX, dW = rms_norm_backward(
|
|
574
656
|
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
575
657
|
)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from typing import List
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LigerTiledMLPFunction(torch.autograd.Function):
|
|
13
|
+
"""
|
|
14
|
+
Based on DeepSpeed's TiledMLP:
|
|
15
|
+
https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
|
|
16
|
+
|
|
17
|
+
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
|
|
18
|
+
when using very long sequence lengths.
|
|
19
|
+
|
|
20
|
+
This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
|
|
21
|
+
And if you're using activation checkpointing it then occurs thrice.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
fn: the function to call on sharded inputs (e.g., mlp.forward)
|
|
25
|
+
mlp_module: the MLP nn.Module object
|
|
26
|
+
x: the input to MLP.forward (hidden_states)
|
|
27
|
+
shards: how many shards to use
|
|
28
|
+
compute_params: a list of weights engaged in the compute
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
the computed hidden_states
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@ensure_contiguous
|
|
36
|
+
def forward(
|
|
37
|
+
ctx,
|
|
38
|
+
fn: Callable,
|
|
39
|
+
mlp_module: torch.nn.Module,
|
|
40
|
+
x: torch.Tensor,
|
|
41
|
+
shards: int,
|
|
42
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
ctx.fn = fn
|
|
45
|
+
ctx.mlp_module = mlp_module
|
|
46
|
+
ctx.shards = shards
|
|
47
|
+
ctx.save_for_backward(x)
|
|
48
|
+
|
|
49
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
50
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
|
|
53
|
+
output_unsharded = torch.cat(output_shards, dim=-2)
|
|
54
|
+
|
|
55
|
+
return output_unsharded
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
@ensure_contiguous
|
|
59
|
+
def backward(ctx, *grads) -> tuple:
|
|
60
|
+
fn = ctx.fn
|
|
61
|
+
(x,) = ctx.saved_tensors
|
|
62
|
+
mlp_module = ctx.mlp_module
|
|
63
|
+
shards = ctx.shards
|
|
64
|
+
|
|
65
|
+
x_requires_grad = x.requires_grad
|
|
66
|
+
x = x.detach()
|
|
67
|
+
# detach() unsets x.requires_grad, so restore it
|
|
68
|
+
x.requires_grad_(x_requires_grad)
|
|
69
|
+
|
|
70
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
71
|
+
hidden_size = x.shape[-1]
|
|
72
|
+
x_shape_orig = x.shape
|
|
73
|
+
|
|
74
|
+
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
|
|
75
|
+
x = x.view(-1, hidden_size)
|
|
76
|
+
incoming_grad = grads[0].view(-1, hidden_size)
|
|
77
|
+
x_grad = torch.zeros_like(x)
|
|
78
|
+
|
|
79
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
|
|
80
|
+
|
|
81
|
+
for i, x_shard in enumerate(x_shards):
|
|
82
|
+
x_shard.requires_grad_(x_requires_grad)
|
|
83
|
+
|
|
84
|
+
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
|
|
85
|
+
shard_step = x_shards[i].shape[0]
|
|
86
|
+
shard_offset = i * x_shards[0].shape[0]
|
|
87
|
+
|
|
88
|
+
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
89
|
+
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
90
|
+
|
|
91
|
+
with torch.enable_grad():
|
|
92
|
+
output = fn(mlp_module, x_shard)
|
|
93
|
+
torch.autograd.backward(output, incoming_grad_shard)
|
|
94
|
+
|
|
95
|
+
# unflatten
|
|
96
|
+
x_grad = x_grad.view(x_shape_orig)
|
|
97
|
+
|
|
98
|
+
return (None, None, x_grad, None, None)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def apply_tiled_mlp(
|
|
102
|
+
fn: Callable,
|
|
103
|
+
mlp_module: torch.nn.Module,
|
|
104
|
+
x: torch.Tensor,
|
|
105
|
+
num_shards: Optional[int] = None,
|
|
106
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
Apply tiled MLP computation for memory efficiency.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
|
|
113
|
+
mlp_module: the MLP nn.Module object
|
|
114
|
+
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
115
|
+
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
|
|
116
|
+
compute_params: list of parameters for DeepSpeed ZeRO optimization
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
output tensor with the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
if num_shards is None:
|
|
122
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
123
|
+
hidden_size = x.shape[-1]
|
|
124
|
+
seqlen = x.shape[-2]
|
|
125
|
+
num_shards = math.ceil(seqlen / hidden_size)
|
|
126
|
+
|
|
127
|
+
# Ensure num_shards is at least 1
|
|
128
|
+
num_shards = max(1, num_shards)
|
|
129
|
+
|
|
130
|
+
return LigerTiledMLPFunction.apply(
|
|
131
|
+
fn,
|
|
132
|
+
mlp_module,
|
|
133
|
+
x,
|
|
134
|
+
num_shards,
|
|
135
|
+
compute_params,
|
|
136
|
+
)
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
|
|
|
78
78
|
functools.partial(torch.amp.custom_fwd, device_type=device),
|
|
79
79
|
functools.partial(torch.amp.custom_bwd, device_type=device),
|
|
80
80
|
)
|
|
81
|
+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
|
|
82
|
+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
|
|
81
83
|
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
82
84
|
|
|
83
85
|
|
|
@@ -125,3 +127,15 @@ def element_mul_kernel(
|
|
|
125
127
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
126
128
|
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
127
129
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_npu_core_count(default: int = 20) -> int:
|
|
133
|
+
"""Return NPU vector core count.
|
|
134
|
+
Fallback to `default` if Triton runtime or NPU device is unavailable.
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
utils = triton.runtime.driver.active.utils
|
|
138
|
+
props = utils.get_device_properties(0)
|
|
139
|
+
return int(props.get("num_vectorcore", default))
|
|
140
|
+
except Exception:
|
|
141
|
+
return default
|