liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +71 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +89 -69
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +25 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +44 -26
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +17 -8
- liger_kernel/transformers/model/gemma3.py +35 -16
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +37 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +584 -49
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.3.dist-info/RECORD +0 -111
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
liger_kernel/ops/poly_norm.py
CHANGED
|
@@ -7,8 +7,11 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import calculate_settings
|
|
8
8
|
from liger_kernel.ops.utils import compare_version
|
|
9
9
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
11
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
12
|
+
from liger_kernel.utils import is_npu_available
|
|
10
13
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
15
|
try:
|
|
13
16
|
from triton.language.extra.libdevice import rsqrt
|
|
14
17
|
except ModuleNotFoundError:
|
|
@@ -138,20 +141,19 @@ def _poly_norm_backward_kernel(
|
|
|
138
141
|
w1 = tl.load(W_ptr + 1).to(tl.float32)
|
|
139
142
|
w2 = tl.load(W_ptr + 2).to(tl.float32)
|
|
140
143
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
144
|
+
for row_idx in range(row_start, row_end):
|
|
145
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
146
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
147
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
148
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
145
149
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
149
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
151
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
|
|
150
152
|
|
|
151
153
|
# Load cached rstd values
|
|
152
|
-
rstd_3 = tl.load(
|
|
153
|
-
rstd_2 = tl.load(
|
|
154
|
-
rstd_1 = tl.load(
|
|
154
|
+
rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
|
|
155
|
+
rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
|
|
156
|
+
rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
|
|
155
157
|
|
|
156
158
|
# Compute powers
|
|
157
159
|
X_pow3 = X_row * X_row * X_row
|
|
@@ -188,13 +190,7 @@ def _poly_norm_backward_kernel(
|
|
|
188
190
|
dX_row = grad_x_3 + grad_x_2 + grad_x_1
|
|
189
191
|
|
|
190
192
|
# Store gradient
|
|
191
|
-
tl.store(
|
|
192
|
-
|
|
193
|
-
# Update pointers
|
|
194
|
-
dY_ptr += dY_row_stride
|
|
195
|
-
dX_ptr += dX_row_stride
|
|
196
|
-
X_ptr += X_row_stride
|
|
197
|
-
RSTD_ptr += RSTD_row_stride
|
|
193
|
+
tl.store(dx_base + col_offsets, dX_row, mask=mask)
|
|
198
194
|
|
|
199
195
|
# Store accumulated gradients (scalars)
|
|
200
196
|
tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
|
|
@@ -237,7 +233,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
|
|
|
237
233
|
# XPU-specific optimization
|
|
238
234
|
kernel_args = {}
|
|
239
235
|
if X.device.type == "xpu":
|
|
240
|
-
kernel_args
|
|
236
|
+
set_large_grf_mode(kernel_args)
|
|
241
237
|
|
|
242
238
|
# Launch kernel
|
|
243
239
|
_poly_norm_forward_kernel[(n_rows,)](
|
|
@@ -290,6 +286,8 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
|
290
286
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
291
287
|
elif X.device.type == "xpu":
|
|
292
288
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
289
|
+
elif X.device.type == "npu":
|
|
290
|
+
sm_count = get_npu_core_count()
|
|
293
291
|
|
|
294
292
|
# Allocate or reuse gradients
|
|
295
293
|
if in_place is True:
|
|
@@ -306,7 +304,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
|
|
|
306
304
|
# XPU-specific optimization
|
|
307
305
|
kernel_args = {}
|
|
308
306
|
if X.device.type == "xpu":
|
|
309
|
-
kernel_args
|
|
307
|
+
set_large_grf_mode(kernel_args)
|
|
310
308
|
|
|
311
309
|
# Launch backward kernel
|
|
312
310
|
_poly_norm_backward_kernel[grid](
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -20,9 +20,12 @@ import triton.language as tl
|
|
|
20
20
|
from liger_kernel.ops.utils import calculate_settings
|
|
21
21
|
from liger_kernel.ops.utils import compare_version
|
|
22
22
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
23
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
24
|
+
from liger_kernel.ops.utils import set_large_grf_mode
|
|
23
25
|
from liger_kernel.ops.utils import torch_to_triton_dtype
|
|
26
|
+
from liger_kernel.utils import is_npu_available
|
|
24
27
|
|
|
25
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
28
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
26
29
|
try:
|
|
27
30
|
# typical import path with dispatch available
|
|
28
31
|
from triton.language.extra.libdevice import rsqrt
|
|
@@ -52,6 +55,7 @@ def _rms_norm_forward_kernel(
|
|
|
52
55
|
eps,
|
|
53
56
|
offset,
|
|
54
57
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
58
|
+
elementwise_affine: tl.constexpr,
|
|
55
59
|
BLOCK_SIZE: tl.constexpr,
|
|
56
60
|
):
|
|
57
61
|
"""
|
|
@@ -67,13 +71,14 @@ def _rms_norm_forward_kernel(
|
|
|
67
71
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
68
72
|
mask = col_offsets < n_cols
|
|
69
73
|
|
|
70
|
-
Y_ptr
|
|
71
|
-
X_ptr
|
|
72
|
-
RSTD_ptr
|
|
74
|
+
y_base = Y_ptr + row_idx * Y_row_stride
|
|
75
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
76
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
73
77
|
|
|
74
|
-
X_row = tl.load(
|
|
78
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
|
|
75
79
|
X_row_dtype = X_row.dtype
|
|
76
|
-
|
|
80
|
+
if elementwise_affine:
|
|
81
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
|
77
82
|
|
|
78
83
|
# On Llama, only rstd is computed on fp32
|
|
79
84
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -81,7 +86,8 @@ def _rms_norm_forward_kernel(
|
|
|
81
86
|
|
|
82
87
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
83
88
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
84
|
-
|
|
89
|
+
if elementwise_affine:
|
|
90
|
+
W_row = W_row.to(tl.float32)
|
|
85
91
|
X_row = X_row.to(tl.float32)
|
|
86
92
|
|
|
87
93
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -94,7 +100,7 @@ def _rms_norm_forward_kernel(
|
|
|
94
100
|
# We can save time by caching rms with minimal memory overhead
|
|
95
101
|
# because rms is much smaller compared to X_row, as rms is for each row.
|
|
96
102
|
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
|
|
97
|
-
tl.store(
|
|
103
|
+
tl.store(rstd_base, rstd)
|
|
98
104
|
|
|
99
105
|
X_row = X_row * rstd
|
|
100
106
|
|
|
@@ -102,12 +108,15 @@ def _rms_norm_forward_kernel(
|
|
|
102
108
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
103
109
|
X_row = X_row.to(X_row_dtype)
|
|
104
110
|
|
|
105
|
-
|
|
111
|
+
if elementwise_affine:
|
|
112
|
+
Y_row = X_row * (offset + W_row)
|
|
113
|
+
else:
|
|
114
|
+
Y_row = X_row
|
|
106
115
|
|
|
107
116
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
108
117
|
Y_row = Y_row.to(X_row_dtype)
|
|
109
118
|
|
|
110
|
-
tl.store(
|
|
119
|
+
tl.store(y_base + col_offsets, Y_row, mask=mask)
|
|
111
120
|
|
|
112
121
|
|
|
113
122
|
@triton.jit
|
|
@@ -128,8 +137,9 @@ def _rms_norm_backward_kernel(
|
|
|
128
137
|
n_rows,
|
|
129
138
|
n_cols,
|
|
130
139
|
offset,
|
|
131
|
-
rows_per_program
|
|
140
|
+
rows_per_program,
|
|
132
141
|
casting_mode: tl.constexpr,
|
|
142
|
+
elementwise_affine: tl.constexpr,
|
|
133
143
|
BLOCK_SIZE: tl.constexpr,
|
|
134
144
|
):
|
|
135
145
|
"""
|
|
@@ -143,55 +153,63 @@ def _rms_norm_backward_kernel(
|
|
|
143
153
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
144
154
|
mask = col_offsets < n_cols
|
|
145
155
|
|
|
146
|
-
|
|
156
|
+
if elementwise_affine:
|
|
157
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
147
158
|
|
|
148
|
-
|
|
149
|
-
|
|
159
|
+
if elementwise_affine:
|
|
160
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
161
|
+
W_row = W_row + offset
|
|
150
162
|
|
|
151
|
-
|
|
152
|
-
|
|
163
|
+
for row_idx in range(row_start, row_end):
|
|
164
|
+
dy_base = dY_ptr + row_idx * dY_row_stride
|
|
165
|
+
dx_base = dX_ptr + row_idx * dX_row_stride
|
|
153
166
|
|
|
154
|
-
|
|
155
|
-
|
|
167
|
+
x_base = X_ptr + row_idx * X_row_stride
|
|
168
|
+
rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
|
|
156
169
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
|
|
170
|
+
dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
|
|
171
|
+
X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
|
|
160
172
|
|
|
161
173
|
# Get cached rms
|
|
162
|
-
rstd_row = tl.load(
|
|
174
|
+
rstd_row = tl.load(rstd_base)
|
|
163
175
|
|
|
164
176
|
X_row = X_row.to(tl.float32)
|
|
165
177
|
|
|
166
178
|
# Different bacward graphs for different casting modes
|
|
167
179
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
168
|
-
|
|
180
|
+
if elementwise_affine:
|
|
181
|
+
m = (dY_row * W_row).to(tl.float32)
|
|
182
|
+
else:
|
|
183
|
+
m = dY_row.to(tl.float32)
|
|
169
184
|
|
|
170
185
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
171
186
|
dY_row = dY_row.to(tl.float32)
|
|
172
|
-
|
|
187
|
+
if elementwise_affine:
|
|
188
|
+
m = dY_row * W_row
|
|
189
|
+
else:
|
|
190
|
+
m = dY_row
|
|
173
191
|
else:
|
|
174
|
-
|
|
192
|
+
if elementwise_affine:
|
|
193
|
+
m = dY_row * W_row
|
|
194
|
+
else:
|
|
195
|
+
m = dY_row
|
|
175
196
|
|
|
176
197
|
dX_row = rstd_row * m
|
|
177
198
|
|
|
178
199
|
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
|
|
179
200
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
201
|
+
if elementwise_affine:
|
|
202
|
+
# calculate the gradient of W
|
|
203
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
204
|
+
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
|
|
205
|
+
else:
|
|
206
|
+
# here X_row is already in fp32 (see previous if block)
|
|
207
|
+
dW_row += dY_row * (X_row * rstd_row)
|
|
186
208
|
|
|
187
|
-
tl.store(
|
|
209
|
+
tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
|
|
188
210
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
X_ptr += X_row_stride
|
|
192
|
-
RSTD_ptr += RSTD_row_stride
|
|
193
|
-
|
|
194
|
-
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
211
|
+
if elementwise_affine:
|
|
212
|
+
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
|
|
195
213
|
|
|
196
214
|
|
|
197
215
|
@triton.jit
|
|
@@ -209,6 +227,7 @@ def _block_rms_norm_forward_kernel(
|
|
|
209
227
|
eps,
|
|
210
228
|
offset,
|
|
211
229
|
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
|
|
230
|
+
elementwise_affine: tl.constexpr,
|
|
212
231
|
BLOCK_SIZE: tl.constexpr,
|
|
213
232
|
BLOCK_ROW: tl.constexpr,
|
|
214
233
|
):
|
|
@@ -232,7 +251,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
232
251
|
other=0,
|
|
233
252
|
)
|
|
234
253
|
X_row_dtype = X_row.dtype
|
|
235
|
-
|
|
254
|
+
if elementwise_affine:
|
|
255
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
|
|
236
256
|
|
|
237
257
|
# On Llama, only rstd is computed on fp32
|
|
238
258
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
@@ -240,7 +260,8 @@ def _block_rms_norm_forward_kernel(
|
|
|
240
260
|
|
|
241
261
|
# Gemma computes everything on fp32, and then casts back the output to the original dtype
|
|
242
262
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
243
|
-
|
|
263
|
+
if elementwise_affine:
|
|
264
|
+
W_row = W_row.to(tl.float32)
|
|
244
265
|
X_row = X_row.to(tl.float32)
|
|
245
266
|
|
|
246
267
|
if casting_mode == _CASTING_MODE_NONE:
|
|
@@ -261,7 +282,10 @@ def _block_rms_norm_forward_kernel(
|
|
|
261
282
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
262
283
|
X_row = X_row.to(X_row_dtype)
|
|
263
284
|
|
|
264
|
-
|
|
285
|
+
if elementwise_affine:
|
|
286
|
+
Y_row = X_row * (offset + W_row)[None, :]
|
|
287
|
+
else:
|
|
288
|
+
Y_row = X_row
|
|
265
289
|
|
|
266
290
|
if casting_mode == _CASTING_MODE_GEMMA:
|
|
267
291
|
Y_row = Y_row.to(X_row_dtype)
|
|
@@ -291,8 +315,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
291
315
|
n_rows,
|
|
292
316
|
n_cols,
|
|
293
317
|
offset,
|
|
294
|
-
rows_per_program: tl.constexpr,
|
|
295
318
|
casting_mode: tl.constexpr,
|
|
319
|
+
elementwise_affine: tl.constexpr,
|
|
296
320
|
BLOCK_SIZE: tl.constexpr,
|
|
297
321
|
BLOCK_ROW: tl.constexpr,
|
|
298
322
|
):
|
|
@@ -307,10 +331,11 @@ def _block_rms_norm_backward_kernel(
|
|
|
307
331
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
308
332
|
col_mask = col_offsets < n_cols
|
|
309
333
|
|
|
310
|
-
|
|
334
|
+
if elementwise_affine:
|
|
335
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
311
336
|
|
|
312
|
-
|
|
313
|
-
|
|
337
|
+
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
|
|
338
|
+
W_row = W_row + offset
|
|
314
339
|
|
|
315
340
|
for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
|
|
316
341
|
row_idx = start + tl.arange(0, BLOCK_ROW)
|
|
@@ -333,13 +358,22 @@ def _block_rms_norm_backward_kernel(
|
|
|
333
358
|
|
|
334
359
|
# Different bacward graphs for different casting modes
|
|
335
360
|
if casting_mode == _CASTING_MODE_LLAMA:
|
|
336
|
-
|
|
361
|
+
if elementwise_affine:
|
|
362
|
+
m = (dY_row * W_row[None, :]).to(tl.float32)
|
|
363
|
+
else:
|
|
364
|
+
m = dY_row.to(tl.float32)
|
|
337
365
|
|
|
338
366
|
elif casting_mode == _CASTING_MODE_GEMMA:
|
|
339
367
|
dY_row = dY_row.to(tl.float32)
|
|
340
|
-
|
|
368
|
+
if elementwise_affine:
|
|
369
|
+
m = dY_row * W_row[None, :]
|
|
370
|
+
else:
|
|
371
|
+
m = dY_row
|
|
341
372
|
else:
|
|
342
|
-
|
|
373
|
+
if elementwise_affine:
|
|
374
|
+
m = dY_row * W_row[None, :]
|
|
375
|
+
else:
|
|
376
|
+
m = dY_row
|
|
343
377
|
|
|
344
378
|
dX_row = rstd_row[:, None] * m
|
|
345
379
|
|
|
@@ -347,12 +381,13 @@ def _block_rms_norm_backward_kernel(
|
|
|
347
381
|
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
|
|
348
382
|
)
|
|
349
383
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
384
|
+
if elementwise_affine:
|
|
385
|
+
if casting_mode == _CASTING_MODE_LLAMA:
|
|
386
|
+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
|
|
387
|
+
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
|
|
388
|
+
else:
|
|
389
|
+
# here X_row is already in fp32 (see previous if block)
|
|
390
|
+
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
|
|
356
391
|
|
|
357
392
|
tl.store(
|
|
358
393
|
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
|
|
@@ -360,7 +395,8 @@ def _block_rms_norm_backward_kernel(
|
|
|
360
395
|
mask=row_mask[:, None] & col_mask[None, :],
|
|
361
396
|
)
|
|
362
397
|
|
|
363
|
-
|
|
398
|
+
if elementwise_affine:
|
|
399
|
+
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
|
|
364
400
|
|
|
365
401
|
|
|
366
402
|
_str_to_casting_mode = {
|
|
@@ -389,13 +425,19 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
389
425
|
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
|
|
390
426
|
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
|
|
391
427
|
|
|
392
|
-
|
|
393
|
-
|
|
428
|
+
if W is not None:
|
|
429
|
+
# Check constraints.
|
|
430
|
+
assert X.shape[1] == W.shape[0], (
|
|
431
|
+
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
|
|
432
|
+
)
|
|
433
|
+
elementwise_affine = True
|
|
434
|
+
else:
|
|
435
|
+
elementwise_affine = False
|
|
394
436
|
|
|
395
437
|
# XPU-specific optimization
|
|
396
438
|
kernel_args = {}
|
|
397
439
|
if X.device.type == "xpu":
|
|
398
|
-
kernel_args
|
|
440
|
+
set_large_grf_mode(kernel_args)
|
|
399
441
|
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
400
442
|
_rms_norm_forward_kernel[(n_rows,)](
|
|
401
443
|
Y,
|
|
@@ -403,13 +445,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
403
445
|
X,
|
|
404
446
|
X.stride(0),
|
|
405
447
|
W,
|
|
406
|
-
W.stride(0),
|
|
448
|
+
W.stride(0) if elementwise_affine else 0,
|
|
407
449
|
RSTD,
|
|
408
450
|
RSTD.stride(0),
|
|
409
451
|
n_cols,
|
|
410
452
|
eps,
|
|
411
453
|
offset,
|
|
412
454
|
casting_mode,
|
|
455
|
+
elementwise_affine=elementwise_affine,
|
|
413
456
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
414
457
|
num_warps=num_warps,
|
|
415
458
|
**kernel_args, # XPU-specific optimization
|
|
@@ -423,7 +466,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
423
466
|
X,
|
|
424
467
|
X.stride(0),
|
|
425
468
|
W,
|
|
426
|
-
W.stride(0),
|
|
469
|
+
W.stride(0) if elementwise_affine else 0,
|
|
427
470
|
RSTD,
|
|
428
471
|
RSTD.stride(0),
|
|
429
472
|
n_rows,
|
|
@@ -431,6 +474,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
|
|
|
431
474
|
eps,
|
|
432
475
|
offset,
|
|
433
476
|
casting_mode,
|
|
477
|
+
elementwise_affine=elementwise_affine,
|
|
434
478
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
435
479
|
num_warps=num_warps,
|
|
436
480
|
**kernel_args, # XPU-specific optimization
|
|
@@ -449,9 +493,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
449
493
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
450
494
|
elif X.device.type == "xpu":
|
|
451
495
|
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
496
|
+
elif X.device.type == "npu":
|
|
497
|
+
sm_count = get_npu_core_count()
|
|
452
498
|
|
|
453
|
-
|
|
454
|
-
|
|
499
|
+
if W is not None:
|
|
500
|
+
# fp32 for numerical stability especially.
|
|
501
|
+
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
502
|
+
elementwise_affine = True
|
|
503
|
+
else:
|
|
504
|
+
_dW = None
|
|
505
|
+
elementwise_affine = False
|
|
455
506
|
|
|
456
507
|
if n_cols > BLOCK_SIZE:
|
|
457
508
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
@@ -466,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
466
517
|
# XPU-specific optimization
|
|
467
518
|
kernel_args = {}
|
|
468
519
|
if X.device.type == "xpu":
|
|
469
|
-
kernel_args
|
|
520
|
+
set_large_grf_mode(kernel_args)
|
|
470
521
|
|
|
471
522
|
if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
|
|
472
523
|
_rms_norm_backward_kernel[grid](
|
|
@@ -478,16 +529,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
478
529
|
X.stride(0),
|
|
479
530
|
torch_to_triton_dtype[X.dtype],
|
|
480
531
|
W,
|
|
481
|
-
W.stride(0),
|
|
532
|
+
W.stride(0) if elementwise_affine else 0,
|
|
482
533
|
RSTD,
|
|
483
534
|
RSTD.stride(0),
|
|
484
535
|
_dW,
|
|
485
|
-
_dW.stride(0),
|
|
536
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
486
537
|
n_rows,
|
|
487
538
|
n_cols,
|
|
488
539
|
offset,
|
|
489
540
|
rows_per_program,
|
|
490
541
|
casting_mode,
|
|
542
|
+
elementwise_affine=elementwise_affine,
|
|
491
543
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
492
544
|
num_warps=num_warps,
|
|
493
545
|
**kernel_args, # XPU-specific optimization
|
|
@@ -504,22 +556,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
|
|
|
504
556
|
X.stride(0),
|
|
505
557
|
torch_to_triton_dtype[X.dtype],
|
|
506
558
|
W,
|
|
507
|
-
W.stride(0),
|
|
559
|
+
W.stride(0) if elementwise_affine else 0,
|
|
508
560
|
RSTD,
|
|
509
561
|
RSTD.stride(0),
|
|
510
562
|
_dW,
|
|
511
|
-
_dW.stride(0),
|
|
563
|
+
_dW.stride(0) if elementwise_affine else 0,
|
|
512
564
|
n_rows,
|
|
513
565
|
n_cols,
|
|
514
566
|
offset,
|
|
515
|
-
rows_per_program,
|
|
516
567
|
casting_mode,
|
|
568
|
+
elementwise_affine=elementwise_affine,
|
|
517
569
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
518
570
|
num_warps=num_warps,
|
|
519
571
|
**kernel_args, # XPU-specific optimization
|
|
520
572
|
)
|
|
521
573
|
dX = dX.view(*shape)
|
|
522
|
-
|
|
574
|
+
|
|
575
|
+
if elementwise_affine:
|
|
576
|
+
dW = _dW.sum(dim=0).to(W.dtype)
|
|
577
|
+
else:
|
|
578
|
+
dW = None
|
|
523
579
|
|
|
524
580
|
return dX, dW
|
|
525
581
|
|
|
@@ -553,6 +609,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
553
609
|
X: (B, T, H) or (BxT, H)
|
|
554
610
|
W: (H,)
|
|
555
611
|
"""
|
|
612
|
+
if isinstance(X, torch.distributed.tensor.DTensor):
|
|
613
|
+
# Input tensor is output of a tensor parallel module and
|
|
614
|
+
# needs to be gathered to a local tensor to compute
|
|
615
|
+
# RMSE layer norm on each TP worker.
|
|
616
|
+
# TODO: support CP.
|
|
617
|
+
X = X.full_tensor()
|
|
618
|
+
|
|
556
619
|
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
|
|
557
620
|
ctx.offset = offset
|
|
558
621
|
ctx.casting_mode = casting_mode
|
|
@@ -560,7 +623,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
560
623
|
ctx.row_mode = row_mode
|
|
561
624
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
562
625
|
ctx.num_warps = num_warps
|
|
563
|
-
ctx.
|
|
626
|
+
ctx.elementwise_affine = W is not None
|
|
627
|
+
if W is not None:
|
|
628
|
+
ctx.save_for_backward(X, W, RSTD)
|
|
629
|
+
else:
|
|
630
|
+
ctx.save_for_backward(X, RSTD)
|
|
564
631
|
return Y
|
|
565
632
|
|
|
566
633
|
@staticmethod
|
|
@@ -569,7 +636,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
|
|
|
569
636
|
"""
|
|
570
637
|
Y: (B, T, H) or (BxT, H)
|
|
571
638
|
"""
|
|
572
|
-
|
|
639
|
+
if ctx.elementwise_affine:
|
|
640
|
+
X, W, RSTD = ctx.saved_tensors
|
|
641
|
+
else:
|
|
642
|
+
X, RSTD = ctx.saved_tensors
|
|
643
|
+
W = None
|
|
644
|
+
|
|
645
|
+
if isinstance(dY, torch.distributed.tensor.DTensor):
|
|
646
|
+
# Gradients are output of a tensor parallel module and
|
|
647
|
+
# needs to be gathered to a local tensor for computing RMSE layer.
|
|
648
|
+
# TODO: support CP.
|
|
649
|
+
dY = dY.full_tensor()
|
|
650
|
+
|
|
573
651
|
dX, dW = rms_norm_backward(
|
|
574
652
|
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
|
|
575
653
|
)
|