liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -21,8 +21,10 @@ from liger_kernel.ops.utils import calculate_settings
|
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
23
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
24
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
25
|
+
from liger_kernel.utils import is_npu_available
|
|
24
26
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
27
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
28
|
try:
|
|
27
29
|
# typical import path with dispatch available
|
|
28
30
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -52,6 +54,7 @@ def _rms_norm_forward_kernel(
|
|
|
52
54
|
eps,
|
|
53
55
|
offset,
|
|
54
56
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
57
|
+
elementwise_affine: tl.constexpr,
|
|
55
58
|
BLOCK_SIZE: tl.constexpr,
|
|
56
59
|
):
|
|
57
60
|
"""
|
|
@@ -63,17 +66,18 @@ def _rms_norm_forward_kernel(
|
|
|
63
66
|
3. https://arxiv.org/pdf/1910.07467
|
|
64
67
|
"""
|
|
65
68
|
|
|
66
|
-
row_idx = tl.program_id(0)
|
|
69
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
67
70
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
68
71
|
mask = col_offsets < n_cols
|
|
69
72
|
|
|
70
|
-
Y_ptr
|
|
71
|
-
X_ptr
|
|
72
|
-
RSTD_ptr
|
|
73
|
+
y_base = Y_ptr + row_idx * Y_row_stride
|
|
74
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
75
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
73
76
|
|
|
74
|
-
X_row = tl.load(
|
|
77
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
|
|
75
78
|
X_row_dtype = X_row.dtype
|
|
76
|
-
|
|
79
|
+
if elementwise_affine:
|
|
80
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
77
81
|
|
|
78
82
|
# On Llama, only rstd is computed on fp32
|
|
79
83
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -81,7 +85,8 @@ def _rms_norm_forward_kernel(
|
|
|
81
85
|
|
|
82
86
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
87
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
-
|
|
88
|
+
if elementwise_affine:
|
|
89
|
+
W_row = W_row.to(tl.float32)
|
|
85
90
|
X_row = X_row.to(tl.float32)
|
|
86
91
|
|
|
87
92
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -94,7 +99,7 @@ def _rms_norm_forward_kernel(
|
|
|
94
99
|
# We can save time by caching rms with minimal memory overhead
|
|
95
100
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
96
101
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
97
|
-
tl.store(
|
|
102
|
+
tl.store(rstd_base, rstd)
|
|
98
103
|
|
|
99
104
|
X_row = X_row * rstd
|
|
100
105
|
|
|
@@ -102,12 +107,15 @@ def _rms_norm_forward_kernel(
|
|
|
102
107
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
103
108
|
X_row = X_row.to(X_row_dtype)
|
|
104
109
|
|
|
105
|
-
|
|
110
|
+
if elementwise_affine:
|
|
111
|
+
Y_row = X_row * (offset + W_row)
|
|
112
|
+
else:
|
|
113
|
+
Y_row = X_row
|
|
106
114
|
|
|
107
115
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
108
116
|
Y_row = Y_row.to(X_row_dtype)
|
|
109
117
|
|
|
110
|
-
tl.store(
|
|
118
|
+
tl.store(y_base + col_offsets, Y_row, mask=mask)
|
|
111
119
|
|
|
112
120
|
|
|
113
121
|
@triton.jit
|
|
@@ -128,8 +136,9 @@ def _rms_norm_backward_kernel(
|
|
|
128
136
|
n_rows,
|
|
129
137
|
n_cols,
|
|
130
138
|
offset,
|
|
131
|
-
rows_per_program
|
|
139
|
+
rows_per_program,
|
|
132
140
|
casting_mode: tl.constexpr,
|
|
141
|
+
elementwise_affine: tl.constexpr,
|
|
133
142
|
BLOCK_SIZE: tl.constexpr,
|
|
134
143
|
):
|
|
135
144
|
"""
|
|
@@ -137,61 +146,256 @@ def _rms_norm_backward_kernel(
|
|
|
137
146
|
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
138
147
|
"""
|
|
139
148
|
|
|
140
|
-
row_block_id = tl.program_id(0)
|
|
149
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
141
150
|
row_start = row_block_id * rows_per_program
|
|
142
151
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
143
152
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
144
153
|
mask = col_offsets < n_cols
|
|
145
154
|
|
|
146
|
-
|
|
155
|
+
if elementwise_affine:
|
|
156
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
157
|
|
|
148
|
-
|
|
149
|
-
|
|
158
|
+
if elementwise_affine:
|
|
159
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
160
|
+
W_row = W_row + offset
|
|
150
161
|
|
|
151
|
-
|
|
152
|
-
|
|
162
|
+
for row_idx in range(row_start, row_end):
|
|
163
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
164
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
153
165
|
|
|
154
|
-
|
|
155
|
-
|
|
166
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
167
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
156
168
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
169
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
170
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
160
171
|
|
|
161
172
|
# Get cached rms
|
|
162
|
-
rstd_row = tl.load(
|
|
173
|
+
rstd_row = tl.load(rstd_base)
|
|
163
174
|
|
|
164
175
|
X_row = X_row.to(tl.float32)
|
|
165
176
|
|
|
166
177
|
# Different bacward graphs for different casting modes
|
|
167
178
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
168
|
-
|
|
179
|
+
if elementwise_affine:
|
|
180
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
181
|
+
else:
|
|
182
|
+
m = dY_row.to(tl.float32)
|
|
169
183
|
|
|
170
184
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
171
185
|
dY_row = dY_row.to(tl.float32)
|
|
172
|
-
|
|
186
|
+
if elementwise_affine:
|
|
187
|
+
m = dY_row * W_row
|
|
188
|
+
else:
|
|
189
|
+
m = dY_row
|
|
173
190
|
else:
|
|
174
|
-
|
|
191
|
+
if elementwise_affine:
|
|
192
|
+
m = dY_row * W_row
|
|
193
|
+
else:
|
|
194
|
+
m = dY_row
|
|
175
195
|
|
|
176
196
|
dX_row = rstd_row * m
|
|
177
197
|
|
|
178
198
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
179
199
|
|
|
180
|
-
|
|
200
|
+
if elementwise_affine:
|
|
201
|
+
# calculate the gradient of W
|
|
202
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
203
|
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
|
204
|
+
else:
|
|
205
|
+
# here X_row is already in fp32 (see previous if block)
|
|
206
|
+
dW_row += dY_row * (X_row * rstd_row)
|
|
207
|
+
|
|
208
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
209
|
+
|
|
210
|
+
if elementwise_affine:
|
|
211
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@triton.jit
|
|
215
|
+
def _block_rms_norm_forward_kernel(
|
|
216
|
+
Y_ptr,
|
|
217
|
+
Y_row_stride,
|
|
218
|
+
X_ptr,
|
|
219
|
+
X_row_stride,
|
|
220
|
+
W_ptr,
|
|
221
|
+
W_row_stride,
|
|
222
|
+
RSTD_ptr,
|
|
223
|
+
RSTD_row_stride,
|
|
224
|
+
n_rows,
|
|
225
|
+
n_cols,
|
|
226
|
+
eps,
|
|
227
|
+
offset,
|
|
228
|
+
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
229
|
+
elementwise_affine: tl.constexpr,
|
|
230
|
+
BLOCK_SIZE: tl.constexpr,
|
|
231
|
+
BLOCK_ROW: tl.constexpr,
|
|
232
|
+
):
|
|
233
|
+
"""
|
|
234
|
+
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
|
|
235
|
+
|
|
236
|
+
Reference:
|
|
237
|
+
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
238
|
+
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
|
|
239
|
+
3. https://arxiv.org/pdf/1910.07467
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
|
|
243
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
244
|
+
row_mask = row_idx < n_rows
|
|
245
|
+
col_mask = col_offsets < n_cols
|
|
246
|
+
|
|
247
|
+
X_row = tl.load(
|
|
248
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
249
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
250
|
+
other=0,
|
|
251
|
+
)
|
|
252
|
+
X_row_dtype = X_row.dtype
|
|
253
|
+
if elementwise_affine:
|
|
254
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
255
|
+
|
|
256
|
+
# On Llama, only rstd is computed on fp32
|
|
257
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
258
|
+
X_row = X_row.to(tl.float32)
|
|
259
|
+
|
|
260
|
+
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
261
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
262
|
+
if elementwise_affine:
|
|
263
|
+
W_row = W_row.to(tl.float32)
|
|
264
|
+
X_row = X_row.to(tl.float32)
|
|
265
|
+
|
|
266
|
+
if casting_mode == _CASTING_MODE_NONE:
|
|
267
|
+
eps = eps.to(X_row_dtype)
|
|
268
|
+
offset = offset.to(X_row_dtype)
|
|
269
|
+
|
|
270
|
+
mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
|
|
271
|
+
rstd = rsqrt(mean_square + eps)
|
|
272
|
+
|
|
273
|
+
# We can save time by caching rms with minimal memory overhead
|
|
274
|
+
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
275
|
+
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
276
|
+
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
|
|
277
|
+
|
|
278
|
+
X_row = X_row * rstd[:, None]
|
|
279
|
+
|
|
280
|
+
# On Llama, the multiplication with the weight is done on the original dtype
|
|
281
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
282
|
+
X_row = X_row.to(X_row_dtype)
|
|
283
|
+
|
|
284
|
+
if elementwise_affine:
|
|
285
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
286
|
+
else:
|
|
287
|
+
Y_row = X_row
|
|
288
|
+
|
|
289
|
+
if casting_mode == _CASTING_MODE_GEMMA:
|
|
290
|
+
Y_row = Y_row.to(X_row_dtype)
|
|
291
|
+
|
|
292
|
+
tl.store(
|
|
293
|
+
Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
|
|
294
|
+
Y_row,
|
|
295
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@triton.jit
|
|
300
|
+
def _block_rms_norm_backward_kernel(
|
|
301
|
+
dY_ptr,
|
|
302
|
+
dY_row_stride,
|
|
303
|
+
dX_ptr,
|
|
304
|
+
dX_row_stride,
|
|
305
|
+
X_ptr,
|
|
306
|
+
X_row_stride,
|
|
307
|
+
X_dtype: tl.constexpr,
|
|
308
|
+
W_ptr,
|
|
309
|
+
W_row_stride,
|
|
310
|
+
RSTD_ptr,
|
|
311
|
+
RSTD_row_stride,
|
|
312
|
+
dW_ptr,
|
|
313
|
+
dW_row_stride,
|
|
314
|
+
n_rows,
|
|
315
|
+
n_cols,
|
|
316
|
+
offset,
|
|
317
|
+
casting_mode: tl.constexpr,
|
|
318
|
+
elementwise_affine: tl.constexpr,
|
|
319
|
+
BLOCK_SIZE: tl.constexpr,
|
|
320
|
+
BLOCK_ROW: tl.constexpr,
|
|
321
|
+
):
|
|
322
|
+
"""
|
|
323
|
+
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
|
|
324
|
+
dw = sum(dy * (x / RMS)). summation over BxT dimension
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
pid = tl.program_id(0).cast(tl.int64)
|
|
328
|
+
NUM_SMS = tl.num_programs(0)
|
|
329
|
+
|
|
330
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
331
|
+
col_mask = col_offsets < n_cols
|
|
332
|
+
|
|
333
|
+
if elementwise_affine:
|
|
334
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
335
|
+
|
|
336
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
337
|
+
W_row = W_row + offset
|
|
338
|
+
|
|
339
|
+
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
340
|
+
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
341
|
+
row_mask = row_idx < n_rows
|
|
342
|
+
dY_row = tl.load(
|
|
343
|
+
dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
|
|
344
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
345
|
+
other=0.0,
|
|
346
|
+
)
|
|
347
|
+
X_row = tl.load(
|
|
348
|
+
X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
|
|
349
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
350
|
+
other=0.0,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Get cached rms
|
|
354
|
+
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
|
|
355
|
+
|
|
356
|
+
X_row = X_row.to(tl.float32)
|
|
357
|
+
|
|
358
|
+
# Different bacward graphs for different casting modes
|
|
181
359
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
182
|
-
|
|
360
|
+
if elementwise_affine:
|
|
361
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
362
|
+
else:
|
|
363
|
+
m = dY_row.to(tl.float32)
|
|
364
|
+
|
|
365
|
+
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
366
|
+
dY_row = dY_row.to(tl.float32)
|
|
367
|
+
if elementwise_affine:
|
|
368
|
+
m = dY_row * W_row[None, :]
|
|
369
|
+
else:
|
|
370
|
+
m = dY_row
|
|
183
371
|
else:
|
|
184
|
-
|
|
185
|
-
|
|
372
|
+
if elementwise_affine:
|
|
373
|
+
m = dY_row * W_row[None, :]
|
|
374
|
+
else:
|
|
375
|
+
m = dY_row
|
|
186
376
|
|
|
187
|
-
|
|
377
|
+
dX_row = rstd_row[:, None] * m
|
|
188
378
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
379
|
+
dX_row += (rstd_row[:, None]) * (
|
|
380
|
+
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if elementwise_affine:
|
|
384
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
385
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
386
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
387
|
+
else:
|
|
388
|
+
# here X_row is already in fp32 (see previous if block)
|
|
389
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
390
|
+
|
|
391
|
+
tl.store(
|
|
392
|
+
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
393
|
+
dX_row,
|
|
394
|
+
mask=row_mask[:, None] & col_mask[None, :],
|
|
395
|
+
)
|
|
193
396
|
|
|
194
|
-
|
|
397
|
+
if elementwise_affine:
|
|
398
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
195
399
|
|
|
196
400
|
|
|
197
401
|
_str_to_casting_mode = {
|
|
@@ -201,7 +405,7 @@ _str_to_casting_mode = {
|
|
|
201
405
|
}
|
|
202
406
|
|
|
203
407
|
|
|
204
|
-
def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
408
|
+
def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
205
409
|
if not isinstance(casting_mode, int):
|
|
206
410
|
assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
|
|
207
411
|
casting_mode = _str_to_casting_mode[casting_mode]
|
|
@@ -220,29 +424,64 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
|
|
|
220
424
|
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
221
425
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
222
426
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
427
|
+
if W is not None:
|
|
428
|
+
# Check constraints.
|
|
429
|
+
assert X.shape[1] == W.shape[0], (
|
|
430
|
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
431
|
+
)
|
|
432
|
+
elementwise_affine = True
|
|
433
|
+
else:
|
|
434
|
+
elementwise_affine = False
|
|
435
|
+
|
|
436
|
+
# XPU-specific optimization
|
|
437
|
+
kernel_args = {}
|
|
438
|
+
if X.device.type == "xpu":
|
|
439
|
+
kernel_args["grf_mode"] = "large"
|
|
440
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
441
|
+
_rms_norm_forward_kernel[(n_rows,)](
|
|
442
|
+
Y,
|
|
443
|
+
Y.stride(0),
|
|
444
|
+
X,
|
|
445
|
+
X.stride(0),
|
|
446
|
+
W,
|
|
447
|
+
W.stride(0) if elementwise_affine else 0,
|
|
448
|
+
RSTD,
|
|
449
|
+
RSTD.stride(0),
|
|
450
|
+
n_cols,
|
|
451
|
+
eps,
|
|
452
|
+
offset,
|
|
453
|
+
casting_mode,
|
|
454
|
+
elementwise_affine=elementwise_affine,
|
|
455
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
456
|
+
num_warps=num_warps,
|
|
457
|
+
**kernel_args, # XPU-specific optimization
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
BLOCK_ROW = 16
|
|
461
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
462
|
+
_block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
|
|
463
|
+
Y,
|
|
464
|
+
Y.stride(0),
|
|
465
|
+
X,
|
|
466
|
+
X.stride(0),
|
|
467
|
+
W,
|
|
468
|
+
W.stride(0) if elementwise_affine else 0,
|
|
469
|
+
RSTD,
|
|
470
|
+
RSTD.stride(0),
|
|
471
|
+
n_rows,
|
|
472
|
+
n_cols,
|
|
473
|
+
eps,
|
|
474
|
+
offset,
|
|
475
|
+
casting_mode,
|
|
476
|
+
elementwise_affine=elementwise_affine,
|
|
477
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
478
|
+
num_warps=num_warps,
|
|
479
|
+
**kernel_args, # XPU-specific optimization
|
|
480
|
+
)
|
|
242
481
|
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
|
|
243
482
|
|
|
244
483
|
|
|
245
|
-
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place):
|
|
484
|
+
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
|
|
246
485
|
shape = dY.shape
|
|
247
486
|
dim = shape[-1]
|
|
248
487
|
dY = dY.view(-1, dim)
|
|
@@ -252,10 +491,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
252
491
|
if X.device.type == "cuda":
|
|
253
492
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
254
493
|
elif X.device.type == "xpu":
|
|
255
|
-
sm_count = torch.xpu.get_device_properties(X.device).
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
494
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
495
|
+
elif X.device.type == "npu":
|
|
496
|
+
sm_count = get_npu_multi_processor_count()
|
|
497
|
+
|
|
498
|
+
if W is not None:
|
|
499
|
+
# fp32 for numerical stability especially.
|
|
500
|
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
501
|
+
elementwise_affine = True
|
|
502
|
+
else:
|
|
503
|
+
_dW = None
|
|
504
|
+
elementwise_affine = False
|
|
259
505
|
|
|
260
506
|
if n_cols > BLOCK_SIZE:
|
|
261
507
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -267,30 +513,68 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
267
513
|
else:
|
|
268
514
|
dX = torch.zeros_like(dY)
|
|
269
515
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
516
|
+
# XPU-specific optimization
|
|
517
|
+
kernel_args = {}
|
|
518
|
+
if X.device.type == "xpu":
|
|
519
|
+
kernel_args["grf_mode"] = "large"
|
|
520
|
+
|
|
521
|
+
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
522
|
+
_rms_norm_backward_kernel[grid](
|
|
523
|
+
dY,
|
|
524
|
+
dY.stride(0),
|
|
525
|
+
dX,
|
|
526
|
+
dX.stride(0),
|
|
527
|
+
X,
|
|
528
|
+
X.stride(0),
|
|
529
|
+
torch_to_triton_dtype[X.dtype],
|
|
530
|
+
W,
|
|
531
|
+
W.stride(0) if elementwise_affine else 0,
|
|
532
|
+
RSTD,
|
|
533
|
+
RSTD.stride(0),
|
|
534
|
+
_dW,
|
|
535
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
536
|
+
n_rows,
|
|
537
|
+
n_cols,
|
|
538
|
+
offset,
|
|
539
|
+
rows_per_program,
|
|
540
|
+
casting_mode,
|
|
541
|
+
elementwise_affine=elementwise_affine,
|
|
542
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
543
|
+
num_warps=num_warps,
|
|
544
|
+
**kernel_args, # XPU-specific optimization
|
|
545
|
+
)
|
|
546
|
+
else:
|
|
547
|
+
BLOCK_ROW = 16
|
|
548
|
+
kernel_args["BLOCK_ROW"] = BLOCK_ROW
|
|
549
|
+
_block_rms_norm_backward_kernel[grid](
|
|
550
|
+
dY,
|
|
551
|
+
dY.stride(0),
|
|
552
|
+
dX,
|
|
553
|
+
dX.stride(0),
|
|
554
|
+
X,
|
|
555
|
+
X.stride(0),
|
|
556
|
+
torch_to_triton_dtype[X.dtype],
|
|
557
|
+
W,
|
|
558
|
+
W.stride(0) if elementwise_affine else 0,
|
|
559
|
+
RSTD,
|
|
560
|
+
RSTD.stride(0),
|
|
561
|
+
_dW,
|
|
562
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
563
|
+
n_rows,
|
|
564
|
+
n_cols,
|
|
565
|
+
offset,
|
|
566
|
+
casting_mode,
|
|
567
|
+
elementwise_affine=elementwise_affine,
|
|
568
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
569
|
+
num_warps=num_warps,
|
|
570
|
+
**kernel_args, # XPU-specific optimization
|
|
571
|
+
)
|
|
292
572
|
dX = dX.view(*shape)
|
|
293
|
-
|
|
573
|
+
|
|
574
|
+
if elementwise_affine:
|
|
575
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
576
|
+
else:
|
|
577
|
+
dW = None
|
|
294
578
|
|
|
295
579
|
return dX, dW
|
|
296
580
|
|
|
@@ -319,18 +603,30 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
319
603
|
|
|
320
604
|
@staticmethod
|
|
321
605
|
@ensure_contiguous
|
|
322
|
-
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
|
|
606
|
+
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
|
|
323
607
|
"""
|
|
324
608
|
X: (B, T, H) or (BxT, H)
|
|
325
609
|
W: (H,)
|
|
326
610
|
"""
|
|
327
|
-
|
|
611
|
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
|
612
|
+
# Input tensor is output of a tensor parallel module and
|
|
613
|
+
# needs to be gathered to a local tensor to compute
|
|
614
|
+
# RMSE layer norm on each TP worker.
|
|
615
|
+
# TODO: support CP.
|
|
616
|
+
X = X.full_tensor()
|
|
617
|
+
|
|
618
|
+
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
328
619
|
ctx.offset = offset
|
|
329
620
|
ctx.casting_mode = casting_mode
|
|
330
621
|
ctx.in_place = in_place
|
|
622
|
+
ctx.row_mode = row_mode
|
|
331
623
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
332
624
|
ctx.num_warps = num_warps
|
|
333
|
-
ctx.
|
|
625
|
+
ctx.elementwise_affine = W is not None
|
|
626
|
+
if W is not None:
|
|
627
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
628
|
+
else:
|
|
629
|
+
ctx.save_for_backward(X, RSTD)
|
|
334
630
|
return Y
|
|
335
631
|
|
|
336
632
|
@staticmethod
|
|
@@ -339,16 +635,19 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
339
635
|
"""
|
|
340
636
|
Y: (B, T, H) or (BxT, H)
|
|
341
637
|
"""
|
|
342
|
-
|
|
638
|
+
if ctx.elementwise_affine:
|
|
639
|
+
X, W, RSTD = ctx.saved_tensors
|
|
640
|
+
else:
|
|
641
|
+
X, RSTD = ctx.saved_tensors
|
|
642
|
+
W = None
|
|
643
|
+
|
|
644
|
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
|
645
|
+
# Gradients are output of a tensor parallel module and
|
|
646
|
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
|
647
|
+
# TODO: support CP.
|
|
648
|
+
dY = dY.full_tensor()
|
|
649
|
+
|
|
343
650
|
dX, dW = rms_norm_backward(
|
|
344
|
-
dY,
|
|
345
|
-
X,
|
|
346
|
-
W,
|
|
347
|
-
RSTD,
|
|
348
|
-
ctx.offset,
|
|
349
|
-
ctx.casting_mode,
|
|
350
|
-
ctx.BLOCK_SIZE,
|
|
351
|
-
ctx.num_warps,
|
|
352
|
-
ctx.in_place,
|
|
651
|
+
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
353
652
|
)
|
|
354
|
-
return dX, dW, None, None, None, None
|
|
653
|
+
return dX, dW, None, None, None, None, None
|
liger_kernel/ops/rope.py
CHANGED
|
@@ -32,7 +32,7 @@ def _triton_rope(
|
|
|
32
32
|
|
|
33
33
|
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
|
|
34
34
|
# stride: (seq_len * head_dim, head_dim, 1)
|
|
35
|
-
pid = tl.program_id(0)
|
|
35
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
36
36
|
|
|
37
37
|
# locate start address
|
|
38
38
|
q_ptr = q_ptr + pid * q_row_stride
|