liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.2.dev20250922212712__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/layer_norm.py +4 -6
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.2.dev20250922212712.dist-info}/METADATA +6 -3
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.2.dev20250922212712.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.2.dev20250922212712.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.2.dev20250922212712.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.2.dev20250922212712.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.2.dev20250922212712.dist-info}/top_level.txt +0 -0
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -63,12 +63,11 @@ def _layer_norm_forward_kernel(
|
|
|
63
63
|
X_f32 = X_row.to(tl.float32)
|
|
64
64
|
|
|
65
65
|
# Compute statistics in fp32 for numerical stability
|
|
66
|
-
|
|
67
|
-
mean = tl.sum(X_f32, axis=0) / n_cols_f32
|
|
66
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
68
67
|
X_centered = X_f32 - mean
|
|
69
68
|
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
70
69
|
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
71
|
-
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) /
|
|
70
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
72
71
|
rstd = rsqrt(var + eps)
|
|
73
72
|
|
|
74
73
|
# Store statistics (convert back to original dtype only once)
|
|
@@ -113,7 +112,6 @@ def _layer_norm_backward_kernel(
|
|
|
113
112
|
# Pre-load weights once (same optimization as forward pass)
|
|
114
113
|
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
115
114
|
w_f32 = w.to(tl.float32)
|
|
116
|
-
n_cols_f32 = n_cols.to(tl.float32)
|
|
117
115
|
|
|
118
116
|
# Calculate pointers for this specific row
|
|
119
117
|
row_X_ptr = X_ptr + row_idx * stride_x
|
|
@@ -137,8 +135,8 @@ def _layer_norm_backward_kernel(
|
|
|
137
135
|
# Compute backward pass for this row
|
|
138
136
|
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
139
137
|
wdy = w_f32 * dy_f32
|
|
140
|
-
c1 = tl.sum(x_hat * wdy, axis=0) /
|
|
141
|
-
c2 = tl.sum(wdy, axis=0) /
|
|
138
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
139
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
142
140
|
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
143
141
|
|
|
144
142
|
# Store input gradient
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.6.2.
|
|
3
|
+
Version: 0.6.2.dev20250922212712
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -177,8 +177,8 @@ y = orpo_loss(lm_head.weight, x, target)
|
|
|
177
177
|
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
|
|
178
178
|
|
|
179
179
|
```bash
|
|
180
|
-
|
|
181
|
-
|
|
180
|
+
pip install -e .[dev]
|
|
181
|
+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
|
|
182
182
|
```
|
|
183
183
|
|
|
184
184
|
### Optional Dependencies
|
|
@@ -212,6 +212,9 @@ pip install -e .
|
|
|
212
212
|
|
|
213
213
|
# Setup Development Dependencies
|
|
214
214
|
pip install -e ".[dev]"
|
|
215
|
+
|
|
216
|
+
# NOTE -> For AMD users only
|
|
217
|
+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
|
|
215
218
|
```
|
|
216
219
|
|
|
217
220
|
|
|
@@ -28,7 +28,7 @@ liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2wogg
|
|
|
28
28
|
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
|
29
29
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
|
30
30
|
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
|
31
|
-
liger_kernel/ops/layer_norm.py,sha256=
|
|
31
|
+
liger_kernel/ops/layer_norm.py,sha256=WmiORsIyufOhazmYZTPjeSc5Z-xTAYwXAKqUcCv_dlY,9807
|
|
32
32
|
liger_kernel/ops/llama4_rope.py,sha256=-aqdZzllklTN8b9--e-TsWY_ntGCN8-tyseT4x0bd8s,8223
|
|
33
33
|
liger_kernel/ops/multi_token_attention.py,sha256=Oz_RXDp-OSS_R_HuGmaETHdAJ7Toda_70OfE7TXMUlY,7645
|
|
34
34
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
|
@@ -97,9 +97,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
97
97
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
|
98
98
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
99
99
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
100
|
-
liger_kernel_nightly-0.6.2.
|
|
101
|
-
liger_kernel_nightly-0.6.2.
|
|
102
|
-
liger_kernel_nightly-0.6.2.
|
|
103
|
-
liger_kernel_nightly-0.6.2.
|
|
104
|
-
liger_kernel_nightly-0.6.2.
|
|
105
|
-
liger_kernel_nightly-0.6.2.
|
|
100
|
+
liger_kernel_nightly-0.6.2.dev20250922212712.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
101
|
+
liger_kernel_nightly-0.6.2.dev20250922212712.dist-info/METADATA,sha256=xDJV_P4V9fNICAiAuDBJ2MFE9An_zn8kRDeZjgQL7DM,24605
|
|
102
|
+
liger_kernel_nightly-0.6.2.dev20250922212712.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
103
|
+
liger_kernel_nightly-0.6.2.dev20250922212712.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
|
104
|
+
liger_kernel_nightly-0.6.2.dev20250922212712.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
105
|
+
liger_kernel_nightly-0.6.2.dev20250922212712.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|