liger-kernel-nightly 0.5.10.dev20250611191801__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +54 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +25 -5
  7. liger_kernel/chunked_loss/grpo_loss.py +46 -9
  8. liger_kernel/chunked_loss/jsd_loss.py +44 -13
  9. liger_kernel/ops/__init__.py +141 -0
  10. liger_kernel/ops/backends/README.md +151 -0
  11. liger_kernel/ops/backends/__init__.py +13 -0
  12. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  13. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
  15. liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
  16. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  17. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  18. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  19. liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
  20. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  21. liger_kernel/ops/backends/registry.py +61 -0
  22. liger_kernel/ops/cross_entropy.py +130 -64
  23. liger_kernel/ops/dyt.py +5 -4
  24. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  25. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  26. liger_kernel/ops/geglu.py +6 -4
  27. liger_kernel/ops/group_norm.py +7 -7
  28. liger_kernel/ops/grpo_loss.py +3 -1
  29. liger_kernel/ops/kl_div.py +8 -11
  30. liger_kernel/ops/layer_norm.py +135 -80
  31. liger_kernel/ops/llama4_rope.py +225 -0
  32. liger_kernel/ops/poly_norm.py +390 -0
  33. liger_kernel/ops/rms_norm.py +148 -71
  34. liger_kernel/ops/rope.py +1 -1
  35. liger_kernel/ops/swiglu.py +1 -1
  36. liger_kernel/ops/tiled_mlp.py +136 -0
  37. liger_kernel/ops/utils.py +14 -0
  38. liger_kernel/transformers/__init__.py +65 -0
  39. liger_kernel/transformers/auto_model.py +21 -0
  40. liger_kernel/transformers/cross_entropy.py +9 -4
  41. liger_kernel/transformers/dyt.py +1 -1
  42. liger_kernel/transformers/experimental/__init__.py +5 -0
  43. liger_kernel/transformers/experimental/embedding.py +1 -1
  44. liger_kernel/transformers/functional.py +56 -24
  45. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  46. liger_kernel/transformers/fused_linear_cross_entropy.py +17 -5
  47. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  48. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  49. liger_kernel/transformers/geglu.py +1 -1
  50. liger_kernel/transformers/group_norm.py +1 -1
  51. liger_kernel/transformers/grpo_loss.py +57 -2
  52. liger_kernel/transformers/jsd.py +1 -1
  53. liger_kernel/transformers/kl_div.py +1 -1
  54. liger_kernel/transformers/layer_norm.py +1 -1
  55. liger_kernel/transformers/llama4_rope.py +93 -0
  56. liger_kernel/transformers/model/exaone4.py +136 -0
  57. liger_kernel/transformers/model/falcon_h1.py +122 -0
  58. liger_kernel/transformers/model/gemma.py +28 -8
  59. liger_kernel/transformers/model/gemma2.py +34 -11
  60. liger_kernel/transformers/model/gemma3.py +102 -112
  61. liger_kernel/transformers/model/glm4.py +18 -5
  62. liger_kernel/transformers/model/glm4v.py +163 -0
  63. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  64. liger_kernel/transformers/model/gpt_oss.py +211 -0
  65. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  66. liger_kernel/transformers/model/internvl.py +157 -0
  67. liger_kernel/transformers/model/llama.py +26 -7
  68. liger_kernel/transformers/model/llama4.py +121 -0
  69. liger_kernel/transformers/model/llava.py +18 -6
  70. liger_kernel/transformers/model/loss_utils.py +34 -3
  71. liger_kernel/transformers/model/mistral.py +17 -10
  72. liger_kernel/transformers/model/mixtral.py +24 -9
  73. liger_kernel/transformers/model/mllama.py +18 -7
  74. liger_kernel/transformers/model/olmo2.py +18 -5
  75. liger_kernel/transformers/model/olmo3.py +142 -0
  76. liger_kernel/transformers/model/output_classes.py +147 -0
  77. liger_kernel/transformers/model/paligemma.py +42 -5
  78. liger_kernel/transformers/model/phi3.py +24 -159
  79. liger_kernel/transformers/model/qwen2.py +26 -4
  80. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  81. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  82. liger_kernel/transformers/model/qwen3.py +22 -6
  83. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  84. liger_kernel/transformers/model/qwen3_next.py +146 -0
  85. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  86. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  87. liger_kernel/transformers/model/smollm3.py +199 -0
  88. liger_kernel/transformers/model/smolvlm.py +158 -0
  89. liger_kernel/transformers/monkey_patch.py +1423 -100
  90. liger_kernel/transformers/multi_token_attention.py +2 -2
  91. liger_kernel/transformers/poly_norm.py +42 -0
  92. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  93. liger_kernel/transformers/rms_norm.py +15 -5
  94. liger_kernel/transformers/rope.py +45 -1
  95. liger_kernel/transformers/softmax.py +1 -1
  96. liger_kernel/transformers/sparsemax.py +1 -1
  97. liger_kernel/transformers/swiglu.py +18 -1
  98. liger_kernel/transformers/tiled_mlp.py +125 -0
  99. liger_kernel/transformers/tvd.py +1 -1
  100. liger_kernel/utils.py +52 -0
  101. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +37 -25
  102. liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
  103. liger_kernel_nightly-0.5.10.dev20250611191801.dist-info/RECORD +0 -95
  104. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250611191801.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py CHANGED
