liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,63 @@
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import math
14
+ import operator
15
+
1
16
  import torch
2
17
  import triton
3
18
  import triton.language as tl
4
19
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
20
+ from liger_kernel.ops.utils import calculate_settings
21
+ from liger_kernel.ops.utils import compare_version
22
+ from liger_kernel.ops.utils import ensure_contiguous
23
+ from liger_kernel.ops.utils import torch_to_triton_dtype
24
+ from liger_kernel.utils import get_npu_multi_processor_count
25
+ from liger_kernel.utils import is_npu_available
26
+
27
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
28
+ try:
29
+ # typical import path with dispatch available
30
+ from triton.language.extra.libdevice import rsqrt
31
+ except ModuleNotFoundError:
32
+ # for working with NGC containers
33
+ from triton.language.extra.cuda.libdevice import rsqrt
34
+ else:
35
+ from triton.language.math import rsqrt
36
+
37
+
38
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
39
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
40
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
6
41
 
7
42
 
8
43
  @triton.jit
9
- def _rms_norm_forward(
44
+ def _rms_norm_forward_kernel(
10
45
  Y_ptr,
11
46
  Y_row_stride,
12
47
  X_ptr,
13
48
  X_row_stride,
14
49
  W_ptr,
15
50
  W_row_stride,
16
- r_ptr,
17
- r_row_stride,
51
+ RSTD_ptr,
52
+ RSTD_row_stride,
18
53
  n_cols,
19
54
  eps,
55
+ offset,
56
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
20
57
  BLOCK_SIZE: tl.constexpr,
21
58
  ):
22
59
  """
23
- y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
60
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
24
61
 
25
62
  Reference:
26
63
  1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
@@ -28,158 +65,517 @@ def _rms_norm_forward(
28
65
  3. https://arxiv.org/pdf/1910.07467
29
66
  """
30
67
 
31
- row_idx = tl.program_id(0)
68
+ row_idx = tl.program_id(0).to(tl.int64)
32
69
  col_offsets = tl.arange(0, BLOCK_SIZE)
33
70
  mask = col_offsets < n_cols
34
71
 
35
72
  Y_ptr += row_idx * Y_row_stride
36
73
  X_ptr += row_idx * X_row_stride
37
- r_ptr += row_idx * r_row_stride
74
+ RSTD_ptr += row_idx * RSTD_row_stride
38
75
 
39
76
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
77
+ X_row_dtype = X_row.dtype
40
78
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
41
79
 
80
+ # On Llama, only rstd is computed on fp32
81
+ if casting_mode == _CASTING_MODE_LLAMA:
82
+ X_row = X_row.to(tl.float32)
83
+
84
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
85
+ if casting_mode == _CASTING_MODE_GEMMA:
86
+ W_row = W_row.to(tl.float32)
87
+ X_row = X_row.to(tl.float32)
88
+
89
+ if casting_mode == _CASTING_MODE_NONE:
90
+ eps = eps.to(X_row_dtype)
91
+ offset = offset.to(X_row_dtype)
92
+
42
93
  mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
43
- inv_rms = tl.math.rsqrt(mean_square + eps)
94
+ rstd = rsqrt(mean_square + eps)
44
95
 
45
96
  # We can save time by caching rms with minimal memory overhead
46
97
  # because rms is much smaller compared to X_row, as rms is for each row.
47
98
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
48
- tl.store(r_ptr, inv_rms)
99
+ tl.store(RSTD_ptr, rstd)
100
+
101
+ X_row = X_row * rstd
49
102
 
50
- Y_row = X_row * inv_rms * W_row
103
+ # On Llama, the multiplication with the weight is done on the original dtype
104
+ if casting_mode == _CASTING_MODE_LLAMA:
105
+ X_row = X_row.to(X_row_dtype)
106
+
107
+ Y_row = X_row * (offset + W_row)
108
+
109
+ if casting_mode == _CASTING_MODE_GEMMA:
110
+ Y_row = Y_row.to(X_row_dtype)
51
111
 
52
112
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
53
113
 
54
114
 
55
115
  @triton.jit
56
- def _rms_norm_backward(
116
+ def _rms_norm_backward_kernel(
57
117
  dY_ptr,
58
118
  dY_row_stride,
119
+ dX_ptr,
120
+ dX_row_stride,
59
121
  X_ptr,
60
122
  X_row_stride,
123
+ X_dtype: tl.constexpr,
61
124
  W_ptr,
62
125
  W_row_stride,
63
- r_ptr,
64
- r_row_stride,
126
+ RSTD_ptr,
127
+ RSTD_row_stride,
65
128
  dW_ptr,
66
129
  dW_row_stride,
130
+ n_rows,
67
131
  n_cols,
68
- eps,
132
+ offset,
133
+ rows_per_program: tl.constexpr,
134
+ casting_mode: tl.constexpr,
69
135
  BLOCK_SIZE: tl.constexpr,
70
136
  ):
71
137
  """
72
- dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
138
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
73
139
  dw = sum(dy * (x / RMS)). summation over BxT dimension
74
140
  """
75
141
 
76
- row_idx = tl.program_id(0)
142
+ row_block_id = tl.program_id(0).to(tl.int64)
143
+ row_start = row_block_id * rows_per_program
144
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
77
145
  col_offsets = tl.arange(0, BLOCK_SIZE)
78
146
  mask = col_offsets < n_cols
79
147
 
80
- dY_ptr += row_idx * dY_row_stride
81
- X_ptr += row_idx * X_row_stride
82
- r_ptr += row_idx * r_row_stride
83
- dW_ptr += row_idx * dW_row_stride
148
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
84
149
 
85
- dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
86
- X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
87
- W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
150
+ dY_ptr += row_start * dY_row_stride
151
+ dX_ptr += row_start * dX_row_stride
152
+
153
+ X_ptr += row_start * X_row_stride
154
+ RSTD_ptr += row_start
155
+
156
+ W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
157
+ W_row = W_row + offset
158
+
159
+ for _ in range(row_start, row_end):
160
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
161
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
162
+
163
+ # Get cached rms
164
+ rstd_row = tl.load(RSTD_ptr)
165
+
166
+ X_row = X_row.to(tl.float32)
88
167
 
89
- # Get cached rms
90
- inv_rms_row = tl.load(r_ptr)
168
+ # Different bacward graphs for different casting modes
169
+ if casting_mode == _CASTING_MODE_LLAMA:
170
+ m = (dY_row * W_row).to(tl.float32)
171
+
172
+ elif casting_mode == _CASTING_MODE_GEMMA:
173
+ dY_row = dY_row.to(tl.float32)
174
+ m = dY_row * W_row
175
+ else:
176
+ m = dY_row * W_row
177
+
178
+ dX_row = rstd_row * m
179
+
180
+ dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
181
+
182
+ # calculate the gradient of W
183
+ if casting_mode == _CASTING_MODE_LLAMA:
184
+ dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
185
+ else:
186
+ # here X_row is already in fp32 (see previous if block)
187
+ dW_row += dY_row * (X_row * rstd_row)
188
+
189
+ tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
190
+
191
+ dY_ptr += dY_row_stride
192
+ dX_ptr += dX_row_stride
193
+ X_ptr += X_row_stride
194
+ RSTD_ptr += RSTD_row_stride
195
+
196
+ tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
197
+
198
+
199
+ @triton.jit
200
+ def _block_rms_norm_forward_kernel(
201
+ Y_ptr,
202
+ Y_row_stride,
203
+ X_ptr,
204
+ X_row_stride,
205
+ W_ptr,
206
+ W_row_stride,
207
+ RSTD_ptr,
208
+ RSTD_row_stride,
209
+ n_rows,
210
+ n_cols,
211
+ eps,
212
+ offset,
213
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
214
+ BLOCK_SIZE: tl.constexpr,
215
+ BLOCK_ROW: tl.constexpr,
216
+ ):
217
+ """
218
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
91
219
 
92
- dX_row = (inv_rms_row) * (
93
- dY_row * W_row
94
- - (1 / n_cols)
95
- * inv_rms_row
96
- * inv_rms_row
97
- * tl.sum(dY_row * W_row * X_row, axis=0)
98
- * X_row
220
+ Reference:
221
+ 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
222
+ 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
223
+ 3. https://arxiv.org/pdf/1910.07467
224
+ """
225
+
226
+ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW)
227
+ col_offsets = tl.arange(0, BLOCK_SIZE)
228
+ row_mask = row_idx < n_rows
229
+ col_mask = col_offsets < n_cols
230
+
231
+ X_row = tl.load(
232
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
233
+ mask=row_mask[:, None] & col_mask[None, :],
234
+ other=0,
99
235
  )
100
- tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
236
+ X_row_dtype = X_row.dtype
237
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
101
238
 
102
- # calculate the gradient of W
103
- dW_row = dY_row * X_row * inv_rms_row
104
- tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
239
+ # On Llama, only rstd is computed on fp32
240
+ if casting_mode == _CASTING_MODE_LLAMA:
241
+ X_row = X_row.to(tl.float32)
105
242
 
243
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
244
+ if casting_mode == _CASTING_MODE_GEMMA:
245
+ W_row = W_row.to(tl.float32)
246
+ X_row = X_row.to(tl.float32)
106
247
 
107
- class LigerRMSNormFunction(torch.autograd.Function):
108
- @staticmethod
109
- @ensure_contiguous
110
- def forward(ctx, X, W, eps):
111
- """
112
- X: (B, T, H) or (BxT, H)
113
- W: (H,)
114
- """
248
+ if casting_mode == _CASTING_MODE_NONE:
249
+ eps = eps.to(X_row_dtype)
250
+ offset = offset.to(X_row_dtype)
251
+
252
+ mean_square = tl.sum(X_row * X_row, axis=1) / n_cols
253
+ rstd = rsqrt(mean_square + eps)
254
+
255
+ # We can save time by caching rms with minimal memory overhead
256
+ # because rms is much smaller compared to X_row, as rms is for each row.
257
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
258
+ tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask)
259
+
260
+ X_row = X_row * rstd[:, None]
261
+
262
+ # On Llama, the multiplication with the weight is done on the original dtype
263
+ if casting_mode == _CASTING_MODE_LLAMA:
264
+ X_row = X_row.to(X_row_dtype)
265
+
266
+ Y_row = X_row * (offset + W_row)[None, :]
267
+
268
+ if casting_mode == _CASTING_MODE_GEMMA:
269
+ Y_row = Y_row.to(X_row_dtype)
270
+
271
+ tl.store(
272
+ Y_ptr + row_idx[:, None] * Y_row_stride + col_offsets[None, :],
273
+ Y_row,
274
+ mask=row_mask[:, None] & col_mask[None, :],
275
+ )
276
+
277
+
278
+ @triton.jit
279
+ def _block_rms_norm_backward_kernel(
280
+ dY_ptr,
281
+ dY_row_stride,
282
+ dX_ptr,
283
+ dX_row_stride,
284
+ X_ptr,
285
+ X_row_stride,
286
+ X_dtype: tl.constexpr,
287
+ W_ptr,
288
+ W_row_stride,
289
+ RSTD_ptr,
290
+ RSTD_row_stride,
291
+ dW_ptr,
292
+ dW_row_stride,
293
+ n_rows,
294
+ n_cols,
295
+ offset,
296
+ rows_per_program: tl.constexpr,
297
+ casting_mode: tl.constexpr,
298
+ BLOCK_SIZE: tl.constexpr,
299
+ BLOCK_ROW: tl.constexpr,
300
+ ):
301
+ """
302
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
303
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
304
+ """
305
+
306
+ pid = tl.program_id(0).cast(tl.int64)
307
+ NUM_SMS = tl.num_programs(0)
308
+
309
+ col_offsets = tl.arange(0, BLOCK_SIZE)
310
+ col_mask = col_offsets < n_cols
311
+
312
+ dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
313
+
314
+ W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
315
+ W_row = W_row + offset
316
+
317
+ for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
318
+ row_idx = start + tl.arange(0, BLOCK_ROW)
319
+ row_mask = row_idx < n_rows
320
+ dY_row = tl.load(
321
+ dY_ptr + row_idx[:, None] * dY_row_stride + col_offsets[None, :],
322
+ mask=row_mask[:, None] & col_mask[None, :],
323
+ other=0.0,
324
+ )
325
+ X_row = tl.load(
326
+ X_ptr + row_idx[:, None] * X_row_stride + col_offsets[None, :],
327
+ mask=row_mask[:, None] & col_mask[None, :],
328
+ other=0.0,
329
+ )
330
+
331
+ # Get cached rms
332
+ rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride, row_mask)
333
+
334
+ X_row = X_row.to(tl.float32)
335
+
336
+ # Different bacward graphs for different casting modes
337
+ if casting_mode == _CASTING_MODE_LLAMA:
338
+ m = (dY_row * W_row[None, :]).to(tl.float32)
339
+
340
+ elif casting_mode == _CASTING_MODE_GEMMA:
341
+ dY_row = dY_row.to(tl.float32)
342
+ m = dY_row * W_row[None, :]
343
+ else:
344
+ m = dY_row * W_row[None, :]
345
+
346
+ dX_row = rstd_row[:, None] * m
347
+
348
+ dX_row += (rstd_row[:, None]) * (
349
+ -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
350
+ )
351
+
352
+ # calculate the gradient of W
353
+ if casting_mode == _CASTING_MODE_LLAMA:
354
+ # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355
+ dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
356
+ else:
357
+ # here X_row is already in fp32 (see previous if block)
358
+ dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
359
+
360
+ tl.store(
361
+ dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
362
+ dX_row,
363
+ mask=row_mask[:, None] & col_mask[None, :],
364
+ )
365
+
366
+ tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
367
+
368
+
369
+ _str_to_casting_mode = {
370
+ "llama": _CASTING_MODE_LLAMA.value,
371
+ "gemma": _CASTING_MODE_GEMMA.value,
372
+ "none": _CASTING_MODE_NONE.value,
373
+ }
115
374
 
