liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  12. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  13. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  14. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  15. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  16. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  17. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  18. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  19. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  20. liger_kernel/ops/backends/registry.py +61 -0
  21. liger_kernel/ops/cross_entropy.py +71 -11
  22. liger_kernel/ops/dyt.py +5 -2
  23. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  24. liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
  25. liger_kernel/ops/geglu.py +5 -3
  26. liger_kernel/ops/group_norm.py +12 -8
  27. liger_kernel/ops/grpo_loss.py +3 -1
  28. liger_kernel/ops/kl_div.py +8 -11
  29. liger_kernel/ops/layer_norm.py +89 -69
  30. liger_kernel/ops/poly_norm.py +19 -21
  31. liger_kernel/ops/rms_norm.py +149 -71
  32. liger_kernel/ops/tiled_mlp.py +136 -0
  33. liger_kernel/ops/utils.py +25 -0
  34. liger_kernel/transformers/__init__.py +25 -0
  35. liger_kernel/transformers/auto_model.py +21 -0
  36. liger_kernel/transformers/cross_entropy.py +9 -4
  37. liger_kernel/transformers/dyt.py +1 -1
  38. liger_kernel/transformers/experimental/embedding.py +1 -1
  39. liger_kernel/transformers/functional.py +44 -26
  40. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  41. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  42. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  43. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  44. liger_kernel/transformers/geglu.py +1 -1
  45. liger_kernel/transformers/group_norm.py +1 -1
  46. liger_kernel/transformers/grpo_loss.py +57 -2
  47. liger_kernel/transformers/jsd.py +1 -1
  48. liger_kernel/transformers/kl_div.py +1 -1
  49. liger_kernel/transformers/layer_norm.py +1 -1
  50. liger_kernel/transformers/llama4_rope.py +1 -1
  51. liger_kernel/transformers/model/exaone4.py +136 -0
  52. liger_kernel/transformers/model/falcon_h1.py +19 -5
  53. liger_kernel/transformers/model/gemma.py +17 -6
  54. liger_kernel/transformers/model/gemma2.py +17 -8
  55. liger_kernel/transformers/model/gemma3.py +35 -16
  56. liger_kernel/transformers/model/glm4.py +16 -4
  57. liger_kernel/transformers/model/glm4v.py +16 -4
  58. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  59. liger_kernel/transformers/model/gpt_oss.py +211 -0
  60. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  61. liger_kernel/transformers/model/internvl.py +12 -5
  62. liger_kernel/transformers/model/llama.py +14 -5
  63. liger_kernel/transformers/model/llama4.py +16 -4
  64. liger_kernel/transformers/model/llava.py +12 -4
  65. liger_kernel/transformers/model/loss_utils.py +37 -3
  66. liger_kernel/transformers/model/mistral.py +15 -6
  67. liger_kernel/transformers/model/mixtral.py +16 -7
  68. liger_kernel/transformers/model/mllama.py +12 -4
  69. liger_kernel/transformers/model/olmo2.py +16 -4
  70. liger_kernel/transformers/model/olmo3.py +142 -0
  71. liger_kernel/transformers/model/output_classes.py +147 -0
  72. liger_kernel/transformers/model/paligemma.py +23 -5
  73. liger_kernel/transformers/model/phi3.py +14 -7
  74. liger_kernel/transformers/model/qwen2.py +16 -3
  75. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  76. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  77. liger_kernel/transformers/model/qwen3.py +20 -5
  78. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  79. liger_kernel/transformers/model/qwen3_next.py +17 -5
  80. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  81. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  82. liger_kernel/transformers/model/smollm3.py +15 -6
  83. liger_kernel/transformers/monkey_patch.py +584 -49
  84. liger_kernel/transformers/multi_token_attention.py +1 -1
  85. liger_kernel/transformers/poly_norm.py +1 -1
  86. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  87. liger_kernel/transformers/rms_norm.py +8 -3
  88. liger_kernel/transformers/rope.py +45 -1
  89. liger_kernel/transformers/softmax.py +1 -1
  90. liger_kernel/transformers/sparsemax.py +1 -1
  91. liger_kernel/transformers/swiglu.py +18 -1
  92. liger_kernel/transformers/tiled_mlp.py +125 -0
  93. liger_kernel/transformers/tvd.py +1 -1
  94. liger_kernel/utils.py +54 -0
  95. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
  96. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  97. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  98. liger_kernel-0.6.3.dist-info/RECORD +0 -111
  99. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  100. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  101. {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,11 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import calculate_settings
8
8
  from liger_kernel.ops.utils import compare_version
9
9
  from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import get_npu_core_count
11
+ from liger_kernel.ops.utils import set_large_grf_mode
12
+ from liger_kernel.utils import is_npu_available
10
13
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
14
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
15
  try:
13
16
  from triton.language.extra.libdevice import rsqrt
14
17
  except ModuleNotFoundError:
@@ -138,20 +141,19 @@ def _poly_norm_backward_kernel(
138
141
  w1 = tl.load(W_ptr + 1).to(tl.float32)
139
142
  w2 = tl.load(W_ptr + 2).to(tl.float32)
140
143
 
141
- dY_ptr += row_start * dY_row_stride
142
- dX_ptr += row_start * dX_row_stride
143
- X_ptr += row_start * X_row_stride
144
- RSTD_ptr += row_start * RSTD_row_stride
144
+ for row_idx in range(row_start, row_end):
145
+ dy_base = dY_ptr + row_idx * dY_row_stride
146
+ x_base = X_ptr + row_idx * X_row_stride
147
+ dx_base = dX_ptr + row_idx * dX_row_stride
148
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
145
149
 
146
- for _ in range(row_start, row_end):
147
- # Load input and gradient
148
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
149
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
151
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
152
 
151
153
  # Load cached rstd values
152
- rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
153
- rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
154
- rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
154
+ rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
155
+ rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
156
+ rstd_1 = tl.load(rstd_base + 2).to(tl.float32)
155
157
 
156
158
  # Compute powers
157
159
  X_pow3 = X_row * X_row * X_row
@@ -188,13 +190,7 @@ def _poly_norm_backward_kernel(
188
190
  dX_row = grad_x_3 + grad_x_2 + grad_x_1
189
191
 
190
192
  # Store gradient
191
- tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
192
-
193
- # Update pointers
194
- dY_ptr += dY_row_stride
195
- dX_ptr += dX_row_stride
196
- X_ptr += X_row_stride
197
- RSTD_ptr += RSTD_row_stride
193
+ tl.store(dx_base + col_offsets, dX_row, mask=mask)
198
194
 
199
195
  # Store accumulated gradients (scalars)
200
196
  tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
@@ -237,7 +233,7 @@ def poly_norm_forward(X, W, B, eps=1e-6):
237
233
  # XPU-specific optimization
238
234
  kernel_args = {}
239
235
  if X.device.type == "xpu":
240
- kernel_args["grf_mode"] = "large"
236
+ set_large_grf_mode(kernel_args)
241
237
 
242
238
  # Launch kernel
243
239
  _poly_norm_forward_kernel[(n_rows,)](
@@ -290,6 +286,8 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
290
286
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291
287
  elif X.device.type == "xpu":
292
288
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
289
+ elif X.device.type == "npu":
290
+ sm_count = get_npu_core_count()
293
291
 
294
292
  # Allocate or reuse gradients
295
293
  if in_place is True:
@@ -306,7 +304,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
306
304
  # XPU-specific optimization
307
305
  kernel_args = {}
308
306
  if X.device.type == "xpu":
309
- kernel_args["grf_mode"] = "large"
307
+ set_large_grf_mode(kernel_args)
310
308
 
311
309
  # Launch backward kernel
312
310
  _poly_norm_backward_kernel[grid](
@@ -20,9 +20,12 @@ import triton.language as tl
20
20
  from liger_kernel.ops.utils import calculate_settings
21
21
  from liger_kernel.ops.utils import compare_version
22
22
  from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import get_npu_core_count
24
+ from liger_kernel.ops.utils import set_large_grf_mode
23
25
  from liger_kernel.ops.utils import torch_to_triton_dtype
26
+ from liger_kernel.utils import is_npu_available
24
27
 
25
- if compare_version("triton", operator.ge, "3.0.0"):
28
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
26
29
  try:
27
30
  # typical import path with dispatch available
28
31
  from triton.language.extra.libdevice import rsqrt
@@ -52,6 +55,7 @@ def _rms_norm_forward_kernel(
52
55
  eps,
53
56
  offset,
54
57
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
58
+ elementwise_affine: tl.constexpr,
55
59
  BLOCK_SIZE: tl.constexpr,
56
60
  ):
57
61
  """
@@ -67,13 +71,14 @@ def _rms_norm_forward_kernel(
67
71
  col_offsets = tl.arange(0, BLOCK_SIZE)
68
72
  mask = col_offsets < n_cols
69
73
 
70
- Y_ptr += row_idx * Y_row_stride
71
- X_ptr += row_idx * X_row_stride
72
- RSTD_ptr += row_idx * RSTD_row_stride
74
+ y_base = Y_ptr + row_idx * Y_row_stride
75
+ x_base = X_ptr + row_idx * X_row_stride
76
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
73
77
 
74
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
78
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0)
75
79
  X_row_dtype = X_row.dtype
76
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
80
+ if elementwise_affine:
81
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
77
82
 
78
83
  # On Llama, only rstd is computed on fp32
79
84
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -81,7 +86,8 @@ def _rms_norm_forward_kernel(
81
86
 
82
87
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
88
  if casting_mode == _CASTING_MODE_GEMMA:
84
- W_row = W_row.to(tl.float32)
89
+ if elementwise_affine:
90
+ W_row = W_row.to(tl.float32)
85
91
  X_row = X_row.to(tl.float32)
86
92
 
87
93
  if casting_mode == _CASTING_MODE_NONE:
@@ -94,7 +100,7 @@ def _rms_norm_forward_kernel(
94
100
  # We can save time by caching rms with minimal memory overhead
95
101
  # because rms is much smaller compared to X_row, as rms is for each row.
96
102
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
97
- tl.store(RSTD_ptr, rstd)
103
+ tl.store(rstd_base, rstd)
98
104
 
99
105
  X_row = X_row * rstd
100
106
 
@@ -102,12 +108,15 @@ def _rms_norm_forward_kernel(
102
108
  if casting_mode == _CASTING_MODE_LLAMA:
103
109
  X_row = X_row.to(X_row_dtype)
104
110
 
105
- Y_row = X_row * (offset + W_row)
111
+ if elementwise_affine:
112
+ Y_row = X_row * (offset + W_row)
113
+ else:
114
+ Y_row = X_row
106
115
 
107
116
  if casting_mode == _CASTING_MODE_GEMMA:
108
117
  Y_row = Y_row.to(X_row_dtype)
109
118
 
110
- tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
119
+ tl.store(y_base + col_offsets, Y_row, mask=mask)
111
120
 
112
121
 
113
122
  @triton.jit
@@ -128,8 +137,9 @@ def _rms_norm_backward_kernel(
128
137
  n_rows,
129
138
  n_cols,
130
139
  offset,
131
- rows_per_program: tl.constexpr,
140
+ rows_per_program,
132
141
  casting_mode: tl.constexpr,
142
+ elementwise_affine: tl.constexpr,
133
143
  BLOCK_SIZE: tl.constexpr,
134
144
  ):
135
145
  """
@@ -143,55 +153,63 @@ def _rms_norm_backward_kernel(
143
153
  col_offsets = tl.arange(0, BLOCK_SIZE)
144
154
  mask = col_offsets < n_cols
145
155
 
146
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
156
+ if elementwise_affine:
157
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
147
158
 
148
- dY_ptr += row_start * dY_row_stride
149
- dX_ptr += row_start * dX_row_stride
159
+ if elementwise_affine:
160
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
161
+ W_row = W_row + offset
150
162
 
151
- X_ptr += row_start * X_row_stride
152
- RSTD_ptr += row_start
163
+ for row_idx in range(row_start, row_end):
164
+ dy_base = dY_ptr + row_idx * dY_row_stride
165
+ dx_base = dX_ptr + row_idx * dX_row_stride
153
166
 
154
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
155
- W_row = W_row + offset
167
+ x_base = X_ptr + row_idx * X_row_stride
168
+ rstd_base = RSTD_ptr + row_idx * RSTD_row_stride
156
169
 
157
- for _ in range(row_start, row_end):
158
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
159
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
170
+ dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0)
171
+ X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0)
160
172
 
161
173
  # Get cached rms
162
- rstd_row = tl.load(RSTD_ptr)
174
+ rstd_row = tl.load(rstd_base)
163
175
 
164
176
  X_row = X_row.to(tl.float32)
165
177
 
166
178
  # Different bacward graphs for different casting modes
167
179
  if casting_mode == _CASTING_MODE_LLAMA:
168
- m = (dY_row * W_row).to(tl.float32)
180
+ if elementwise_affine:
181
+ m = (dY_row * W_row).to(tl.float32)
182
+ else:
183
+ m = dY_row.to(tl.float32)
169
184
 
170
185
  elif casting_mode == _CASTING_MODE_GEMMA:
171
186
  dY_row = dY_row.to(tl.float32)
172
- m = dY_row * W_row
187
+ if elementwise_affine:
188
+ m = dY_row * W_row
189
+ else:
190
+ m = dY_row
173
191
  else:
174
- m = dY_row * W_row
192
+ if elementwise_affine:
193
+ m = dY_row * W_row
194
+ else:
195
+ m = dY_row
175
196
 
176
197
  dX_row = rstd_row * m
177
198
 
178
199
  dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
179
200
 
180
- # calculate the gradient of W
181
- if casting_mode == _CASTING_MODE_LLAMA:
182
- dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
183
- else:
184
- # here X_row is already in fp32 (see previous if block)
185
- dW_row += dY_row * (X_row * rstd_row)
201
+ if elementwise_affine:
202
+ # calculate the gradient of W
203
+ if casting_mode == _CASTING_MODE_LLAMA:
204
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
205
+ else:
206
+ # here X_row is already in fp32 (see previous if block)
207
+ dW_row += dY_row * (X_row * rstd_row)
186
208
 
187
- tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
209
+ tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask)
188
210
 
189
- dY_ptr += dY_row_stride
190
- dX_ptr += dX_row_stride
191
- X_ptr += X_row_stride
192
- RSTD_ptr += RSTD_row_stride
193
-
194
- tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
211
+ if elementwise_affine:
212
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
195
213
 
196
214
 
197
215
  @triton.jit
@@ -209,6 +227,7 @@ def _block_rms_norm_forward_kernel(
209
227
  eps,
210
228
  offset,
211
229
  casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
230
+ elementwise_affine: tl.constexpr,
212
231
  BLOCK_SIZE: tl.constexpr,
213
232
  BLOCK_ROW: tl.constexpr,
214
233
  ):
@@ -232,7 +251,8 @@ def _block_rms_norm_forward_kernel(
232
251
  other=0,
233
252
  )
234
253
  X_row_dtype = X_row.dtype
235
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
254
+ if elementwise_affine:
255
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
236
256
 
237
257
  # On Llama, only rstd is computed on fp32
238
258
  if casting_mode == _CASTING_MODE_LLAMA:
@@ -240,7 +260,8 @@ def _block_rms_norm_forward_kernel(
240
260
 
241
261
  # Gemma computes everything on fp32, and then casts back the output to the original dtype
242
262
  if casting_mode == _CASTING_MODE_GEMMA:
243
- W_row = W_row.to(tl.float32)
263
+ if elementwise_affine:
264
+ W_row = W_row.to(tl.float32)
244
265
  X_row = X_row.to(tl.float32)
245
266
 
246
267
  if casting_mode == _CASTING_MODE_NONE:
@@ -261,7 +282,10 @@ def _block_rms_norm_forward_kernel(
261
282
  if casting_mode == _CASTING_MODE_LLAMA:
262
283
  X_row = X_row.to(X_row_dtype)
263
284
 
264
- Y_row = X_row * (offset + W_row)[None, :]
285
+ if elementwise_affine:
286
+ Y_row = X_row * (offset + W_row)[None, :]
287
+ else:
288
+ Y_row = X_row
265
289
 
266
290
  if casting_mode == _CASTING_MODE_GEMMA:
267
291
  Y_row = Y_row.to(X_row_dtype)
@@ -291,8 +315,8 @@ def _block_rms_norm_backward_kernel(
291
315
  n_rows,
292
316
  n_cols,
293
317
  offset,
294
- rows_per_program: tl.constexpr,
295
318
  casting_mode: tl.constexpr,
319
+ elementwise_affine: tl.constexpr,
296
320
  BLOCK_SIZE: tl.constexpr,
297
321
  BLOCK_ROW: tl.constexpr,
298
322
  ):
@@ -307,10 +331,11 @@ def _block_rms_norm_backward_kernel(
307
331
  col_offsets = tl.arange(0, BLOCK_SIZE)
308
332
  col_mask = col_offsets < n_cols
309
333
 
310
- dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
334
+ if elementwise_affine:
335
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
311
336
 
312
- W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
313
- W_row = W_row + offset
337
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
338
+ W_row = W_row + offset
314
339
 
315
340
  for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
316
341
  row_idx = start + tl.arange(0, BLOCK_ROW)
@@ -333,13 +358,22 @@ def _block_rms_norm_backward_kernel(
333
358
 
334
359
  # Different bacward graphs for different casting modes
335
360
  if casting_mode == _CASTING_MODE_LLAMA:
336
- m = (dY_row * W_row[None, :]).to(tl.float32)
361
+ if elementwise_affine:
362
+ m = (dY_row * W_row[None, :]).to(tl.float32)
363
+ else:
364
+ m = dY_row.to(tl.float32)
337
365
 
338
366
  elif casting_mode == _CASTING_MODE_GEMMA:
339
367
  dY_row = dY_row.to(tl.float32)
340
- m = dY_row * W_row[None, :]
368
+ if elementwise_affine:
369
+ m = dY_row * W_row[None, :]
370
+ else:
371
+ m = dY_row
341
372
  else:
342
- m = dY_row * W_row[None, :]
373
+ if elementwise_affine:
374
+ m = dY_row * W_row[None, :]
375
+ else:
376
+ m = dY_row
343
377
 
344
378
  dX_row = rstd_row[:, None] * m
345
379
 
@@ -347,12 +381,13 @@ def _block_rms_norm_backward_kernel(
347
381
  -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
348
382
  )
349
383
 
350
- # calculate the gradient of W
351
- if casting_mode == _CASTING_MODE_LLAMA:
352
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]).to(X_dtype), 0)
353
- else:
354
- # here X_row is already in fp32 (see previous if block)
355
- dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
384
+ if elementwise_affine:
385
+ if casting_mode == _CASTING_MODE_LLAMA:
386
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
387
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
388
+ else:
389
+ # here X_row is already in fp32 (see previous if block)
390
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
356
391
 
357
392
  tl.store(
358
393
  dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
@@ -360,7 +395,8 @@ def _block_rms_norm_backward_kernel(
360
395
  mask=row_mask[:, None] & col_mask[None, :],
361
396
  )
362
397
 
363
- tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
398
+ if elementwise_affine:
399
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
364
400
 
365
401
 
366
402
  _str_to_casting_mode = {
@@ -389,13 +425,19 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
389
425
  rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
390
426
  RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
391
427
 
392
- # Check constraints.
393
- assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
428
+ if W is not None:
429
+ # Check constraints.
430
+ assert X.shape[1] == W.shape[0], (
431
+ "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
432
+ )
433
+ elementwise_affine = True
434
+ else:
435
+ elementwise_affine = False
394
436
 
395
437
  # XPU-specific optimization
396
438
  kernel_args = {}
397
439
  if X.device.type == "xpu":
398
- kernel_args["grf_mode"] = "large"
440
+ set_large_grf_mode(kernel_args)
399
441
  if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
400
442
  _rms_norm_forward_kernel[(n_rows,)](
401
443
  Y,
@@ -403,13 +445,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
403
445
  X,
404
446
  X.stride(0),
405
447
  W,
406
- W.stride(0),
448
+ W.stride(0) if elementwise_affine else 0,
407
449
  RSTD,
408
450
  RSTD.stride(0),
409
451
  n_cols,
410
452
  eps,
411
453
  offset,
412
454
  casting_mode,
455
+ elementwise_affine=elementwise_affine,
413
456
  BLOCK_SIZE=BLOCK_SIZE,
414
457
  num_warps=num_warps,
415
458
  **kernel_args, # XPU-specific optimization
@@ -423,7 +466,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
423
466
  X,
424
467
  X.stride(0),
425
468
  W,
426
- W.stride(0),
469
+ W.stride(0) if elementwise_affine else 0,
427
470
  RSTD,
428
471
  RSTD.stride(0),
429
472
  n_rows,
@@ -431,6 +474,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
431
474
  eps,
432
475
  offset,
433
476
  casting_mode,
477
+ elementwise_affine=elementwise_affine,
434
478
  BLOCK_SIZE=BLOCK_SIZE,
435
479
  num_warps=num_warps,
436
480
  **kernel_args, # XPU-specific optimization
@@ -449,9 +493,16 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
449
493
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
450
494
  elif X.device.type == "xpu":
451
495
  sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
496
+ elif X.device.type == "npu":
497
+ sm_count = get_npu_core_count()
452
498
 
453
- # fp32 for numerical stability especially.
454
- _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
499
+ if W is not None:
500
+ # fp32 for numerical stability especially.
501
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
502
+ elementwise_affine = True
503
+ else:
504
+ _dW = None
505
+ elementwise_affine = False
455
506
 
456
507
  if n_cols > BLOCK_SIZE:
457
508
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
@@ -466,7 +517,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
466
517
  # XPU-specific optimization
467
518
  kernel_args = {}
468
519
  if X.device.type == "xpu":
469
- kernel_args["grf_mode"] = "large"
520
+ set_large_grf_mode(kernel_args)
470
521
 
471
522
  if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
472
523
  _rms_norm_backward_kernel[grid](
@@ -478,16 +529,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
478
529
  X.stride(0),
479
530
  torch_to_triton_dtype[X.dtype],
480
531
  W,
481
- W.stride(0),
532
+ W.stride(0) if elementwise_affine else 0,
482
533
  RSTD,
483
534
  RSTD.stride(0),
484
535
  _dW,
485
- _dW.stride(0),
536
+ _dW.stride(0) if elementwise_affine else 0,
486
537
  n_rows,
487
538
  n_cols,
488
539
  offset,
489
540
  rows_per_program,
490
541
  casting_mode,
542
+ elementwise_affine=elementwise_affine,
491
543
  BLOCK_SIZE=BLOCK_SIZE,
492
544
  num_warps=num_warps,
493
545
  **kernel_args, # XPU-specific optimization
@@ -504,22 +556,26 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
504
556
  X.stride(0),
505
557
  torch_to_triton_dtype[X.dtype],
506
558
  W,
507
- W.stride(0),
559
+ W.stride(0) if elementwise_affine else 0,
508
560
  RSTD,
509
561
  RSTD.stride(0),
510
562
  _dW,
511
- _dW.stride(0),
563
+ _dW.stride(0) if elementwise_affine else 0,
512
564
  n_rows,
513
565
  n_cols,
514
566
  offset,
515
- rows_per_program,
516
567
  casting_mode,
568
+ elementwise_affine=elementwise_affine,
517
569
  BLOCK_SIZE=BLOCK_SIZE,
518
570
  num_warps=num_warps,
519
571
  **kernel_args, # XPU-specific optimization
520
572
  )
521
573
  dX = dX.view(*shape)
522
- dW = _dW.sum(dim=0).to(W.dtype)
574
+
575
+ if elementwise_affine:
576
+ dW = _dW.sum(dim=0).to(W.dtype)
577
+ else:
578
+ dW = None
523
579
 
524
580
  return dX, dW
525
581
 
@@ -553,6 +609,13 @@ class LigerRMSNormFunction(torch.autograd.Function):
553
609
  X: (B, T, H) or (BxT, H)
554
610
  W: (H,)
555
611
  """
612
+ if isinstance(X, torch.distributed.tensor.DTensor):
613
+ # Input tensor is output of a tensor parallel module and
614
+ # needs to be gathered to a local tensor to compute
615
+ # RMSE layer norm on each TP worker.
616
+ # TODO: support CP.
617
+ X = X.full_tensor()
618
+
556
619
  Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
557
620
  ctx.offset = offset
558
621
  ctx.casting_mode = casting_mode
@@ -560,7 +623,11 @@ class LigerRMSNormFunction(torch.autograd.Function):
560
623
  ctx.row_mode = row_mode
561
624
  ctx.BLOCK_SIZE = BLOCK_SIZE
562
625
  ctx.num_warps = num_warps
563
- ctx.save_for_backward(X, W, RSTD)
626
+ ctx.elementwise_affine = W is not None
627
+ if W is not None:
628
+ ctx.save_for_backward(X, W, RSTD)
629
+ else:
630
+ ctx.save_for_backward(X, RSTD)
564
631
  return Y
565
632
 
566
633
  @staticmethod
@@ -569,7 +636,18 @@ class LigerRMSNormFunction(torch.autograd.Function):
569
636
  """
570
637
  Y: (B, T, H) or (BxT, H)
571
638
  """
572
- X, W, RSTD = ctx.saved_tensors
639
+ if ctx.elementwise_affine:
640
+ X, W, RSTD = ctx.saved_tensors
641
+ else:
642
+ X, RSTD = ctx.saved_tensors
643
+ W = None
644
+
645
+ if isinstance(dY, torch.distributed.tensor.DTensor):
646
+ # Gradients are output of a tensor parallel module and
647
+ # needs to be gathered to a local tensor for computing RMSE layer.
648
+ # TODO: support CP.
649
+ dY = dY.full_tensor()
650
+
573
651
  dX, dW = rms_norm_backward(
574
652
  dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
575
653
  )