@@ -7,8 +7,9 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.utils import is_npu_available
10
11
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
12
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
13
  try:
13
14
  # typical import path with dispatch available
14
15
  from triton.language.extra.libdevice import tanh
@@ -40,7 +41,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
41
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
42
  tanh_result = tanh(tanh_arg)
42
43
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
44
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
45
  tl.store(c + col_offsets, c_row, mask=mask)
45
46
 
46
47
 
@@ -66,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
66
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
67
68
  tanh_result = tanh(tanh_arg)
68
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
69
71
 
70
- db_row = dc_row * geglu_a
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
71
73
 
72
74
  # Gradient w.r.t. a can be computed with:
73
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -78,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
78
80
  da_row = dc_row * b_row * (term1 + term2)
79
81
 
80
82
  tl.store(a + col_offsets, da_row, mask=mask)
81
- tl.store(b + col_offsets, db_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
82
84
 
83
85
 
84
86
  def geglu_forward(a, b):
@@ -6,8 +6,9 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.utils import is_npu_available
9
10
 
10
- if compare_version("triton", operator.ge, "3.0.0"):
11
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
11
12
  try:
12
13
  # typical import path with dispatch available
13
14
  from triton.language.extra.libdevice import rsqrt
@@ -77,15 +78,14 @@ def _group_norm_forward_kernel(
77
78
  for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
78
79
  W = tl.load(W_ptr + channel_idx)
79
80
  B = tl.load(B_ptr + channel_idx)
80
- for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
81
+ # Calculate channel offset within the group
82
+ channel_offset = (channel_idx - group_idx * channels_per_group) * hidden_size_per_channel
83
+ for i in tl.range(0, hidden_size_per_channel, BLOCK_SIZE):
81
84
  hidden_size_offsets = i + block_range
82
85
  mask = hidden_size_offsets < hidden_size_per_channel
83
- X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
86
+ X = tl.load(X_ptr + channel_offset + hidden_size_offsets, mask=mask, other=m)
84
87
  Y = (X - m) * rstd * W + B
85
- tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
86
-
87
- X_ptr += hidden_size_per_channel
88
- Y_ptr += hidden_size_per_channel
88
+ tl.store(Y_ptr + channel_offset + hidden_size_offsets, Y, mask=mask)
89
89
 
90
90
  tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
91
91
  tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
@@ -128,7 +128,9 @@ def _grpo_loss_fwd_kernel(
128
128
  per_token_loss1 = coef_1 * advantage
129
129
  per_token_loss2 = coef_2 * advantage
130
130
  per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
- is_clipped = per_token_loss1 < per_token_loss2
131
+ is_low_clipped = (coef_1 < 1 - EPS_LOW) & (advantage < 0)
132
+ is_high_clipped = (coef_1 > 1 + EPS_HIGH) & (advantage > 0)
133
+ is_clipped = is_low_clipped | is_high_clipped
132
134
 
133
135
  if BETA != 0.0:
134
136
  REF_LOGP += off_b * L + off_l
@@ -21,7 +21,12 @@ def get_num_warps(BLOCK_SIZE):
21
21
  return num_warps
22
22
 
23
23
 
24
- MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
24
+ if infer_device() == "xpu":
25
+ MAX_FUSED_SIZE = 8192
26
+ elif infer_device() == "npu":
27
+ MAX_FUSED_SIZE = 8192
28
+ else:
29
+ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
25
30
 
26
31
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
27
32
 
@@ -116,11 +121,7 @@ def _kldiv_kernel_backward(
116
121
 
117
122
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
118
123
  BT, V = y_pred.shape
119
- BLOCK_SIZE = (
120
- min(8192, triton.next_power_of_2(V))
121
- if infer_device() == "xpu"
122
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
- )
124
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
124
125
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
125
126
 
126
127
  grid = (BT,)
@@ -159,11 +160,7 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
159
160
 
160
161
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
161
162
  BT, V = target.shape
162
- BLOCK_SIZE = (
163
- min(8192, triton.next_power_of_2(V))
164
- if infer_device() == "xpu"
165
- else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
- )
163
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
167
164
  num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
168
165
 
169
166
  grid = (BT,)
@@ -8,8 +8,9 @@ import triton.language as tl
8
8
  from liger_kernel.ops.utils import calculate_settings
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.utils import is_npu_available
11
12
 
12
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
13
14
  try:
14
15
  # typical import path with dispatch available
15
16
  from triton.language.extra.libdevice import rsqrt
@@ -43,111 +44,157 @@ def _layer_norm_forward_kernel(
43
44
  https://arxiv.org/abs/1607.06450
44
45
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
45
46
  """
46
- row_idx = tl.program_id(0)
47
+ row_idx = tl.program_id(0).to(tl.int64)
47
48
  col_offsets = tl.arange(0, BLOCK_SIZE)
48
49
  mask = col_offsets < n_cols
49
50
 
50
- Y_ptr += row_idx * Y_row_stride
51
- X_ptr += row_idx * X_row_stride
52
- Mean_ptr += row_idx * Mean_row_stride
53
- RSTD_ptr += row_idx * RSTD_row_stride
54
-
55
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
56
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
57
- B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
-
59
- mean = tl.sum(X_row, axis=0) / n_cols
60
- Xmm = tl.where(mask, X_row - mean, 0)
61
- var = tl.sum(Xmm * Xmm, axis=0) / n_cols
51
+ # Pre-load weights and bias in fp32 to avoid repeated conversions
52
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
53
+ B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0.0)
54
+ W_f32 = W_row.to(tl.float32)
55
+ B_f32 = B_row.to(tl.float32)
56
+
57
+ # Calculate pointers for this row
58
+ row_X_ptr = X_ptr + row_idx * X_row_stride
59
+ row_Y_ptr = Y_ptr + row_idx * Y_row_stride
60
+ row_Mean_ptr = Mean_ptr + row_idx * Mean_row_stride
61
+ row_RSTD_ptr = RSTD_ptr + row_idx * RSTD_row_stride
62
+
63
+ # Load input data and convert to fp32 for numerical stability
64
+ X_row = tl.load(row_X_ptr + col_offsets, mask=mask, other=0.0)
65
+ X_f32 = X_row.to(tl.float32)
66
+
67
+ # Compute statistics in fp32 for numerical stability
68
+ mean = tl.sum(X_f32, axis=0) / n_cols
69
+ X_centered = X_f32 - mean
70
+ # Apply mask to variance calculation to exclude contributions from masked elements
71
+ X_centered_masked = tl.where(mask, X_centered, 0.0)
72
+ var = tl.sum(X_centered_masked * X_centered_masked, axis=0) / n_cols
62
73
  rstd = rsqrt(var + eps)
63
74
 
64
- tl.store(Mean_ptr, mean)
65
- tl.store(RSTD_ptr, rstd)
75
+ # Store statistics (convert back to original dtype only once)
76
+ tl.store(row_Mean_ptr, mean.to(X_row.dtype))
77
+ tl.store(row_RSTD_ptr, rstd.to(X_row.dtype))
66
78
 
67
- Y_row = Xmm * rstd * W_row + B_row
79
+ # Fused normalization and affine transformation
80
+ # Y = (X - mean) * rstd * W + B = X_centered * rstd * W + B
81
+ Y_f32 = X_centered * rstd * W_f32 + B_f32
68
82
 
69
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
83
+ # Store output (single conversion back to original dtype)
84
+ tl.store(row_Y_ptr + col_offsets, Y_f32.to(X_row.dtype), mask=mask)
70
85
 
71
86
 
72
87
  @triton.jit
73
88
  def _layer_norm_backward_kernel(
74
89
  X_ptr, # pointer to input, shape (n_rows, n_cols)
90
+ stride_x, # stride of each row in input
75
91
  W_ptr, # pointer to weights, shape (n_cols,)
76
92
  Mean_ptr, # pointer to mean, shape (n_rows,)
93
+ stride_mean, # stride of each row in mean
77
94
  RSTD_ptr, # pointer to rstd, shape (n_rows,)
95
+ stride_rstd, # stride of each row in rstd
78
96
  DX_ptr, # pointer to input grad, shape (n_rows, n_cols)
79
- DW_ptr, # pointer to weights grad, shape (n_cols,)
80
- DB_ptr, # pointer to bias grad, shape (n_cols,)
81
- DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
82
- stride_x, # stride of each row in input
83
97
  stride_dx, # stride of each row in input grad
98
+ DW_ptr, # pointer to weights grad, shape (n_cols,)
84
99
  stride_dw, # stride of each row in weights grad
100
+ DB_ptr, # pointer to bias grad, shape (n_cols,)
85
101
  stride_db, # stride of each row in bias grad
102
+ DY_ptr, # pointer to output grad, shape (n_rows, n_cols)
86
103
  stride_dy, # stride of each row in output grad
87
104
  n_rows,
88
105
  n_cols,
89
106
  rows_per_program: tl.constexpr,
90
107
  BLOCK_SIZE: tl.constexpr,
91
- dtype: tl.constexpr,
92
108
  ):
93
109
  """
94
110
  References:
95
111
  https://arxiv.org/abs/1607.06450
96
112
  https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
97
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
98
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
99
113
  """
100
- row_block_id = tl.program_id(0)
114
+ row_block_id = tl.program_id(0).to(tl.int64)
101
115
  row_start = row_block_id * rows_per_program
102
116
  row_end = min((row_block_id + 1) * rows_per_program, n_rows)
103
117
  cols = tl.arange(0, BLOCK_SIZE)
104
118
  mask = cols < n_cols
105
119
 
106
- dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
120
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
121
  db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
108
122
 
109
- X_ptr += row_start * stride_x
110
- Mean_ptr += row_start
111
- RSTD_ptr += row_start
112
- DX_ptr += row_start * stride_dx
113
- DY_ptr += row_start * stride_dy
123
+ # Pre-load weights once (same optimization as forward pass)
124
+ w = tl.load(W_ptr + cols, mask=mask, other=0.0)
125
+ w_f32 = w.to(tl.float32)
126
+
127
+ # Calculate pointers for this specific row
128
+ row_X_ptr = X_ptr + row_start * stride_x
129
+ row_DX_ptr = DX_ptr + row_start * stride_dx
130
+ row_DY_ptr = DY_ptr + row_start * stride_dy
131
+ row_Mean_ptr = Mean_ptr + row_start
132
+ row_RSTD_ptr = RSTD_ptr + row_start
114
133
 
115
134
  for _ in range(row_start, row_end):
116
- x = tl.load(X_ptr + cols, mask=mask, other=0.0)
117
- w = tl.load(W_ptr + cols, mask=mask, other=0.0)
118
- dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)
119
- mean = tl.load(Mean_ptr)
120
- rstd = tl.load(RSTD_ptr)
121
-
122
- x_hat = (x - mean) * rstd
123
- wdy = w * dy
135
+ # Load data for this row
136
+ x = tl.load(row_X_ptr + cols, mask=mask, other=0.0)
137
+ dy = tl.load(row_DY_ptr + cols, mask=mask, other=0.0)
138
+ mean = tl.load(row_Mean_ptr)
139
+ rstd = tl.load(row_RSTD_ptr)
140
+
141
+ # Convert to fp32 for numerical stability
142
+ x_f32 = x.to(tl.float32)
143
+ dy_f32 = dy.to(tl.float32)
144
+ mean_f32 = mean.to(tl.float32)
145
+ rstd_f32 = rstd.to(tl.float32)
146
+
147
+ # Compute backward pass for this row
148
+ x_hat = (x_f32 - mean_f32) * rstd_f32
149
+ wdy = w_f32 * dy_f32
124
150
  c1 = tl.sum(x_hat * wdy, axis=0) / n_cols
125
151
  c2 = tl.sum(wdy, axis=0) / n_cols
126
- dx = (wdy - (x_hat * c1 + c2)) * rstd
127
- tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)
152
+ dx = (wdy - (x_hat * c1 + c2)) * rstd_f32
153
+
154
+ # Store input gradient
155
+ tl.store(row_DX_ptr + cols, dx, mask=mask)
128
156
 
129
- dw_row += dy * x_hat
130
- db_row += dy
157
+ # Accumulate weight and bias gradients for this thread block's assigned rows
158
+ dw = dy_f32 * x_hat
159
+ db = dy_f32
160
+ dW_row += dw
161
+ db_row += db
131
162
 
132
- X_ptr += stride_x
133
- Mean_ptr += 1
134
- RSTD_ptr += 1
135
- DX_ptr += stride_dx
136
- DY_ptr += stride_dy
163
+ row_X_ptr += stride_x
164
+ row_DX_ptr += stride_dx
165
+ row_DY_ptr += stride_dy
166
+ row_Mean_ptr += stride_mean
167
+ row_RSTD_ptr += stride_rstd
137
168
 
138
- tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)
139
- tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)
169
+ tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask)
170
+ tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask)
140
171
 
141
172
 
142
173
  def layer_norm_forward(X, W, B, eps):
174
+ """
175
+ Args:
176
+ X: Input tensor of shape (..., hidden_size)
177
+ W: Weight tensor of shape (hidden_size,)
178
+ B: Bias tensor of shape (hidden_size,)
179
+ eps: Small constant for numerical stability
180
+
181
+ Returns:
182
+ Tuple of (output, input, mean, rstd, block_size, num_warps)
183
+ """
143
184
  shape = X.shape
144
185
  dim = shape[-1]
145
186
  X = X.view(-1, dim)
146
187
  n_rows, n_cols = X.shape
188
+
189
+ # Calculate optimal block size and warp configuration
147
190
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
191
+
192
+ # Allocate output tensors
148
193
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
149
194
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
195
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
196
+
197
+ # Validate input dimensions
151
198
  if X.shape[1] != W.shape[0]:
152
199
  raise ValueError(
153
200
  f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
@@ -159,7 +206,9 @@ def layer_norm_forward(X, W, B, eps):
159
206
  if X.device.type == "xpu":
160
207
  kernel_args["grf_mode"] = "large"
161
208
 
162
- _layer_norm_forward_kernel[(n_rows,)](
209
+ # Launch kernel with one thread block per row for optimal performance
210
+ grid = (n_rows,)
211
+ _layer_norm_forward_kernel[grid](
163
212
  Y,
164
213
  Y.stride(0),
165
214
  X,
@@ -176,12 +225,25 @@ def layer_norm_forward(X, W, B, eps):
176
225
  eps,
177
226
  BLOCK_SIZE=BLOCK_SIZE,
178
227
  num_warps=num_warps,
179
- **kernel_args, # XPU-specific optimization
228
+ **kernel_args,
180
229
  )
230
+
181
231
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
182
232
 
183
233
 
184
234
  def layer_norm_backward(dY, X, W, B, Mean, RSTD):
235
+ """
236
+ Args:
237
+ dY: Gradient of output
238
+ X: Input tensor
239
+ W: Weight tensor
240
+ B: Bias tensor
241
+ Mean: Pre-computed mean
242
+ RSTD: Pre-computed reciprocal standard deviation
243
+
244
+ Returns:
245
+ Tuple of (input_grad, weight_grad, bias_grad)
246
+ """
185
247
  shape = dY.shape
186
248
  dim = shape[-1]
187
249
  dY = dY.view(-1, dim)
@@ -193,59 +255,52 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
193
255
  elif X.device.type == "xpu":
194
256
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
195
257
 
196
- DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
197
- _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
198
- _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
258
+ # fp32 for numerical stability especially.
259
+ _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
260
+ _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
199
261
 
262
+ # Calculate optimal block size and warp configuration
200
263
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
264
  if n_cols > BLOCK_SIZE:
202
- raise RuntimeError(
203
- f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
204
- )
205
-
265
+ raise RuntimeError(f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}.")
206
266
  rows_per_program = math.ceil(n_rows / sm_count)
207
267
  grid = (sm_count,)
208
- triton_dtype = (
209
- tl.float32
210
- if X.dtype == torch.float32
211
- else tl.bfloat16
212
- if X.dtype == torch.bfloat16
213
- else tl.float16
214
- if X.dtype == torch.float16
215
- else tl.float32 # fallback to float32 for other types
216
- )
217
268
 
269
+ # Allocate gradient tensors
270
+ DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
271
+
272
+ kernel_args = {"num_warps": num_warps}
218
273
  # XPU-specific optimization
219
- kernel_args = {}
220
274
  if X.device.type == "xpu":
221
275
  kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
276
 
277
+ # Launch kernel with one thread block per row for optimal performance
223
278
  _layer_norm_backward_kernel[grid](
224
279
  X,
280
+ X.stride(0),
225
281
  W,
226
282
  Mean,
283
+ Mean.stride(0),
227
284
  RSTD,
285
+ RSTD.stride(0),
228
286
  DX,
229
- _DW,
230
- _DB,
231
- dY,
232
- X.stride(0),
233
287
  DX.stride(0),
288
+ _DW,
234
289
  _DW.stride(0),
290
+ _DB,
235
291
  _DB.stride(0),
292
+ dY,
236
293
  dY.stride(0),
237
294
  n_rows,
238
295
  n_cols,
239
- rows_per_program,
296
+ rows_per_program=rows_per_program,
240
297
  BLOCK_SIZE=BLOCK_SIZE,
241
- dtype=triton_dtype,
242
- **kernel_args, # XPU-specific optimization
298
+ **kernel_args,
243
299
  )
244
300
 
245
- DW = _DW.sum(dim=0).to(W.dtype)
246
- DB = _DB.sum(dim=0).to(W.dtype)
247
-
248
301
  DX = DX.view(*shape)
302
+ DW = _DW.sum(dim=0).to(W.dtype)
303
+ DB = _DB.sum(dim=0).to(B.dtype)
249
304
  return DX, DW, DB
250
305
 
251
306
 
@@ -0,0 +1,225 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
7
+ # Split or unpack complex frequencies into real and imag parts
8
+ if freqs_cis.is_complex():
9
+ freqs_real = freqs_cis.real
10
+ freqs_imag = freqs_cis.imag
11
+ else:
12
+ # Already split: last dim should be 2*head_dim_half
13
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
14
+ freqs_real = freqs_cis[..., :head_dim_half]
15
+ freqs_imag = freqs_cis[..., head_dim_half:]
16
+ else:
17
+ raise ValueError(
18
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, expected last dim = {2 * head_dim_half}"
19
+ )
20
+
21
+ # Canonicalize to shape (seq_len, head_dim_half):
22
+ # 1) Ensure the last dimension is head_dim_half
23
+ if freqs_real.shape[-1] != head_dim_half:
24
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
25
+ # 2) Flatten all leading dims to a single row dimension
26
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
27
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
28
+ # 3) If we have fewer rows than seq_len, allow broadcasting when single row
29
+ if freqs_real.shape[0] < seq_len:
30
+ if freqs_real.shape[0] == 1:
31
+ freqs_real = freqs_real.expand(seq_len, -1)
32
+ freqs_imag = freqs_imag.expand(seq_len, -1)
33
+ else:
34
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
35
+ # 4) If we have more rows than seq_len (e.g., batch present), take the first seq_len rows
36
+ elif freqs_real.shape[0] > seq_len:
37
+ freqs_real = freqs_real[:seq_len]
38
+ freqs_imag = freqs_imag[:seq_len]
39
+
40
+ return freqs_real, freqs_imag
41
+
42
+
43
+ def _maybe_to_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
44
+ return t if t.dtype == dtype else t.to(dtype)
45
+
46
+
47
+ def _maybe_contiguous(t: torch.Tensor) -> torch.Tensor:
48
+ return t if t.is_contiguous() else t.contiguous()
49
+
50
+
51
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
52
+ # Choose compute dtype: use fp32 only when inputs are fp32; otherwise keep input dtype for performance
53
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
54
+
55
+ # Make sure q/k share the same dtype before casting to compute dtype
56
+ if k.dtype != q.dtype:
57
+ k = k.to(q.dtype)
58
+
59
+ q = _maybe_contiguous(_maybe_to_dtype(q, compute_dtype))
60
+ k = _maybe_contiguous(_maybe_to_dtype(k, compute_dtype))
61
+ freqs_real = _maybe_contiguous(_maybe_to_dtype(freqs_real, compute_dtype))
62
+ freqs_imag = _maybe_contiguous(_maybe_to_dtype(freqs_imag, compute_dtype))
63
+ return q, k, freqs_real, freqs_imag
64
+
65
+
66
+ @triton.jit
67
+ def _llama4_rope_kernel(
68
+ q_ptr,
69
+ k_ptr,
70
+ freqs_real_ptr,
71
+ freqs_imag_ptr,
72
+ q_row_stride,
73
+ k_row_stride,
74
+ q_head_stride,
75
+ k_head_stride,
76
+ freqs_row_stride,
77
+ seq_len,
78
+ batch_size,
79
+ imag_sign,
80
+ head_dim_half: tl.constexpr,
81
+ n_q_heads: tl.constexpr,
82
+ n_k_heads: tl.constexpr,
83
+ BLOCK_SIZE: tl.constexpr,
84
+ ):
85
+ """
86
+ H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
87
+ Grid: (batch*seq, head)
88
+ """
89
+ # 2D grid
90
+ pid_bs = tl.program_id(0) # over batch*seq
91
+ pid_h = tl.program_id(1) # over heads
92
+
93
+ batch_idx = pid_bs // seq_len
94
+ seq_idx = pid_bs % seq_len
95
+
96
+ # Bounds check
97
+ if batch_idx >= batch_size or seq_idx >= seq_len:
98
+ return
99
+
100
+ # Base pointers for this (batch, seq) position
101
+ base_offset = batch_idx * seq_len + seq_idx
102
+ q_base = q_ptr + base_offset * q_row_stride
103
+ k_base = k_ptr + base_offset * k_row_stride
104
+
105
+ # Tiling over dim/2
106
+ for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
107
+ d_indices = d_start + tl.arange(0, BLOCK_SIZE)
108
+ mask_d = d_indices < head_dim_half
109
+
110
+ # Load frequencies once per tile (freqs layout: [seq_len, head_dim_half])
111
+ freq_idx = d_indices
112
+ freqs_real = tl.load(freqs_real_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
113
+ freqs_imag = tl.load(freqs_imag_ptr + seq_idx * freqs_row_stride + freq_idx, mask=mask_d, other=0.0)
114
+ freqs_imag = freqs_imag * imag_sign
115
+
116
+ # Process one query head per program in pid_h
117
+ if pid_h < n_q_heads:
118
+ q_head_ptr = q_base + pid_h * q_head_stride
119
+ q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
120
+ q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
121
+
122
+ # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
123
+ new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
124
+ new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
125
+
126
+ tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
127
+ tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)
128
+
129
+ # Process one key head per program in pid_h
130
+ if pid_h < n_k_heads:
131
+ k_head_ptr = k_base + pid_h * k_head_stride
132
+ k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
133
+ k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)
134
+
135
+ new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
136
+ new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
137
+
138
+ tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
139
+ tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)
140
+
141
+
142
+ def _select_kernel_meta(head_dim_half: int):
143
+ # Heuristic tuning for block size and num_warps
144
+ if head_dim_half >= 256:
145
+ return 128, 8
146
+ if head_dim_half >= 96:
147
+ return 128, 4
148
+ if head_dim_half >= 48:
149
+ return 64, 4
150
+ if head_dim_half >= 24:
151
+ return 32, 2
152
+ return 16, 2
153
+
154
+
155
+ def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
156
+ # Save original dtype for casting back
157
+ original_dtype = q.dtype
158
+
159
+ batch_size, seq_len, n_q_heads, head_dim = q.shape
160
+ _, _, n_k_heads, _ = k.shape
161
+ head_dim_half = head_dim // 2
162
+
163
+ # Prepare frequencies
164
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, seq_len, head_dim_half)
165
+
166
+ # Cast to appropriate dtype and make contiguous only when needed
167
+ q, k, freqs_real, freqs_imag = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
168
+
169
+ # H100-optimized meta-params
170
+ if BLOCK_SIZE is None:
171
+ BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
172
+ else:
173
+ # Provide a default num_warps if caller pins BLOCK_SIZE
174
+ _, num_warps = _select_kernel_meta(head_dim_half)
175
+
176
+ # 2D grid: one program per (batch, seq, head)
177
+ n_heads_max = max(n_q_heads, n_k_heads)
178
+ grid = (batch_size * seq_len, n_heads_max)
179
+
180
+ # Launch kernel
181
+ _llama4_rope_kernel[grid](
182
+ q,
183
+ k,
184
+ freqs_real,
185
+ freqs_imag,
186
+ q.stride(1),
187
+ k.stride(1),
188
+ q.stride(2),
189
+ k.stride(2),
190
+ freqs_real.stride(0),
191
+ seq_len,
192
+ batch_size,
193
+ imag_sign,
194
+ head_dim_half,
195
+ n_q_heads,
196
+ n_k_heads,
197
+ BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ num_stages=2,
200
+ )
201
+
202
+ # Cast back to original dtype only if it differs from compute dtype
203
+ if q.dtype != original_dtype:
204
+ q = q.to(original_dtype)
205
+ if k.dtype != original_dtype:
206
+ k = k.to(original_dtype)
207
+
208
+ return q, k
209
+
210
+
211
+ class LigerLlama4RopeFunction(torch.autograd.Function):
212
+ @staticmethod
213
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
214
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
215
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
216
+ ctx.BLOCK_SIZE = BLOCK_SIZE
217
+ return q_out, k_out
218
+
219
+ @staticmethod
220
+ def backward(ctx, dq, dk):
221
+ (freqs_cis,) = ctx.saved_tensors
222
+ BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
223
+ # Use imag_sign=-1.0 for conjugate without materializing a new tensor
224
+ dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
225
+ return dq_out, dk_out, None