liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107181130__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +12 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/rms_norm.py +126 -49
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/gemma3.py +1 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +118 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +3 -3
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +27 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/METADATA +9 -3
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/RECORD +58 -46
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107181130.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -54,6 +54,7 @@ def _rms_norm_forward_kernel(
|
|
|
54
54
|
eps,
|
|
55
55
|
offset,
|
|
56
56
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
57
|
+
elementwise_affine: tl.constexpr,
|
|
57
58
|
BLOCK_SIZE: tl.constexpr,
|
|
58
59
|
):
|
|
59
60
|
"""
|
|
@@ -75,7 +76,8 @@ def _rms_norm_forward_kernel(
|
|
|
75
76
|
|
|
76
77
|
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
|
|
77
78
|
X_row_dtype = X_row.dtype
|
|
78
|
-
|
|
79
|
+
if elementwise_affine:
|
|
80
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
79
81
|
|
|
80
82
|
# On Llama, only rstd is computed on fp32
|
|
81
83
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -83,7 +85,8 @@ def _rms_norm_forward_kernel(
|
|
|
83
85
|
|
|
84
86
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
85
87
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
86
|
-
|
|
88
|
+
if elementwise_affine:
|
|
89
|
+
W_row = W_row.to(tl.float32)
|
|
87
90
|
X_row = X_row.to(tl.float32)
|
|
88
91
|
|
|
89
92
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -104,7 +107,10 @@ def _rms_norm_forward_kernel(
|
|
|
104
107
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
105
108
|
X_row = X_row.to(X_row_dtype)
|
|
106
109
|
|
|
107
|
-
|
|
110
|
+
if elementwise_affine:
|
|
111
|
+
Y_row = X_row * (offset + W_row)
|
|
112
|
+
else:
|
|
113
|
+
Y_row = X_row
|
|
108
114
|
|
|
109
115
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
110
116
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -130,8 +136,9 @@ def _rms_norm_backward_kernel(
|
|
|
130
136
|
n_rows,
|
|
131
137
|
n_cols,
|
|
132
138
|
offset,
|
|
133
|
-
rows_per_program
|
|
139
|
+
rows_per_program,
|
|
134
140
|
casting_mode: tl.constexpr,
|
|
141
|
+
elementwise_affine: tl.constexpr,
|
|
135
142
|
BLOCK_SIZE: tl.constexpr,
|
|
136
143
|
):
|
|
137
144
|
"""
|
|
@@ -145,7 +152,8 @@ def _rms_norm_backward_kernel(
|
|
|
145
152
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
146
153
|
mask = col_offsets < n_cols
|
|
147
154
|
|
|
148
|
-
|
|
155
|
+
if elementwise_affine:
|
|
156
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
149
157
|
|
|
150
158
|
dY_ptr += row_start * dY_row_stride
|
|
151
159
|
dX_ptr += row_start * dX_row_stride
|
|
@@ -153,8 +161,9 @@ def _rms_norm_backward_kernel(
|
|
|
153
161
|
X_ptr += row_start * X_row_stride
|
|
154
162
|
RSTD_ptr += row_start
|
|
155
163
|
|
|
156
|
-
|
|
157
|
-
|
|
164
|
+
if elementwise_affine:
|
|
165
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
166
|
+
W_row = W_row + offset
|
|
158
167
|
|
|
159
168
|
for _ in range(row_start, row_end):
|
|
160
169
|
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
|
|
@@ -167,24 +176,34 @@ def _rms_norm_backward_kernel(
|
|
|
167
176
|
|
|
168
177
|
# Different bacward graphs for different casting modes
|
|
169
178
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
170
|
-
|
|
179
|
+
if elementwise_affine:
|
|
180
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
181
|
+
else:
|
|
182
|
+
m = dY_row.to(tl.float32)
|
|
171
183
|
|
|
172
184
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
173
185
|
dY_row = dY_row.to(tl.float32)
|
|
174
|
-
|
|
186
|
+
if elementwise_affine:
|
|
187
|
+
m = dY_row * W_row
|
|
188
|
+
else:
|
|
189
|
+
m = dY_row
|
|
175
190
|
else:
|
|
176
|
-
|
|
191
|
+
if elementwise_affine:
|
|
192
|
+
m = dY_row * W_row
|
|
193
|
+
else:
|
|
194
|
+
m = dY_row
|
|
177
195
|
|
|
178
196
|
dX_row = rstd_row * m
|
|
179
197
|
|
|
180
198
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
181
199
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
200
|
+
if elementwise_affine:
|
|
201
|
+
# calculate the gradient of W
|
|
202
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
203
|
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
|
204
|
+
else:
|
|
205
|
+
# here X_row is already in fp32 (see previous if block)
|
|
206
|
+
dW_row += dY_row * (X_row * rstd_row)
|
|
188
207
|
|
|
189
208
|
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
190
209
|
|
|
@@ -193,7 +212,8 @@ def _rms_norm_backward_kernel(
|
|
|
193
212
|
X_ptr += X_row_stride
|
|
194
213
|
RSTD_ptr += RSTD_row_stride
|
|
195
214
|
|
|
196
|
-
|
|
215
|
+
if elementwise_affine:
|
|
216
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
197
217
|
|
|
198
218
|
|
|
199
219
|
@triton.jit
|
|
@@ -211,6 +231,7 @@ def _block_rms_norm_forward_kernel(
|
|
|
211
231
|
eps,
|
|
212
232
|
offset,
|
|
213
233
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
234
|
+
elementwise_affine: tl.constexpr,
|
|
214
235
|
BLOCK_SIZE: tl.constexpr,
|
|
215
236
|
BLOCK_ROW: tl.constexpr,
|
|
216
237
|
):
|
|
@@ -234,7 +255,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
234
255
|
other=0,
|
|
235
256
|
)
|
|
236
257
|
X_row_dtype = X_row.dtype
|
|
237
|
-
|
|
258
|
+
if elementwise_affine:
|
|
259
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
238
260
|
|
|
239
261
|
# On Llama, only rstd is computed on fp32
|
|
240
262
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -242,7 +264,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
242
264
|
|
|
243
265
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
244
266
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
245
|
-
|
|
267
|
+
if elementwise_affine:
|
|
268
|
+
W_row = W_row.to(tl.float32)
|
|
246
269
|
X_row = X_row.to(tl.float32)
|
|
247
270
|
|
|
248
271
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -263,7 +286,10 @@ def _block_rms_norm_forward_kernel(
|
|
|
263
286
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
264
287
|
X_row = X_row.to(X_row_dtype)
|
|
265
288
|
|
|
266
|
-
|
|
289
|
+
if elementwise_affine:
|
|
290
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
291
|
+
else:
|
|
292
|
+
Y_row = X_row
|
|
267
293
|
|
|
268
294
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
269
295
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -293,8 +319,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
293
319
|
n_rows,
|
|
294
320
|
n_cols,
|
|
295
321
|
offset,
|
|
296
|
-
rows_per_program: tl.constexpr,
|
|
297
322
|
casting_mode: tl.constexpr,
|
|
323
|
+
elementwise_affine: tl.constexpr,
|
|
298
324
|
BLOCK_SIZE: tl.constexpr,
|
|
299
325
|
BLOCK_ROW: tl.constexpr,
|
|
300
326
|
):
|
|
@@ -309,10 +335,11 @@ def _block_rms_norm_backward_kernel(
|
|
|
309
335
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
310
336
|
col_mask = col_offsets < n_cols
|
|
311
337
|
|
|
312
|
-
|
|
338
|
+
if elementwise_affine:
|
|
339
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
313
340
|
|
|
314
|
-
|
|
315
|
-
|
|
341
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
342
|
+
W_row = W_row + offset
|
|
316
343
|
|
|
317
344
|
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
318
345
|
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
@@ -335,13 +362,22 @@ def _block_rms_norm_backward_kernel(
|
|
|
335
362
|
|
|
336
363
|
# Different bacward graphs for different casting modes
|
|
337
364
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
338
|
-
|
|
365
|
+
if elementwise_affine:
|
|
366
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
367
|
+
else:
|
|
368
|
+
m = dY_row.to(tl.float32)
|
|
339
369
|
|
|
340
370
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
341
371
|
dY_row = dY_row.to(tl.float32)
|
|
342
|
-
|
|
372
|
+
if elementwise_affine:
|
|
373
|
+
m = dY_row * W_row[None, :]
|
|
374
|
+
else:
|
|
375
|
+
m = dY_row
|
|
343
376
|
else:
|
|
344
|
-
|
|
377
|
+
if elementwise_affine:
|
|
378
|
+
m = dY_row * W_row[None, :]
|
|
379
|
+
else:
|
|
380
|
+
m = dY_row
|
|
345
381
|
|
|
346
382
|
dX_row = rstd_row[:, None] * m
|
|
347
383
|
|
|
@@ -349,13 +385,13 @@ def _block_rms_norm_backward_kernel(
|
|
|
349
385
|
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
350
386
|
)
|
|
351
387
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
388
|
+
if elementwise_affine:
|
|
389
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
390
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
391
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
392
|
+
else:
|
|
393
|
+
# here X_row is already in fp32 (see previous if block)
|
|
394
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
359
395
|
|
|
360
396
|
tl.store(
|
|
361
397
|
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
@@ -363,7 +399,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
363
399
|
mask=row_mask[:, None] & col_mask[None, :],
|
|
364
400
|
)
|
|
365
401
|
|
|
366
|
-
|
|
402
|
+
if elementwise_affine:
|
|
403
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
367
404
|
|
|
368
405
|
|
|
369
406
|
_str_to_casting_mode = {
|
|
@@ -392,8 +429,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
392
429
|
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
393
430
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
394
431
|
|
|
395
|
-
|
|
396
|
-
|
|
432
|
+
if W is not None:
|
|
433
|
+
# Check constraints.
|
|
434
|
+
assert X.shape[1] == W.shape[0], (
|
|
435
|
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
436
|
+
)
|
|
437
|
+
elementwise_affine = True
|
|
438
|
+
else:
|
|
439
|
+
elementwise_affine = False
|
|
397
440
|
|
|
398
441
|
# XPU-specific optimization
|
|
399
442
|
kernel_args = {}
|
|
@@ -406,13 +449,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
406
449
|
X,
|
|
407
450
|
X.stride(0),
|
|
408
451
|
W,
|
|
409
|
-
W.stride(0),
|
|
452
|
+
W.stride(0) if elementwise_affine else 0,
|
|
410
453
|
RSTD,
|
|
411
454
|
RSTD.stride(0),
|
|
412
455
|
n_cols,
|
|
413
456
|
eps,
|
|
414
457
|
offset,
|
|
415
458
|
casting_mode,
|
|
459
|
+
elementwise_affine=elementwise_affine,
|
|
416
460
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
417
461
|
num_warps=num_warps,
|
|
418
462
|
**kernel_args, # XPU-specific optimization
|
|
@@ -426,7 +470,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
426
470
|
X,
|
|
427
471
|
X.stride(0),
|
|
428
472
|
W,
|
|
429
|
-
W.stride(0),
|
|
473
|
+
W.stride(0) if elementwise_affine else 0,
|
|
430
474
|
RSTD,
|
|
431
475
|
RSTD.stride(0),
|
|
432
476
|
n_rows,
|
|
@@ -434,6 +478,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
434
478
|
eps,
|
|
435
479
|
offset,
|
|
436
480
|
casting_mode,
|
|
481
|
+
elementwise_affine=elementwise_affine,
|
|
437
482
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
438
483
|
num_warps=num_warps,
|
|
439
484
|
**kernel_args, # XPU-specific optimization
|
|
@@ -455,8 +500,13 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
455
500
|
elif X.device.type == "npu":
|
|
456
501
|
sm_count = get_npu_multi_processor_count()
|
|
457
502
|
|
|
458
|
-
|
|
459
|
-
|
|
503
|
+
if W is not None:
|
|
504
|
+
# fp32 for numerical stability especially.
|
|
505
|
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
506
|
+
elementwise_affine = True
|
|
507
|
+
else:
|
|
508
|
+
_dW = None
|
|
509
|
+
elementwise_affine = False
|
|
460
510
|
|
|
461
511
|
if n_cols > BLOCK_SIZE:
|
|
462
512
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -483,16 +533,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
483
533
|
X.stride(0),
|
|
484
534
|
torch_to_triton_dtype[X.dtype],
|
|
485
535
|
W,
|
|
486
|
-
W.stride(0),
|
|
536
|
+
W.stride(0) if elementwise_affine else 0,
|
|
487
537
|
RSTD,
|
|
488
538
|
RSTD.stride(0),
|
|
489
539
|
_dW,
|
|
490
|
-
_dW.stride(0),
|
|
540
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
491
541
|
n_rows,
|
|
492
542
|
n_cols,
|
|
493
543
|
offset,
|
|
494
544
|
rows_per_program,
|
|
495
545
|
casting_mode,
|
|
546
|
+
elementwise_affine=elementwise_affine,
|
|
496
547
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
497
548
|
num_warps=num_warps,
|
|
498
549
|
**kernel_args, # XPU-specific optimization
|
|
@@ -509,22 +560,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
509
560
|
X.stride(0),
|
|
510
561
|
torch_to_triton_dtype[X.dtype],
|
|
511
562
|
W,
|
|
512
|
-
W.stride(0),
|
|
563
|
+
W.stride(0) if elementwise_affine else 0,
|
|
513
564
|
RSTD,
|
|
514
565
|
RSTD.stride(0),
|
|
515
566
|
_dW,
|
|
516
|
-
_dW.stride(0),
|
|
567
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
517
568
|
n_rows,
|
|
518
569
|
n_cols,
|
|
519
570
|
offset,
|
|
520
|
-
rows_per_program,
|
|
521
571
|
casting_mode,
|
|
572
|
+
elementwise_affine=elementwise_affine,
|
|
522
573
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
523
574
|
num_warps=num_warps,
|
|
524
575
|
**kernel_args, # XPU-specific optimization
|
|
525
576
|
)
|
|
526
577
|
dX = dX.view(*shape)
|
|
527
|
-
|
|
578
|
+
|
|
579
|
+
if elementwise_affine:
|
|
580
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
581
|
+
else:
|
|
582
|
+
dW = None
|
|
528
583
|
|
|
529
584
|
return dX, dW
|
|
530
585
|
|
|
@@ -558,6 +613,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
558
613
|
X: (B, T, H) or (BxT, H)
|
|
559
614
|
W: (H,)
|
|
560
615
|
"""
|
|
616
|
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
|
617
|
+
# Input tensor is output of a tensor parallel module and
|
|
618
|
+
# needs to be gathered to a local tensor to compute
|
|
619
|
+
# RMSE layer norm on each TP worker.
|
|
620
|
+
# TODO: support CP.
|
|
621
|
+
X = X.full_tensor()
|
|
622
|
+
|
|
561
623
|
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
562
624
|
ctx.offset = offset
|
|
563
625
|
ctx.casting_mode = casting_mode
|
|
@@ -565,7 +627,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
565
627
|
ctx.row_mode = row_mode
|
|
566
628
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
567
629
|
ctx.num_warps = num_warps
|
|
568
|
-
ctx.
|
|
630
|
+
ctx.elementwise_affine = W is not None
|
|
631
|
+
if W is not None:
|
|
632
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
633
|
+
else:
|
|
634
|
+
ctx.save_for_backward(X, RSTD)
|
|
569
635
|
return Y
|
|
570
636
|
|
|
571
637
|
@staticmethod
|
|
@@ -574,7 +640,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
574
640
|
"""
|
|
575
641
|
Y: (B, T, H) or (BxT, H)
|
|
576
642
|
"""
|
|
577
|
-
|
|
643
|
+
if ctx.elementwise_affine:
|
|
644
|
+
X, W, RSTD = ctx.saved_tensors
|
|
645
|
+
else:
|
|
646
|
+
X, RSTD = ctx.saved_tensors
|
|
647
|
+
W = None
|
|
648
|
+
|
|
649
|
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
|
650
|
+
# Gradients are output of a tensor parallel module and
|
|
651
|
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
|
652
|
+
# TODO: support CP.
|
|
653
|
+
dY = dY.full_tensor()
|
|
654
|
+
|
|
578
655
|
dX, dW = rms_norm_backward(
|
|
579
656
|
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
580
657
|
)
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -127,3 +127,15 @@ def element_mul_kernel(
|
|
|
127
127
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
128
128
|
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
129
129
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_npu_core_count(default: int = 20) -> int:
|
|
133
|
+
"""Return NPU vector core count.
|
|
134
|
+
Fallback to `default` if Triton runtime or NPU device is unavailable.
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
utils = triton.runtime.driver.active.utils
|
|
138
|
+
props = utils.get_device_properties(0)
|
|
139
|
+
return int(props.get("num_vectorcore", default))
|
|
140
|
+
except Exception:
|
|
141
|
+
return default
|
|
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
|
|
|
41
41
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
42
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
|
|
43
43
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
|
|
44
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
|
|
44
45
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
45
46
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
|
|
46
47
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
|
|
@@ -110,6 +111,7 @@ def __getattr__(name: str):
|
|
|
110
111
|
"apply_liger_kernel_to_glm4",
|
|
111
112
|
"apply_liger_kernel_to_glm4v",
|
|
112
113
|
"apply_liger_kernel_to_glm4v_moe",
|
|
114
|
+
"apply_liger_kernel_to_gpt_oss",
|
|
113
115
|
"apply_liger_kernel_to_granite",
|
|
114
116
|
"apply_liger_kernel_to_internvl",
|
|
115
117
|
"apply_liger_kernel_to_llama",
|
|
@@ -187,6 +189,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
187
189
|
"apply_liger_kernel_to_glm4",
|
|
188
190
|
"apply_liger_kernel_to_glm4v",
|
|
189
191
|
"apply_liger_kernel_to_glm4v_moe",
|
|
192
|
+
"apply_liger_kernel_to_gpt_oss",
|
|
190
193
|
"apply_liger_kernel_to_granite",
|
|
191
194
|
"apply_liger_kernel_to_internvl",
|
|
192
195
|
"apply_liger_kernel_to_llama",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import logging
|
|
2
3
|
|
|
3
4
|
from transformers import AutoConfig
|
|
4
5
|
from transformers import AutoModelForCausalLM
|
|
@@ -6,6 +7,8 @@ from transformers import AutoModelForCausalLM
|
|
|
6
7
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
7
8
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
|
8
9
|
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
def _get_model_config(model_dir, **model_init_kwargs):
|
|
11
14
|
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
|
|
@@ -36,3 +39,21 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
|
36
39
|
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
|
37
40
|
|
|
38
41
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_config(cls, config, **kwargs):
|
|
45
|
+
model_type = getattr(config, "model_type", None)
|
|
46
|
+
if not model_type:
|
|
47
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
|
48
|
+
return
|
|
49
|
+
model_type = config.model_type
|
|
50
|
+
|
|
51
|
+
_apply_liger_kernel(model_type, **kwargs)
|
|
52
|
+
|
|
53
|
+
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
|
|
54
|
+
# model initialization errors otherwise
|
|
55
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
56
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
57
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
|
58
|
+
|
|
59
|
+
return super().from_config(config, **applicable_kwargs)
|
liger_kernel/transformers/dyt.py
CHANGED
|
@@ -3,26 +3,26 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from liger_kernel.ops
|
|
7
|
-
from liger_kernel.ops
|
|
8
|
-
from liger_kernel.ops
|
|
9
|
-
from liger_kernel.ops
|
|
10
|
-
from liger_kernel.ops
|
|
11
|
-
from liger_kernel.ops
|
|
12
|
-
from liger_kernel.ops
|
|
13
|
-
from liger_kernel.ops
|
|
14
|
-
from liger_kernel.ops
|
|
15
|
-
from liger_kernel.ops
|
|
16
|
-
from liger_kernel.ops
|
|
17
|
-
from liger_kernel.ops
|
|
18
|
-
from liger_kernel.ops
|
|
19
|
-
from liger_kernel.ops
|
|
20
|
-
from liger_kernel.ops
|
|
21
|
-
from liger_kernel.ops
|
|
22
|
-
from liger_kernel.ops
|
|
23
|
-
from liger_kernel.ops
|
|
24
|
-
from liger_kernel.ops
|
|
25
|
-
from liger_kernel.ops
|
|
6
|
+
from liger_kernel.ops import LigerCrossEntropyFunction
|
|
7
|
+
from liger_kernel.ops import LigerDyTFunction
|
|
8
|
+
from liger_kernel.ops import LigerFusedAddRMSNormFunction
|
|
9
|
+
from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
|
|
10
|
+
from liger_kernel.ops import LigerFusedLinearJSDFunction
|
|
11
|
+
from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
|
|
12
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
13
|
+
from liger_kernel.ops import LigerGroupNormFunction
|
|
14
|
+
from liger_kernel.ops import LigerJSDFunction
|
|
15
|
+
from liger_kernel.ops import LigerKLDivLossFunction
|
|
16
|
+
from liger_kernel.ops import LigerLayerNormFunction
|
|
17
|
+
from liger_kernel.ops import LigerMultiTokenAttentionFunction
|
|
18
|
+
from liger_kernel.ops import LigerPolyNormFunction
|
|
19
|
+
from liger_kernel.ops import LigerQwen2VLMRopeFunction
|
|
20
|
+
from liger_kernel.ops import LigerRMSNormFunction
|
|
21
|
+
from liger_kernel.ops import LigerRopeFunction
|
|
22
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
23
|
+
from liger_kernel.ops import LigerSoftmaxFunction
|
|
24
|
+
from liger_kernel.ops import LigerSparsemaxFunction
|
|
25
|
+
from liger_kernel.ops import LigerTVDLossFunction
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@dataclass
|
|
@@ -2,7 +2,7 @@ from typing import Optional
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from liger_kernel.ops
|
|
5
|
+
from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
|
|
6
6
|
from liger_kernel.transformers.functional import CrossEntropyOutput
|
|
7
7
|
|
|
8
8
|
|
|
@@ -5,7 +5,7 @@ from typing import Optional
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn as nn
|
|
7
7
|
|
|
8
|
-
from liger_kernel.ops
|
|
8
|
+
from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class LigerFusedNeighborhoodAttention(nn.Module):
|
liger_kernel/transformers/jsd.py
CHANGED
|
@@ -5,7 +5,7 @@ Supports both text and vision RoPE variants with fused operations for optimal pe
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from liger_kernel.ops
|
|
8
|
+
from liger_kernel.ops import LigerLlama4RopeFunction
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def liger_llama4_text_rotary_pos_emb(
|