liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +54 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
- liger_kernel/chunked_loss/grpo_loss.py +46 -9
- liger_kernel/chunked_loss/jsd_loss.py +44 -13
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +130 -64
- liger_kernel/ops/dyt.py +5 -4
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +135 -80
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +148 -71
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +65 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +56 -24
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +28 -8
- liger_kernel/transformers/model/gemma2.py +34 -11
- liger_kernel/transformers/model/gemma3.py +102 -112
- liger_kernel/transformers/model/glm4.py +18 -5
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +26 -7
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +18 -6
- liger_kernel/transformers/model/loss_utils.py +34 -3
- liger_kernel/transformers/model/mistral.py +17 -10
- liger_kernel/transformers/model/mixtral.py +24 -9
- liger_kernel/transformers/model/mllama.py +18 -7
- liger_kernel/transformers/model/olmo2.py +18 -5
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +42 -5
- liger_kernel/transformers/model/phi3.py +24 -159
- liger_kernel/transformers/model/qwen2.py +26 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
- liger_kernel/transformers/model/qwen2_vl.py +24 -7
- liger_kernel/transformers/model/qwen3.py +22 -6
- liger_kernel/transformers/model/qwen3_moe.py +27 -7
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +1423 -100
- liger_kernel/transformers/multi_token_attention.py +2 -2
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -5
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
23
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
24
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
25
|
+
from liger_kernel.utils import is_npu_available
|
|
24
26
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
27
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
28
|
try:
|
|
27
29
|
# typical import path with dispatch available
|
|
28
30
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -52,6 +54,7 @@ def _rms_norm_forward_kernel(
|
|
|
52
54
|
eps,
|
|
53
55
|
offset,
|
|
54
56
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
57
|
+
elementwise_affine: tl.constexpr,
|
|
55
58
|
BLOCK_SIZE: tl.constexpr,
|
|
56
59
|
):
|
|
57
60
|
"""
|
|
@@ -63,17 +66,18 @@ def _rms_norm_forward_kernel(
|
|
|
63
66
|
3. https://arxiv.org/pdf/1910.07467
|
|
64
67
|
"""
|
|
65
68
|
|
|
66
|
-
row_idx = tl.program_id(0)
|
|
69
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
67
70
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
68
71
|
mask = col_offsets < n_cols
|
|
69
72
|
|
|
70
|
-
Y_ptr
|
|
71
|
-
X_ptr
|
|
72
|
-
RSTD_ptr
|
|
73
|
+
y_base = Y_ptr + row_idx * Y_row_stride
|
|
74
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
75
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
73
76
|
|
|
74
|
-
X_row = tl.load(
|
|
77
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
|
|
75
78
|
X_row_dtype = X_row.dtype
|
|
76
|
-
|
|
79
|
+
if elementwise_affine:
|
|
80
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
77
81
|
|
|
78
82
|
# On Llama, only rstd is computed on fp32
|
|
79
83
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -81,7 +85,8 @@ def _rms_norm_forward_kernel(
|
|
|
81
85
|
|
|
82
86
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
87
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
-
|
|
88
|
+
if elementwise_affine:
|
|
89
|
+
W_row = W_row.to(tl.float32)
|
|
85
90
|
X_row = X_row.to(tl.float32)
|
|
86
91
|
|
|
87
92
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -94,7 +99,7 @@ def _rms_norm_forward_kernel(
|
|
|
94
99
|
# We can save time by caching rms with minimal memory overhead
|
|
95
100
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
96
101
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
97
|
-
tl.store(
|
|
102
|
+
tl.store(rstd_base, rstd)
|
|
98
103
|
|
|
99
104
|
X_row = X_row * rstd
|
|
100
105
|
|
|
@@ -102,12 +107,15 @@ def _rms_norm_forward_kernel(
|
|
|
102
107
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
103
108
|
X_row = X_row.to(X_row_dtype)
|
|
104
109
|
|
|
105
|
-
|
|
110
|
+
if elementwise_affine:
|
|
111
|
+
Y_row = X_row * (offset + W_row)
|
|
112
|
+
else:
|
|
113
|
+
Y_row = X_row
|
|
106
114
|
|
|
107
115
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
108
116
|
Y_row = Y_row.to(X_row_dtype)
|
|
109
117
|
|
|
110
|
-
tl.store(
|
|
118
|
+
tl.store(y_base + col_offsets, Y_row, mask=mask)
|
|
111
119
|
|
|
112
120
|
|
|
113
121
|
@triton.jit
|
|
@@ -128,8 +136,9 @@ def _rms_norm_backward_kernel(
|
|
|
128
136
|
n_rows,
|
|
129
137
|
n_cols,
|
|
130
138
|
offset,
|
|
131
|
-
rows_per_program
|
|
139
|
+
rows_per_program,
|
|
132
140
|
casting_mode: tl.constexpr,
|
|
141
|
+
elementwise_affine: tl.constexpr,
|
|
133
142
|
BLOCK_SIZE: tl.constexpr,
|
|
134
143
|
):
|
|
135
144
|
"""
|
|
@@ -137,61 +146,69 @@ def _rms_norm_backward_kernel(
|
|
|
137
146
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
138
147
|
"""
|
|
139
148
|
|
|
140
|
-
row_block_id = tl.program_id(0)
|
|
149
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
141
150
|
row_start = row_block_id * rows_per_program
|
|
142
151
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
143
152
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
144
153
|
mask = col_offsets < n_cols
|
|
145
154
|
|
|
146
|
-
|
|
155
|
+
if elementwise_affine:
|
|
156
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
157
|
|
|
148
|
-
|
|
149
|
-
|
|
158
|
+
if elementwise_affine:
|
|
159
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
160
|
+
W_row = W_row + offset
|
|
150
161
|
|
|
151
|
-
|
|
152
|
-
|
|
162
|
+
for row_idx in range(row_start, row_end):
|
|
163
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
164
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
153
165
|
|
|
154
|
-
|
|
155
|
-
|
|
166
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
167
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
156
168
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
169
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
170
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
160
171
|
|
|
161
172
|
# Get cached rms
|
|
162
|
-
rstd_row = tl.load(
|
|
173
|
+
rstd_row = tl.load(rstd_base)
|
|
163
174
|
|
|
164
175
|
X_row = X_row.to(tl.float32)
|
|
165
176
|
|
|
166
177
|
# Different bacward graphs for different casting modes
|
|
167
178
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
168
|
-
|
|
179
|
+
if elementwise_affine:
|
|
180
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
181
|
+
else:
|
|
182
|
+
m = dY_row.to(tl.float32)
|
|
169
183
|
|
|
170
184
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
171
185
|
dY_row = dY_row.to(tl.float32)
|
|
172
|
-
|
|
186
|
+
if elementwise_affine:
|
|
187
|
+
m = dY_row * W_row
|
|
188
|
+
else:
|
|
189
|
+
m = dY_row
|
|
173
190
|
else:
|
|
174
|
-
|
|
191
|
+
if elementwise_affine:
|
|
192
|
+
m = dY_row * W_row
|
|
193
|
+
else:
|
|
194
|
+
m = dY_row
|
|
175
195
|
|
|
176
196
|
dX_row = rstd_row * m
|
|
177
197
|
|
|
178
198
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
179
199
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
200
|
+
if elementwise_affine:
|
|
201
|
+
# calculate the gradient of W
|
|
202
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
203
|
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
|
204
|
+
else:
|
|
205
|
+
# here X_row is already in fp32 (see previous if block)
|
|
206
|
+
dW_row += dY_row * (X_row * rstd_row)
|
|
186
207
|
|
|
187
|
-
tl.store(
|
|
208
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
188
209
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
X_ptr += X_row_stride
|
|
192
|
-
RSTD_ptr += RSTD_row_stride
|
|
193
|
-
|
|
194
|
-
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
210
|
+
if elementwise_affine:
|
|
211
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
195
212
|
|
|
196
213
|
|
|
197
214
|
@triton.jit
|
|
@@ -209,6 +226,7 @@ def _block_rms_norm_forward_kernel(
|
|
|
209
226
|
eps,
|
|
210
227
|
offset,
|
|
211
228
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
229
|
+
elementwise_affine: tl.constexpr,
|
|
212
230
|
BLOCK_SIZE: tl.constexpr,
|
|
213
231
|
BLOCK_ROW: tl.constexpr,
|
|
214
232
|
):
|
|
@@ -232,7 +250,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
232
250
|
other=0,
|
|
233
251
|
)
|
|
234
252
|
X_row_dtype = X_row.dtype
|
|
235
|
-
|
|
253
|
+
if elementwise_affine:
|
|
254
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
236
255
|
|
|
237
256
|
# On Llama, only rstd is computed on fp32
|
|
238
257
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -240,7 +259,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
240
259
|
|
|
241
260
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
242
261
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
243
|
-
|
|
262
|
+
if elementwise_affine:
|
|
263
|
+
W_row = W_row.to(tl.float32)
|
|
244
264
|
X_row = X_row.to(tl.float32)
|
|
245
265
|
|
|
246
266
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -261,7 +281,10 @@ def _block_rms_norm_forward_kernel(
|
|
|
261
281
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
262
282
|
X_row = X_row.to(X_row_dtype)
|
|
263
283
|
|
|
264
|
-
|
|
284
|
+
if elementwise_affine:
|
|
285
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
286
|
+
else:
|
|
287
|
+
Y_row = X_row
|
|
265
288
|
|
|
266
289
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
267
290
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -291,8 +314,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
291
314
|
n_rows,
|
|
292
315
|
n_cols,
|
|
293
316
|
offset,
|
|
294
|
-
rows_per_program: tl.constexpr,
|
|
295
317
|
casting_mode: tl.constexpr,
|
|
318
|
+
elementwise_affine: tl.constexpr,
|
|
296
319
|
BLOCK_SIZE: tl.constexpr,
|
|
297
320
|
BLOCK_ROW: tl.constexpr,
|
|
298
321
|
):
|
|
@@ -307,10 +330,11 @@ def _block_rms_norm_backward_kernel(
|
|
|
307
330
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
308
331
|
col_mask = col_offsets < n_cols
|
|
309
332
|
|
|
310
|
-
|
|
333
|
+
if elementwise_affine:
|
|
334
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
311
335
|
|
|
312
|
-
|
|
313
|
-
|
|
336
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
337
|
+
W_row = W_row + offset
|
|
314
338
|
|
|
315
339
|
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
316
340
|
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
@@ -333,13 +357,22 @@ def _block_rms_norm_backward_kernel(
|
|
|
333
357
|
|
|
334
358
|
# Different bacward graphs for different casting modes
|
|
335
359
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
336
|
-
|
|
360
|
+
if elementwise_affine:
|
|
361
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
362
|
+
else:
|
|
363
|
+
m = dY_row.to(tl.float32)
|
|
337
364
|
|
|
338
365
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
339
366
|
dY_row = dY_row.to(tl.float32)
|
|
340
|
-
|
|
367
|
+
if elementwise_affine:
|
|
368
|
+
m = dY_row * W_row[None, :]
|
|
369
|
+
else:
|
|
370
|
+
m = dY_row
|
|
341
371
|
else:
|
|
342
|
-
|
|
372
|
+
if elementwise_affine:
|
|
373
|
+
m = dY_row * W_row[None, :]
|
|
374
|
+
else:
|
|
375
|
+
m = dY_row
|
|
343
376
|
|
|
344
377
|
dX_row = rstd_row[:, None] * m
|
|
345
378
|
|
|
@@ -347,12 +380,13 @@ def _block_rms_norm_backward_kernel(
|
|
|
347
380
|
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
348
381
|
)
|
|
349
382
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
383
|
+
if elementwise_affine:
|
|
384
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
385
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
386
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
387
|
+
else:
|
|
388
|
+
# here X_row is already in fp32 (see previous if block)
|
|
389
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
356
390
|
|
|
357
391
|
tl.store(
|
|
358
392
|
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
@@ -360,7 +394,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
360
394
|
mask=row_mask[:, None] & col_mask[None, :],
|
|
361
395
|
)
|
|
362
396
|
|
|
363
|
-
|
|
397
|
+
if elementwise_affine:
|
|
398
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
364
399
|
|
|
365
400
|
|
|
366
401
|
_str_to_casting_mode = {
|
|
@@ -389,8 +424,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
389
424
|
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
390
425
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
391
426
|
|
|
392
|
-
|
|
393
|
-
|
|
427
|
+
if W is not None:
|
|
428
|
+
# Check constraints.
|
|
429
|
+
assert X.shape[1] == W.shape[0], (
|
|
430
|
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
431
|
+
)
|
|
432
|
+
elementwise_affine = True
|
|
433
|
+
else:
|
|
434
|
+
elementwise_affine = False
|
|
394
435
|
|
|
395
436
|
# XPU-specific optimization
|
|
396
437
|
kernel_args = {}
|
|
@@ -403,13 +444,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
403
444
|
X,
|
|
404
445
|
X.stride(0),
|
|
405
446
|
W,
|
|
406
|
-
W.stride(0),
|
|
447
|
+
W.stride(0) if elementwise_affine else 0,
|
|
407
448
|
RSTD,
|
|
408
449
|
RSTD.stride(0),
|
|
409
450
|
n_cols,
|
|
410
451
|
eps,
|
|
411
452
|
offset,
|
|
412
453
|
casting_mode,
|
|
454
|
+
elementwise_affine=elementwise_affine,
|
|
413
455
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
414
456
|
num_warps=num_warps,
|
|
415
457
|
**kernel_args, # XPU-specific optimization
|
|
@@ -423,7 +465,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
423
465
|
X,
|
|
424
466
|
X.stride(0),
|
|
425
467
|
W,
|
|
426
|
-
W.stride(0),
|
|
468
|
+
W.stride(0) if elementwise_affine else 0,
|
|
427
469
|
RSTD,
|
|
428
470
|
RSTD.stride(0),
|
|
429
471
|
n_rows,
|
|
@@ -431,6 +473,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
431
473
|
eps,
|
|
432
474
|
offset,
|
|
433
475
|
casting_mode,
|
|
476
|
+
elementwise_affine=elementwise_affine,
|
|
434
477
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
435
478
|
num_warps=num_warps,
|
|
436
479
|
**kernel_args, # XPU-specific optimization
|
|
@@ -449,9 +492,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
449
492
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
450
493
|
elif X.device.type == "xpu":
|
|
451
494
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
495
|
+
elif X.device.type == "npu":
|
|
496
|
+
sm_count = get_npu_multi_processor_count()
|
|
452
497
|
|
|
453
|
-
|
|
454
|
-
|
|
498
|
+
if W is not None:
|
|
499
|
+
# fp32 for numerical stability especially.
|
|
500
|
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
501
|
+
elementwise_affine = True
|
|
502
|
+
else:
|
|
503
|
+
_dW = None
|
|
504
|
+
elementwise_affine = False
|
|
455
505
|
|
|
456
506
|
if n_cols > BLOCK_SIZE:
|
|
457
507
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -478,16 +528,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
478
528
|
X.stride(0),
|
|
479
529
|
torch_to_triton_dtype[X.dtype],
|
|
480
530
|
W,
|
|
481
|
-
W.stride(0),
|
|
531
|
+
W.stride(0) if elementwise_affine else 0,
|
|
482
532
|
RSTD,
|
|
483
533
|
RSTD.stride(0),
|
|
484
534
|
_dW,
|
|
485
|
-
_dW.stride(0),
|
|
535
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
486
536
|
n_rows,
|
|
487
537
|
n_cols,
|
|
488
538
|
offset,
|
|
489
539
|
rows_per_program,
|
|
490
540
|
casting_mode,
|
|
541
|
+
elementwise_affine=elementwise_affine,
|
|
491
542
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
492
543
|
num_warps=num_warps,
|
|
493
544
|
**kernel_args, # XPU-specific optimization
|
|
@@ -504,22 +555,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
504
555
|
X.stride(0),
|
|
505
556
|
torch_to_triton_dtype[X.dtype],
|
|
506
557
|
W,
|
|
507
|
-
W.stride(0),
|
|
558
|
+
W.stride(0) if elementwise_affine else 0,
|
|
508
559
|
RSTD,
|
|
509
560
|
RSTD.stride(0),
|
|
510
561
|
_dW,
|
|
511
|
-
_dW.stride(0),
|
|
562
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
512
563
|
n_rows,
|
|
513
564
|
n_cols,
|
|
514
565
|
offset,
|
|
515
|
-
rows_per_program,
|
|
516
566
|
casting_mode,
|
|
567
|
+
elementwise_affine=elementwise_affine,
|
|
517
568
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
518
569
|
num_warps=num_warps,
|
|
519
570
|
**kernel_args, # XPU-specific optimization
|
|
520
571
|
)
|
|
521
572
|
dX = dX.view(*shape)
|
|
522
|
-
|
|
573
|
+
|
|
574
|
+
if elementwise_affine:
|
|
575
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
576
|
+
else:
|
|
577
|
+
dW = None
|
|
523
578
|
|
|
524
579
|
return dX, dW
|
|
525
580
|
|
|
@@ -553,6 +608,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
553
608
|
X: (B, T, H) or (BxT, H)
|
|
554
609
|
W: (H,)
|
|
555
610
|
"""
|
|
611
|
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
|
612
|
+
# Input tensor is output of a tensor parallel module and
|
|
613
|
+
# needs to be gathered to a local tensor to compute
|
|
614
|
+
# RMSE layer norm on each TP worker.
|
|
615
|
+
# TODO: support CP.
|
|
616
|
+
X = X.full_tensor()
|
|
617
|
+
|
|
556
618
|
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
557
619
|
ctx.offset = offset
|
|
558
620
|
ctx.casting_mode = casting_mode
|
|
@@ -560,7 +622,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
560
622
|
ctx.row_mode = row_mode
|
|
561
623
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
562
624
|
ctx.num_warps = num_warps
|
|
563
|
-
ctx.
|
|
625
|
+
ctx.elementwise_affine = W is not None
|
|
626
|
+
if W is not None:
|
|
627
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
628
|
+
else:
|
|
629
|
+
ctx.save_for_backward(X, RSTD)
|
|
564
630
|
return Y
|
|
565
631
|
|
|
566
632
|
@staticmethod
|
|
@@ -569,7 +635,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
569
635
|
"""
|
|
570
636
|
Y: (B, T, H) or (BxT, H)
|
|
571
637
|
"""
|
|
572
|
-
|
|
638
|
+
if ctx.elementwise_affine:
|
|
639
|
+
X, W, RSTD = ctx.saved_tensors
|
|
640
|
+
else:
|
|
641
|
+
X, RSTD = ctx.saved_tensors
|
|
642
|
+
W = None
|
|
643
|
+
|
|
644
|
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
|
645
|
+
# Gradients are output of a tensor parallel module and
|
|
646
|
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
|
647
|
+
# TODO: support CP.
|
|
648
|
+
dY = dY.full_tensor()
|
|
649
|
+
|
|
573
650
|
dX, dW = rms_norm_backward(
|
|
574
651
|
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
575
652
|
)
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -32,7 +32,7 @@ def _triton_rope(
|
|
|
32
32
|
|
|
33
33
|
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
34
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
35
|
-
pid = tl.program_id(0)
|
|
35
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
36
36
|
|
|
37
37
|
# locate start address
|
|
38
38
|
q_ptr = q_ptr + pid * q_row_stride
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
|
|
|
26
26
|
# sigmoid requires type float32
|
|
27
27
|
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
28
28
|
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
29
|
-
c_row = silu(a_row) * b_row
|
|
29
|
+
c_row = silu(a_row).cast(b_row.dtype) * b_row
|
|
30
30
|
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
|
31
31
|
|
|
32
32
|
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from typing import List
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LigerTiledMLPFunction(torch.autograd.Function):
|
|
13
|
+
"""
|
|
14
|
+
Based on DeepSpeed's TiledMLP:
|
|
15
|
+
https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
|
|
16
|
+
|
|
17
|
+
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
|
|
18
|
+
when using very long sequence lengths.
|
|
19
|
+
|
|
20
|
+
This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
|
|
21
|
+
And if you're using activation checkpointing it then occurs thrice.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
fn: the function to call on sharded inputs (e.g., mlp.forward)
|
|
25
|
+
mlp_module: the MLP nn.Module object
|
|
26
|
+
x: the input to MLP.forward (hidden_states)
|
|
27
|
+
shards: how many shards to use
|
|
28
|
+
compute_params: a list of weights engaged in the compute
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
the computed hidden_states
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@ensure_contiguous
|
|
36
|
+
def forward(
|
|
37
|
+
ctx,
|
|
38
|
+
fn: Callable,
|
|
39
|
+
mlp_module: torch.nn.Module,
|
|
40
|
+
x: torch.Tensor,
|
|
41
|
+
shards: int,
|
|
42
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
ctx.fn = fn
|
|
45
|
+
ctx.mlp_module = mlp_module
|
|
46
|
+
ctx.shards = shards
|
|
47
|
+
ctx.save_for_backward(x)
|
|
48
|
+
|
|
49
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
50
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
|
|
53
|
+
output_unsharded = torch.cat(output_shards, dim=-2)
|
|
54
|
+
|
|
55
|
+
return output_unsharded
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
@ensure_contiguous
|
|
59
|
+
def backward(ctx, *grads) -> tuple:
|
|
60
|
+
fn = ctx.fn
|
|
61
|
+
(x,) = ctx.saved_tensors
|
|
62
|
+
mlp_module = ctx.mlp_module
|
|
63
|
+
shards = ctx.shards
|
|
64
|
+
|
|
65
|
+
x_requires_grad = x.requires_grad
|
|
66
|
+
x = x.detach()
|
|
67
|
+
# detach() unsets x.requires_grad, so restore it
|
|
68
|
+
x.requires_grad_(x_requires_grad)
|
|
69
|
+
|
|
70
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
71
|
+
hidden_size = x.shape[-1]
|
|
72
|
+
x_shape_orig = x.shape
|
|
73
|
+
|
|
74
|
+
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
|
|
75
|
+
x = x.view(-1, hidden_size)
|
|
76
|
+
incoming_grad = grads[0].view(-1, hidden_size)
|
|
77
|
+
x_grad = torch.zeros_like(x)
|
|
78
|
+
|
|
79
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
|
|
80
|
+
|
|
81
|
+
for i, x_shard in enumerate(x_shards):
|
|
82
|
+
x_shard.requires_grad_(x_requires_grad)
|
|
83
|
+
|
|
84
|
+
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
|
|
85
|
+
shard_step = x_shards[i].shape[0]
|
|
86
|
+
shard_offset = i * x_shards[0].shape[0]
|
|
87
|
+
|
|
88
|
+
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
89
|
+
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
90
|
+
|
|
91
|
+
with torch.enable_grad():
|
|
92
|
+
output = fn(mlp_module, x_shard)
|
|
93
|
+
torch.autograd.backward(output, incoming_grad_shard)
|
|
94
|
+
|
|
95
|
+
# unflatten
|
|
96
|
+
x_grad = x_grad.view(x_shape_orig)
|
|
97
|
+
|
|
98
|
+
return (None, None, x_grad, None, None)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def apply_tiled_mlp(
|
|
102
|
+
fn: Callable,
|
|
103
|
+
mlp_module: torch.nn.Module,
|
|
104
|
+
x: torch.Tensor,
|
|
105
|
+
num_shards: Optional[int] = None,
|
|
106
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
Apply tiled MLP computation for memory efficiency.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
|
|
113
|
+
mlp_module: the MLP nn.Module object
|
|
114
|
+
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
115
|
+
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
|
|
116
|
+
compute_params: list of parameters for DeepSpeed ZeRO optimization
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
output tensor with the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
if num_shards is None:
|
|
122
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
123
|
+
hidden_size = x.shape[-1]
|
|
124
|
+
seqlen = x.shape[-2]
|
|
125
|
+
num_shards = math.ceil(seqlen / hidden_size)
|
|
126
|
+
|
|
127
|
+
# Ensure num_shards is at least 1
|
|
128
|
+
num_shards = max(1, num_shards)
|
|
129
|
+
|
|
130
|
+
return LigerTiledMLPFunction.apply(
|
|
131
|
+
fn,
|
|
132
|
+
mlp_module,
|
|
133
|
+
x,
|
|
134
|
+
num_shards,
|
|
135
|
+
compute_params,
|
|
136
|
+
)
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
|
|
|
78
78
|
functools.partial(torch.amp.custom_fwd, device_type=device),
|
|
79
79
|
functools.partial(torch.amp.custom_bwd, device_type=device),
|
|
80
80
|
)
|
|
81
|
+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
|
|
82
|
+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
|
|
81
83
|
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
82
84
|
|
|
83
85
|
|
|
@@ -125,3 +127,15 @@ def element_mul_kernel(
|
|
|
125
127
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
126
128
|
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
127
129
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_npu_core_count(default: int = 20) -> int:
|
|
133
|
+
"""Return NPU vector core count.
|
|
134
|
+
Fallback to `default` if Triton runtime or NPU device is unavailable.
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
utils = triton.runtime.driver.active.utils
|
|
138
|
+
props = utils.get_device_properties(0)
|
|
139
|
+
return int(props.get("num_vectorcore", default))
|
|
140
|
+
except Exception:
|
|
141
|
+
return default
|