liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
8
|
+
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_num_warps(BLOCK_SIZE):
|
|
13
|
+
num_warps = 4
|
|
14
|
+
if BLOCK_SIZE >= 32768:
|
|
15
|
+
num_warps = 32 if not is_hip() else 16
|
|
16
|
+
elif BLOCK_SIZE >= 8192:
|
|
17
|
+
num_warps = 16
|
|
18
|
+
elif BLOCK_SIZE >= 2048:
|
|
19
|
+
num_warps = 8
|
|
20
|
+
|
|
21
|
+
return num_warps
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
|
|
25
|
+
|
|
26
|
+
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
|
|
27
|
+
|
|
28
|
+
_REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
|
|
29
|
+
_REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
|
|
30
|
+
_REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
|
|
31
|
+
_REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
|
|
32
|
+
|
|
33
|
+
_str_to_reduction_mode = {
|
|
34
|
+
"none": _REDUCTION_MODE_NONE.value,
|
|
35
|
+
"sum": _REDUCTION_MODE_SUM.value,
|
|
36
|
+
"mean": _REDUCTION_MODE_MEAN.value,
|
|
37
|
+
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@triton.jit
|
|
42
|
+
def _kldiv_kernel_forward(
|
|
43
|
+
y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
|
|
44
|
+
y_stride, # int, prediction stride
|
|
45
|
+
gt_ptr, # [B, S], ground truth ptr
|
|
46
|
+
gt_stride, # int, ground truth stride
|
|
47
|
+
loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
|
|
48
|
+
loss_stride, # int, output stride
|
|
49
|
+
n_cols, # int, number of columns in the input tensor
|
|
50
|
+
eps,
|
|
51
|
+
BLOCK_SIZE: tl.constexpr,
|
|
52
|
+
log_target: tl.constexpr = False,
|
|
53
|
+
reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
|
|
54
|
+
):
|
|
55
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
56
|
+
y_ptr += pid * y_stride
|
|
57
|
+
gt_ptr += pid * gt_stride
|
|
58
|
+
loss_ptr += pid * loss_stride
|
|
59
|
+
|
|
60
|
+
base_offsets = tl.arange(0, BLOCK_SIZE)
|
|
61
|
+
|
|
62
|
+
loss_sum = 0.0
|
|
63
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
64
|
+
offsets = i + base_offsets
|
|
65
|
+
mask = offsets < n_cols
|
|
66
|
+
y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
|
|
67
|
+
y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
|
|
68
|
+
|
|
69
|
+
# KL(y_true || y) = y_true * (log(y_true) - log(y))
|
|
70
|
+
# We compute KL(y_true || y) with y in the log-space
|
|
71
|
+
if not log_target:
|
|
72
|
+
loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
|
|
73
|
+
else:
|
|
74
|
+
loss = tl.exp(y_true) * (y_true - y)
|
|
75
|
+
|
|
76
|
+
if reduction == _REDUCTION_MODE_NONE:
|
|
77
|
+
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
78
|
+
else:
|
|
79
|
+
loss_sum += tl.sum(loss, axis=0)
|
|
80
|
+
|
|
81
|
+
if reduction != _REDUCTION_MODE_NONE:
|
|
82
|
+
tl.store(loss_ptr, loss_sum)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@triton.jit
|
|
86
|
+
def _kldiv_kernel_backward(
|
|
87
|
+
target_ptr,
|
|
88
|
+
target_stride,
|
|
89
|
+
new_grads_ptr,
|
|
90
|
+
new_grads_stride,
|
|
91
|
+
n_cols,
|
|
92
|
+
BLOCK_SIZE: tl.constexpr,
|
|
93
|
+
log_target: tl.constexpr = False,
|
|
94
|
+
):
|
|
95
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
96
|
+
|
|
97
|
+
target_ptr += pid * target_stride
|
|
98
|
+
new_grads_ptr += pid * new_grads_stride
|
|
99
|
+
|
|
100
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
|
101
|
+
mask = offsets < n_cols
|
|
102
|
+
|
|
103
|
+
for i in range(0, n_cols, BLOCK_SIZE):
|
|
104
|
+
offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
105
|
+
mask = offsets < n_cols
|
|
106
|
+
|
|
107
|
+
target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
|
|
108
|
+
|
|
109
|
+
if not log_target:
|
|
110
|
+
res = target * -1
|
|
111
|
+
else:
|
|
112
|
+
res = -tl.exp(target)
|
|
113
|
+
|
|
114
|
+
tl.store(new_grads_ptr + offsets, res, mask=mask)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
118
|
+
BT, V = y_pred.shape
|
|
119
|
+
BLOCK_SIZE = (
|
|
120
|
+
min(8192, triton.next_power_of_2(V))
|
|
121
|
+
if infer_device() == "xpu"
|
|
122
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
123
|
+
)
|
|
124
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
125
|
+
|
|
126
|
+
grid = (BT,)
|
|
127
|
+
reduction = _str_to_reduction_mode[reduction]
|
|
128
|
+
|
|
129
|
+
out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
|
|
130
|
+
output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
|
|
131
|
+
|
|
132
|
+
_kldiv_kernel_forward[grid](
|
|
133
|
+
y_pred,
|
|
134
|
+
y_pred.stride(0),
|
|
135
|
+
y_true,
|
|
136
|
+
y_true.stride(0),
|
|
137
|
+
output_tensor,
|
|
138
|
+
output_tensor.stride(0),
|
|
139
|
+
V,
|
|
140
|
+
eps=eps,
|
|
141
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
142
|
+
num_warps=num_warps,
|
|
143
|
+
log_target=log_target,
|
|
144
|
+
reduction=reduction,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
|
|
148
|
+
# https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
149
|
+
# https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
|
|
150
|
+
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
|
|
151
|
+
return output_tensor.sum() / BT
|
|
152
|
+
elif reduction == _REDUCTION_MODE_SUM.value:
|
|
153
|
+
return output_tensor.sum(dim=0)
|
|
154
|
+
elif reduction == _REDUCTION_MODE_MEAN.value:
|
|
155
|
+
return output_tensor.sum() / (BT * V)
|
|
156
|
+
else:
|
|
157
|
+
return output_tensor
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
|
161
|
+
BT, V = target.shape
|
|
162
|
+
BLOCK_SIZE = (
|
|
163
|
+
min(8192, triton.next_power_of_2(V))
|
|
164
|
+
if infer_device() == "xpu"
|
|
165
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
|
166
|
+
)
|
|
167
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
|
168
|
+
|
|
169
|
+
grid = (BT,)
|
|
170
|
+
|
|
171
|
+
# We store the gradients in-place in the input tensor
|
|
172
|
+
_kldiv_kernel_backward[grid](
|
|
173
|
+
target,
|
|
174
|
+
target.stride(0),
|
|
175
|
+
new_grads,
|
|
176
|
+
new_grads.stride(0),
|
|
177
|
+
V,
|
|
178
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
179
|
+
num_warps=num_warps,
|
|
180
|
+
log_target=log_target,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
|
|
184
|
+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
185
|
+
return new_grads
|
|
186
|
+
|
|
187
|
+
return new_grads * grad_output
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class LigerKLDivLossFunction(torch.autograd.Function):
|
|
191
|
+
"""
|
|
192
|
+
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
|
|
193
|
+
```python
|
|
194
|
+
if log_target:
|
|
195
|
+
loss = target.exp() * (target - input)
|
|
196
|
+
else:
|
|
197
|
+
loss = target * (target.log() - input)
|
|
198
|
+
```,
|
|
199
|
+
then the loss is reduced according to the `reduction` parameter.
|
|
200
|
+
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
@ensure_contiguous
|
|
205
|
+
def forward(
|
|
206
|
+
ctx,
|
|
207
|
+
y_pred: torch.Tensor,
|
|
208
|
+
y_true: torch.Tensor,
|
|
209
|
+
reduction: REDUCTION_LITERAL = "batchmean",
|
|
210
|
+
log_target: bool = False,
|
|
211
|
+
eps: float = 1e-10,
|
|
212
|
+
) -> torch.Tensor:
|
|
213
|
+
"""A forward pass for the KL Divergence Loss.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
ctx: Torch autograd context
|
|
217
|
+
y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
|
|
218
|
+
y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
|
|
219
|
+
reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
|
|
220
|
+
log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
|
|
221
|
+
eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
|
|
225
|
+
"""
|
|
226
|
+
ctx.save_for_backward(y_true)
|
|
227
|
+
ctx.reduction = reduction
|
|
228
|
+
ctx.log_target = log_target
|
|
229
|
+
return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
@ensure_contiguous
|
|
233
|
+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
|
234
|
+
"""A backward pass for the KL Divergence Loss.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
ctx: Torch autograd context
|
|
238
|
+
grad_output (torch.Tensor): The gradient of the loss with respect to the output.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
|
|
242
|
+
"""
|
|
243
|
+
(y_true,) = ctx.saved_tensors
|
|
244
|
+
|
|
245
|
+
new_grads = torch.empty_like(y_true)
|
|
246
|
+
|
|
247
|
+
derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
|
|
248
|
+
|
|
249
|
+
if ctx.reduction == "batchmean":
|
|
250
|
+
derivative = derivative / y_true.shape[0]
|
|
251
|
+
elif ctx.reduction == "sum" or ctx.reduction == "none":
|
|
252
|
+
pass
|
|
253
|
+
elif ctx.reduction == "mean":
|
|
254
|
+
derivative = derivative / (y_true.shape[0] * y_true.shape[1])
|
|
255
|
+
|
|
256
|
+
return (
|
|
257
|
+
derivative,
|
|
258
|
+
None,
|
|
259
|
+
None,
|
|
260
|
+
None,
|
|
261
|
+
None,
|
|
262
|
+
)
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import operator
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
9
|
+
from liger_kernel.ops.utils import compare_version
|
|
10
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
12
|
+
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
14
|
+
try:
|
|
15
|
+
# typical import path with dispatch available
|
|
16
|
+
from triton.language.extra.libdevice import rsqrt
|
|
17
|
+
except ModuleNotFoundError:
|
|
18
|
+
# for working with NGC containers
|
|
19
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
20
|
+
else:
|
|
21
|
+
from triton.language.math import rsqrt
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@triton.jit
|
|
25
|
+
def _layer_norm_forward_kernel(
|
|
26
|
+
Y_ptr, # pointer to output, shape (n_rows, n_cols)
|
|
27
|
+
Y_row_stride, # stride of each row in output
|
|
28
|
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
29
|
+
X_row_stride, # stride of each row in input
|
|
30
|
+
W_ptr, # pointer to weights, shape (n_cols,)
|
|
31
|
+
W_row_stride, # stride of each row in weights
|
|
32
|
+
B_ptr, # pointer to bias, shape (n_cols,)
|
|
33
|
+
B_row_stride, # stride of each row in bias
|
|
34
|
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
35
|
+
Mean_row_stride, # stride of each row in mean
|
|
36
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
37
|
+
RSTD_row_stride, # stride of each row in rstd
|
|
38
|
+
n_cols,
|
|
39
|
+
eps,
|
|
40
|
+
BLOCK_SIZE: tl.constexpr,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
References:
|
|
44
|
+
https://arxiv.org/abs/1607.06450
|
|
45
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
46
|
+
"""
|
|
47
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
48
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
49
|
+
mask = col_offsets < n_cols
|
|
50
|
+
|
|
51
|
+
# Pre-load weights and bias in fp32 to avoid repeated conversions
|
|
52
|
+
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
|
53
|
+
B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
|
|
54
|
+
W_f32 = W_row.to(tl.float32)
|
|
55
|
+
B_f32 = B_row.to(tl.float32)
|
|
56
|
+
|
|
57
|
+
# Calculate pointers for this row
|
|
58
|
+
row_X_ptr = X_ptr + row_idx * X_row_stride
|
|
59
|
+
row_Y_ptr = Y_ptr + row_idx * Y_row_stride
|
|
60
|
+
row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
|
|
61
|
+
row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
|
|
62
|
+
|
|
63
|
+
# Load input data and convert to fp32 for numerical stability
|
|
64
|
+
X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
|
|
65
|
+
X_f32 = X_row.to(tl.float32)
|
|
66
|
+
|
|
67
|
+
# Compute statistics in fp32 for numerical stability
|
|
68
|
+
mean = tl.sum(X_f32, axis=0) / n_cols
|
|
69
|
+
X_centered = X_f32 - mean
|
|
70
|
+
# Apply mask to variance calculation to exclude contributions from masked elements
|
|
71
|
+
X_centered_masked = tl.where(mask, X_centered, 0.0)
|
|
72
|
+
var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
|
|
73
|
+
rstd = rsqrt(var + eps)
|
|
74
|
+
|
|
75
|
+
# Store statistics (convert back to original dtype only once)
|
|
76
|
+
tl.store(row_Mean_ptr, mean.to(X_row.dtype))
|
|
77
|
+
tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
|
|
78
|
+
|
|
79
|
+
# Fused normalization and affine transformation
|
|
80
|
+
# Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
|
|
81
|
+
Y_f32 = X_centered * rstd * W_f32 + B_f32
|
|
82
|
+
|
|
83
|
+
# Store output (single conversion back to original dtype)
|
|
84
|
+
tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@triton.jit
|
|
88
|
+
def _layer_norm_backward_kernel(
|
|
89
|
+
X_ptr, # pointer to input, shape (n_rows, n_cols)
|
|
90
|
+
stride_x, # stride of each row in input
|
|
91
|
+
W_ptr, # pointer to weights, shape (n_cols,)
|
|
92
|
+
Mean_ptr, # pointer to mean, shape (n_rows,)
|
|
93
|
+
stride_mean, # stride of each row in mean
|
|
94
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows,)
|
|
95
|
+
stride_rstd, # stride of each row in rstd
|
|
96
|
+
DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
|
|
97
|
+
stride_dx, # stride of each row in input grad
|
|
98
|
+
DW_ptr, # pointer to weights grad, shape (n_cols,)
|
|
99
|
+
stride_dw, # stride of each row in weights grad
|
|
100
|
+
DB_ptr, # pointer to bias grad, shape (n_cols,)
|
|
101
|
+
stride_db, # stride of each row in bias grad
|
|
102
|
+
DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
|
|
103
|
+
stride_dy, # stride of each row in output grad
|
|
104
|
+
n_rows,
|
|
105
|
+
n_cols,
|
|
106
|
+
rows_per_program: tl.constexpr,
|
|
107
|
+
BLOCK_SIZE: tl.constexpr,
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
References:
|
|
111
|
+
https://arxiv.org/abs/1607.06450
|
|
112
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
113
|
+
"""
|
|
114
|
+
row_block_id = tl.program_id(0).to(tl.int64)
|
|
115
|
+
row_start = row_block_id * rows_per_program
|
|
116
|
+
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
|
117
|
+
cols = tl.arange(0, BLOCK_SIZE)
|
|
118
|
+
mask = cols < n_cols
|
|
119
|
+
|
|
120
|
+
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
121
|
+
db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
122
|
+
|
|
123
|
+
# Pre-load weights once (same optimization as forward pass)
|
|
124
|
+
w = tl.load(W_ptr + cols, mask=mask, other=0.0)
|
|
125
|
+
w_f32 = w.to(tl.float32)
|
|
126
|
+
|
|
127
|
+
# Calculate pointers for this specific row
|
|
128
|
+
row_X_ptr = X_ptr + row_start * stride_x
|
|
129
|
+
row_DX_ptr = DX_ptr + row_start * stride_dx
|
|
130
|
+
row_DY_ptr = DY_ptr + row_start * stride_dy
|
|
131
|
+
row_Mean_ptr = Mean_ptr + row_start
|
|
132
|
+
row_RSTD_ptr = RSTD_ptr + row_start
|
|
133
|
+
|
|
134
|
+
for _ in range(row_start, row_end):
|
|
135
|
+
# Load data for this row
|
|
136
|
+
x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
|
|
137
|
+
dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
|
|
138
|
+
mean = tl.load(row_Mean_ptr)
|
|
139
|
+
rstd = tl.load(row_RSTD_ptr)
|
|
140
|
+
|
|
141
|
+
# Convert to fp32 for numerical stability
|
|
142
|
+
x_f32 = x.to(tl.float32)
|
|
143
|
+
dy_f32 = dy.to(tl.float32)
|
|
144
|
+
mean_f32 = mean.to(tl.float32)
|
|
145
|
+
rstd_f32 = rstd.to(tl.float32)
|
|
146
|
+
|
|
147
|
+
# Compute backward pass for this row
|
|
148
|
+
x_hat = (x_f32 - mean_f32) * rstd_f32
|
|
149
|
+
wdy = w_f32 * dy_f32
|
|
150
|
+
c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
|
|
151
|
+
c2 = tl.sum(wdy, axis=0) / n_cols
|
|
152
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
|
|
153
|
+
|
|
154
|
+
# Store input gradient
|
|
155
|
+
tl.store(row_DX_ptr + cols, dx, mask=mask)
|
|
156
|
+
|
|
157
|
+
# Accumulate weight and bias gradients for this thread block's assigned rows
|
|
158
|
+
dw = dy_f32 * x_hat
|
|
159
|
+
db = dy_f32
|
|
160
|
+
dW_row += dw
|
|
161
|
+
db_row += db
|
|
162
|
+
|
|
163
|
+
row_X_ptr += stride_x
|
|
164
|
+
row_DX_ptr += stride_dx
|
|
165
|
+
row_DY_ptr += stride_dy
|
|
166
|
+
row_Mean_ptr += stride_mean
|
|
167
|
+
row_RSTD_ptr += stride_rstd
|
|
168
|
+
|
|
169
|
+
tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
|
|
170
|
+
tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def layer_norm_forward(X, W, B, eps):
|
|
174
|
+
"""
|
|
175
|
+
Args:
|
|
176
|
+
X: Input tensor of shape (..., hidden_size)
|
|
177
|
+
W: Weight tensor of shape (hidden_size,)
|
|
178
|
+
B: Bias tensor of shape (hidden_size,)
|
|
179
|
+
eps: Small constant for numerical stability
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Tuple of (output, input, mean, rstd, block_size, num_warps)
|
|
183
|
+
"""
|
|
184
|
+
shape = X.shape
|
|
185
|
+
dim = shape[-1]
|
|
186
|
+
X = X.view(-1, dim)
|
|
187
|
+
n_rows, n_cols = X.shape
|
|
188
|
+
|
|
189
|
+
# Calculate optimal block size and warp configuration
|
|
190
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
191
|
+
|
|
192
|
+
# Allocate output tensors
|
|
193
|
+
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
194
|
+
Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
195
|
+
RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
|
|
196
|
+
|
|
197
|
+
# Validate input dimensions
|
|
198
|
+
if X.shape[1] != W.shape[0]:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
|
|
201
|
+
f"must match weight size (W.shape[0]={W.shape[0]})"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# XPU-specific optimization
|
|
205
|
+
kernel_args = {}
|
|
206
|
+
if X.device.type == "xpu":
|
|
207
|
+
kernel_args["grf_mode"] = "large"
|
|
208
|
+
|
|
209
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
210
|
+
grid = (n_rows,)
|
|
211
|
+
_layer_norm_forward_kernel[grid](
|
|
212
|
+
Y,
|
|
213
|
+
Y.stride(0),
|
|
214
|
+
X,
|
|
215
|
+
X.stride(0),
|
|
216
|
+
W,
|
|
217
|
+
W.stride(0),
|
|
218
|
+
B,
|
|
219
|
+
B.stride(0),
|
|
220
|
+
Mean,
|
|
221
|
+
Mean.stride(0),
|
|
222
|
+
RSTD,
|
|
223
|
+
RSTD.stride(0),
|
|
224
|
+
n_cols,
|
|
225
|
+
eps,
|
|
226
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
227
|
+
num_warps=num_warps,
|
|
228
|
+
**kernel_args,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
235
|
+
"""
|
|
236
|
+
Args:
|
|
237
|
+
dY: Gradient of output
|
|
238
|
+
X: Input tensor
|
|
239
|
+
W: Weight tensor
|
|
240
|
+
B: Bias tensor
|
|
241
|
+
Mean: Pre-computed mean
|
|
242
|
+
RSTD: Pre-computed reciprocal standard deviation
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Tuple of (input_grad, weight_grad, bias_grad)
|
|
246
|
+
"""
|
|
247
|
+
shape = dY.shape
|
|
248
|
+
dim = shape[-1]
|
|
249
|
+
dY = dY.view(-1, dim)
|
|
250
|
+
n_rows, n_cols = dY.shape
|
|
251
|
+
|
|
252
|
+
sm_count = 1
|
|
253
|
+
if X.device.type == "cuda":
|
|
254
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
255
|
+
elif X.device.type == "xpu":
|
|
256
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
|
|
257
|
+
|
|
258
|
+
# fp32 for numerical stability especially.
|
|
259
|
+
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
260
|
+
_DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
261
|
+
|
|
262
|
+
# Calculate optimal block size and warp configuration
|
|
263
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
264
|
+
if n_cols > BLOCK_SIZE:
|
|
265
|
+
raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
|
|
266
|
+
rows_per_program = math.ceil(n_rows / sm_count)
|
|
267
|
+
grid = (sm_count,)
|
|
268
|
+
|
|
269
|
+
# Allocate gradient tensors
|
|
270
|
+
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
271
|
+
|
|
272
|
+
kernel_args = {"num_warps": num_warps}
|
|
273
|
+
# XPU-specific optimization
|
|
274
|
+
if X.device.type == "xpu":
|
|
275
|
+
kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
|
|
276
|
+
|
|
277
|
+
# Launch kernel with one thread block per row for optimal performance
|
|
278
|
+
_layer_norm_backward_kernel[grid](
|
|
279
|
+
X,
|
|
280
|
+
X.stride(0),
|
|
281
|
+
W,
|
|
282
|
+
Mean,
|
|
283
|
+
Mean.stride(0),
|
|
284
|
+
RSTD,
|
|
285
|
+
RSTD.stride(0),
|
|
286
|
+
DX,
|
|
287
|
+
DX.stride(0),
|
|
288
|
+
_DW,
|
|
289
|
+
_DW.stride(0),
|
|
290
|
+
_DB,
|
|
291
|
+
_DB.stride(0),
|
|
292
|
+
dY,
|
|
293
|
+
dY.stride(0),
|
|
294
|
+
n_rows,
|
|
295
|
+
n_cols,
|
|
296
|
+
rows_per_program=rows_per_program,
|
|
297
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
298
|
+
**kernel_args,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
DX = DX.view(*shape)
|
|
302
|
+
DW = _DW.sum(dim=0).to(W.dtype)
|
|
303
|
+
DB = _DB.sum(dim=0).to(B.dtype)
|
|
304
|
+
return DX, DW, DB
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class LigerLayerNormFunction(torch.autograd.Function):
|
|
308
|
+
@staticmethod
|
|
309
|
+
@ensure_contiguous
|
|
310
|
+
def forward(ctx, X, W, B, eps):
|
|
311
|
+
Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
|
|
312
|
+
ctx.save_for_backward(X, W, B, Mean, RSTD)
|
|
313
|
+
return Y
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
@ensure_contiguous
|
|
317
|
+
def backward(ctx, dY):
|
|
318
|
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
319
|
+
DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
|
|
320
|
+
return DX, DW, DB, None
|