liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -17,12 +17,10 @@ import torch
|
|
|
17
17
|
import triton
|
|
18
18
|
import triton.language as tl
|
|
19
19
|
|
|
20
|
-
from liger_kernel.ops.utils import
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
torch_to_triton_dtype,
|
|
25
|
-
)
|
|
20
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
21
|
+
from liger_kernel.ops.utils import compare_version
|
|
22
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
|
+
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
26
24
|
|
|
27
25
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
28
26
|
try:
|
|
@@ -35,9 +33,9 @@ else:
|
|
|
35
33
|
from triton.language.math import rsqrt
|
|
36
34
|
|
|
37
35
|
|
|
38
|
-
_CASTING_MODE_NONE = tl.constexpr(-1)
|
|
39
|
-
_CASTING_MODE_LLAMA = tl.constexpr(0)
|
|
40
|
-
_CASTING_MODE_GEMMA = tl.constexpr(1)
|
|
36
|
+
_CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
|
|
37
|
+
_CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
|
|
38
|
+
_CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
|
|
41
39
|
|
|
42
40
|
|
|
43
41
|
@triton.jit
|
|
@@ -65,7 +63,7 @@ def _rms_norm_forward_kernel(
|
|
|
65
63
|
3. https://arxiv.org/pdf/1910.07467
|
|
66
64
|
"""
|
|
67
65
|
|
|
68
|
-
row_idx = tl.program_id(0)
|
|
66
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
69
67
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
70
68
|
mask = col_offsets < n_cols
|
|
71
69
|
|
|
@@ -116,6 +114,8 @@ def _rms_norm_forward_kernel(
|
|
|
116
114
|
def _rms_norm_backward_kernel(
|
|
117
115
|
dY_ptr,
|
|
118
116
|
dY_row_stride,
|
|
117
|
+
dX_ptr,
|
|
118
|
+
dX_row_stride,
|
|
119
119
|
X_ptr,
|
|
120
120
|
X_row_stride,
|
|
121
121
|
X_dtype: tl.constexpr,
|
|
@@ -137,7 +137,7 @@ def _rms_norm_backward_kernel(
|
|
|
137
137
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
138
138
|
"""
|
|
139
139
|
|
|
140
|
-
row_block_id = tl.program_id(0)
|
|
140
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
141
141
|
row_start = row_block_id * rows_per_program
|
|
142
142
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
143
143
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
@@ -146,6 +146,8 @@ def _rms_norm_backward_kernel(
|
|
|
146
146
|
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
147
|
|
|
148
148
|
dY_ptr += row_start * dY_row_stride
|
|
149
|
+
dX_ptr += row_start * dX_row_stride
|
|
150
|
+
|
|
149
151
|
X_ptr += row_start * X_row_stride
|
|
150
152
|
RSTD_ptr += row_start
|
|
151
153
|
|
|
@@ -173,9 +175,7 @@ def _rms_norm_backward_kernel(
|
|
|
173
175
|
|
|
174
176
|
dX_row = rstd_row * m
|
|
175
177
|
|
|
176
|
-
dX_row += (rstd_row) * (
|
|
177
|
-
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
|
|
178
|
-
)
|
|
178
|
+
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
179
179
|
|
|
180
180
|
# calculate the gradient of W
|
|
181
181
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -184,15 +184,185 @@ def _rms_norm_backward_kernel(
|
|
|
184
184
|
# here X_row is already in fp32 (see previous if block)
|
|
185
185
|
dW_row += dY_row * (X_row * rstd_row)
|
|
186
186
|
|
|
187
|
-
tl.store(
|
|
187
|
+
tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
188
188
|
|
|
189
189
|
dY_ptr += dY_row_stride
|
|
190
|
+
dX_ptr += dX_row_stride
|
|
190
191
|
X_ptr += X_row_stride
|
|
191
192
|
RSTD_ptr += RSTD_row_stride
|
|
192
193
|
|
|
193
194
|
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
194
195
|
|
|
195
196
|
|
|
197
|
+
@triton.jit
|
|
198
|
+
def _block_rms_norm_forward_kernel(
|
|
199
|
+
Y_ptr,
|
|
200
|
+
Y_row_stride,
|
|
201
|
+
X_ptr,
|
|
202
|
+
X_row_stride,
|
|
203
|
+
W_ptr,
|
|
204
|
+
W_row_stride,
|
|
205
|
+
RSTD_ptr,
|
|
206
|
+
RSTD_row_stride,
|
|
207
|
+
n_rows,
|
|
208
|
+
n_cols,
|
|
209
|
+
eps,
|
|
210
|
+
offset,
|
|
211
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
212
|
+
BLOCK_SIZE: tl.constexpr,
|
|
213
|
+
BLOCK_ROW: tl.constexpr,
|
|
214
|
+
):
|
|
215
|
+
"""
|
|
216
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
217
|
+
|
|
218
|
+
Reference:
|
|
219
|
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
220
|
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
221
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
|
|
225
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
226
|
+
row_mask = row_idx < n_rows
|
|
227
|
+
col_mask = col_offsets < n_cols
|
|
228
|
+
|
|
229
|
+
X_row = tl.load(
|
|
230
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
231
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
232
|
+
other=0,
|
|
233
|
+
)
|
|
234
|
+
X_row_dtype = X_row.dtype
|
|
235
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
236
|
+
|
|
237
|
+
# On Llama, only rstd is computed on fp32
|
|
238
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
239
|
+
X_row = X_row.to(tl.float32)
|
|
240
|
+
|
|
241
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
242
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
243
|
+
W_row = W_row.to(tl.float32)
|
|
244
|
+
X_row = X_row.to(tl.float32)
|
|
245
|
+
|
|
246
|
+
if casting_mode == _CASTING_MODE_NONE:
|
|
247
|
+
eps = eps.to(X_row_dtype)
|
|
248
|
+
offset = offset.to(X_row_dtype)
|
|
249
|
+
|
|
250
|
+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
|
|
251
|
+
rstd = rsqrt(mean_square + eps)
|
|
252
|
+
|
|
253
|
+
# We can save time by caching rms with minimal memory overhead
|
|
254
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
255
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
256
|
+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
|
|
257
|
+
|
|
258
|
+
X_row = X_row * rstd[:, None]
|
|
259
|
+
|
|
260
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
261
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
262
|
+
X_row = X_row.to(X_row_dtype)
|
|
263
|
+
|
|
264
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
265
|
+
|
|
266
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
267
|
+
Y_row = Y_row.to(X_row_dtype)
|
|
268
|
+
|
|
269
|
+
tl.store(
|
|
270
|
+
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
|
|
271
|
+
Y_row,
|
|
272
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@triton.jit
|
|
277
|
+
def _block_rms_norm_backward_kernel(
|
|
278
|
+
dY_ptr,
|
|
279
|
+
dY_row_stride,
|
|
280
|
+
dX_ptr,
|
|
281
|
+
dX_row_stride,
|
|
282
|
+
X_ptr,
|
|
283
|
+
X_row_stride,
|
|
284
|
+
X_dtype: tl.constexpr,
|
|
285
|
+
W_ptr,
|
|
286
|
+
W_row_stride,
|
|
287
|
+
RSTD_ptr,
|
|
288
|
+
RSTD_row_stride,
|
|
289
|
+
dW_ptr,
|
|
290
|
+
dW_row_stride,
|
|
291
|
+
n_rows,
|
|
292
|
+
n_cols,
|
|
293
|
+
offset,
|
|
294
|
+
rows_per_program: tl.constexpr,
|
|
295
|
+
casting_mode: tl.constexpr,
|
|
296
|
+
BLOCK_SIZE: tl.constexpr,
|
|
297
|
+
BLOCK_ROW: tl.constexpr,
|
|
298
|
+
):
|
|
299
|
+
"""
|
|
300
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
301
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
pid = tl.program_id(0).cast(tl.int64)
|
|
305
|
+
NUM_SMS = tl.num_programs(0)
|
|
306
|
+
|
|
307
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
308
|
+
col_mask = col_offsets < n_cols
|
|
309
|
+
|
|
310
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
311
|
+
|
|
312
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
313
|
+
W_row = W_row + offset
|
|
314
|
+
|
|
315
|
+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
316
|
+
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
317
|
+
row_mask = row_idx < n_rows
|
|
318
|
+
dY_row = tl.load(
|
|
319
|
+
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
|
|
320
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
321
|
+
other=0.0,
|
|
322
|
+
)
|
|
323
|
+
X_row = tl.load(
|
|
324
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
325
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
326
|
+
other=0.0,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Get cached rms
|
|
330
|
+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
|
|
331
|
+
|
|
332
|
+
X_row = X_row.to(tl.float32)
|
|
333
|
+
|
|
334
|
+
# Different bacward graphs for different casting modes
|
|
335
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
336
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
337
|
+
|
|
338
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
339
|
+
dY_row = dY_row.to(tl.float32)
|
|
340
|
+
m = dY_row * W_row[None, :]
|
|
341
|
+
else:
|
|
342
|
+
m = dY_row * W_row[None, :]
|
|
343
|
+
|
|
344
|
+
dX_row = rstd_row[:, None] * m
|
|
345
|
+
|
|
346
|
+
dX_row += (rstd_row[:, None]) * (
|
|
347
|
+
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
# calculate the gradient of W
|
|
351
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
352
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
|
|
353
|
+
else:
|
|
354
|
+
# here X_row is already in fp32 (see previous if block)
|
|
355
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
356
|
+
|
|
357
|
+
tl.store(
|
|
358
|
+
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
359
|
+
dX_row,
|
|
360
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
364
|
+
|
|
365
|
+
|
|
196
366
|
_str_to_casting_mode = {
|
|
197
367
|
"llama": _CASTING_MODE_LLAMA.value,
|
|
198
368
|
"gemma": _CASTING_MODE_GEMMA.value,
|
|
@@ -200,16 +370,12 @@ _str_to_casting_mode = {
|
|
|
200
370
|
}
|
|
201
371
|
|
|
202
372
|
|
|
203
|
-
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
373
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
204
374
|
if not isinstance(casting_mode, int):
|
|
205
|
-
assert
|
|
206
|
-
casting_mode in _str_to_casting_mode
|
|
207
|
-
), f"Invalid casting mode: {casting_mode}"
|
|
375
|
+
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
|
208
376
|
casting_mode = _str_to_casting_mode[casting_mode]
|
|
209
377
|
else:
|
|
210
|
-
assert (
|
|
211
|
-
casting_mode in _str_to_casting_mode.values()
|
|
212
|
-
), f"Invalid casting mode: {casting_mode}"
|
|
378
|
+
assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
|
|
213
379
|
|
|
214
380
|
shape = X.shape
|
|
215
381
|
dim = shape[-1]
|
|
@@ -220,44 +386,70 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
220
386
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
221
387
|
# RSTD is to cache rstd for each row
|
|
222
388
|
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
|
|
223
|
-
rstd_dtype = (
|
|
224
|
-
torch.float32
|
|
225
|
-
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
|
|
226
|
-
else X.dtype
|
|
227
|
-
)
|
|
389
|
+
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
228
390
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
229
391
|
|
|
230
392
|
# Check constraints.
|
|
231
|
-
assert
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
393
|
+
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
394
|
+
|
|
395
|
+
# XPU-specific optimization
|
|
396
|
+
kernel_args = {}
|
|
397
|
+
if X.device.type == "xpu":
|
|
398
|
+
kernel_args["grf_mode"] = "large"
|
|
399
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
400
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
|
401
|
+
Y,
|
|
402
|
+
Y.stride(0),
|
|
403
|
+
X,
|
|
404
|
+
X.stride(0),
|
|
405
|
+
W,
|
|
406
|
+
W.stride(0),
|
|
407
|
+
RSTD,
|
|
408
|
+
RSTD.stride(0),
|
|
409
|
+
n_cols,
|
|
410
|
+
eps,
|
|
411
|
+
offset,
|
|
412
|
+
casting_mode,
|
|
413
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
414
|
+
num_warps=num_warps,
|
|
415
|
+
**kernel_args, # XPU-specific optimization
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
BLOCK_ROW = 16
|
|
419
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
420
|
+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
|
|
421
|
+
Y,
|
|
422
|
+
Y.stride(0),
|
|
423
|
+
X,
|
|
424
|
+
X.stride(0),
|
|
425
|
+
W,
|
|
426
|
+
W.stride(0),
|
|
427
|
+
RSTD,
|
|
428
|
+
RSTD.stride(0),
|
|
429
|
+
n_rows,
|
|
430
|
+
n_cols,
|
|
431
|
+
eps,
|
|
432
|
+
offset,
|
|
433
|
+
casting_mode,
|
|
434
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
435
|
+
num_warps=num_warps,
|
|
436
|
+
**kernel_args, # XPU-specific optimization
|
|
437
|
+
)
|
|
251
438
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
252
439
|
|
|
253
440
|
|
|
254
|
-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
|
|
441
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
|
|
255
442
|
shape = dY.shape
|
|
256
443
|
dim = shape[-1]
|
|
257
444
|
dY = dY.view(-1, dim)
|
|
258
445
|
n_rows, n_cols = dY.shape
|
|
259
446
|
|
|
260
|
-
sm_count =
|
|
447
|
+
sm_count = 1
|
|
448
|
+
if X.device.type == "cuda":
|
|
449
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
450
|
+
elif X.device.type == "xpu":
|
|
451
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
452
|
+
|
|
261
453
|
# fp32 for numerical stability especially.
|
|
262
454
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
263
455
|
|
|
@@ -265,29 +457,70 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
265
457
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
266
458
|
rows_per_program = math.ceil(n_rows / sm_count)
|
|
267
459
|
grid = (sm_count,)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
dY
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
460
|
+
|
|
461
|
+
if in_place is True:
|
|
462
|
+
dX = dY
|
|
463
|
+
else:
|
|
464
|
+
dX = torch.zeros_like(dY)
|
|
465
|
+
|
|
466
|
+
# XPU-specific optimization
|
|
467
|
+
kernel_args = {}
|
|
468
|
+
if X.device.type == "xpu":
|
|
469
|
+
kernel_args["grf_mode"] = "large"
|
|
470
|
+
|
|
471
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
472
|
+
_rms_norm_backward_kernel[grid](
|
|
473
|
+
dY,
|
|
474
|
+
dY.stride(0),
|
|
475
|
+
dX,
|
|
476
|
+
dX.stride(0),
|
|
477
|
+
X,
|
|
478
|
+
X.stride(0),
|
|
479
|
+
torch_to_triton_dtype[X.dtype],
|
|
480
|
+
W,
|
|
481
|
+
W.stride(0),
|
|
482
|
+
RSTD,
|
|
483
|
+
RSTD.stride(0),
|
|
484
|
+
_dW,
|
|
485
|
+
_dW.stride(0),
|
|
486
|
+
n_rows,
|
|
487
|
+
n_cols,
|
|
488
|
+
offset,
|
|
489
|
+
rows_per_program,
|
|
490
|
+
casting_mode,
|
|
491
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
492
|
+
num_warps=num_warps,
|
|
493
|
+
**kernel_args, # XPU-specific optimization
|
|
494
|
+
)
|
|
495
|
+
else:
|
|
496
|
+
BLOCK_ROW = 16
|
|
497
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
498
|
+
_block_rms_norm_backward_kernel[grid](
|
|
499
|
+
dY,
|
|
500
|
+
dY.stride(0),
|
|
501
|
+
dX,
|
|
502
|
+
dX.stride(0),
|
|
503
|
+
X,
|
|
504
|
+
X.stride(0),
|
|
505
|
+
torch_to_triton_dtype[X.dtype],
|
|
506
|
+
W,
|
|
507
|
+
W.stride(0),
|
|
508
|
+
RSTD,
|
|
509
|
+
RSTD.stride(0),
|
|
510
|
+
_dW,
|
|
511
|
+
_dW.stride(0),
|
|
512
|
+
n_rows,
|
|
513
|
+
n_cols,
|
|
514
|
+
offset,
|
|
515
|
+
rows_per_program,
|
|
516
|
+
casting_mode,
|
|
517
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
518
|
+
num_warps=num_warps,
|
|
519
|
+
**kernel_args, # XPU-specific optimization
|
|
520
|
+
)
|
|
521
|
+
dX = dX.view(*shape)
|
|
290
522
|
dW = _dW.sum(dim=0).to(W.dtype)
|
|
523
|
+
|
|
291
524
|
return dX, dW
|
|
292
525
|
|
|
293
526
|
|
|
@@ -307,20 +540,24 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
307
540
|
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
|
|
308
541
|
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
|
|
309
542
|
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
|
|
543
|
+
|
|
544
|
+
`in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
|
|
545
|
+
For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
|
|
546
|
+
Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
|
|
310
547
|
"""
|
|
311
548
|
|
|
312
549
|
@staticmethod
|
|
313
550
|
@ensure_contiguous
|
|
314
|
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
|
|
551
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
|
|
315
552
|
"""
|
|
316
553
|
X: (B, T, H) or (BxT, H)
|
|
317
554
|
W: (H,)
|
|
318
555
|
"""
|
|
319
|
-
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
|
|
320
|
-
X, W, eps, offset, casting_mode
|
|
321
|
-
)
|
|
556
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
322
557
|
ctx.offset = offset
|
|
323
558
|
ctx.casting_mode = casting_mode
|
|
559
|
+
ctx.in_place = in_place
|
|
560
|
+
ctx.row_mode = row_mode
|
|
324
561
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
325
562
|
ctx.num_warps = num_warps
|
|
326
563
|
ctx.save_for_backward(X, W, RSTD)
|
|
@@ -334,13 +571,6 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
334
571
|
"""
|
|
335
572
|
X, W, RSTD = ctx.saved_tensors
|
|
336
573
|
dX, dW = rms_norm_backward(
|
|
337
|
-
dY,
|
|
338
|
-
X,
|
|
339
|
-
W,
|
|
340
|
-
RSTD,
|
|
341
|
-
ctx.offset,
|
|
342
|
-
ctx.casting_mode,
|
|
343
|
-
ctx.BLOCK_SIZE,
|
|
344
|
-
ctx.num_warps,
|
|
574
|
+
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
345
575
|
)
|
|
346
|
-
return dX, dW, None, None, None
|
|
576
|
+
return dX, dW, None, None, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -15,6 +15,7 @@ def _triton_rope(
|
|
|
15
15
|
sin_row_stride,
|
|
16
16
|
sl,
|
|
17
17
|
bs: tl.constexpr,
|
|
18
|
+
cos_bs: tl.constexpr,
|
|
18
19
|
n_qh: tl.constexpr,
|
|
19
20
|
n_kh: tl.constexpr,
|
|
20
21
|
hd: tl.constexpr,
|
|
@@ -29,9 +30,9 @@ def _triton_rope(
|
|
|
29
30
|
# k size: (bsz, seq_len, num_kv_heads, head_dim)
|
|
30
31
|
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
|
|
31
32
|
|
|
32
|
-
# cos size: (1, seq_len, head_dim)
|
|
33
|
+
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
33
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
34
|
-
pid = tl.program_id(0)
|
|
35
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
35
36
|
|
|
36
37
|
# locate start address
|
|
37
38
|
q_ptr = q_ptr + pid * q_row_stride
|
|
@@ -48,9 +49,19 @@ def _triton_rope(
|
|
|
48
49
|
# and pid % sl to get the sequence index.
|
|
49
50
|
# 2. We only need the left half of cos and sin matrix because the right half is just
|
|
50
51
|
# a clone of the left half.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
batch_idx = pid // sl
|
|
53
|
+
cos_row_idx = pid % sl
|
|
54
|
+
cos = cos + tl.where(
|
|
55
|
+
cos_bs == 1,
|
|
56
|
+
cos_row_idx * cos_row_stride,
|
|
57
|
+
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
|
|
58
|
+
)
|
|
59
|
+
sin = sin + tl.where(
|
|
60
|
+
cos_bs == 1,
|
|
61
|
+
cos_row_idx * sin_row_stride,
|
|
62
|
+
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
|
|
63
|
+
)
|
|
64
|
+
|
|
54
65
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
55
66
|
cos_mask = cos_offsets < hd // 2
|
|
56
67
|
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
|
|
@@ -61,36 +72,20 @@ def _triton_rope(
|
|
|
61
72
|
# program instance (i.e. for the current token) separately
|
|
62
73
|
# ####################################################################
|
|
63
74
|
# left half of the head
|
|
64
|
-
first_half_q_offsets = (
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
)
|
|
70
|
-
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
71
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
72
|
-
)
|
|
73
|
-
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
74
|
-
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
75
|
-
)
|
|
76
|
-
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
77
|
-
sin_row.dtype
|
|
78
|
-
)
|
|
79
|
-
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
80
|
-
sin_row.dtype
|
|
81
|
-
)
|
|
75
|
+
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
76
|
+
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
77
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
78
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)
|
|
79
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)
|
|
80
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)
|
|
82
81
|
|
|
83
82
|
# right half of the head
|
|
84
83
|
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
85
84
|
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
86
85
|
second_q_mask = first_q_mask
|
|
87
86
|
second_k_mask = first_k_mask
|
|
88
|
-
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
89
|
-
|
|
90
|
-
)
|
|
91
|
-
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
92
|
-
sin_row.dtype
|
|
93
|
-
)
|
|
87
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)
|
|
88
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)
|
|
94
89
|
|
|
95
90
|
if not BACKWARD_PASS:
|
|
96
91
|
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
@@ -118,7 +113,6 @@ def _triton_rope(
|
|
|
118
113
|
|
|
119
114
|
|
|
120
115
|
def rope_forward(q, k, cos, sin):
|
|
121
|
-
|
|
122
116
|
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
123
117
|
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
124
118
|
q = q.transpose(1, 2)
|
|
@@ -138,6 +132,7 @@ def rope_forward(q, k, cos, sin):
|
|
|
138
132
|
k = k.contiguous()
|
|
139
133
|
cos = cos.contiguous()
|
|
140
134
|
sin = sin.contiguous()
|
|
135
|
+
cos_batch_size = cos.shape[0]
|
|
141
136
|
|
|
142
137
|
_triton_rope[(n_row,)](
|
|
143
138
|
q,
|
|
@@ -150,6 +145,7 @@ def rope_forward(q, k, cos, sin):
|
|
|
150
145
|
sin.stride(-2),
|
|
151
146
|
seq_len,
|
|
152
147
|
batch_size,
|
|
148
|
+
cos_batch_size,
|
|
153
149
|
n_q_head,
|
|
154
150
|
n_kv_head,
|
|
155
151
|
head_dim,
|
|
@@ -167,6 +163,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
167
163
|
dk = dk.transpose(1, 2)
|
|
168
164
|
|
|
169
165
|
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
166
|
+
cos_batch_size = cos.shape[0]
|
|
170
167
|
n_kv_head = dk.shape[2]
|
|
171
168
|
pad_hd = triton.next_power_of_2(head_dim)
|
|
172
169
|
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
@@ -191,6 +188,7 @@ def rope_backward(dq, dk, cos, sin):
|
|
|
191
188
|
sin.stride(-2),
|
|
192
189
|
seq_len,
|
|
193
190
|
batch_size,
|
|
191
|
+
cos_batch_size,
|
|
194
192
|
n_q_head,
|
|
195
193
|
n_kv_head,
|
|
196
194
|
head_dim,
|
|
@@ -221,8 +219,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
221
219
|
"""
|
|
222
220
|
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
223
221
|
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
224
|
-
cos size: (1, seq_len, head_dim)
|
|
225
|
-
sin size: (1, seq_len, head_dim)
|
|
222
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
223
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
226
224
|
"""
|
|
227
225
|
q, k, cos, sin = rope_forward(q, k, cos, sin)
|
|
228
226
|
ctx.save_for_backward(cos, sin)
|
|
@@ -232,8 +230,8 @@ class LigerRopeFunction(torch.autograd.Function):
|
|
|
232
230
|
"""
|
|
233
231
|
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
234
232
|
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
235
|
-
cos size: (1, seq_len, head_dim)
|
|
236
|
-
sin size: (1, seq_len, head_dim)
|
|
233
|
+
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
234
|
+
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
237
235
|
"""
|
|
238
236
|
|
|
239
237
|
cos, sin = ctx.saved_tensors
|