liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/ops/jsd.py CHANGED
@@ -5,6 +5,7 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.utils import infer_device
8
9
 
9
10
 
10
11
  @triton.jit
@@ -18,7 +19,7 @@ def _jsd_kernel(
18
19
  dX_ptr,
19
20
  dX_stride,
20
21
  label_ptr,
21
- beta,
22
+ beta: tl.constexpr,
22
23
  n_non_ignore: int,
23
24
  ignore_index: tl.constexpr,
24
25
  n_cols,
@@ -50,21 +51,49 @@ def _jsd_kernel(
50
51
  X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51
52
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
53
 
53
- Q = tl.exp(X)
54
- P = tl.exp(Y)
55
- M = beta * P + (1 - beta) * Q
56
- log_M = tl.log(M)
54
+ if beta == 0.0: # forward KL
55
+ Y_max = tl.max(Y, axis=0)
56
+ Y_shifted = Y - Y_max
57
+ Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
58
+ loss = Y_prob * (Y - X)
59
+ dX = -Y_prob
60
+ elif beta == 1.0: # reverse KL
61
+ X_max = tl.max(X, axis=0)
62
+ X_shifted = X - X_max
63
+ X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
64
+ loss = X_prob * (X - Y)
65
+ dX = loss + X_prob
66
+ else:
67
+ max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
68
+ X_shifted = X - max_val
69
+ Y_shifted = Y - max_val
70
+
71
+ # Pre-compute exp(max_val) since it's used twice
72
+ exp_max = tl.exp(max_val)
73
+
74
+ # Compute exp terms with compensation
75
+ Q = tl.exp(X_shifted) * exp_max # = exp(X)
76
+ P = tl.exp(Y_shifted) * exp_max # = exp(Y)
77
+
78
+ # Pre-compute common terms
79
+ beta_P = beta * P
80
+ one_minus_beta_Q = (1 - beta) * Q
81
+ M = beta_P + one_minus_beta_Q
82
+ log_M = tl.log(M) # No need to compensate as M is already in original scale
83
+
84
+ loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
85
+ dX = one_minus_beta_Q * (X - log_M)
86
+
87
+ # Pre-compute scaling factor
88
+ scale = 1.0 / n_non_ignore
89
+ loss = loss * scale
90
+ dX = dX * scale
57
91
 
58
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59
- # reduction == "batchmean"
60
- loss = loss / n_non_ignore
61
92
  tl.store(loss_ptr + offsets, loss, mask=mask)
62
-
63
- dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
64
93
  tl.store(dX_ptr + offsets, dX, mask=mask)
65
94
 
66
95
 
67
- MAX_FUSED_SIZE = 65536
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
68
97
 
69
98
 
70
99
  def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
@@ -89,9 +118,7 @@ def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
89
118
  loss_stride=loss.stride(-2),
90
119
  dX_ptr=dX,
91
120
  dX_stride=dX.stride(-2),
92
- label_ptr=(
93
- shift_labels if has_label else torch.empty(1, device=_input.device)
94
- ), # dummy ptr if no label
121
+ label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), # dummy ptr if no label
95
122
  beta=beta,
96
123
  n_non_ignore=n_non_ignore,
97
124
  ignore_index=ignore_index,
@@ -142,7 +169,7 @@ class LigerJSDFunction(torch.autograd.Function):
142
169
  _input (torch.Tensor): predict values with shape (BT, V) in logspace
143
170
  target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
171
  shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
- beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
172
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
146
173
  ignore_index (int): the index to ignore. Default: -100
147
174
 
148
175
  Returns:
@@ -150,15 +177,13 @@ class LigerJSDFunction(torch.autograd.Function):
150
177
  """
151
178
  has_label = False
152
179
  if shift_labels is not None:
153
- assert shift_labels.shape == (
154
- _input.shape[0],
155
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
180
+ assert shift_labels.shape == (_input.shape[0],), (
181
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
182
+ )
156
183
  shift_labels = shift_labels.contiguous()
157
184
  has_label = True
158
185
 
159
- loss, dX = jsd_forward(
160
- _input, target, shift_labels, beta, ignore_index, has_label
161
- )
186
+ loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
162
187
  ctx.save_for_backward(dX)
163
188
  return loss
164
189
 
@@ -4,7 +4,9 @@ import torch
4
4
  import triton
5
5
  import triton.language as tl
6
6
 
7
- from liger_kernel.ops.utils import ensure_contiguous, is_hip
7
+ from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
8
10
 
9
11
 
10
12
  def get_num_warps(BLOCK_SIZE):
@@ -23,10 +25,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
23
25
 
24
26
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
25
27
 
26
- _REDUCTION_MODE_NONE = tl.constexpr(0)
27
- _REDUCTION_MODE_SUM = tl.constexpr(1)
28
- _REDUCTION_MODE_MEAN = tl.constexpr(2)
29
- _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
28
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
29
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
30
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
31
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
30
32
 
31
33
  _str_to_reduction_mode = {
32
34
  "none": _REDUCTION_MODE_NONE.value,
@@ -114,9 +116,12 @@ def _kldiv_kernel_backward(
114
116
 
115
117
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
116
118
  BT, V = y_pred.shape
117
-
118
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
119
- num_warps = get_num_warps(BLOCK_SIZE)
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
120
125
 
121
126
  grid = (BT,)
122
127
  reduction = _str_to_reduction_mode[reduction]
@@ -154,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
154
159
 
155
160
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
156
161
  BT, V = target.shape
157
-
158
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
159
- num_warps = get_num_warps(BLOCK_SIZE)
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
160
168
 
161
169
  grid = (BT,)
162
170
 
@@ -184,9 +192,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
184
192
  Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
185
193
  ```python
186
194
  if log_target:
187
- loss = target * (target.log() - input)
188
- else:
189
195
  loss = target.exp() * (target - input)
196
+ else:
197
+ loss = target * (target.log() - input)
190
198
  ```,
191
199
  then the loss is reduced according to the `reduction` parameter.
192
200
  as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
@@ -218,9 +226,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
218
226
  ctx.save_for_backward(y_true)
219
227
  ctx.reduction = reduction
220
228
  ctx.log_target = log_target
221
- return kldiv_forward_triton(
222
- y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223
- )
229
+ return kldiv_forward_triton(y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps)
224
230
 
225
231
  @staticmethod
226
232
  @ensure_contiguous
@@ -238,9 +244,7 @@ class LigerKLDivLossFunction(torch.autograd.Function):
238
244
 
239
245
  new_grads = torch.empty_like(y_true)
240
246
 
241
- derivative = kldiv_backward_triton(
242
- y_true, grad_output, new_grads, ctx.log_target
243
- )
247
+ derivative = kldiv_backward_triton(y_true, grad_output, new_grads, ctx.log_target)
244
248
 
245
249
  if ctx.reduction == "batchmean":
246
250
  derivative = derivative / y_true.shape[0]
@@ -1,15 +1,12 @@
1
- import math
2
1
  import operator
3
2
 
4
3
  import torch
5
4
  import triton
6
5
  import triton.language as tl
7
6
 
8
- from liger_kernel.ops.utils import (
9
- calculate_settings,
10
- compare_version,
11
- ensure_contiguous,
12
- )
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
13
10
 
14
11
  if compare_version("triton", operator.ge, "3.0.0"):
15
12
  try:
@@ -45,29 +42,44 @@ def _layer_norm_forward_kernel(
45
42
  https://arxiv.org/abs/1607.06450
46
43
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
47
44
  """
48
- row_idx = tl.program_id(0)
45
+ row_idx = tl.program_id(0).to(tl.int64)
49
46
  col_offsets = tl.arange(0, BLOCK_SIZE)
50
47
  mask = col_offsets < n_cols
51
48
 
52
- Y_ptr += row_idx * Y_row_stride
53
- X_ptr += row_idx * X_row_stride
54
- Mean_ptr += row_idx * Mean_row_stride
55
- RSTD_ptr += row_idx * RSTD_row_stride
56
-
57
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
58
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
59
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
60
-
61
- mean = tl.sum(X_row, axis=0) / n_cols
62
- var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
49
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
50
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
51
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
52
+ W_f32 = W_row.to(tl.float32)
53
+ B_f32 = B_row.to(tl.float32)
54
+
55
+ # Calculate pointers for this row
56
+ row_X_ptr = X_ptr + row_idx * X_row_stride
57
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
58
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
59
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
60
+
61
+ # Load input data and convert to fp32 for numerical stability
62
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
63
+ X_f32 = X_row.to(tl.float32)
64
+
65
+ # Compute statistics in fp32 for numerical stability
66
+ mean = tl.sum(X_f32, axis=0) / n_cols
67
+ X_centered = X_f32 - mean
68
+ # Apply mask to variance calculation to exclude contributions from masked elements
69
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
70
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
63
71
  rstd = rsqrt(var + eps)
64
72
 
65
- tl.store(Mean_ptr, mean)
66
- tl.store(RSTD_ptr, rstd)
73
+ # Store statistics (convert back to original dtype only once)
74
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
75
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
67
76
 
68
- Y_row = (X_row - mean) * rstd * W_row + B_row
77
+ # Fused normalization and affine transformation
78
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
79
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
69
80
 
70
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
81
+ # Store output (single conversion back to original dtype)
82
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
71
83
 
72
84
 
73
85
  @triton.jit
@@ -82,78 +94,100 @@ def _layer_norm_backward_kernel(
82
94
  DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
83
95
  stride_x, # stride of each row in input
84
96
  stride_dx, # stride of each row in input grad
85
- stride_dw, # stride of each row in weights grad
86
- stride_db, # stride of each row in bias grad
87
97
  stride_dy, # stride of each row in output grad
88
- n_rows,
89
98
  n_cols,
90
- rows_per_program: tl.constexpr,
91
99
  BLOCK_SIZE: tl.constexpr,
92
100
  dtype: tl.constexpr,
101
+ atomic_dtype: tl.constexpr,
93
102
  ):
94
103
  """
95
104
  References:
96
105
  https://arxiv.org/abs/1607.06450
97
106
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
98
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
99
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
100
107
  """
101
- row_block_id = tl.program_id(0)
102
- row_start = row_block_id * rows_per_program
103
- row_end = min((row_block_id + 1) * rows_per_program, n_rows)
108
+ row_idx = tl.program_id(0).to(tl.int64)
104
109
  cols = tl.arange(0, BLOCK_SIZE)
105
110
  mask = cols < n_cols
106
111
 
107
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
- db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
109
-
110
- X_ptr += row_start * stride_x
111
- Mean_ptr += row_start
112
- RSTD_ptr += row_start
113
- DX_ptr += row_start * stride_dx
114
- DY_ptr += row_start * stride_dy
115
-
116
- for _ in range(row_start, row_end):
117
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
118
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
119
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
120
- mean = tl.load(Mean_ptr)
121
- rstd = tl.load(RSTD_ptr)
122
-
123
- x_hat = (x - mean) * rstd
124
- wdy = w * dy
125
- c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
126
- c2 = tl.sum(wdy, axis=0) / n_cols
127
- dx = (wdy - (x_hat * c1 + c2)) * rstd
128
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
129
-
130
- dw_row += dy * x_hat
131
- db_row += dy
132
-
133
- X_ptr += stride_x
134
- Mean_ptr += 1
135
- RSTD_ptr += 1
136
- DX_ptr += stride_dx
137
- DY_ptr += stride_dy
138
-
139
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
140
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
112
+ # Pre-load weights once (same optimization as forward pass)
113
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
114
+ w_f32 = w.to(tl.float32)
115
+
116
+ # Calculate pointers for this specific row
117
+ row_X_ptr = X_ptr + row_idx * stride_x
118
+ row_DX_ptr = DX_ptr + row_idx * stride_dx
119
+ row_DY_ptr = DY_ptr + row_idx * stride_dy
120
+ row_Mean_ptr = Mean_ptr + row_idx
121
+ row_RSTD_ptr = RSTD_ptr + row_idx
122
+
123
+ # Load data for this row
124
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
125
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
126
+ mean = tl.load(row_Mean_ptr)
127
+ rstd = tl.load(row_RSTD_ptr)
128
+
129
+ # Convert to fp32 for numerical stability
130
+ x_f32 = x.to(tl.float32)
131
+ dy_f32 = dy.to(tl.float32)
132
+ mean_f32 = mean.to(tl.float32)
133
+ rstd_f32 = rstd.to(tl.float32)
134
+
135
+ # Compute backward pass for this row
136
+ x_hat = (x_f32 - mean_f32) * rstd_f32
137
+ wdy = w_f32 * dy_f32
138
+ c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
139
+ c2 = tl.sum(wdy, axis=0) / n_cols
140
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
141
+
142
+ # Store input gradient
143
+ tl.store(row_DX_ptr + cols, dx.to(dtype), mask=mask)
144
+
145
+ # Accumulate weight and bias gradients using atomic operations
146
+ dw = dy_f32 * x_hat
147
+ db = dy_f32
148
+ tl.atomic_add(DW_ptr + cols, dw.to(atomic_dtype), mask=mask)
149
+ tl.atomic_add(DB_ptr + cols, db.to(atomic_dtype), mask=mask)
141
150
 
142
151
 
143
152
  def layer_norm_forward(X, W, B, eps):
153
+ """
154
+ Args:
155
+ X: Input tensor of shape (..., hidden_size)
156
+ W: Weight tensor of shape (hidden_size,)
157
+ B: Bias tensor of shape (hidden_size,)
158
+ eps: Small constant for numerical stability
159
+
160
+ Returns:
161
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
162
+ """
144
163
  shape = X.shape
145
164
  dim = shape[-1]
146
165
  X = X.view(-1, dim)
147
166
  n_rows, n_cols = X.shape
167
+
168
+ # Calculate optimal block size and warp configuration
148
169
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
170
+
171
+ # Allocate output tensors
149
172
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
150
173
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
151
174
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
152
- assert (
153
- X.shape[1] == W.shape[0]
154
- ), f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
155
175
 
156
- _layer_norm_forward_kernel[(n_rows,)](
176
+ # Validate input dimensions
177
+ if X.shape[1] != W.shape[0]:
178
+ raise ValueError(
179
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
180
+ f"must match weight size (W.shape[0]={W.shape[0]})"
181
+ )
182
+
183
+ # XPU-specific optimization
184
+ kernel_args = {}
185
+ if X.device.type == "xpu":
186
+ kernel_args["grf_mode"] = "large"
187
+
188
+ # Launch kernel with one thread block per row for optimal performance
189
+ grid = (n_rows,)
190
+ _layer_norm_forward_kernel[grid](
157
191
  Y,
158
192
  Y.stride(0),
159
193
  X,
@@ -170,54 +204,84 @@ def layer_norm_forward(X, W, B, eps):
170
204
  eps,
171
205
  BLOCK_SIZE=BLOCK_SIZE,
172
206
  num_warps=num_warps,
207
+ **kernel_args,
173
208
  )
209
+
174
210
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
175
211
 
176
212
 
177
213
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
214
+ """
215
+ Args:
216
+ dY: Gradient of output
217
+ X: Input tensor
218
+ W: Weight tensor
219
+ B: Bias tensor
220
+ Mean: Pre-computed mean
221
+ RSTD: Pre-computed reciprocal standard deviation
222
+
223
+ Returns:
224
+ Tuple of (input_grad, weight_grad, bias_grad)
225
+ """
178
226
  shape = dY.shape
179
227
  dim = shape[-1]
180
228
  dY = dY.view(-1, dim)
181
229
  n_rows, n_cols = dY.shape
182
230
 
231
+ # Allocate gradient tensors
183
232
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
184
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
185
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
186
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
233
+ # Use float32 for weight/bias gradients if bfloat16 (due to atomic_add limitation)
234
+ grad_dtype = torch.float32 if W.dtype == torch.bfloat16 else W.dtype
235
+ DW = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
236
+ DB = torch.zeros(n_cols, dtype=grad_dtype, device=W.device)
187
237
 
238
+ # Calculate optimal block size and warp configuration
188
239
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
189
240
  if n_cols > BLOCK_SIZE:
190
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
241
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
242
+
243
+ # Determine dtype for triton operations
244
+ triton_dtype = (
245
+ tl.float32
246
+ if X.dtype == torch.float32
247
+ else tl.bfloat16
248
+ if X.dtype == torch.bfloat16
249
+ else tl.float16
250
+ if X.dtype == torch.float16
251
+ else tl.float32 # fallback
252
+ )
253
+
254
+ # Use float32 for atomic operations if bfloat16 is not supported
255
+ atomic_dtype = tl.float32 if triton_dtype == tl.bfloat16 else triton_dtype
191
256
 
192
- rows_per_program = math.ceil(n_rows / sm_count)
193
- grid = (sm_count,)
194
- triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
257
+ kernel_args = {"num_warps": num_warps}
258
+ # XPU-specific optimization
259
+ if X.device.type == "xpu":
260
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
261
+
262
+ # Launch kernel with one thread block per row for optimal performance
263
+ grid = (n_rows,)
195
264
  _layer_norm_backward_kernel[grid](
196
265
  X,
197
266
  W,
198
267
  Mean,
199
268
  RSTD,
200
269
  DX,
201
- _DW,
202
- _DB,
270
+ DW,
271
+ DB,
203
272
  dY,
204
273
  X.stride(0),
205
274
  DX.stride(0),
206
- _DW.stride(0),
207
- _DB.stride(0),
208
275
  dY.stride(0),
209
- n_rows,
210
276
  n_cols,
211
- rows_per_program,
212
277
  BLOCK_SIZE=BLOCK_SIZE,
213
278
  dtype=triton_dtype,
279
+ atomic_dtype=atomic_dtype,
280
+ **kernel_args,
214
281
  )
215
282
 
216
- DW = _DW.sum(dim=0).to(W.dtype)
217
- DB = _DB.sum(dim=0).to(W.dtype)
218
-
219
283
  DX = DX.view(*shape)
220
- return DX, DW, DB
284
+ return DX, DW.to(W.dtype), DB.to(W.dtype)
221
285
 
222
286
 
223
287
  class LigerLayerNormFunction(torch.autograd.Function):