liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
10
+
11
+
12
+ def get_num_warps(BLOCK_SIZE):
13
+ num_warps = 4
14
+ if BLOCK_SIZE >= 32768:
15
+ num_warps = 32 if not is_hip() else 16
16
+ elif BLOCK_SIZE >= 8192:
17
+ num_warps = 16
18
+ elif BLOCK_SIZE >= 2048:
19
+ num_warps = 8
20
+
21
+ return num_warps
22
+
23
+
24
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
+
26
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
+
28
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
29
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
30
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
31
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
32
+
33
+ _str_to_reduction_mode = {
34
+ "none": _REDUCTION_MODE_NONE.value,
35
+ "sum": _REDUCTION_MODE_SUM.value,
36
+ "mean": _REDUCTION_MODE_MEAN.value,
37
+ "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
38
+ }
39
+
40
+
41
+ @triton.jit
42
+ def _kldiv_kernel_forward(
43
+ y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space
44
+ y_stride, # int, prediction stride
45
+ gt_ptr, # [B, S], ground truth ptr
46
+ gt_stride, # int, ground truth stride
47
+ loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
48
+ loss_stride, # int, output stride
49
+ n_cols, # int, number of columns in the input tensor
50
+ eps,
51
+ BLOCK_SIZE: tl.constexpr,
52
+ log_target: tl.constexpr = False,
53
+ reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
54
+ ):
55
+ pid = tl.program_id(0).to(tl.int64)
56
+ y_ptr += pid * y_stride
57
+ gt_ptr += pid * gt_stride
58
+ loss_ptr += pid * loss_stride
59
+
60
+ base_offsets = tl.arange(0, BLOCK_SIZE)
61
+
62
+ loss_sum = 0.0
63
+ for i in range(0, n_cols, BLOCK_SIZE):
64
+ offsets = i + base_offsets
65
+ mask = offsets < n_cols
66
+ y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
67
+ y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
68
+
69
+ # KL(y_true || y) = y_true * (log(y_true) - log(y))
70
+ # We compute KL(y_true || y) with y in the log-space
71
+ if not log_target:
72
+ loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
73
+ else:
74
+ loss = tl.exp(y_true) * (y_true - y)
75
+
76
+ if reduction == _REDUCTION_MODE_NONE:
77
+ tl.store(loss_ptr + offsets, loss, mask=mask)
78
+ else:
79
+ loss_sum += tl.sum(loss, axis=0)
80
+
81
+ if reduction != _REDUCTION_MODE_NONE:
82
+ tl.store(loss_ptr, loss_sum)
83
+
84
+
85
+ @triton.jit
86
+ def _kldiv_kernel_backward(
87
+ target_ptr,
88
+ target_stride,
89
+ new_grads_ptr,
90
+ new_grads_stride,
91
+ n_cols,
92
+ BLOCK_SIZE: tl.constexpr,
93
+ log_target: tl.constexpr = False,
94
+ ):
95
+ pid = tl.program_id(0).to(tl.int64)
96
+
97
+ target_ptr += pid * target_stride
98
+ new_grads_ptr += pid * new_grads_stride
99
+
100
+ offsets = tl.arange(0, BLOCK_SIZE)
101
+ mask = offsets < n_cols
102
+
103
+ for i in range(0, n_cols, BLOCK_SIZE):
104
+ offsets = i + tl.arange(0, BLOCK_SIZE)
105
+ mask = offsets < n_cols
106
+
107
+ target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
108
+
109
+ if not log_target:
110
+ res = target * -1
111
+ else:
112
+ res = -tl.exp(target)
113
+
114
+ tl.store(new_grads_ptr + offsets, res, mask=mask)
115
+
116
+
117
+ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
+ BT, V = y_pred.shape
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
+
126
+ grid = (BT,)
127
+ reduction = _str_to_reduction_mode[reduction]
128
+
129
+ out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
130
+ output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
131
+
132
+ _kldiv_kernel_forward[grid](
133
+ y_pred,
134
+ y_pred.stride(0),
135
+ y_true,
136
+ y_true.stride(0),
137
+ output_tensor,
138
+ output_tensor.stride(0),
139
+ V,
140
+ eps=eps,
141
+ BLOCK_SIZE=BLOCK_SIZE,
142
+ num_warps=num_warps,
143
+ log_target=log_target,
144
+ reduction=reduction,
145
+ )
146
+
147
+ # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
148
+ # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
149
+ # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
150
+ if reduction == _REDUCTION_MODE_BATCHMEAN.value:
151
+ return output_tensor.sum() / BT
152
+ elif reduction == _REDUCTION_MODE_SUM.value:
153
+ return output_tensor.sum(dim=0)
154
+ elif reduction == _REDUCTION_MODE_MEAN.value:
155
+ return output_tensor.sum() / (BT * V)
156
+ else:
157
+ return output_tensor
158
+
159
+
160
+ def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
+ BT, V = target.shape
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
+
169
+ grid = (BT,)
170
+
171
+ # We store the gradients in-place in the input tensor
172
+ _kldiv_kernel_backward[grid](
173
+ target,
174
+ target.stride(0),
175
+ new_grads,
176
+ new_grads.stride(0),
177
+ V,
178
+ BLOCK_SIZE=BLOCK_SIZE,
179
+ num_warps=num_warps,
180
+ log_target=log_target,
181
+ )
182
+
183
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
184
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
185
+ return new_grads
186
+
187
+ return new_grads * grad_output
188
+
189
+
190
+ class LigerKLDivLossFunction(torch.autograd.Function):
191
+ """
192
+ Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
193
+ ```python
194
+ if log_target:
195
+ loss = target.exp() * (target - input)
196
+ else:
197
+ loss = target * (target.log() - input)
198
+ ```,
199
+ then the loss is reduced according to the `reduction` parameter.
200
+ as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
201
+ """
202
+
203
+ @staticmethod
204
+ @ensure_contiguous
205
+ def forward(
206
+ ctx,
207
+ y_pred: torch.Tensor,
208
+ y_true: torch.Tensor,
209
+ reduction: REDUCTION_LITERAL = "batchmean",
210
+ log_target: bool = False,
211
+ eps: float = 1e-10,
212
+ ) -> torch.Tensor:
213
+ """A forward pass for the KL Divergence Loss.
214
+
215
+ Args:
216
+ ctx: Torch autograd context
217
+ y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
218
+ y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
219
+ reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
220
+ log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
221
+ eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
222
+
223
+ Returns:
224
+ torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
225
+ """
226
+ ctx.save_for_backward(y_true)
227
+ ctx.reduction = reduction
228
+ ctx.log_target = log_target
229
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
230
+
231
+ @staticmethod
232
+ @ensure_contiguous
233
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
234
+ """A backward pass for the KL Divergence Loss.
235
+
236
+ Args:
237
+ ctx: Torch autograd context
238
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
239
+
240
+ Returns:
241
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
242
+ """
243
+ (y_true,) = ctx.saved_tensors
244
+
245
+ new_grads = torch.empty_like(y_true)
246
+
247
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
248
+
249
+ if ctx.reduction == "batchmean":
250
+ derivative = derivative / y_true.shape[0]
251
+ elif ctx.reduction == "sum" or ctx.reduction == "none":
252
+ pass
253
+ elif ctx.reduction == "mean":
254
+ derivative = derivative / (y_true.shape[0] * y_true.shape[1])
255
+
256
+ return (
257
+ derivative,
258
+ None,
259
+ None,
260
+ None,
261
+ None,
262
+ )
@@ -0,0 +1,320 @@
1
+ import math
2
+ import operator
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.utils import calculate_settings
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.utils import is_npu_available
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import rsqrt
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import rsqrt
20
+ else:
21
+ from triton.language.math import rsqrt
22
+
23
+
24
+ @triton.jit
25
+ def _layer_norm_forward_kernel(
26
+ Y_ptr, # pointer to output, shape (n_rows, n_cols)
27
+ Y_row_stride, # stride of each row in output
28
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
29
+ X_row_stride, # stride of each row in input
30
+ W_ptr, # pointer to weights, shape (n_cols,)
31
+ W_row_stride, # stride of each row in weights
32
+ B_ptr, # pointer to bias, shape (n_cols,)
33
+ B_row_stride, # stride of each row in bias
34
+ Mean_ptr, # pointer to mean, shape (n_rows,)
35
+ Mean_row_stride, # stride of each row in mean
36
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
37
+ RSTD_row_stride, # stride of each row in rstd
38
+ n_cols,
39
+ eps,
40
+ BLOCK_SIZE: tl.constexpr,
41
+ ):
42
+ """
43
+ References:
44
+ https://arxiv.org/abs/1607.06450
45
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
46
+ """
47
+ row_idx = tl.program_id(0).to(tl.int64)
48
+ col_offsets = tl.arange(0, BLOCK_SIZE)
49
+ mask = col_offsets < n_cols
50
+
51
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
52
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
53
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
54
+ W_f32 = W_row.to(tl.float32)
55
+ B_f32 = B_row.to(tl.float32)
56
+
57
+ # Calculate pointers for this row
58
+ row_X_ptr = X_ptr + row_idx * X_row_stride
59
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
60
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
61
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
62
+
63
+ # Load input data and convert to fp32 for numerical stability
64
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
65
+ X_f32 = X_row.to(tl.float32)
66
+
67
+ # Compute statistics in fp32 for numerical stability
68
+ mean = tl.sum(X_f32, axis=0) / n_cols
69
+ X_centered = X_f32 - mean
70
+ # Apply mask to variance calculation to exclude contributions from masked elements
71
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
72
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
73
+ rstd = rsqrt(var + eps)
74
+
75
+ # Store statistics (convert back to original dtype only once)
76
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
77
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
78
+
79
+ # Fused normalization and affine transformation
80
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
81
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
82
+
83
+ # Store output (single conversion back to original dtype)
84
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
85
+
86
+
87
+ @triton.jit
88
+ def _layer_norm_backward_kernel(
89
+ X_ptr, # pointer to input, shape (n_rows, n_cols)
90
+ stride_x, # stride of each row in input
91
+ W_ptr, # pointer to weights, shape (n_cols,)
92
+ Mean_ptr, # pointer to mean, shape (n_rows,)
93
+ stride_mean, # stride of each row in mean
94
+ RSTD_ptr, # pointer to rstd, shape (n_rows,)
95
+ stride_rstd, # stride of each row in rstd
96
+ DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
97
+ stride_dx, # stride of each row in input grad
98
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
99
+ stride_dw, # stride of each row in weights grad
100
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
101
+ stride_db, # stride of each row in bias grad
102
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
103
+ stride_dy, # stride of each row in output grad
104
+ n_rows,
105
+ n_cols,
106
+ rows_per_program: tl.constexpr,
107
+ BLOCK_SIZE: tl.constexpr,
108
+ ):
109
+ """
110
+ References:
111
+ https://arxiv.org/abs/1607.06450
112
+ https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
113
+ """
114
+ row_block_id = tl.program_id(0).to(tl.int64)
115
+ row_start = row_block_id * rows_per_program
116
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
117
+ cols = tl.arange(0, BLOCK_SIZE)
118
+ mask = cols < n_cols
119
+
120
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
121
+ db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
122
+
123
+ # Pre-load weights once (same optimization as forward pass)
124
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
125
+ w_f32 = w.to(tl.float32)
126
+
127
+ # Calculate pointers for this specific row
128
+ row_X_ptr = X_ptr + row_start * stride_x
129
+ row_DX_ptr = DX_ptr + row_start * stride_dx
130
+ row_DY_ptr = DY_ptr + row_start * stride_dy
131
+ row_Mean_ptr = Mean_ptr + row_start
132
+ row_RSTD_ptr = RSTD_ptr + row_start
133
+
134
+ for _ in range(row_start, row_end):
135
+ # Load data for this row
136
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
138
+ mean = tl.load(row_Mean_ptr)
139
+ rstd = tl.load(row_RSTD_ptr)
140
+
141
+ # Convert to fp32 for numerical stability
142
+ x_f32 = x.to(tl.float32)
143
+ dy_f32 = dy.to(tl.float32)
144
+ mean_f32 = mean.to(tl.float32)
145
+ rstd_f32 = rstd.to(tl.float32)
146
+
147
+ # Compute backward pass for this row
148
+ x_hat = (x_f32 - mean_f32) * rstd_f32
149
+ wdy = w_f32 * dy_f32
150
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
151
+ c2 = tl.sum(wdy, axis=0) / n_cols
152
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
153
+
154
+ # Store input gradient
155
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
156
+
157
+ # Accumulate weight and bias gradients for this thread block's assigned rows
158
+ dw = dy_f32 * x_hat
159
+ db = dy_f32
160
+ dW_row += dw
161
+ db_row += db
162
+
163
+ row_X_ptr += stride_x
164
+ row_DX_ptr += stride_dx
165
+ row_DY_ptr += stride_dy
166
+ row_Mean_ptr += stride_mean
167
+ row_RSTD_ptr += stride_rstd
168
+
169
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
171
+
172
+
173
+ def layer_norm_forward(X, W, B, eps):
174
+ """
175
+ Args:
176
+ X: Input tensor of shape (..., hidden_size)
177
+ W: Weight tensor of shape (hidden_size,)
178
+ B: Bias tensor of shape (hidden_size,)
179
+ eps: Small constant for numerical stability
180
+
181
+ Returns:
182
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
183
+ """
184
+ shape = X.shape
185
+ dim = shape[-1]
186
+ X = X.view(-1, dim)
187
+ n_rows, n_cols = X.shape
188
+
189
+ # Calculate optimal block size and warp configuration
190
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
191
+
192
+ # Allocate output tensors
193
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
194
+ Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
195
+ RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
196
+
197
+ # Validate input dimensions
198
+ if X.shape[1] != W.shape[0]:
199
+ raise ValueError(
200
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
201
+ f"must match weight size (W.shape[0]={W.shape[0]})"
202
+ )
203
+
204
+ # XPU-specific optimization
205
+ kernel_args = {}
206
+ if X.device.type == "xpu":
207
+ kernel_args["grf_mode"] = "large"
208
+
209
+ # Launch kernel with one thread block per row for optimal performance
210
+ grid = (n_rows,)
211
+ _layer_norm_forward_kernel[grid](
212
+ Y,
213
+ Y.stride(0),
214
+ X,
215
+ X.stride(0),
216
+ W,
217
+ W.stride(0),
218
+ B,
219
+ B.stride(0),
220
+ Mean,
221
+ Mean.stride(0),
222
+ RSTD,
223
+ RSTD.stride(0),
224
+ n_cols,
225
+ eps,
226
+ BLOCK_SIZE=BLOCK_SIZE,
227
+ num_warps=num_warps,
228
+ **kernel_args,
229
+ )
230
+
231
+ return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
232
+
233
+
234
+ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
235
+ """
236
+ Args:
237
+ dY: Gradient of output
238
+ X: Input tensor
239
+ W: Weight tensor
240
+ B: Bias tensor
241
+ Mean: Pre-computed mean
242
+ RSTD: Pre-computed reciprocal standard deviation
243
+
244
+ Returns:
245
+ Tuple of (input_grad, weight_grad, bias_grad)
246
+ """
247
+ shape = dY.shape
248
+ dim = shape[-1]
249
+ dY = dY.view(-1, dim)
250
+ n_rows, n_cols = dY.shape
251
+
252
+ sm_count = 1
253
+ if X.device.type == "cuda":
254
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
255
+ elif X.device.type == "xpu":
256
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
257
+
258
+ # fp32 for numerical stability especially.
259
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
260
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
261
+
262
+ # Calculate optimal block size and warp configuration
263
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
264
+ if n_cols > BLOCK_SIZE:
265
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
266
+ rows_per_program = math.ceil(n_rows / sm_count)
267
+ grid = (sm_count,)
268
+
269
+ # Allocate gradient tensors
270
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
271
+
272
+ kernel_args = {"num_warps": num_warps}
273
+ # XPU-specific optimization
274
+ if X.device.type == "xpu":
275
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
276
+
277
+ # Launch kernel with one thread block per row for optimal performance
278
+ _layer_norm_backward_kernel[grid](
279
+ X,
280
+ X.stride(0),
281
+ W,
282
+ Mean,
283
+ Mean.stride(0),
284
+ RSTD,
285
+ RSTD.stride(0),
286
+ DX,
287
+ DX.stride(0),
288
+ _DW,
289
+ _DW.stride(0),
290
+ _DB,
291
+ _DB.stride(0),
292
+ dY,
293
+ dY.stride(0),
294
+ n_rows,
295
+ n_cols,
296
+ rows_per_program=rows_per_program,
297
+ BLOCK_SIZE=BLOCK_SIZE,
298
+ **kernel_args,
299
+ )
300
+
301
+ DX = DX.view(*shape)
302
+ DW = _DW.sum(dim=0).to(W.dtype)
303
+ DB = _DB.sum(dim=0).to(B.dtype)
304
+ return DX, DW, DB
305
+
306
+
307
+ class LigerLayerNormFunction(torch.autograd.Function):
308
+ @staticmethod
309
+ @ensure_contiguous
310
+ def forward(ctx, X, W, B, eps):
311
+ Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)
312
+ ctx.save_for_backward(X, W, B, Mean, RSTD)
313
+ return Y
314
+
315
+ @staticmethod
316
+ @ensure_contiguous
317
+ def backward(ctx, dY):
318
+ X, W, B, Mean, RSTD = ctx.saved_tensors
319
+ DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)
320
+ return DX, DW, DB, None