liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +14 -4
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +17 -16
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/gemma2.py +3 -3
- liger_kernel/transformers/model/gemma3.py +11 -5
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/loss_utils.py +6 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +196 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +5 -13
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.4.dist-info/RECORD +0 -118
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -20,9 +20,12 @@ import triton.language as tl
|
|
|
20
20
|
from liger_kernel.ops.utils import calculate_settings
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
24
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
23
25
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
26
|
+
from liger_kernel.utils import is_npu_available
|
|
24
27
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
28
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
29
|
try:
|
|
27
30
|
# typical import path with dispatch available
|
|
28
31
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -52,6 +55,7 @@ def _rms_norm_forward_kernel(
|
|
|
52
55
|
eps,
|
|
53
56
|
offset,
|
|
54
57
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
58
|
+
elementwise_affine: tl.constexpr,
|
|
55
59
|
BLOCK_SIZE: tl.constexpr,
|
|
56
60
|
):
|
|
57
61
|
"""
|
|
@@ -67,13 +71,14 @@ def _rms_norm_forward_kernel(
|
|
|
67
71
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
68
72
|
mask = col_offsets < n_cols
|
|
69
73
|
|
|
70
|
-
Y_ptr
|
|
71
|
-
X_ptr
|
|
72
|
-
RSTD_ptr
|
|
74
|
+
y_base = Y_ptr + row_idx * Y_row_stride
|
|
75
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
76
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
73
77
|
|
|
74
|
-
X_row = tl.load(
|
|
78
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
|
|
75
79
|
X_row_dtype = X_row.dtype
|
|
76
|
-
|
|
80
|
+
if elementwise_affine:
|
|
81
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
77
82
|
|
|
78
83
|
# On Llama, only rstd is computed on fp32
|
|
79
84
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -81,7 +86,8 @@ def _rms_norm_forward_kernel(
|
|
|
81
86
|
|
|
82
87
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
88
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
-
|
|
89
|
+
if elementwise_affine:
|
|
90
|
+
W_row = W_row.to(tl.float32)
|
|
85
91
|
X_row = X_row.to(tl.float32)
|
|
86
92
|
|
|
87
93
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -94,7 +100,7 @@ def _rms_norm_forward_kernel(
|
|
|
94
100
|
# We can save time by caching rms with minimal memory overhead
|
|
95
101
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
96
102
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
97
|
-
tl.store(
|
|
103
|
+
tl.store(rstd_base, rstd)
|
|
98
104
|
|
|
99
105
|
X_row = X_row * rstd
|
|
100
106
|
|
|
@@ -102,12 +108,15 @@ def _rms_norm_forward_kernel(
|
|
|
102
108
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
103
109
|
X_row = X_row.to(X_row_dtype)
|
|
104
110
|
|
|
105
|
-
|
|
111
|
+
if elementwise_affine:
|
|
112
|
+
Y_row = X_row * (offset + W_row)
|
|
113
|
+
else:
|
|
114
|
+
Y_row = X_row
|
|
106
115
|
|
|
107
116
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
108
117
|
Y_row = Y_row.to(X_row_dtype)
|
|
109
118
|
|
|
110
|
-
tl.store(
|
|
119
|
+
tl.store(y_base + col_offsets, Y_row, mask=mask)
|
|
111
120
|
|
|
112
121
|
|
|
113
122
|
@triton.jit
|
|
@@ -128,8 +137,9 @@ def _rms_norm_backward_kernel(
|
|
|
128
137
|
n_rows,
|
|
129
138
|
n_cols,
|
|
130
139
|
offset,
|
|
131
|
-
rows_per_program
|
|
140
|
+
rows_per_program,
|
|
132
141
|
casting_mode: tl.constexpr,
|
|
142
|
+
elementwise_affine: tl.constexpr,
|
|
133
143
|
BLOCK_SIZE: tl.constexpr,
|
|
134
144
|
):
|
|
135
145
|
"""
|
|
@@ -143,55 +153,63 @@ def _rms_norm_backward_kernel(
|
|
|
143
153
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
144
154
|
mask = col_offsets < n_cols
|
|
145
155
|
|
|
146
|
-
|
|
156
|
+
if elementwise_affine:
|
|
157
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
158
|
|
|
148
|
-
|
|
149
|
-
|
|
159
|
+
if elementwise_affine:
|
|
160
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
161
|
+
W_row = W_row + offset
|
|
150
162
|
|
|
151
|
-
|
|
152
|
-
|
|
163
|
+
for row_idx in range(row_start, row_end):
|
|
164
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
165
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
153
166
|
|
|
154
|
-
|
|
155
|
-
|
|
167
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
168
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
156
169
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
170
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
171
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
160
172
|
|
|
161
173
|
# Get cached rms
|
|
162
|
-
rstd_row = tl.load(
|
|
174
|
+
rstd_row = tl.load(rstd_base)
|
|
163
175
|
|
|
164
176
|
X_row = X_row.to(tl.float32)
|
|
165
177
|
|
|
166
178
|
# Different bacward graphs for different casting modes
|
|
167
179
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
168
|
-
|
|
180
|
+
if elementwise_affine:
|
|
181
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
182
|
+
else:
|
|
183
|
+
m = dY_row.to(tl.float32)
|
|
169
184
|
|
|
170
185
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
171
186
|
dY_row = dY_row.to(tl.float32)
|
|
172
|
-
|
|
187
|
+
if elementwise_affine:
|
|
188
|
+
m = dY_row * W_row
|
|
189
|
+
else:
|
|
190
|
+
m = dY_row
|
|
173
191
|
else:
|
|
174
|
-
|
|
192
|
+
if elementwise_affine:
|
|
193
|
+
m = dY_row * W_row
|
|
194
|
+
else:
|
|
195
|
+
m = dY_row
|
|
175
196
|
|
|
176
197
|
dX_row = rstd_row * m
|
|
177
198
|
|
|
178
199
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
179
200
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
201
|
+
if elementwise_affine:
|
|
202
|
+
# calculate the gradient of W
|
|
203
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
204
|
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
|
205
|
+
else:
|
|
206
|
+
# here X_row is already in fp32 (see previous if block)
|
|
207
|
+
dW_row += dY_row * (X_row * rstd_row)
|
|
186
208
|
|
|
187
|
-
tl.store(
|
|
209
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
188
210
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
X_ptr += X_row_stride
|
|
192
|
-
RSTD_ptr += RSTD_row_stride
|
|
193
|
-
|
|
194
|
-
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
211
|
+
if elementwise_affine:
|
|
212
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
195
213
|
|
|
196
214
|
|
|
197
215
|
@triton.jit
|
|
@@ -209,6 +227,7 @@ def _block_rms_norm_forward_kernel(
|
|
|
209
227
|
eps,
|
|
210
228
|
offset,
|
|
211
229
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
230
|
+
elementwise_affine: tl.constexpr,
|
|
212
231
|
BLOCK_SIZE: tl.constexpr,
|
|
213
232
|
BLOCK_ROW: tl.constexpr,
|
|
214
233
|
):
|
|
@@ -232,7 +251,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
232
251
|
other=0,
|
|
233
252
|
)
|
|
234
253
|
X_row_dtype = X_row.dtype
|
|
235
|
-
|
|
254
|
+
if elementwise_affine:
|
|
255
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
236
256
|
|
|
237
257
|
# On Llama, only rstd is computed on fp32
|
|
238
258
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -240,7 +260,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
240
260
|
|
|
241
261
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
242
262
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
243
|
-
|
|
263
|
+
if elementwise_affine:
|
|
264
|
+
W_row = W_row.to(tl.float32)
|
|
244
265
|
X_row = X_row.to(tl.float32)
|
|
245
266
|
|
|
246
267
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -261,7 +282,10 @@ def _block_rms_norm_forward_kernel(
|
|
|
261
282
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
262
283
|
X_row = X_row.to(X_row_dtype)
|
|
263
284
|
|
|
264
|
-
|
|
285
|
+
if elementwise_affine:
|
|
286
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
287
|
+
else:
|
|
288
|
+
Y_row = X_row
|
|
265
289
|
|
|
266
290
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
267
291
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -291,8 +315,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
291
315
|
n_rows,
|
|
292
316
|
n_cols,
|
|
293
317
|
offset,
|
|
294
|
-
rows_per_program: tl.constexpr,
|
|
295
318
|
casting_mode: tl.constexpr,
|
|
319
|
+
elementwise_affine: tl.constexpr,
|
|
296
320
|
BLOCK_SIZE: tl.constexpr,
|
|
297
321
|
BLOCK_ROW: tl.constexpr,
|
|
298
322
|
):
|
|
@@ -307,10 +331,11 @@ def _block_rms_norm_backward_kernel(
|
|
|
307
331
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
308
332
|
col_mask = col_offsets < n_cols
|
|
309
333
|
|
|
310
|
-
|
|
334
|
+
if elementwise_affine:
|
|
335
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
311
336
|
|
|
312
|
-
|
|
313
|
-
|
|
337
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
338
|
+
W_row = W_row + offset
|
|
314
339
|
|
|
315
340
|
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
316
341
|
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
@@ -333,13 +358,22 @@ def _block_rms_norm_backward_kernel(
|
|
|
333
358
|
|
|
334
359
|
# Different bacward graphs for different casting modes
|
|
335
360
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
336
|
-
|
|
361
|
+
if elementwise_affine:
|
|
362
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
363
|
+
else:
|
|
364
|
+
m = dY_row.to(tl.float32)
|
|
337
365
|
|
|
338
366
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
339
367
|
dY_row = dY_row.to(tl.float32)
|
|
340
|
-
|
|
368
|
+
if elementwise_affine:
|
|
369
|
+
m = dY_row * W_row[None, :]
|
|
370
|
+
else:
|
|
371
|
+
m = dY_row
|
|
341
372
|
else:
|
|
342
|
-
|
|
373
|
+
if elementwise_affine:
|
|
374
|
+
m = dY_row * W_row[None, :]
|
|
375
|
+
else:
|
|
376
|
+
m = dY_row
|
|
343
377
|
|
|
344
378
|
dX_row = rstd_row[:, None] * m
|
|
345
379
|
|
|
@@ -347,12 +381,13 @@ def _block_rms_norm_backward_kernel(
|
|
|
347
381
|
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
348
382
|
)
|
|
349
383
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
384
|
+
if elementwise_affine:
|
|
385
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
386
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
387
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
388
|
+
else:
|
|
389
|
+
# here X_row is already in fp32 (see previous if block)
|
|
390
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
356
391
|
|
|
357
392
|
tl.store(
|
|
358
393
|
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
@@ -360,7 +395,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
360
395
|
mask=row_mask[:, None] & col_mask[None, :],
|
|
361
396
|
)
|
|
362
397
|
|
|
363
|
-
|
|
398
|
+
if elementwise_affine:
|
|
399
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
364
400
|
|
|
365
401
|
|
|
366
402
|
_str_to_casting_mode = {
|
|
@@ -389,13 +425,19 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
389
425
|
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
390
426
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
391
427
|
|
|
392
|
-
|
|
393
|
-
|
|
428
|
+
if W is not None:
|
|
429
|
+
# Check constraints.
|
|
430
|
+
assert X.shape[1] == W.shape[0], (
|
|
431
|
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
432
|
+
)
|
|
433
|
+
elementwise_affine = True
|
|
434
|
+
else:
|
|
435
|
+
elementwise_affine = False
|
|
394
436
|
|
|
395
437
|
# XPU-specific optimization
|
|
396
438
|
kernel_args = {}
|
|
397
439
|
if X.device.type == "xpu":
|
|
398
|
-
kernel_args
|
|
440
|
+
set_large_grf_mode(kernel_args)
|
|
399
441
|
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
400
442
|
_rms_norm_forward_kernel[(n_rows,)](
|
|
401
443
|
Y,
|
|
@@ -403,13 +445,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
403
445
|
X,
|
|
404
446
|
X.stride(0),
|
|
405
447
|
W,
|
|
406
|
-
W.stride(0),
|
|
448
|
+
W.stride(0) if elementwise_affine else 0,
|
|
407
449
|
RSTD,
|
|
408
450
|
RSTD.stride(0),
|
|
409
451
|
n_cols,
|
|
410
452
|
eps,
|
|
411
453
|
offset,
|
|
412
454
|
casting_mode,
|
|
455
|
+
elementwise_affine=elementwise_affine,
|
|
413
456
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
414
457
|
num_warps=num_warps,
|
|
415
458
|
**kernel_args, # XPU-specific optimization
|
|
@@ -423,7 +466,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
423
466
|
X,
|
|
424
467
|
X.stride(0),
|
|
425
468
|
W,
|
|
426
|
-
W.stride(0),
|
|
469
|
+
W.stride(0) if elementwise_affine else 0,
|
|
427
470
|
RSTD,
|
|
428
471
|
RSTD.stride(0),
|
|
429
472
|
n_rows,
|
|
@@ -431,6 +474,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
431
474
|
eps,
|
|
432
475
|
offset,
|
|
433
476
|
casting_mode,
|
|
477
|
+
elementwise_affine=elementwise_affine,
|
|
434
478
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
435
479
|
num_warps=num_warps,
|
|
436
480
|
**kernel_args, # XPU-specific optimization
|
|
@@ -449,9 +493,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
449
493
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
450
494
|
elif X.device.type == "xpu":
|
|
451
495
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
496
|
+
elif X.device.type == "npu":
|
|
497
|
+
sm_count = get_npu_core_count()
|
|
452
498
|
|
|
453
|
-
|
|
454
|
-
|
|
499
|
+
if W is not None:
|
|
500
|
+
# fp32 for numerical stability especially.
|
|
501
|
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
502
|
+
elementwise_affine = True
|
|
503
|
+
else:
|
|
504
|
+
_dW = None
|
|
505
|
+
elementwise_affine = False
|
|
455
506
|
|
|
456
507
|
if n_cols > BLOCK_SIZE:
|
|
457
508
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -466,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
466
517
|
# XPU-specific optimization
|
|
467
518
|
kernel_args = {}
|
|
468
519
|
if X.device.type == "xpu":
|
|
469
|
-
kernel_args
|
|
520
|
+
set_large_grf_mode(kernel_args)
|
|
470
521
|
|
|
471
522
|
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
472
523
|
_rms_norm_backward_kernel[grid](
|
|
@@ -478,16 +529,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
478
529
|
X.stride(0),
|
|
479
530
|
torch_to_triton_dtype[X.dtype],
|
|
480
531
|
W,
|
|
481
|
-
W.stride(0),
|
|
532
|
+
W.stride(0) if elementwise_affine else 0,
|
|
482
533
|
RSTD,
|
|
483
534
|
RSTD.stride(0),
|
|
484
535
|
_dW,
|
|
485
|
-
_dW.stride(0),
|
|
536
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
486
537
|
n_rows,
|
|
487
538
|
n_cols,
|
|
488
539
|
offset,
|
|
489
540
|
rows_per_program,
|
|
490
541
|
casting_mode,
|
|
542
|
+
elementwise_affine=elementwise_affine,
|
|
491
543
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
492
544
|
num_warps=num_warps,
|
|
493
545
|
**kernel_args, # XPU-specific optimization
|
|
@@ -504,22 +556,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
504
556
|
X.stride(0),
|
|
505
557
|
torch_to_triton_dtype[X.dtype],
|
|
506
558
|
W,
|
|
507
|
-
W.stride(0),
|
|
559
|
+
W.stride(0) if elementwise_affine else 0,
|
|
508
560
|
RSTD,
|
|
509
561
|
RSTD.stride(0),
|
|
510
562
|
_dW,
|
|
511
|
-
_dW.stride(0),
|
|
563
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
512
564
|
n_rows,
|
|
513
565
|
n_cols,
|
|
514
566
|
offset,
|
|
515
|
-
rows_per_program,
|
|
516
567
|
casting_mode,
|
|
568
|
+
elementwise_affine=elementwise_affine,
|
|
517
569
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
518
570
|
num_warps=num_warps,
|
|
519
571
|
**kernel_args, # XPU-specific optimization
|
|
520
572
|
)
|
|
521
573
|
dX = dX.view(*shape)
|
|
522
|
-
|
|
574
|
+
|
|
575
|
+
if elementwise_affine:
|
|
576
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
577
|
+
else:
|
|
578
|
+
dW = None
|
|
523
579
|
|
|
524
580
|
return dX, dW
|
|
525
581
|
|
|
@@ -553,6 +609,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
553
609
|
X: (B, T, H) or (BxT, H)
|
|
554
610
|
W: (H,)
|
|
555
611
|
"""
|
|
612
|
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
|
613
|
+
# Input tensor is output of a tensor parallel module and
|
|
614
|
+
# needs to be gathered to a local tensor to compute
|
|
615
|
+
# RMSE layer norm on each TP worker.
|
|
616
|
+
# TODO: support CP.
|
|
617
|
+
X = X.full_tensor()
|
|
618
|
+
|
|
556
619
|
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
557
620
|
ctx.offset = offset
|
|
558
621
|
ctx.casting_mode = casting_mode
|
|
@@ -560,7 +623,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
560
623
|
ctx.row_mode = row_mode
|
|
561
624
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
562
625
|
ctx.num_warps = num_warps
|
|
563
|
-
ctx.
|
|
626
|
+
ctx.elementwise_affine = W is not None
|
|
627
|
+
if W is not None:
|
|
628
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
629
|
+
else:
|
|
630
|
+
ctx.save_for_backward(X, RSTD)
|
|
564
631
|
return Y
|
|
565
632
|
|
|
566
633
|
@staticmethod
|
|
@@ -569,7 +636,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
569
636
|
"""
|
|
570
637
|
Y: (B, T, H) or (BxT, H)
|
|
571
638
|
"""
|
|
572
|
-
|
|
639
|
+
if ctx.elementwise_affine:
|
|
640
|
+
X, W, RSTD = ctx.saved_tensors
|
|
641
|
+
else:
|
|
642
|
+
X, RSTD = ctx.saved_tensors
|
|
643
|
+
W = None
|
|
644
|
+
|
|
645
|
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
|
646
|
+
# Gradients are output of a tensor parallel module and
|
|
647
|
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
|
648
|
+
# TODO: support CP.
|
|
649
|
+
dY = dY.full_tensor()
|
|
650
|
+
|
|
573
651
|
dX, dW = rms_norm_backward(
|
|
574
652
|
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
575
653
|
)
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
|
|
|
78
78
|
functools.partial(torch.amp.custom_fwd, device_type=device),
|
|
79
79
|
functools.partial(torch.amp.custom_bwd, device_type=device),
|
|
80
80
|
)
|
|
81
|
+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
|
|
82
|
+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
|
|
81
83
|
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
82
84
|
|
|
83
85
|
|
|
@@ -125,3 +127,26 @@ def element_mul_kernel(
|
|
|
125
127
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
126
128
|
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
127
129
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_npu_core_count(default: int = 20) -> int:
|
|
133
|
+
"""Return NPU vector core count.
|
|
134
|
+
Fallback to `default` if Triton runtime or NPU device is unavailable.
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
utils = triton.runtime.driver.active.utils
|
|
138
|
+
props = utils.get_device_properties(0)
|
|
139
|
+
return int(props.get("num_vectorcore", default))
|
|
140
|
+
except Exception:
|
|
141
|
+
return default
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def set_large_grf_mode(kernel_args: dict):
|
|
145
|
+
"""Set large GRF mode for XPU devices."""
|
|
146
|
+
# On XPU triton installed along with pytorch-xpu will be called `pytorch-triton-xpu`,
|
|
147
|
+
# triton XPU installed from source will be called `triton`.
|
|
148
|
+
if compare_version("pytorch-triton-xpu", operator.ge, "3.6.0") or compare_version("triton", operator.ge, "3.6.0"):
|
|
149
|
+
kernel_args["grf_mode"] = "256"
|
|
150
|
+
else:
|
|
151
|
+
# API was changed in https://github.com/intel/intel-xpu-backend-for-triton/pull/5430
|
|
152
|
+
kernel_args["grf_mode"] = "large"
|
|
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
|
|
|
33
33
|
from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
|
|
34
34
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
|
|
35
35
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
|
|
36
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_exaone4 # noqa: F401
|
|
36
37
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
|
|
37
38
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
|
|
38
39
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
|
|
@@ -41,6 +42,7 @@ if TYPE_CHECKING:
|
|
|
41
42
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
|
|
42
43
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
|
|
43
44
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
|
|
45
|
+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
|
|
44
46
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
|
|
45
47
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
|
|
46
48
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
|
|
@@ -110,6 +112,7 @@ def __getattr__(name: str):
|
|
|
110
112
|
"apply_liger_kernel_to_glm4",
|
|
111
113
|
"apply_liger_kernel_to_glm4v",
|
|
112
114
|
"apply_liger_kernel_to_glm4v_moe",
|
|
115
|
+
"apply_liger_kernel_to_gpt_oss",
|
|
113
116
|
"apply_liger_kernel_to_granite",
|
|
114
117
|
"apply_liger_kernel_to_internvl",
|
|
115
118
|
"apply_liger_kernel_to_llama",
|
|
@@ -134,6 +137,7 @@ def __getattr__(name: str):
|
|
|
134
137
|
"apply_liger_kernel_to_smolvlm",
|
|
135
138
|
"apply_liger_kernel_to_hunyuan_v1_dense",
|
|
136
139
|
"apply_liger_kernel_to_hunyuan_v1_moe",
|
|
140
|
+
"apply_liger_kernel_to_exaone4",
|
|
137
141
|
}
|
|
138
142
|
|
|
139
143
|
if name in monkey_patch_symbols:
|
|
@@ -187,6 +191,7 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
187
191
|
"apply_liger_kernel_to_glm4",
|
|
188
192
|
"apply_liger_kernel_to_glm4v",
|
|
189
193
|
"apply_liger_kernel_to_glm4v_moe",
|
|
194
|
+
"apply_liger_kernel_to_gpt_oss",
|
|
190
195
|
"apply_liger_kernel_to_granite",
|
|
191
196
|
"apply_liger_kernel_to_internvl",
|
|
192
197
|
"apply_liger_kernel_to_llama",
|
|
@@ -211,5 +216,6 @@ if _TRANSFORMERS_AVAILABLE:
|
|
|
211
216
|
"apply_liger_kernel_to_smolvlm",
|
|
212
217
|
"apply_liger_kernel_to_hunyuan_v1_dense",
|
|
213
218
|
"apply_liger_kernel_to_hunyuan_v1_moe",
|
|
219
|
+
"apply_liger_kernel_to_exaone4",
|
|
214
220
|
]
|
|
215
221
|
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import logging
|
|
2
3
|
|
|
3
4
|
from transformers import AutoConfig
|
|
4
5
|
from transformers import AutoModelForCausalLM
|
|
@@ -6,6 +7,8 @@ from transformers import AutoModelForCausalLM
|
|
|
6
7
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
7
8
|
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
|
|
8
9
|
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
def _get_model_config(model_dir, **model_init_kwargs):
|
|
11
14
|
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
|
|
@@ -36,3 +39,21 @@ class AutoLigerKernelForCausalLM(AutoModelForCausalLM):
|
|
|
36
39
|
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
|
37
40
|
|
|
38
41
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_config(cls, config, **kwargs):
|
|
45
|
+
model_type = getattr(config, "model_type", None)
|
|
46
|
+
if not model_type:
|
|
47
|
+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
|
|
48
|
+
return
|
|
49
|
+
model_type = config.model_type
|
|
50
|
+
|
|
51
|
+
_apply_liger_kernel(model_type, **kwargs)
|
|
52
|
+
|
|
53
|
+
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
|
|
54
|
+
# model initialization errors otherwise
|
|
55
|
+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
|
|
56
|
+
apply_fn_signature = inspect.signature(apply_fn)
|
|
57
|
+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
|
|
58
|
+
|
|
59
|
+
return super().from_config(config, **applicable_kwargs)
|
liger_kernel/transformers/dyt.py
CHANGED
|
@@ -3,26 +3,26 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from liger_kernel.ops
|
|
7
|
-
from liger_kernel.ops
|
|
8
|
-
from liger_kernel.ops
|
|
9
|
-
from liger_kernel.ops
|
|
10
|
-
from liger_kernel.ops
|
|
11
|
-
from liger_kernel.ops
|
|
12
|
-
from liger_kernel.ops
|
|
13
|
-
from liger_kernel.ops
|
|
14
|
-
from liger_kernel.ops
|
|
15
|
-
from liger_kernel.ops
|
|
16
|
-
from liger_kernel.ops
|
|
17
|
-
from liger_kernel.ops
|
|
18
|
-
from liger_kernel.ops
|
|
19
|
-
from liger_kernel.ops
|
|
20
|
-
from liger_kernel.ops
|
|
21
|
-
from liger_kernel.ops
|
|
22
|
-
from liger_kernel.ops
|
|
23
|
-
from liger_kernel.ops
|
|
24
|
-
from liger_kernel.ops
|
|
25
|
-
from liger_kernel.ops
|
|
6
|
+
from liger_kernel.ops import LigerCrossEntropyFunction
|
|
7
|
+
from liger_kernel.ops import LigerDyTFunction
|
|
8
|
+
from liger_kernel.ops import LigerFusedAddRMSNormFunction
|
|
9
|
+
from liger_kernel.ops import LigerFusedLinearCrossEntropyFunction
|
|
10
|
+
from liger_kernel.ops import LigerFusedLinearJSDFunction
|
|
11
|
+
from liger_kernel.ops import LigerFusedNeighborhoodAttentionFunction
|
|
12
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
13
|
+
from liger_kernel.ops import LigerGroupNormFunction
|
|
14
|
+
from liger_kernel.ops import LigerJSDFunction
|
|
15
|
+
from liger_kernel.ops import LigerKLDivLossFunction
|
|
16
|
+
from liger_kernel.ops import LigerLayerNormFunction
|
|
17
|
+
from liger_kernel.ops import LigerMultiTokenAttentionFunction
|
|
18
|
+
from liger_kernel.ops import LigerPolyNormFunction
|
|
19
|
+
from liger_kernel.ops import LigerQwen2VLMRopeFunction
|
|
20
|
+
from liger_kernel.ops import LigerRMSNormFunction
|
|
21
|
+
from liger_kernel.ops import LigerRopeFunction
|
|
22
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
23
|
+
from liger_kernel.ops import LigerSoftmaxFunction
|
|
24
|
+
from liger_kernel.ops import LigerSparsemaxFunction
|
|
25
|
+
from liger_kernel.ops import LigerTVDLossFunction
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@dataclass
|