liger-kernel-nightly 0.6.2.dev20250921193116__py3-none-any.whl → 0.6.2.dev20250923161350__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/layer_norm.py +4 -6
- {liger_kernel_nightly-0.6.2.dev20250921193116.dist-info → liger_kernel_nightly-0.6.2.dev20250923161350.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.6.2.dev20250921193116.dist-info → liger_kernel_nightly-0.6.2.dev20250923161350.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.6.2.dev20250921193116.dist-info → liger_kernel_nightly-0.6.2.dev20250923161350.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250921193116.dist-info → liger_kernel_nightly-0.6.2.dev20250923161350.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250921193116.dist-info → liger_kernel_nightly-0.6.2.dev20250923161350.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250921193116.dist-info → liger_kernel_nightly-0.6.2.dev20250923161350.dist-info}/top_level.txt +0 -0
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -63,12 +63,11 @@ def _layer_norm_forward_kernel(
|
|
|
63
63
|
X_f32 = X_row.to(tl.float32)
|
|
64
64
|
|
|
65
65
|
# Compute statistics in fp32 for numerical stability
|
|
66
|
-
|
|
67
|
-
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
|
66
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
68
67
|
X_centered = X_f32 - mean
|
|
69
68
|
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
70
69
|
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
71
|
-
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) /
|
|
70
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
72
71
|
rstd = rsqrt(var + eps)
|
|
73
72
|
|
|
74
73
|
# Store statistics (convert back to original dtype only once)
|
|
@@ -113,7 +112,6 @@ def _layer_norm_backward_kernel(
|
|
|
113
112
|
# Pre-load weights once (same optimization as forward pass)
|
|
114
113
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
115
114
|
w_f32 = w.to(tl.float32)
|
|
116
|
-
n_cols_f32 = n_cols.to(tl.float32)
|
|
117
115
|
|
|
118
116
|
# Calculate pointers for this specific row
|
|
119
117
|
row_X_ptr = X_ptr + row_idx * stride_x
|
|
@@ -137,8 +135,8 @@ def _layer_norm_backward_kernel(
|
|
|
137
135
|
# Compute backward pass for this row
|
|
138
136
|
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
139
137
|
wdy = w_f32 * dy_f32
|
|
140
|
-
c1 = tl.sum(x_hat * wdy, axis=0) /
|
|
141
|
-
c2 = tl.sum(wdy, axis=0) /
|
|
138
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
139
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
142
140
|
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
143
141
|
|
|
144
142
|
# Store input gradient
|
|
@@ -28,7 +28,7 @@ liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2wogg
|
|
|
28
28
|
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
|
29
29
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
|
30
30
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
|
31
|
-
liger_kernel/ops/layer_norm.py,sha256=
|
|
31
|
+
liger_kernel/ops/layer_norm.py,sha256=WmiORsIyufOhazmYZTPjeSc5Z-xTAYwXAKqUcCv_dlY,9807
|
|
32
32
|
liger_kernel/ops/llama4_rope.py,sha256=-aqdZzllklTN8b9--e-TsWY_ntGCN8-tyseT4x0bd8s,8223
|
|
33
33
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
|
34
34
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
|
@@ -97,9 +97,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
97
97
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
98
98
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
99
99
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
100
|
-
liger_kernel_nightly-0.6.2.
|
|
101
|
-
liger_kernel_nightly-0.6.2.
|
|
102
|
-
liger_kernel_nightly-0.6.2.
|
|
103
|
-
liger_kernel_nightly-0.6.2.
|
|
104
|
-
liger_kernel_nightly-0.6.2.
|
|
105
|
-
liger_kernel_nightly-0.6.2.
|
|
100
|
+
liger_kernel_nightly-0.6.2.dev20250923161350.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
101
|
+
liger_kernel_nightly-0.6.2.dev20250923161350.dist-info/METADATA,sha256=WqiloVBMdQO2pUusvcTNwjrib3CUnO9Q_iBffY-JaM8,24605
|
|
102
|
+
liger_kernel_nightly-0.6.2.dev20250923161350.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
103
|
+
liger_kernel_nightly-0.6.2.dev20250923161350.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
104
|
+
liger_kernel_nightly-0.6.2.dev20250923161350.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
105
|
+
liger_kernel_nightly-0.6.2.dev20250923161350.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|