liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/ops/jsd.py
CHANGED
|
@@ -5,6 +5,7 @@ import triton
|
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.utils import infer_device
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@triton.jit
|
|
@@ -18,7 +19,7 @@ def _jsd_kernel(
|
|
|
18
19
|
dX_ptr,
|
|
19
20
|
dX_stride,
|
|
20
21
|
label_ptr,
|
|
21
|
-
beta,
|
|
22
|
+
beta: tl.constexpr,
|
|
22
23
|
n_non_ignore: int,
|
|
23
24
|
ignore_index: tl.constexpr,
|
|
24
25
|
n_cols,
|
|
@@ -50,21 +51,49 @@ def _jsd_kernel(
|
|
|
50
51
|
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
51
52
|
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
if beta == 0.0: # forward KL
|
|
55
|
+
Y_max = tl.max(Y, axis=0)
|
|
56
|
+
Y_shifted = Y - Y_max
|
|
57
|
+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
|
|
58
|
+
loss = Y_prob * (Y - X)
|
|
59
|
+
dX = -Y_prob
|
|
60
|
+
elif beta == 1.0: # reverse KL
|
|
61
|
+
X_max = tl.max(X, axis=0)
|
|
62
|
+
X_shifted = X - X_max
|
|
63
|
+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
|
|
64
|
+
loss = X_prob * (X - Y)
|
|
65
|
+
dX = loss + X_prob
|
|
66
|
+
else:
|
|
67
|
+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
|
|
68
|
+
X_shifted = X - max_val
|
|
69
|
+
Y_shifted = Y - max_val
|
|
70
|
+
|
|
71
|
+
# Pre-compute exp(max_val) since it's used twice
|
|
72
|
+
exp_max = tl.exp(max_val)
|
|
73
|
+
|
|
74
|
+
# Compute exp terms with compensation
|
|
75
|
+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
|
|
76
|
+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
|
|
77
|
+
|
|
78
|
+
# Pre-compute common terms
|
|
79
|
+
beta_P = beta * P
|
|
80
|
+
one_minus_beta_Q = (1 - beta) * Q
|
|
81
|
+
M = beta_P + one_minus_beta_Q
|
|
82
|
+
log_M = tl.log(M) # No need to compensate as M is already in original scale
|
|
83
|
+
|
|
84
|
+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
|
|
85
|
+
dX = one_minus_beta_Q * (X - log_M)
|
|
86
|
+
|
|
87
|
+
# Pre-compute scaling factor
|
|
88
|
+
scale = 1.0 / n_non_ignore
|
|
89
|
+
loss = loss * scale
|
|
90
|
+
dX = dX * scale
|
|
57
91
|
|
|
58
|
-
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
|
|
59
|
-
# reduction == "batchmean"
|
|
60
|
-
loss = loss / n_non_ignore
|
|
61
92
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
62
|
-
|
|
63
|
-
dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
|
|
64
93
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
65
94
|
|
|
66
95
|
|
|
67
|
-
MAX_FUSED_SIZE = 65536
|
|
96
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
|
|
68
97
|
|
|
69
98
|
|
|
70
99
|
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
@@ -89,9 +118,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
|
|
|
89
118
|
loss_stride=loss.stride(-2),
|
|
90
119
|
dX_ptr=dX,
|
|
91
120
|
dX_stride=dX.stride(-2),
|
|
92
|
-
label_ptr=(
|
|
93
|
-
shift_labels if has_label else torch.empty(1, device=_input.device)
|
|
94
|
-
), # dummy ptr if no label
|
|
121
|
+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
|
|
95
122
|
beta=beta,
|
|
96
123
|
n_non_ignore=n_non_ignore,
|
|
97
124
|
ignore_index=ignore_index,
|
|
@@ -142,7 +169,7 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
|
142
169
|
_input (torch.Tensor): predict values with shape (BT, V) in logspace
|
|
143
170
|
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
|
|
144
171
|
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
145
|
-
beta (float): coefficient beta of generalized JSD in the
|
|
172
|
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
146
173
|
ignore_index (int): the index to ignore. Default: -100
|
|
147
174
|
|
|
148
175
|
Returns:
|
|
@@ -150,15 +177,13 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
|
150
177
|
"""
|
|
151
178
|
has_label = False
|
|
152
179
|
if shift_labels is not None:
|
|
153
|
-
assert shift_labels.shape == (
|
|
154
|
-
|
|
155
|
-
)
|
|
180
|
+
assert shift_labels.shape == (_input.shape[0],), (
|
|
181
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
182
|
+
)
|
|
156
183
|
shift_labels = shift_labels.contiguous()
|
|
157
184
|
has_label = True
|
|
158
185
|
|
|
159
|
-
loss, dX = jsd_forward(
|
|
160
|
-
_input, target, shift_labels, beta, ignore_index, has_label
|
|
161
|
-
)
|
|
186
|
+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
|
|
162
187
|
ctx.save_for_backward(dX)
|
|
163
188
|
return loss
|
|
164
189
|
|
liger_kernel/ops/kl_div.py
CHANGED
|
@@ -4,7 +4,9 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
def get_num_warps(BLOCK_SIZE):
|
|
@@ -23,10 +25,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
|
23
25
|
|
|
24
26
|
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
25
27
|
|
|
26
|
-
_REDUCTION_MODE_NONE = tl.constexpr(0)
|
|
27
|
-
_REDUCTION_MODE_SUM = tl.constexpr(1)
|
|
28
|
-
_REDUCTION_MODE_MEAN = tl.constexpr(2)
|
|
29
|
-
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
|
|
28
|
+
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
|
|
29
|
+
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
|
|
30
|
+
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
|
|
31
|
+
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
|
|
30
32
|
|
|
31
33
|
_str_to_reduction_mode = {
|
|
32
34
|
"none": _REDUCTION_MODE_NONE.value,
|
|
@@ -114,9 +116,12 @@ def _kldiv_kernel_backward(
|
|
|
114
116
|
|
|
115
117
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
116
118
|
BT, V = y_pred.shape
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
119
|
+
BLOCK_SIZE = (
|
|
120
|
+
min(8192, triton.next_power_of_2(V))
|
|
121
|
+
if infer_device() == "xpu"
|
|
122
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
+
)
|
|
124
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
120
125
|
|
|
121
126
|
grid = (BT,)
|
|
122
127
|
reduction = _str_to_reduction_mode[reduction]
|
|
@@ -154,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
|
154
159
|
|
|
155
160
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
156
161
|
BT, V = target.shape
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
162
|
+
BLOCK_SIZE = (
|
|
163
|
+
min(8192, triton.next_power_of_2(V))
|
|
164
|
+
if infer_device() == "xpu"
|
|
165
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
+
)
|
|
167
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
160
168
|
|
|
161
169
|
grid = (BT,)
|
|
162
170
|
|
|
@@ -184,9 +192,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
184
192
|
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
|
185
193
|
```python
|
|
186
194
|
if log_target:
|
|
187
|
-
loss = target * (target.log() - input)
|
|
188
|
-
else:
|
|
189
195
|
loss = target.exp() * (target - input)
|
|
196
|
+
else:
|
|
197
|
+
loss = target * (target.log() - input)
|
|
190
198
|
```,
|
|
191
199
|
then the loss is reduced according to the `reduction` parameter.
|
|
192
200
|
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
@@ -218,9 +226,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
218
226
|
ctx.save_for_backward(y_true)
|
|
219
227
|
ctx.reduction = reduction
|
|
220
228
|
ctx.log_target = log_target
|
|
221
|
-
return kldiv_forward_triton(
|
|
222
|
-
y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
|
|
223
|
-
)
|
|
229
|
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
|
224
230
|
|
|
225
231
|
@staticmethod
|
|
226
232
|
@ensure_contiguous
|
|
@@ -238,9 +244,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
|
|
|
238
244
|
|
|
239
245
|
new_grads = torch.empty_like(y_true)
|
|
240
246
|
|
|
241
|
-
derivative = kldiv_backward_triton(
|
|
242
|
-
y_true, grad_output, new_grads, ctx.log_target
|
|
243
|
-
)
|
|
247
|
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
|
244
248
|
|
|
245
249
|
if ctx.reduction == "batchmean":
|
|
246
250
|
derivative = derivative / y_true.shape[0]
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -1,15 +1,12 @@
|
|
|
1
|
-
import math
|
|
2
1
|
import operator
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
import triton
|
|
6
5
|
import triton.language as tl
|
|
7
6
|
|
|
8
|
-
from liger_kernel.ops.utils import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
ensure_contiguous,
|
|
12
|
-
)
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
13
10
|
|
|
14
11
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
12
|
try:
|
|
@@ -45,29 +42,44 @@ def _layer_norm_forward_kernel(
|
|
|
45
42
|
https://arxiv.org/abs/1607.06450
|
|
46
43
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
47
44
|
"""
|
|
48
|
-
row_idx = tl.program_id(0)
|
|
45
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
49
46
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
50
47
|
mask = col_offsets < n_cols
|
|
51
48
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
49
|
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
|
50
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
51
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
|
52
|
+
W_f32 = W_row.to(tl.float32)
|
|
53
|
+
B_f32 = B_row.to(tl.float32)
|
|
54
|
+
|
|
55
|
+
# Calculate pointers for this row
|
|
56
|
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
|
57
|
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
|
58
|
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
|
59
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
|
60
|
+
|
|
61
|
+
# Load input data and convert to fp32 for numerical stability
|
|
62
|
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
|
63
|
+
X_f32 = X_row.to(tl.float32)
|
|
64
|
+
|
|
65
|
+
# Compute statistics in fp32 for numerical stability
|
|
66
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
67
|
+
X_centered = X_f32 - mean
|
|
68
|
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
69
|
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
70
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
63
71
|
rstd = rsqrt(var + eps)
|
|
64
72
|
|
|
65
|
-
|
|
66
|
-
tl.store(
|
|
73
|
+
# Store statistics (convert back to original dtype only once)
|
|
74
|
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
|
75
|
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
|
67
76
|
|
|
68
|
-
|
|
77
|
+
# Fused normalization and affine transformation
|
|
78
|
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
|
79
|
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
|
69
80
|
|
|
70
|
-
|
|
81
|
+
# Store output (single conversion back to original dtype)
|
|
82
|
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
|
71
83
|
|
|
72
84
|
|
|
73
85
|
@triton.jit
|
|
@@ -82,78 +94,100 @@ def _layer_norm_backward_kernel(
|
|
|
82
94
|
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
83
95
|
stride_x, # stride of each row in input
|
|
84
96
|
stride_dx, # stride of each row in input grad
|
|
85
|
-
stride_dw, # stride of each row in weights grad
|
|
86
|
-
stride_db, # stride of each row in bias grad
|
|
87
97
|
stride_dy, # stride of each row in output grad
|
|
88
|
-
n_rows,
|
|
89
98
|
n_cols,
|
|
90
|
-
rows_per_program: tl.constexpr,
|
|
91
99
|
BLOCK_SIZE: tl.constexpr,
|
|
92
100
|
dtype: tl.constexpr,
|
|
101
|
+
atomic_dtype: tl.constexpr,
|
|
93
102
|
):
|
|
94
103
|
"""
|
|
95
104
|
References:
|
|
96
105
|
https://arxiv.org/abs/1607.06450
|
|
97
106
|
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
98
|
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
99
|
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
|
|
100
107
|
"""
|
|
101
|
-
|
|
102
|
-
row_start = row_block_id * rows_per_program
|
|
103
|
-
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
108
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
104
109
|
cols = tl.arange(0, BLOCK_SIZE)
|
|
105
110
|
mask = cols < n_cols
|
|
106
111
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
DX_ptr
|
|
114
|
-
DY_ptr
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
112
|
+
# Pre-load weights once (same optimization as forward pass)
|
|
113
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
114
|
+
w_f32 = w.to(tl.float32)
|
|
115
|
+
|
|
116
|
+
# Calculate pointers for this specific row
|
|
117
|
+
row_X_ptr = X_ptr + row_idx * stride_x
|
|
118
|
+
row_DX_ptr = DX_ptr + row_idx * stride_dx
|
|
119
|
+
row_DY_ptr = DY_ptr + row_idx * stride_dy
|
|
120
|
+
row_Mean_ptr = Mean_ptr + row_idx
|
|
121
|
+
row_RSTD_ptr = RSTD_ptr + row_idx
|
|
122
|
+
|
|
123
|
+
# Load data for this row
|
|
124
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
125
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
126
|
+
mean = tl.load(row_Mean_ptr)
|
|
127
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
128
|
+
|
|
129
|
+
# Convert to fp32 for numerical stability
|
|
130
|
+
x_f32 = x.to(tl.float32)
|
|
131
|
+
dy_f32 = dy.to(tl.float32)
|
|
132
|
+
mean_f32 = mean.to(tl.float32)
|
|
133
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
134
|
+
|
|
135
|
+
# Compute backward pass for this row
|
|
136
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
137
|
+
wdy = w_f32 * dy_f32
|
|
138
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
139
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
140
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
141
|
+
|
|
142
|
+
# Store input gradient
|
|
143
|
+
tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
|
|
144
|
+
|
|
145
|
+
# Accumulate weight and bias gradients using atomic operations
|
|
146
|
+
dw = dy_f32 * x_hat
|
|
147
|
+
db = dy_f32
|
|
148
|
+
tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
|
|
149
|
+
tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
|
|
141
150
|
|
|
142
151
|
|
|
143
152
|
def layer_norm_forward(X, W, B, eps):
|
|
153
|
+
"""
|
|
154
|
+
Args:
|
|
155
|
+
X: Input tensor of shape (..., hidden_size)
|
|
156
|
+
W: Weight tensor of shape (hidden_size,)
|
|
157
|
+
B: Bias tensor of shape (hidden_size,)
|
|
158
|
+
eps: Small constant for numerical stability
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
|
162
|
+
"""
|
|
144
163
|
shape = X.shape
|
|
145
164
|
dim = shape[-1]
|
|
146
165
|
X = X.view(-1, dim)
|
|
147
166
|
n_rows, n_cols = X.shape
|
|
167
|
+
|
|
168
|
+
# Calculate optimal block size and warp configuration
|
|
148
169
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
170
|
+
|
|
171
|
+
# Allocate output tensors
|
|
149
172
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
150
173
|
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
151
174
|
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
152
|
-
assert (
|
|
153
|
-
X.shape[1] == W.shape[0]
|
|
154
|
-
), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
|
|
155
175
|
|
|
156
|
-
|
|
176
|
+
# Validate input dimensions
|
|
177
|
+
if X.shape[1] != W.shape[0]:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
180
|
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# XPU-specific optimization
|
|
184
|
+
kernel_args = {}
|
|
185
|
+
if X.device.type == "xpu":
|
|
186
|
+
kernel_args["grf_mode"] = "large"
|
|
187
|
+
|
|
188
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
189
|
+
grid = (n_rows,)
|
|
190
|
+
_layer_norm_forward_kernel[grid](
|
|
157
191
|
Y,
|
|
158
192
|
Y.stride(0),
|
|
159
193
|
X,
|
|
@@ -170,54 +204,84 @@ def layer_norm_forward(X, W, B, eps):
|
|
|
170
204
|
eps,
|
|
171
205
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
172
206
|
num_warps=num_warps,
|
|
207
|
+
**kernel_args,
|
|
173
208
|
)
|
|
209
|
+
|
|
174
210
|
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
175
211
|
|
|
176
212
|
|
|
177
213
|
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
214
|
+
"""
|
|
215
|
+
Args:
|
|
216
|
+
dY: Gradient of output
|
|
217
|
+
X: Input tensor
|
|
218
|
+
W: Weight tensor
|
|
219
|
+
B: Bias tensor
|
|
220
|
+
Mean: Pre-computed mean
|
|
221
|
+
RSTD: Pre-computed reciprocal standard deviation
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Tuple of (input_grad, weight_grad, bias_grad)
|
|
225
|
+
"""
|
|
178
226
|
shape = dY.shape
|
|
179
227
|
dim = shape[-1]
|
|
180
228
|
dY = dY.view(-1, dim)
|
|
181
229
|
n_rows, n_cols = dY.shape
|
|
182
230
|
|
|
231
|
+
# Allocate gradient tensors
|
|
183
232
|
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
233
|
+
# Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
|
|
234
|
+
grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
|
|
235
|
+
DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
|
236
|
+
DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
|
|
187
237
|
|
|
238
|
+
# Calculate optimal block size and warp configuration
|
|
188
239
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
189
240
|
if n_cols > BLOCK_SIZE:
|
|
190
|
-
raise RuntimeError("
|
|
241
|
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
242
|
+
|
|
243
|
+
# Determine dtype for triton operations
|
|
244
|
+
triton_dtype = (
|
|
245
|
+
tl.float32
|
|
246
|
+
if X.dtype == torch.float32
|
|
247
|
+
else tl.bfloat16
|
|
248
|
+
if X.dtype == torch.bfloat16
|
|
249
|
+
else tl.float16
|
|
250
|
+
if X.dtype == torch.float16
|
|
251
|
+
else tl.float32 # fallback
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Use float32 for atomic operations if bfloat16 is not supported
|
|
255
|
+
atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
|
|
191
256
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
257
|
+
kernel_args = {"num_warps": num_warps}
|
|
258
|
+
# XPU-specific optimization
|
|
259
|
+
if X.device.type == "xpu":
|
|
260
|
+
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
261
|
+
|
|
262
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
263
|
+
grid = (n_rows,)
|
|
195
264
|
_layer_norm_backward_kernel[grid](
|
|
196
265
|
X,
|
|
197
266
|
W,
|
|
198
267
|
Mean,
|
|
199
268
|
RSTD,
|
|
200
269
|
DX,
|
|
201
|
-
|
|
202
|
-
|
|
270
|
+
DW,
|
|
271
|
+
DB,
|
|
203
272
|
dY,
|
|
204
273
|
X.stride(0),
|
|
205
274
|
DX.stride(0),
|
|
206
|
-
_DW.stride(0),
|
|
207
|
-
_DB.stride(0),
|
|
208
275
|
dY.stride(0),
|
|
209
|
-
n_rows,
|
|
210
276
|
n_cols,
|
|
211
|
-
rows_per_program,
|
|
212
277
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
213
278
|
dtype=triton_dtype,
|
|
279
|
+
atomic_dtype=atomic_dtype,
|
|
280
|
+
**kernel_args,
|
|
214
281
|
)
|
|
215
282
|
|
|
216
|
-
DW = _DW.sum(dim=0).to(W.dtype)
|
|
217
|
-
DB = _DB.sum(dim=0).to(W.dtype)
|
|
218
|
-
|
|
219
283
|
DX = DX.view(*shape)
|
|
220
|
-
return DX, DW, DB
|
|
284
|
+
return DX, DW.to(W.dtype), DB.to(W.dtype)
|
|
221
285
|
|
|
222
286
|
|
|
223
287
|
class LigerLayerNormFunction(torch.autograd.Function):
|