116
- shape = X.shape
117
- dim = shape[-1]
118
- X = X.view(-1, dim)
119
- n_rows, n_cols = X.shape
120
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
121
375
 
122
- Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
123
- # r is to cache (1/rms) for each row
124
- r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
376
+ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
377
+ if not isinstance(casting_mode, int):
378
+ assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}"
379
+ casting_mode = _str_to_casting_mode[casting_mode]
380
+ else:
381
+ assert casting_mode in _str_to_casting_mode.values(), f"Invalid casting mode: {casting_mode}"
125
382
 
126
- # Check constraints.
127
- assert (
128
- X.shape[1] == W.shape[0]
129
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
383
+ shape = X.shape
384
+ dim = shape[-1]
385
+ X = X.view(-1, dim)
386
+ n_rows, n_cols = X.shape
387
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
130
388
 
131
- _rms_norm_forward[(n_rows,)](
389
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
390
+ # RSTD is to cache rstd for each row
391
+ # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
392
+ rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
393
+ RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
394
+
395
+ # Check constraints.
396
+ assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
397
+
398
+ # XPU-specific optimization
399
+ kernel_args = {}
400
+ if X.device.type == "xpu":
401
+ kernel_args["grf_mode"] = "large"
402
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
403
+ _rms_norm_forward_kernel[(n_rows,)](
132
404
  Y,
133
405
  Y.stride(0),
134
406
  X,
135
407
  X.stride(0),
136
408
  W,
137
409
  W.stride(0),
138
- r,
139
- r.stride(0),
410
+ RSTD,
411
+ RSTD.stride(0),
140
412
  n_cols,
141
413
  eps,
414
+ offset,
415
+ casting_mode,
142
416
  BLOCK_SIZE=BLOCK_SIZE,
143
417
  num_warps=num_warps,
418
+ **kernel_args, # XPU-specific optimization
144
419
  )
145
- ctx.eps = eps
146
- ctx.BLOCK_SIZE = BLOCK_SIZE
147
- ctx.num_warps = num_warps
420
+ else:
421
+ BLOCK_ROW = 16
422
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
423
+ _block_rms_norm_forward_kernel[(triton.cdiv(n_rows, BLOCK_ROW),)](
424
+ Y,
425
+ Y.stride(0),
426
+ X,
427
+ X.stride(0),
428
+ W,
429
+ W.stride(0),
430
+ RSTD,
431
+ RSTD.stride(0),
432
+ n_rows,
433
+ n_cols,
434
+ eps,
435
+ offset,
436
+ casting_mode,
437
+ BLOCK_SIZE=BLOCK_SIZE,
438
+ num_warps=num_warps,
439
+ **kernel_args, # XPU-specific optimization
440
+ )
441
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
148
442
 
149
- ctx.save_for_backward(X, W, r)
150
- return Y.view(*shape)
151
443
 
152
- @staticmethod
153
- @ensure_contiguous
154
- def backward(ctx, dY):
155
- """
156
- Y: (B, T, H) or (BxT, H)
157
- """
444
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
445
+ shape = dY.shape
446
+ dim = shape[-1]
447
+ dY = dY.view(-1, dim)
448
+ n_rows, n_cols = dY.shape
158
449
 
159
- shape = dY.shape
160
- dim = shape[-1]
161
- dY = dY.view(-1, dim)
162
- X, W, r = ctx.saved_tensors
163
- n_rows, n_cols = dY.shape
164
- dW = torch.zeros_like(X)
450
+ sm_count = 1
451
+ if X.device.type == "cuda":
452
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
453
+ elif X.device.type == "xpu":
454
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
455
+ elif X.device.type == "npu":
456
+ sm_count = get_npu_multi_processor_count()
165
457
 
166
- # Here we use dY to store the value of dX to save memory
167
- _rms_norm_backward[(n_rows,)](
458
+ # fp32 for numerical stability especially.
459
+ _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
460
+
461
+ if n_cols > BLOCK_SIZE:
462
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
463
+ rows_per_program = math.ceil(n_rows / sm_count)
464
+ grid = (sm_count,)
465
+
466
+ if in_place is True:
467
+ dX = dY
468
+ else:
469
+ dX = torch.zeros_like(dY)
470
+
471
+ # XPU-specific optimization
472
+ kernel_args = {}
473
+ if X.device.type == "xpu":
474
+ kernel_args["grf_mode"] = "large"
475
+
476
+ if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode:
477
+ _rms_norm_backward_kernel[grid](
168
478
  dY,
169
479
  dY.stride(0),
480
+ dX,
481
+ dX.stride(0),
170
482
  X,
171
483
  X.stride(0),
484
+ torch_to_triton_dtype[X.dtype],
172
485
  W,
173
486
  W.stride(0),
174
- r,
175
- r.stride(0),
176
- dW,
177
- dW.stride(0),
487
+ RSTD,
488
+ RSTD.stride(0),
489
+ _dW,
490
+ _dW.stride(0),
491
+ n_rows,
178
492
  n_cols,
179
- ctx.eps,
180
- BLOCK_SIZE=ctx.BLOCK_SIZE,
181
- num_warps=ctx.num_warps,
493
+ offset,
494
+ rows_per_program,
495
+ casting_mode,
496
+ BLOCK_SIZE=BLOCK_SIZE,
497
+ num_warps=num_warps,
498
+ **kernel_args, # XPU-specific optimization
499
+ )
500
+ else:
501
+ BLOCK_ROW = 16
502
+ kernel_args["BLOCK_ROW"] = BLOCK_ROW
503
+ _block_rms_norm_backward_kernel[grid](
504
+ dY,
505
+ dY.stride(0),
506
+ dX,
507
+ dX.stride(0),
508
+ X,
509
+ X.stride(0),
510
+ torch_to_triton_dtype[X.dtype],
511
+ W,
512
+ W.stride(0),
513
+ RSTD,
514
+ RSTD.stride(0),
515
+ _dW,
516
+ _dW.stride(0),
517
+ n_rows,
518
+ n_cols,
519
+ offset,
520
+ rows_per_program,
521
+ casting_mode,
522
+ BLOCK_SIZE=BLOCK_SIZE,
523
+ num_warps=num_warps,
524
+ **kernel_args, # XPU-specific optimization
525
+ )
526
+ dX = dX.view(*shape)
527
+ dW = _dW.sum(dim=0).to(W.dtype)
528
+
529
+ return dX, dW
530
+
531
+
532
+ class LigerRMSNormFunction(torch.autograd.Function):
533
+ """
534
+ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
535
+ weight tensor `W`, with an optional offset and casting mode.
536
+
537
+ Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
538
+ uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
539
+ `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
540
+
541
+ In addition, different models cast their inputs at different places during RMSNorm computation. For
542
+ example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
543
+ inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
544
+ support the following casting modes (they match HuggingFace Transformers' implementations):
545
+ - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
546
+ - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
547
+ - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
548
+
549
+ `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
550
+ For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
551
+ Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
552
+ """
553
+
554
+ @staticmethod
555
+ @ensure_contiguous
556
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None):
557
+ """
558
+ X: (B, T, H) or (BxT, H)
559
+ W: (H,)
560
+ """
561
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
562
+ ctx.offset = offset
563
+ ctx.casting_mode = casting_mode
564
+ ctx.in_place = in_place
565
+ ctx.row_mode = row_mode
566
+ ctx.BLOCK_SIZE = BLOCK_SIZE
567
+ ctx.num_warps = num_warps
568
+ ctx.save_for_backward(X, W, RSTD)
569
+ return Y
570
+
571
+ @staticmethod
572
+ @ensure_contiguous
573
+ def backward(ctx, dY):
574
+ """
575
+ Y: (B, T, H) or (BxT, H)
576
+ """
577
+ X, W, RSTD = ctx.saved_tensors
578
+ dX, dW = rms_norm_backward(
579
+ dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
182
580
  )
183
- dX = dY.view(*shape)
184
- dW = torch.sum(dW, dim=0)
185
- return dX, dW, None
581
+ return dX, dW, None, None, None, None, None