liger-kernel-nightly 0.6.0.dev20250718080702__py3-none-any.whl → 0.6.0.dev20250719041256__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/layer_norm.py +126 -88
- {liger_kernel_nightly-0.6.0.dev20250718080702.dist-info → liger_kernel_nightly-0.6.0.dev20250719041256.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.0.dev20250718080702.dist-info → liger_kernel_nightly-0.6.0.dev20250719041256.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.6.0.dev20250718080702.dist-info → liger_kernel_nightly-0.6.0.dev20250719041256.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.0.dev20250718080702.dist-info → liger_kernel_nightly-0.6.0.dev20250719041256.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.0.dev20250718080702.dist-info → liger_kernel_nightly-0.6.0.dev20250719041256.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.0.dev20250718080702.dist-info → liger_kernel_nightly-0.6.0.dev20250719041256.dist-info}/top_level.txt +0 -0
liger_kernel/ops/layer_norm.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import math
|
2
1
|
import operator
|
3
2
|
|
4
3
|
import torch
|
@@ -43,30 +42,45 @@ def _layer_norm_forward_kernel(
|
|
43
42
|
https://arxiv.org/abs/1607.06450
|
44
43
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
45
44
|
"""
|
46
|
-
row_idx = tl.program_id(0)
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
47
46
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
48
47
|
mask = col_offsets < n_cols
|
49
48
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
49
|
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
50
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
51
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
52
|
+
W_f32 = W_row.to(tl.float32)
|
53
|
+
B_f32 = B_row.to(tl.float32)
|
54
|
+
|
55
|
+
# Calculate pointers for this row
|
56
|
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
57
|
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
58
|
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
59
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
60
|
+
|
61
|
+
# Load input data and convert to fp32 for numerical stability
|
62
|
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
63
|
+
X_f32 = X_row.to(tl.float32)
|
64
|
+
|
65
|
+
# Compute statistics in fp32 for numerical stability
|
66
|
+
n_cols_f32 = n_cols.to(tl.float32)
|
67
|
+
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
68
|
+
X_centered = X_f32 - mean
|
69
|
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
70
|
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
71
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols_f32
|
62
72
|
rstd = rsqrt(var + eps)
|
63
73
|
|
64
|
-
|
65
|
-
tl.store(
|
74
|
+
# Store statistics (convert back to original dtype only once)
|
75
|
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
76
|
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
66
77
|
|
67
|
-
|
78
|
+
# Fused normalization and affine transformation
|
79
|
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
80
|
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
68
81
|
|
69
|
-
|
82
|
+
# Store output (single conversion back to original dtype)
|
83
|
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
70
84
|
|
71
85
|
|
72
86
|
@triton.jit
|
@@ -81,73 +95,87 @@ def _layer_norm_backward_kernel(
|
|
81
95
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
82
96
|
stride_x, # stride of each row in input
|
83
97
|
stride_dx, # stride of each row in input grad
|
84
|
-
stride_dw, # stride of each row in weights grad
|
85
|
-
stride_db, # stride of each row in bias grad
|
86
98
|
stride_dy, # stride of each row in output grad
|
87
|
-
n_rows,
|
88
99
|
n_cols,
|
89
|
-
rows_per_program: tl.constexpr,
|
90
100
|
BLOCK_SIZE: tl.constexpr,
|
91
101
|
dtype: tl.constexpr,
|
102
|
+
atomic_dtype: tl.constexpr,
|
92
103
|
):
|
93
104
|
"""
|
94
105
|
References:
|
95
106
|
https://arxiv.org/abs/1607.06450
|
96
107
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
97
|
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
98
|
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
99
108
|
"""
|
100
|
-
|
101
|
-
row_start = row_block_id * rows_per_program
|
102
|
-
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
109
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
103
110
|
cols = tl.arange(0, BLOCK_SIZE)
|
104
111
|
mask = cols < n_cols
|
105
112
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
tl.store(
|
139
|
-
|
113
|
+
# Pre-load weights once (same optimization as forward pass)
|
114
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
115
|
+
w_f32 = w.to(tl.float32)
|
116
|
+
n_cols_f32 = n_cols.to(tl.float32)
|
117
|
+
|
118
|
+
# Calculate pointers for this specific row
|
119
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
120
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
121
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
122
|
+
row_Mean_ptr = Mean_ptr + row_idx
|
123
|
+
row_RSTD_ptr = RSTD_ptr + row_idx
|
124
|
+
|
125
|
+
# Load data for this row
|
126
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
127
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
128
|
+
mean = tl.load(row_Mean_ptr)
|
129
|
+
rstd = tl.load(row_RSTD_ptr)
|
130
|
+
|
131
|
+
# Convert to fp32 for numerical stability
|
132
|
+
x_f32 = x.to(tl.float32)
|
133
|
+
dy_f32 = dy.to(tl.float32)
|
134
|
+
mean_f32 = mean.to(tl.float32)
|
135
|
+
rstd_f32 = rstd.to(tl.float32)
|
136
|
+
|
137
|
+
# Compute backward pass for this row
|
138
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
139
|
+
wdy = w_f32 * dy_f32
|
140
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols_f32
|
141
|
+
c2 = tl.sum(wdy, axis=0) / n_cols_f32
|
142
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
143
|
+
|
144
|
+
# Store input gradient
|
145
|
+
tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
|
146
|
+
|
147
|
+
# Accumulate weight and bias gradients using atomic operations
|
148
|
+
dw = dy_f32 * x_hat
|
149
|
+
db = dy_f32
|
150
|
+
tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
|
151
|
+
tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
|
140
152
|
|
141
153
|
|
142
154
|
def layer_norm_forward(X, W, B, eps):
|
155
|
+
"""
|
156
|
+
Args:
|
157
|
+
X: Input tensor of shape (..., hidden_size)
|
158
|
+
W: Weight tensor of shape (hidden_size,)
|
159
|
+
B: Bias tensor of shape (hidden_size,)
|
160
|
+
eps: Small constant for numerical stability
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
164
|
+
"""
|
143
165
|
shape = X.shape
|
144
166
|
dim = shape[-1]
|
145
167
|
X = X.view(-1, dim)
|
146
168
|
n_rows, n_cols = X.shape
|
169
|
+
|
170
|
+
# Calculate optimal block size and warp configuration
|
147
171
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
172
|
+
|
173
|
+
# Allocate output tensors
|
148
174
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
149
175
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
150
176
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
177
|
+
|
178
|
+
# Validate input dimensions
|
151
179
|
if X.shape[1] != W.shape[0]:
|
152
180
|
raise ValueError(
|
153
181
|
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
@@ -159,7 +187,9 @@ def layer_norm_forward(X, W, B, eps):
|
|
159
187
|
if X.device.type == "xpu":
|
160
188
|
kernel_args["grf_mode"] = "large"
|
161
189
|
|
162
|
-
|
190
|
+
# Launch kernel with one thread block per row for optimal performance
|
191
|
+
grid = (n_rows,)
|
192
|
+
_layer_norm_forward_kernel[grid](
|
163
193
|
Y,
|
164
194
|
Y.stride(0),
|
165
195
|
X,
|
@@ -176,35 +206,43 @@ def layer_norm_forward(X, W, B, eps):
|
|
176
206
|
eps,
|
177
207
|
BLOCK_SIZE=BLOCK_SIZE,
|
178
208
|
num_warps=num_warps,
|
179
|
-
**kernel_args,
|
209
|
+
**kernel_args,
|
180
210
|
)
|
211
|
+
|
181
212
|
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
182
213
|
|
183
214
|
|
184
215
|
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
216
|
+
"""
|
217
|
+
Args:
|
218
|
+
dY: Gradient of output
|
219
|
+
X: Input tensor
|
220
|
+
W: Weight tensor
|
221
|
+
B: Bias tensor
|
222
|
+
Mean: Pre-computed mean
|
223
|
+
RSTD: Pre-computed reciprocal standard deviation
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
Tuple of (input_grad, weight_grad, bias_grad)
|
227
|
+
"""
|
185
228
|
shape = dY.shape
|
186
229
|
dim = shape[-1]
|
187
230
|
dY = dY.view(-1, dim)
|
188
231
|
n_rows, n_cols = dY.shape
|
189
232
|
|
190
|
-
|
191
|
-
if X.device.type == "cuda":
|
192
|
-
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
193
|
-
elif X.device.type == "xpu":
|
194
|
-
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
195
|
-
|
233
|
+
# Allocate gradient tensors
|
196
234
|
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
197
|
-
|
198
|
-
|
235
|
+
# Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
|
236
|
+
grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
|
237
|
+
DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
238
|
+
DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
199
239
|
|
240
|
+
# Calculate optimal block size and warp configuration
|
200
241
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
201
242
|
if n_cols > BLOCK_SIZE:
|
202
|
-
raise RuntimeError(
|
203
|
-
f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
|
204
|
-
)
|
243
|
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
205
244
|
|
206
|
-
|
207
|
-
grid = (sm_count,)
|
245
|
+
# Determine dtype for triton operations
|
208
246
|
triton_dtype = (
|
209
247
|
tl.float32
|
210
248
|
if X.dtype == torch.float32
|
@@ -212,41 +250,41 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
212
250
|
if X.dtype == torch.bfloat16
|
213
251
|
else tl.float16
|
214
252
|
if X.dtype == torch.float16
|
215
|
-
else tl.float32 # fallback
|
253
|
+
else tl.float32 # fallback
|
216
254
|
)
|
217
255
|
|
256
|
+
# Use float32 for atomic operations if bfloat16 is not supported
|
257
|
+
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
258
|
+
|
218
259
|
# XPU-specific optimization
|
219
260
|
kernel_args = {}
|
220
261
|
if X.device.type == "xpu":
|
221
262
|
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
222
263
|
|
264
|
+
# Launch kernel with one thread block per row for optimal performance
|
265
|
+
grid = (n_rows,)
|
223
266
|
_layer_norm_backward_kernel[grid](
|
224
267
|
X,
|
225
268
|
W,
|
226
269
|
Mean,
|
227
270
|
RSTD,
|
228
271
|
DX,
|
229
|
-
|
230
|
-
|
272
|
+
DW,
|
273
|
+
DB,
|
231
274
|
dY,
|
232
275
|
X.stride(0),
|
233
276
|
DX.stride(0),
|
234
|
-
_DW.stride(0),
|
235
|
-
_DB.stride(0),
|
236
277
|
dY.stride(0),
|
237
|
-
n_rows,
|
238
278
|
n_cols,
|
239
|
-
rows_per_program,
|
240
279
|
BLOCK_SIZE=BLOCK_SIZE,
|
241
280
|
dtype=triton_dtype,
|
242
|
-
|
281
|
+
atomic_dtype=atomic_dtype,
|
282
|
+
num_warps=num_warps,
|
283
|
+
**kernel_args,
|
243
284
|
)
|
244
285
|
|
245
|
-
DW = _DW.sum(dim=0).to(W.dtype)
|
246
|
-
DB = _DB.sum(dim=0).to(W.dtype)
|
247
|
-
|
248
286
|
DX = DX.view(*shape)
|
249
|
-
return DX, DW, DB
|
287
|
+
return DX, DW.to(W.dtype), DB.to(W.dtype)
|
250
288
|
|
251
289
|
|
252
290
|
class LigerLayerNormFunction(torch.autograd.Function):
|
@@ -28,7 +28,7 @@ liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2wogg
|
|
28
28
|
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
29
29
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
30
30
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
31
|
-
liger_kernel/ops/layer_norm.py,sha256=
|
31
|
+
liger_kernel/ops/layer_norm.py,sha256=g7TfVMCoUg_zAl4nb4xPr_FTRGjxxDhFXdKEnU9NYhE,9908
|
32
32
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
33
33
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
34
34
|
liger_kernel/ops/rms_norm.py,sha256=-rcgHwWCxlA-Syec2XhdW4jfOeCDt2r7qwjslgXFYDU,18865
|
@@ -92,9 +92,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
92
92
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
93
93
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
94
94
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
95
|
-
liger_kernel_nightly-0.6.0.
|
96
|
-
liger_kernel_nightly-0.6.0.
|
97
|
-
liger_kernel_nightly-0.6.0.
|
98
|
-
liger_kernel_nightly-0.6.0.
|
99
|
-
liger_kernel_nightly-0.6.0.
|
100
|
-
liger_kernel_nightly-0.6.0.
|
95
|
+
liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
96
|
+
liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/METADATA,sha256=fHiX-8I-gf3rBfmNt7LRh5ZwZcXoddbQ5rJzYxohaTg,24672
|
97
|
+
liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
98
|
+
liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
99
|
+
liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
100
|
+
liger_kernel_nightly-0.6.0.dev20250719041256.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|