quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py CHANGED
@@ -1,94 +1,54 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- from typing import Optional, Tuple
3
+ import math
4
+ from typing import Optional, Tuple, Type
4
5
  from functools import partial
5
6
 
6
7
  import cuda.bindings.driver as cuda
7
8
 
8
9
  import cutlass
9
10
  import cutlass.cute as cute
10
- from cutlass import Float32, Int32
11
- from cutlass import const_expr
12
- from cutlass.cute.runtime import from_dlpack
11
+ from cutlass import Float32, Int32, const_expr
13
12
 
14
13
  import torch
15
14
  from torch import Tensor
16
15
 
17
16
  import quack.utils as utils
17
+ import quack.copy_utils as copy_utils
18
+ import quack.layout_utils as layout_utils
19
+ from quack.compile_utils import make_fake_tensor as fake_tensor
18
20
  from quack.reduce import row_reduce
19
21
  from quack.reduction_base import ReductionBase
20
22
  from quack.cute_dsl_utils import torch2cute_dtype_map
21
23
 
22
24
 
23
25
  class RMSNorm(ReductionBase):
24
- def __init__(self, dtype: cutlass.Numeric, N: int):
25
- super().__init__(dtype, N, stage=1)
26
- self.reload_from = None if N <= 8192 else "smem"
26
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False):
27
+ super().__init__(dtype, N, stage=2 if is_layernorm else 1)
28
+ self.is_layernorm = is_layernorm
29
+ self.reload_from = None if N <= (16384 if is_layernorm else 8192) else "smem"
27
30
  self.delay_w_load = False
28
31
 
29
- def _calculate_threads_per_row(self):
30
- """Calculate the number of threads per row for the RMSNorm kernel."""
32
+ def _threads_per_row(self):
31
33
  N = self.N
32
- if N <= 64:
33
- return 8
34
- elif N <= 128:
35
- return 16
36
- elif N <= 3072:
37
- return 32
38
- elif N <= 6144:
39
- return 64
40
- elif N <= 16384:
41
- return 128
42
- else:
43
- return 256
34
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
35
+ if N <= limit:
36
+ return threads
37
+ return 256
44
38
 
45
39
  def _set_cluster_n(self):
46
- """
47
- Set the number of clusters for the RMSNorm kernel.
48
- Stored in self.cluster_n.
49
- """
50
40
  N = self.N
51
-
52
41
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
53
42
  # Similarly cluster_n = 8 is faster for N=128k
54
43
  if const_expr(self.dtype.width == 16):
55
- # 16-bit types (fp16, bf16)
56
- if N <= 16 * 1024:
57
- cluster_n = 1
58
- elif N <= 32 * 1024:
59
- cluster_n = 2
60
- elif N <= 64 * 1024:
61
- cluster_n = 4
62
- elif N <= 128 * 1024:
63
- cluster_n = 8
64
- else:
65
- cluster_n = 16
44
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
66
45
  else:
67
- # 32-bit types (fp32)
68
- if N <= 32 * 1024:
69
- cluster_n = 1
70
- elif N <= 64 * 1024:
71
- cluster_n = 2
72
- elif N <= 128 * 1024:
73
- cluster_n = 4
74
- elif N <= 256 * 1024:
75
- cluster_n = 8
76
- else:
77
- cluster_n = 16
78
-
79
- self.cluster_n = cluster_n
80
-
81
- def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
82
- return (
83
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
84
- + (
85
- cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
86
- if dtype_res is not None
87
- else 0
88
- )
89
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
90
- + self.stage * (cutlass.Int64.width // 8)
91
- )
46
+ thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
47
+ for limit, cluster in thresholds:
48
+ if N <= limit:
49
+ self.cluster_n = cluster
50
+ return
51
+ self.cluster_n = 16
92
52
 
93
53
  @cute.jit
94
54
  def __call__(
@@ -100,60 +60,32 @@ class RMSNorm(ReductionBase):
100
60
  mO: cute.Tensor,
101
61
  mResO: Optional[cute.Tensor],
102
62
  mRstd: Optional[cute.Tensor],
63
+ mMean: Optional[cute.Tensor],
64
+ eps: Float32,
103
65
  stream: cuda.CUstream,
104
- eps: Float32 = 1e-6,
105
66
  ):
106
- semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
107
- new_stride = lambda t: (
108
- cute.assume(t.stride[0], divby=128 // t.element_type.width),
109
- t.stride[1],
110
- )
111
- mX, mRes, mO, mResO = [
112
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
113
- if const_expr(t is not None)
114
- else None
115
- for t in (mX, mRes, mO, mResO)
116
- ]
117
67
  assert mX.element_type == self.dtype
118
- assert mO.element_type == self.dtype
119
68
  self._set_cluster_n()
120
69
  largest_dtype_width = const_expr(
121
- max(
122
- mX.element_type.width,
123
- mRes.element_type.width if mRes is not None else 0,
124
- mO.element_type.width,
125
- mResO.element_type.width if mResO is not None else 0,
126
- )
127
- )
128
- tiler_mn, tv_layout = self._get_tv_layout(
129
- num_copy_bits=128 // largest_dtype_width * mX.element_type.width
70
+ max(*(t.element_type.width for t in [mX, mRes, mW, mB, mO, mResO] if t is not None))
130
71
  )
131
- num_threads = cute.size(tv_layout, mode=[0])
132
- num_warps = num_threads // cute.arch.WARP_SIZE
133
- if const_expr(mW is not None):
134
- mW_expanded_layout = cute.prepend(
135
- mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
136
- )
137
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
138
- if const_expr(mB is not None):
139
- mB_expanded_layout = cute.prepend(
140
- mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
141
- )
142
- mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
143
- if const_expr(mRstd is not None):
144
- mRstd_expanded_layout = cute.append(
145
- mRstd.layout, cute.make_layout((self.N,), stride=(0,))
146
- )
147
- mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
72
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
73
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
74
+ num_threads = tiled_copy.size
75
+ mW, mB = [
76
+ layout_utils.expand(mT, dim=0, size=tiler_mn[0]) if const_expr(mT is not None) else None
77
+ for mT in (mW, mB)
78
+ ]
79
+ mRstd, mMean = [
80
+ layout_utils.expand(mT, dim=1, size=self.N) if const_expr(mT is not None) else None
81
+ for mT in (mRstd, mMean)
82
+ ]
148
83
  self.kernel(
149
- mX, mW, mB, mRes, mO, mResO, mRstd, eps, tv_layout, tiler_mn, self.reload_from
84
+ mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row
150
85
  ).launch(
151
86
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
152
87
  block=[num_threads, 1, 1],
153
- cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
154
- smem=self._smem_size_in_bytes(
155
- tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
156
- ),
88
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
157
89
  stream=stream,
158
90
  )
159
91
 
@@ -167,24 +99,20 @@ class RMSNorm(ReductionBase):
167
99
  mO: cute.Tensor,
168
100
  mResO: Optional[cute.Tensor],
169
101
  mRstd: Optional[cute.Tensor],
170
- eps: cute.Float32,
171
- tv_layout: cute.Layout,
102
+ mMean: Optional[cute.Tensor],
103
+ eps: Float32,
172
104
  tiler_mn: cute.Shape,
173
- reload_from: cutlass.Constexpr = None,
174
- delay_w_load: cutlass.Constexpr = False,
105
+ tiled_copy: cute.TiledCopy,
106
+ threads_per_row: cutlass.Constexpr[int],
175
107
  ):
176
108
  tidx, _, _ = cute.arch.thread_idx()
177
109
  bidx, _, _ = cute.arch.block_idx()
178
- if const_expr(self.cluster_n > 1):
179
- cluster_y = cute.arch.block_idx()[1]
180
- else:
181
- cluster_y = const_expr(0)
110
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
111
+ tv_layout = tiled_copy.layout_tv_tiled
182
112
 
183
113
  smem = cutlass.utils.SmemAllocator()
184
114
  sX = smem.allocate_tensor(
185
- mX.element_type,
186
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
187
- byte_alignment=16,
115
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
188
116
  )
189
117
  if const_expr(mRes is not None):
190
118
  sRes = smem.allocate_tensor(
@@ -197,34 +125,16 @@ class RMSNorm(ReductionBase):
197
125
  shape = mX.shape
198
126
  idX = cute.make_identity_tensor(shape)
199
127
  # slice for CTAs
200
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
201
- mX, mRes, mO, mResO = [
202
- utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) if mT is not None else None
203
- for mT in (mX, mRes, mO, mResO)
128
+ gX, gRes, gO, gResO, gRstd, gMean, cX = [
129
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) if mT is not None else None
130
+ for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX)
204
131
  ]
205
- gX, gRes, gO, gResO = [
206
- cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
207
- for mT in (mX, mRes, mO, mResO)
208
- ]
209
- cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
210
132
  gW, gB = [
211
133
  cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
212
134
  for mT in (mW, mB)
213
135
  ]
214
- gRstd = (
215
- cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
216
- if const_expr(mRstd is not None)
217
- else None
218
- )
219
136
 
220
- # declare the atoms which will be used later for memory copy
221
- num_copy_elems_X = tv_layout.shape[1][0]
222
- copy_atom_load_X_async = utils.get_copy_atom(
223
- mX.element_type, num_copy_elems_X, is_async=True
224
- )
225
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
226
- tidx
227
- )
137
+ thr_copy_X = tiled_copy.get_slice(tidx)
228
138
 
229
139
  tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
230
140
  tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
@@ -237,26 +147,27 @@ class RMSNorm(ReductionBase):
237
147
  if const_expr(mResO is not None):
238
148
  tXgResO = thr_copy_X.partition_D(gResO)
239
149
  tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
150
+ tXrMean = thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None
240
151
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
241
152
 
242
153
  # allocate fragments for gmem->rmem
243
154
  tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
244
- if const_expr(mW is not None):
245
- tXrW.fill(0.0)
246
155
  tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
247
156
  tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
248
157
  if const_expr(mRes is not None):
249
158
  tXrRes = cute.make_fragment_like(tXgRes)
250
159
 
251
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
160
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
252
161
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
253
162
 
254
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
163
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
255
164
  tXpX = (
256
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
165
+ copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
166
+ if not is_even_N
167
+ else None
257
168
  )
258
- # Each copy will use the same number of elements as X and same predicate
259
- copy = partial(utils.copy, pred=tXpX, num_copy_elems=num_copy_elems_X)
169
+ # Each copy will use the same predicate
170
+ copy = partial(copy_utils.copy, pred=tXpX)
260
171
 
261
172
  row = tXcX[0][0]
262
173
  if row < shape[0]:
@@ -265,7 +176,7 @@ class RMSNorm(ReductionBase):
265
176
  copy(tXgRes, tXsRes, is_async=True)
266
177
  cute.arch.cp_async_commit_group()
267
178
 
268
- if const_expr(not delay_w_load):
179
+ if const_expr(not self.delay_w_load):
269
180
  if const_expr(mW is not None):
270
181
  copy(tXgW, tXrW)
271
182
  if const_expr(mB is not None):
@@ -283,17 +194,61 @@ class RMSNorm(ReductionBase):
283
194
  if row < shape[0]:
284
195
  copy(tXrResO, tXgResO)
285
196
 
286
- threads_per_row = tv_layout.shape[0][0]
287
- sum_sq_x = row_reduce(
288
- x * x,
289
- cute.ReductionOp.ADD,
290
- threads_per_row,
291
- reduction_buffer[None, None, 0],
292
- mbar_ptr,
293
- init_val=0.0,
294
- hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
295
- )
296
- rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
197
+ mean, rstd = None, None
198
+ if const_expr(self.is_layernorm):
199
+ # LayerNorm: compute mean first, then variance
200
+ sum_x = row_reduce(
201
+ x,
202
+ cute.ReductionOp.ADD,
203
+ threads_per_row,
204
+ reduction_buffer[None, None, 0],
205
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
206
+ init_val=0.0,
207
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
208
+ )
209
+ mean = sum_x / shape[1]
210
+ if const_expr(mMean is not None):
211
+ # Only the thread corresponding to column 0 writes out the mean to gmem
212
+ if (
213
+ tXcX[0][1] == 0
214
+ and row < shape[0]
215
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
216
+ ):
217
+ tXrMean[0] = mean
218
+ if const_expr(self.reload_from == "smem"):
219
+ cute.autovec_copy(tXsX, tXrX)
220
+ x = tXrX.load().to(cute.Float32)
221
+ if const_expr(mRes is not None):
222
+ cute.autovec_copy(tXsRes, tXrRes)
223
+ x += tXrRes.load().to(cute.Float32)
224
+ elif const_expr(self.reload_from == "gmem"):
225
+ copy(tXgX, tXrX)
226
+ x = tXrX.load().to(cute.Float32)
227
+ if const_expr(mRes is not None):
228
+ copy(tXgRes, tXrRes)
229
+ x += tXrRes.load().to(cute.Float32)
230
+ sum_sq_x_sub_mean = row_reduce(
231
+ (x - mean) * (x - mean),
232
+ cute.ReductionOp.ADD,
233
+ threads_per_row,
234
+ reduction_buffer[None, None, 1],
235
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
236
+ init_val=0.0,
237
+ )
238
+ rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
239
+ else:
240
+ # RMSNorm: compute sum of squares directly
241
+ mean = const_expr(0.0)
242
+ sum_sq_x = row_reduce(
243
+ x * x,
244
+ cute.ReductionOp.ADD,
245
+ threads_per_row,
246
+ reduction_buffer[None, None, 0],
247
+ mbar_ptr,
248
+ init_val=0.0,
249
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
250
+ )
251
+ rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
297
252
  if const_expr(mRstd is not None):
298
253
  # Only the thread corresponding to column 0 writes out the rstd to gmem
299
254
  if (
@@ -302,21 +257,24 @@ class RMSNorm(ReductionBase):
302
257
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
303
258
  ):
304
259
  tXrRstd[0] = rstd
305
- if const_expr(delay_w_load):
260
+ if const_expr(self.delay_w_load):
306
261
  if const_expr(mW is not None):
307
262
  copy(tXgW, tXrW)
308
263
  if const_expr(mB is not None):
309
264
  copy(tXgB, tXrB)
310
- if const_expr(reload_from == "smem" or reload_from == "gmem"):
311
- if const_expr(reload_from == "smem"):
265
+ if const_expr(self.reload_from == "smem" or self.reload_from == "gmem"):
266
+ if const_expr(self.reload_from == "smem"):
312
267
  cute.autovec_copy(tXsX, tXrX)
268
+ if const_expr(mRes is not None):
269
+ cute.autovec_copy(tXsRes, tXrRes)
313
270
  else:
314
271
  copy(tXgX, tXrX)
272
+ if const_expr(mRes is not None):
273
+ copy(tXgRes, tXrRes)
315
274
  x = tXrX.load().to(cute.Float32)
316
275
  if const_expr(mRes is not None):
317
- cute.autovec_copy(tXsRes, tXrRes)
318
276
  x += tXrRes.load().to(cute.Float32)
319
- x_hat = x * rstd
277
+ x_hat = (x - mean) * rstd if const_expr(self.is_layernorm) else x * rstd
320
278
  y = x_hat
321
279
  if const_expr(mW is not None):
322
280
  y *= tXrW.load().to(cute.Float32)
@@ -329,10 +287,10 @@ class RMSNorm(ReductionBase):
329
287
 
330
288
  @torch.library.custom_op(
331
289
  "quack::_rmsnorm_fwd",
332
- mutates_args=("out", "rstd", "residual_out"),
290
+ mutates_args=("out", "rstd", "mean", "residual_out"),
333
291
  device_types="cuda",
334
292
  # We need to specify the schema manually since we're mutating an optional tensor
335
- schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
293
+ schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor(a5!)? mean, Tensor? residual, Tensor(a7!)? residual_out, float eps=1e-6, bool is_layernorm=False) -> ()",
336
294
  )
337
295
  def _rmsnorm_fwd(
338
296
  x: Tensor,
@@ -340,102 +298,73 @@ def _rmsnorm_fwd(
340
298
  out: Tensor,
341
299
  bias: Optional[Tensor] = None,
342
300
  rstd: Optional[Tensor] = None,
301
+ mean: Optional[Tensor] = None,
343
302
  residual: Optional[Tensor] = None,
344
303
  residual_out: Optional[Tensor] = None,
345
304
  eps: float = 1e-6,
305
+ is_layernorm: bool = False,
346
306
  ) -> None:
347
- """RMSNorm forward pass.
307
+ """RMSNorm/LayerNorm forward pass.
348
308
  Args:
349
309
  x: Input tensor of shape (M, N)
350
310
  weight: Optional weight tensor of shape (N,)
351
311
  eps: Small value for numerical stability
312
+ is_layernorm: If True, compute LayerNorm instead of RMSNorm
352
313
  Returns:
353
314
  Normalized output tensor of same shape as x
354
315
  """
355
- assert x.dim() == 2, "Input must be 2D"
356
- assert x.is_cuda, "Input tensor must be on CUDA device"
357
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
316
+ # Don't need to check is_cuda since torch.library ensures that
317
+ supported_types = {torch.float16, torch.bfloat16, torch.float32}
318
+ assert x.dtype in supported_types, "Unsupported dtype"
358
319
  if weight is not None:
359
- assert weight.dim() == 1, "Weight must be 1D"
360
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
361
- assert weight.is_cuda, "Weight tensor must be on CUDA device"
362
- assert weight.dtype in [
363
- torch.float32,
364
- torch.bfloat16,
365
- torch.float16,
366
- ], "Weight must be float32, float16 or bfloat16"
320
+ assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
367
321
  if residual is not None:
368
- assert residual.shape == x.shape
369
- assert residual.is_cuda
370
- assert residual.dtype in [
371
- torch.float16,
372
- torch.bfloat16,
373
- torch.float32,
374
- ], "Residual must be float16, bfloat16, or float32"
322
+ assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
375
323
 
376
324
  _, N = x.shape
377
- device = x.device
378
- dtype = torch2cute_dtype_map[x.dtype]
379
- convert_from_dlpack = lambda x: (
380
- from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
381
- )
382
- x_tensor, res_tensor, out_tensor, res_out_tensor = [
383
- convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
325
+ dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
326
+ torch2cute_dtype_map[t.dtype] if t is not None else None
327
+ for t in [x, out, weight, bias, residual, residual_out]
384
328
  ]
385
- # handle weight divisibility based on weight dtype
386
- if weight is not None:
387
- weight_dtype = torch2cute_dtype_map[weight.dtype]
388
- weight_tensor = utils.convert_from_dlpack(
389
- weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
390
- )
391
- else:
392
- weight_tensor = None
393
- if bias is not None:
394
- bias_dtype = torch2cute_dtype_map[bias.dtype]
395
- bias_tensor = utils.convert_from_dlpack(
396
- bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
397
- )
398
- else:
399
- bias_tensor = None
400
- rstd_tensor = (
401
- from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
402
- if rstd is not None
403
- else None
404
- )
405
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
406
329
  compile_key = (
407
- N,
408
330
  dtype,
409
- res_tensor.element_type if residual is not None else None,
410
- weight_tensor.element_type if weight is not None else None,
411
- bias_tensor.element_type if bias is not None else None,
412
- res_out_tensor.element_type if residual_out is not None else None,
331
+ out_dtype,
332
+ res_dtype,
333
+ weight_dtype,
334
+ bias_dtype,
335
+ res_out_dtype,
336
+ N,
413
337
  rstd is not None,
338
+ mean is not None,
339
+ is_layernorm,
414
340
  )
415
341
  if compile_key not in _rmsnorm_fwd.compile_cache:
416
- rmsnorm_op = RMSNorm(dtype, N)
342
+ batch_sym = cute.sym_int()
343
+ all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype]
344
+ div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
345
+ x_cute, out_cute, res_cute, res_out_cute = [
346
+ fake_tensor(dt, (batch_sym, N), div)
347
+ for dt in [dtype, out_dtype, res_dtype, res_out_dtype]
348
+ ]
349
+ weight_cute, bias_cute = [fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype]]
350
+ rstd_cute = fake_tensor(Float32, (batch_sym,)) if rstd is not None else None
351
+ mean_cute = fake_tensor(Float32, (batch_sym,)) if mean is not None else None
417
352
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
418
- rmsnorm_op,
419
- x_tensor,
420
- weight_tensor,
421
- bias_tensor,
422
- res_tensor,
423
- out_tensor,
424
- res_out_tensor,
425
- rstd_tensor,
426
- current_stream,
427
- eps,
353
+ RMSNorm(dtype, N, is_layernorm=is_layernorm),
354
+ x_cute,
355
+ weight_cute,
356
+ bias_cute,
357
+ res_cute,
358
+ out_cute,
359
+ res_out_cute,
360
+ rstd_cute,
361
+ mean_cute,
362
+ Float32(0), # eps, just for compilation
363
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
364
+ options="--enable-tvm-ffi",
428
365
  )
429
366
  _rmsnorm_fwd.compile_cache[compile_key](
430
- x_tensor,
431
- weight_tensor,
432
- bias_tensor,
433
- res_tensor,
434
- out_tensor,
435
- res_out_tensor,
436
- rstd_tensor,
437
- current_stream,
438
- eps,
367
+ x, weight, bias, residual, out, residual_out, rstd, mean, eps
439
368
  )
440
369
 
441
370
 
@@ -466,7 +395,7 @@ def rmsnorm_fwd(
466
395
  )
467
396
  else:
468
397
  residual_out = None
469
- _rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps=eps)
398
+ _rmsnorm_fwd(x, weight, out, bias, rstd, None, residual, residual_out, eps, False)
470
399
  # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
471
400
  if residual_out is None:
472
401
  residual_out = x
@@ -492,13 +421,19 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
492
421
  """Reference implementation for RMSNorm backward pass."""
493
422
  x_f32 = x.float()
494
423
  x_hat = x_f32 * rstd.unsqueeze(1)
495
- wdy = dout * w
424
+ if w is not None:
425
+ wdy = dout * w
426
+ else:
427
+ wdy = dout
496
428
  c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
497
429
  dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
498
430
 
499
431
  # dL/dW
500
- dw = (dout * x_hat).sum(dim=0)
501
- return dx.to(x.dtype), dw.to(w.dtype)
432
+ if w is not None:
433
+ dw = (dout * x_hat).sum(dim=0)
434
+ return dx.to(x.dtype), dw.to(w.dtype)
435
+ else:
436
+ return dx.to(x.dtype), None
502
437
 
503
438
 
504
439
  class RMSNormBackward(ReductionBase):
@@ -510,94 +445,57 @@ class RMSNormBackward(ReductionBase):
510
445
  # Not enough smem
511
446
  raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
512
447
 
513
- def _get_num_threads(self):
448
+ def _num_threads(self):
514
449
  return 128 if self.N <= 4096 else 256
515
450
 
516
- def _calculate_threads_per_row(self):
451
+ def _threads_per_row(self):
517
452
  N = self.N
518
- return (
519
- 8
520
- if N <= 64
521
- else (
522
- 16
523
- if N <= 128
524
- else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
525
- )
526
- )
453
+ for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]:
454
+ if N <= limit:
455
+ return threads
456
+ return 256
527
457
 
528
458
  def _set_cluster_n(self):
529
459
  N = self.N
530
- cluster_n = (
531
- 1
532
- if N <= 8 * 1024
533
- else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
534
- )
535
- self.cluster_n = cluster_n
536
-
537
- def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
538
- if do_dtype is None:
539
- do_dtype = self.dtype
540
- return (
541
- # We need space for X and dO, and multiply by 2 due to double buffering
542
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
543
- + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
544
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
545
- + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
546
- )
460
+ for limit, cluster in [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)]:
461
+ if N <= limit:
462
+ self.cluster_n = cluster
463
+ return
464
+ self.cluster_n = 16
547
465
 
548
466
  @cute.jit
549
467
  def __call__(
550
468
  self,
551
469
  mX: cute.Tensor,
552
- mW: cute.Tensor,
470
+ mW: Optional[cute.Tensor],
553
471
  mdO: cute.Tensor,
554
472
  mdResO: Optional[cute.Tensor],
555
473
  mRstd: cute.Tensor,
556
474
  mdX: cute.Tensor,
557
- mdW: cute.Tensor,
475
+ mdW: Optional[cute.Tensor],
558
476
  mdRes: Optional[cute.Tensor],
559
477
  mdB: Optional[cute.Tensor],
560
478
  sm_count: Int32,
561
479
  stream: cuda.CUstream,
562
480
  ):
563
- semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
564
- new_stride = lambda t: (
565
- cute.assume(t.stride[0], divby=128 // t.element_type.width),
566
- t.stride[1],
567
- )
568
- mX, mdO, mdResO, mdX, mdRes = [
569
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
570
- if const_expr(t is not None)
571
- else None
572
- for t in (mX, mdO, mdResO, mdX, mdRes)
573
- ]
481
+ assert mX.element_type == self.dtype
574
482
  self._set_cluster_n()
575
483
  largest_dtype_width = const_expr(
576
- max(
577
- mX.element_type.width,
578
- mdO.element_type.width,
579
- mdX.element_type.width,
580
- mdResO.element_type.width if mdResO is not None else 0,
581
- mdRes.element_type.width if mdRes is not None else 0,
582
- )
484
+ max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None))
583
485
  )
584
- tiler_mn, tv_layout = self._get_tv_layout(
585
- num_copy_bits=128 // largest_dtype_width * mX.element_type.width
486
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
487
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
488
+ num_threads = tiled_copy.size
489
+ mW = (
490
+ layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None
586
491
  )
587
- num_threads = cute.size(tv_layout, mode=[0])
588
- num_warps = num_threads // cute.arch.WARP_SIZE
589
- if const_expr(mW is not None):
590
- mW_expanded_layout = cute.prepend(
591
- mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
592
- )
593
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
594
-
595
492
  num_blocks = sm_count
596
- self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
493
+ self.kernel(
494
+ mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
495
+ ).launch(
597
496
  grid=[num_blocks, self.cluster_n, 1],
598
497
  block=[num_threads, 1, 1],
599
498
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
600
- smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
601
499
  stream=stream,
602
500
  )
603
501
 
@@ -605,24 +503,23 @@ class RMSNormBackward(ReductionBase):
605
503
  def kernel(
606
504
  self,
607
505
  mX: cute.Tensor,
608
- mW: cute.Tensor,
506
+ mW: Optional[cute.Tensor],
609
507
  mdO: cute.Tensor,
610
508
  mdResO: Optional[cute.Tensor],
611
509
  mRstd: cute.Tensor,
612
510
  mdX: cute.Tensor,
613
- mdW: cute.Tensor,
511
+ mdW: Optional[cute.Tensor],
614
512
  mdB: Optional[cute.Tensor],
615
513
  mdRes: Optional[cute.Tensor],
616
- tv_layout: cute.Layout,
617
514
  tiler_mn: cute.Shape,
515
+ tiled_copy: cute.TiledCopy,
516
+ threads_per_row: cutlass.Constexpr[int],
618
517
  ):
619
518
  tidx, _, _ = cute.arch.thread_idx()
620
519
  bidx_start, _, _ = cute.arch.block_idx()
621
520
  gdim, _, _ = cute.arch.grid_dim()
622
- if const_expr(self.cluster_n > 1):
623
- cluster_y = cute.arch.block_idx()[1]
624
- else:
625
- cluster_y = const_expr(0)
521
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
522
+ tv_layout = tiled_copy.layout_tv_tiled
626
523
 
627
524
  shape = mX.shape
628
525
  M, N = shape[0], shape[1]
@@ -642,63 +539,20 @@ class RMSNormBackward(ReductionBase):
642
539
  else:
643
540
  mbar_full_ptr, mbar_empty_ptr = None, None
644
541
 
645
- num_copy_elems_X = tv_layout.shape[1][0]
646
- copy_atom_load_X = utils.get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
647
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
648
- # Each copy will use the same number of elements as X
649
- copy = partial(utils.copy, num_copy_elems=num_copy_elems_X)
650
-
651
- gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
652
- tXgW = thr_copy_X.partition_S(gW)
653
- tXrW = cute.make_fragment_like(tXgW)
654
- # Need this, otherwise rW can have arbitrary values that changes the reduction
655
- if not is_even_N:
656
- tXrW.fill(0.0)
657
-
658
- gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
659
- tXpW = (
660
- utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
661
- if not is_even_N
662
- else None
663
- )
664
- copy(tXgW, tXrW, pred=tXpW)
665
- weight = tXrW.load().to(cute.Float32)
666
-
667
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
668
-
669
- self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
670
-
671
- dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
672
- tXpdW = (
673
- utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
674
- if not is_even_N
675
- else None
676
- )
677
- if const_expr(mdB is not None):
678
- db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
679
- tXpdB = (
680
- utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
681
- if not is_even_N
682
- else None
683
- )
684
-
685
- gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
686
- tXgdW = thr_copy_X.partition_S(gdW)
687
- # Always compute partial weight gradients in fp32
688
- tXrdW = cute.make_fragment_like(tXgdW, Float32)
689
-
690
- gdB = (
691
- cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
692
- if const_expr(mdB is not None)
693
- else None
694
- )
695
- tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
696
- tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
542
+ thr_copy_X = tiled_copy.get_slice(tidx)
697
543
 
698
544
  gX, gdO, gdResO, gdX, gdRes, cX = [
699
545
  cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
700
546
  for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
701
547
  ]
548
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None
549
+ gdW, gdB = [
550
+ cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y))
551
+ if const_expr(mT is not None)
552
+ else None
553
+ for mT in (mdW, mdB)
554
+ ]
555
+
702
556
  tXgX = thr_copy_X.partition_S(gX)
703
557
  tXsX = thr_copy_X.partition_D(sX)
704
558
  tXgdO = thr_copy_X.partition_S(gdO)
@@ -709,12 +563,6 @@ class RMSNormBackward(ReductionBase):
709
563
  if const_expr(mdRes is not None):
710
564
  tXgdRes = thr_copy_X.partition_D(gdRes)
711
565
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
712
- # This doesn't change across iterations
713
- tXpX = (
714
- utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
715
- if not is_even_N
716
- else None
717
- )
718
566
 
719
567
  tXrX, tXrdO, tXrdX = [
720
568
  cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
@@ -726,25 +574,57 @@ class RMSNormBackward(ReductionBase):
726
574
  if const_expr(mdRes is not None):
727
575
  tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
728
576
 
577
+ # This doesn't change across iterations
578
+ tXpX = (
579
+ None
580
+ if is_even_N
581
+ else copy_utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
582
+ )
583
+ # Each copy will use the same number of elements as X
584
+ copy = partial(copy_utils.copy, pred=tXpX)
585
+
586
+ tXgdW, tXrdW = None, None
587
+ tXgdB, tXrdB = None, None
588
+ if const_expr(mdW is not None):
589
+ tXgdW = thr_copy_X.partition_S(gdW)
590
+ # Always compute partial weight gradients in fp32
591
+ tXrdW = cute.make_fragment_like(tXgdW, Float32)
592
+ if const_expr(mdB is not None):
593
+ tXgdB = thr_copy_X.partition_S(gdB)
594
+ # Always compute partial bias gradients in fp32
595
+ tXrdB = cute.make_fragment_like(tXgdB, Float32)
596
+
597
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
598
+
599
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
600
+
601
+ tXrW = None
602
+ if const_expr(mW is not None):
603
+ tXgW = thr_copy_X.partition_S(gW)
604
+ tXrW = cute.make_fragment_like(tXgW)
605
+ # Need this, otherwise rW can have arbitrary values that changes the reduction
606
+ if const_expr(not is_even_N):
607
+ tXrW.fill(0.0)
608
+ copy(tXgW, tXrW)
609
+
729
610
  # Prefetch the first batch
730
611
  row = tXcX[None, None, None, bidx_start][0][0]
731
612
  if row < M:
732
- tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
733
- tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
734
- copy(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
735
- copy(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
736
- elif tiler_mn[0] > 1:
737
- # Fill with zero, otherwise smem will be uninitialized, and we could read this back
738
- # later into registers, causing wrong dW.
739
- utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
740
- utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
613
+ copy(tXgX[None, None, None, bidx_start], tXsX[None, None, None, 0], is_async=True)
614
+ copy(tXgdO[None, None, None, bidx_start], tXsdO[None, None, None, 0], is_async=True)
615
+ else:
616
+ if const_expr(tiler_mn[0] > 1):
617
+ # Fill with zero, otherwise smem will be uninitialized, and we could read this back
618
+ # later into registers, causing wrong dW.
619
+ utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
620
+ utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
741
621
  cute.arch.cp_async_commit_group()
742
622
 
743
623
  if const_expr(self.cluster_n > 1):
744
624
  cute.arch.cluster_wait()
745
625
 
746
- threads_per_row = tv_layout.shape[0][0]
747
- tXrdW.fill(0.0)
626
+ if const_expr(mdW is not None):
627
+ tXrdW.fill(0.0)
748
628
  if const_expr(mdB is not None):
749
629
  tXrdB.fill(0.0)
750
630
  stage = Int32(0)
@@ -753,29 +633,31 @@ class RMSNormBackward(ReductionBase):
753
633
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
754
634
  row = tXcX[None, None, None, bidx][0][0]
755
635
  if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
756
- tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
757
- tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
758
- copy(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
759
- copy(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
760
- elif tiler_mn[0] > 1:
761
- utils.fill_oob(
636
+ copy(
637
+ tXgX[None, None, None, bidx + gdim],
762
638
  tXsX[None, None, None, stage ^ 1],
763
- None,
764
- fill_value=mX.element_type.zero,
639
+ is_async=True,
765
640
  )
766
- utils.fill_oob(
641
+ copy(
642
+ tXgdO[None, None, None, bidx + gdim],
767
643
  tXsdO[None, None, None, stage ^ 1],
768
- None,
769
- fill_value=mdO.element_type.zero,
644
+ is_async=True,
770
645
  )
646
+ else:
647
+ if const_expr(tiler_mn[0] > 1):
648
+ utils.fill_oob(
649
+ tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
650
+ )
651
+ utils.fill_oob(
652
+ tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero
653
+ )
771
654
  cute.arch.cp_async_commit_group()
772
655
  rstd = cutlass.Float.zero
773
656
  if row < M or tiler_mn[0] == 1:
774
657
  rstd = mRstd[row]
775
658
  if const_expr(mdResO is not None):
776
- tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
777
659
  if row < M or tiler_mn[0] == 1:
778
- copy(tXgdResO_cur, tXrdResO, pred=tXpX)
660
+ copy(tXgdResO[None, None, None, bidx], tXrdResO)
779
661
  elif tiler_mn[0] > 1:
780
662
  tXrdResO.fill(0.0)
781
663
  cute.arch.cp_async_wait_group(1)
@@ -783,10 +665,10 @@ class RMSNormBackward(ReductionBase):
783
665
  x = tXrX.load().to(cute.Float32)
784
666
  cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
785
667
  dout = tXrdO.load().to(cute.Float32)
786
- if const_expr(mdResO is not None):
787
- dout += tXrdResO.load().to(cute.Float32)
788
668
  x_hat = x * rstd
789
- wdy = dout * weight
669
+ wdy = dout
670
+ if const_expr(mW is not None):
671
+ wdy *= tXrW.load().to(Float32)
790
672
  if const_expr(self.cluster_n > 1):
791
673
  cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
792
674
  mean_xhat_wdy = (
@@ -803,6 +685,10 @@ class RMSNormBackward(ReductionBase):
803
685
  )
804
686
 
805
687
  if const_expr(self.cluster_n > 1):
688
+ # Need this fence since the STAS from the producer is using the async proxy.
689
+ cute.arch.fence_proxy(
690
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
691
+ )
806
692
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
807
693
  # Requires adjusting the thread_count when initializing the mbar
808
694
  cute.arch.sync_warp()
@@ -815,22 +701,22 @@ class RMSNormBackward(ReductionBase):
815
701
  if const_expr(self.reload_wdy == "smem"):
816
702
  cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
817
703
  dout = tXrdO.load().to(cute.Float32)
818
- if const_expr(mdResO is not None):
819
- dout += tXrdResO.load().to(cute.Float32)
820
- wdy = dout * weight
704
+ wdy = dout
705
+ if const_expr(mW is not None):
706
+ wdy *= tXrW.load().to(Float32)
821
707
 
822
708
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
709
+ if const_expr(mdResO is not None):
710
+ dx += tXrdResO.load().to(cute.Float32)
823
711
  tXrdX.store(dx.to(tXrdX.element_type))
824
712
  if row < M or tiler_mn[0] == 1:
825
- tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
826
- copy(tXrdX, tXgdX_cur, pred=tXpX)
713
+ copy(tXrdX, tXgdX[None, None, None, bidx])
827
714
  if const_expr(mdRes is not None):
828
715
  tXrdRes.store(dx.to(tXrdRes.element_type))
829
- tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
830
716
  if row < M or tiler_mn[0] == 1:
831
- copy(tXrdRes, tXgdRes_cur, pred=tXpX)
832
- # Accumulate weight gradients in fp32
833
- tXrdW.store(tXrdW.load() + dout * x_hat)
717
+ copy(tXrdRes, tXgdRes[None, None, None, bidx])
718
+ if const_expr(mdW is not None):
719
+ tXrdW.store(tXrdW.load() + dout * x_hat)
834
720
  if const_expr(mdB is not None):
835
721
  tXrdB.store(tXrdB.load() + dout)
836
722
 
@@ -839,29 +725,29 @@ class RMSNormBackward(ReductionBase):
839
725
  consumer_phase ^= 1
840
726
  producer_phase ^= 1
841
727
 
842
- if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
843
- cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
844
-
845
728
  if const_expr(tiler_mn[0] > 1):
846
- # reduction of dw_partial within the same threadblock
847
- sdW = cute.make_tensor(
848
- cute.recast_ptr(sX.iterator, dtype=cute.Float32),
849
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
850
- )
851
- tXsdW = thr_copy_X.partition_D(sdW)
852
- cute.arch.barrier()
853
- row = tXcX[None, None, None, 0][0][0]
854
- if row > 0:
855
- cute.autovec_copy(tXrdW, tXsdW)
856
- cute.arch.barrier()
857
- if row == 0:
858
- for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
859
- tXrdW_other = cute.make_fragment_like(tXrdW)
860
- tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
861
- cute.autovec_copy(tXsdW_other, tXrdW_other)
862
- tXrdW.store(tXrdW.load() + tXrdW_other.load())
863
- copy(tXrdW, tXgdW, pred=tXpdW)
864
- cute.arch.barrier()
729
+ if const_expr(mdW is not None):
730
+ # reduction of dw_partial within the same threadblock
731
+ sdW = cute.make_tensor(
732
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
733
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
734
+ )
735
+ tXsdW = thr_copy_X.partition_D(sdW)
736
+ cute.arch.barrier()
737
+ row = tXcX[None, None, None, 0][0][0]
738
+ if row > 0:
739
+ cute.autovec_copy(tXrdW, tXsdW)
740
+ cute.arch.barrier()
741
+ if row == 0:
742
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
743
+ tXrdW_other = cute.make_fragment_like(tXrdW)
744
+ tXsdW_other = cute.make_tensor(
745
+ tXsdW.iterator + i * sdW.stride[0], tXsdW.layout
746
+ )
747
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
748
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
749
+ copy(tXrdW, tXgdW)
750
+ cute.arch.barrier()
865
751
  if const_expr(mdB is not None):
866
752
  sdB = cute.make_tensor(
867
753
  cute.recast_ptr(sX.iterator, dtype=cute.Float32),
@@ -881,12 +767,21 @@ class RMSNormBackward(ReductionBase):
881
767
  )
882
768
  cute.autovec_copy(tXsdB_other, tXrdB_other)
883
769
  tXrdB.store(tXrdB.load() + tXrdB_other.load())
884
- copy(tXrdB, tXgdB, pred=tXpdB)
770
+ copy(tXrdB, tXgdB)
885
771
  else:
886
772
  # dw is already in fp32, so we can directly copy to global memory
887
- copy(tXrdW, tXgdW, pred=tXpdW)
773
+ if const_expr(mdW is not None):
774
+ copy(tXrdW, tXgdW)
888
775
  if const_expr(mdB is not None):
889
- copy(tXrdB, tXgdB, pred=tXpdB)
776
+ copy(tXrdB, tXgdB)
777
+
778
+ if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
779
+ # Assume state contains that next useful buffer
780
+ # So we only need to advance to num_stages - 1 times to last used buffer
781
+ stage ^= 1
782
+ if stage == 0:
783
+ producer_phase ^= 1
784
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
890
785
 
891
786
 
892
787
  def _get_sm_count(N: int, device: torch.device) -> int:
@@ -911,120 +806,103 @@ def _get_sm_count(N: int, device: torch.device) -> int:
911
806
  mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
912
807
  device_types="cuda",
913
808
  # We need to specify the schema manually since we're mutating an optional tensor
914
- schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
809
+ schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()",
915
810
  )
916
811
  def _rmsnorm_bwd(
917
812
  x: Tensor,
918
- weight: Tensor,
813
+ weight: Optional[Tensor],
919
814
  dout: Tensor,
920
815
  rstd: Tensor,
921
816
  dx: Tensor,
922
- dw_partial: Tensor,
817
+ dw_partial: Optional[Tensor],
923
818
  db_partial: Optional[Tensor] = None,
924
819
  dresidual_out: Optional[Tensor] = None,
925
820
  dresidual: Optional[Tensor] = None,
821
+ sm_count: Optional[int] = None,
926
822
  ) -> None:
927
823
  """RMSNorm backward pass.
928
824
  Args:
929
825
  x: Input tensor of shape (M, N)
930
- weight: Weight tensor of shape (N,)
826
+ weight: Optional weight tensor of shape (N,)
931
827
  dout: Upstream gradients tensor of shape (M, N)
932
828
  rstd: Reciprocal standard deviation tensor of shape (M,)
933
829
  Returns:
934
830
  Tuple of (dx, dw) where:
935
831
  - dx: Input gradients tensor of same shape as x
936
- - dw: Weight gradients tensor of same shape as weight
832
+ - dw: Weight gradients tensor of same shape as weight (or None if weight is None)
937
833
  """
938
834
  assert x.dim() == 2, "Input must be 2D"
939
- assert weight.dim() == 1, "Weight must be 1D"
940
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
941
- assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
942
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
943
- assert weight.dtype in [
944
- torch.float32,
945
- torch.bfloat16,
946
- torch.float16,
947
- ], "Weight must be float32, float16 or bfloat16"
835
+ assert x.is_cuda, "Input tensor must be on CUDA device"
836
+ supported_types = {torch.float16, torch.bfloat16, torch.float32}
837
+ assert x.dtype in supported_types, "Unsupported dtype"
838
+ if weight is not None:
839
+ assert weight.dim() == 1, "Weight must be 1D"
840
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
841
+ assert weight.is_cuda, "Weight tensor must be on CUDA device"
842
+ assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
948
843
  if dresidual_out is not None:
949
844
  assert dresidual_out.shape == x.shape
950
845
  assert dresidual_out.is_cuda
951
- assert dresidual_out.dtype in [
952
- torch.float16,
953
- torch.bfloat16,
954
- torch.float32,
955
- ], "Residual must be float16, bfloat16, or float32"
846
+ assert dresidual_out.dtype in supported_types, (
847
+ "Residual must be float16, bfloat16, or float32"
848
+ )
956
849
  if dresidual is not None:
957
850
  assert dresidual.shape == x.shape
958
851
  assert dresidual.is_cuda
959
- assert dresidual.dtype in [
960
- torch.float16,
961
- torch.bfloat16,
962
- torch.float32,
963
- ], "Residual must be float16, bfloat16, or float32"
852
+ assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
964
853
 
965
854
  N = x.size(1)
966
- device = x.device
967
- sm_count = dw_partial.shape[0]
968
- convert_from_dlpack = lambda x: (
969
- from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
970
- )
971
- x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
972
- convert_from_dlpack(t) if t is not None else None
973
- for t in (x, dout, dresidual_out, dx, dresidual)
855
+ if dw_partial is None and db_partial is None:
856
+ assert sm_count is not None
857
+ else:
858
+ sm_count = dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0]
859
+ dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [
860
+ torch2cute_dtype_map[t.dtype] if t is not None else None
861
+ for t in [x, dout, dx, weight, dresidual, dresidual_out]
974
862
  ]
975
- # Handle weight div based on weight dtype
976
- weight_dtype = torch2cute_dtype_map[weight.dtype]
977
- weight_tensor = utils.convert_from_dlpack(
978
- weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
979
- )
980
-
981
- dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
982
- db_partial_tensor = (
983
- from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
984
- if db_partial is not None
985
- else None
986
- )
987
- rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
988
-
989
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
990
-
991
863
  compile_key = (
992
864
  N,
993
- x_tensor.element_type,
994
- weight_tensor.element_type,
995
- db_partial.dtype if db_partial is not None else None,
996
- dresidual.dtype if dresidual is not None else None,
997
- dresidual_out.dtype if dresidual_out is not None else None,
865
+ dtype,
866
+ dout_dtype,
867
+ dx_dtype,
868
+ weight_dtype,
869
+ db_partial is not None,
870
+ dres_dtype,
871
+ dres_out_dtype,
998
872
  )
999
873
  if compile_key not in _rmsnorm_bwd.compile_cache:
1000
- rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
874
+ batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
875
+ all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype]
876
+ div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
877
+ x_cute, dout_cute, dx_cute, dres_out_cute, dres_cute = [
878
+ fake_tensor(dt, (batch_sym, N), div)
879
+ for dt in [dtype, dout_dtype, dx_dtype, dres_out_dtype, dres_dtype]
880
+ ]
881
+ weight_cute = fake_tensor(weight_dtype, (N,), div)
882
+ rstd_cute = fake_tensor(Float32, (batch_sym,))
883
+ dw_partial_cute = (
884
+ fake_tensor(Float32, (batch_partial_sym, N), div) if dw_partial is not None else None
885
+ )
886
+ db_partial_cute = (
887
+ fake_tensor(Float32, (batch_partial_sym, N), div) if db_partial is not None else None
888
+ )
1001
889
  _rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
1002
- rmsnorm_backward_op,
1003
- x_tensor,
1004
- weight_tensor,
1005
- dout_tensor,
1006
- dres_out_tensor,
1007
- rstd_tensor,
1008
- dx_tensor,
1009
- dw_partial_tensor,
1010
- dres_tensor,
1011
- db_partial_tensor,
890
+ RMSNormBackward(dtype, N),
891
+ x_cute,
892
+ weight_cute,
893
+ dout_cute,
894
+ dres_out_cute,
895
+ rstd_cute,
896
+ dx_cute,
897
+ dw_partial_cute,
898
+ dres_cute,
899
+ db_partial_cute,
1012
900
  sm_count,
1013
- current_stream,
901
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
902
+ options="--enable-tvm-ffi",
1014
903
  )
1015
-
1016
904
  _rmsnorm_bwd.compile_cache[compile_key](
1017
- x_tensor,
1018
- weight_tensor,
1019
- dout_tensor,
1020
- dres_out_tensor,
1021
- rstd_tensor,
1022
- dx_tensor,
1023
- dw_partial_tensor,
1024
- dres_tensor,
1025
- db_partial_tensor,
1026
- sm_count,
1027
- current_stream,
905
+ x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count
1028
906
  )
1029
907
 
1030
908
 
@@ -1033,30 +911,37 @@ _rmsnorm_bwd.compile_cache = {}
1033
911
 
1034
912
  def rmsnorm_bwd(
1035
913
  x: Tensor,
1036
- weight: Tensor,
914
+ weight: Optional[Tensor],
1037
915
  dout: Tensor,
1038
916
  rstd: Tensor,
1039
917
  dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
1040
918
  has_bias: bool = False,
1041
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
919
+ has_residual: bool = False,
920
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1042
921
  device = x.device
1043
922
  N = x.size(1)
1044
- sm_count = _get_sm_count(N, device)
1045
923
  dx = torch.empty_like(x)
1046
-
1047
924
  if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
1048
925
  dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
1049
926
  else:
1050
927
  dresidual = None
1051
- # Always store partial gradients in fp32 for numerical accuracy
1052
- dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
928
+ sm_count = _get_sm_count(N, device)
929
+ if weight is not None:
930
+ # Always store partial gradients in fp32 for numerical accuracy
931
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
932
+ else:
933
+ dw_partial = None
1053
934
  db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
1054
- _rmsnorm_bwd(x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual)
935
+
936
+ _rmsnorm_bwd(
937
+ x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count
938
+ )
939
+
1055
940
  # we have summed the partial gradients in fp32, now we convert back to the weight dtype
1056
- dw = dw_partial.sum(dim=0).to(weight.dtype)
941
+ dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None
1057
942
  db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
1058
943
  # dresidual is the same as dx in this case
1059
- if dresidual_out is not None and dresidual_out.dtype == dx.dtype:
944
+ if has_residual and dresidual is None:
1060
945
  dresidual = dx
1061
946
  return dx, dw, db, dresidual
1062
947
 
@@ -1104,7 +989,6 @@ class RMSNormFunction(torch.autograd.Function):
1104
989
  @staticmethod
1105
990
  def backward(ctx, dout, *args):
1106
991
  x, weight, rstd = ctx.saved_tensors
1107
- assert weight is not None, "RMSNorm backward doesn't support weight=None yet"
1108
992
  has_bias = ctx.has_bias
1109
993
  if ctx.prenorm and ctx.residual_dtype is not None:
1110
994
  dresidual_out = args[0]
@@ -1114,11 +998,16 @@ class RMSNormFunction(torch.autograd.Function):
1114
998
  x_shape_og = ctx.x_shape_og
1115
999
  # Reshape dout to match the flattened shape used in forward
1116
1000
  dout = dout.view(-1, dout.shape[-1])
1117
-
1118
- dx, dw, db, dresidual = rmsnorm_bwd(x, weight, dout, rstd, dresidual_out, has_bias)
1001
+ dx, dw, db, dresidual = rmsnorm_bwd(
1002
+ x,
1003
+ weight,
1004
+ dout,
1005
+ rstd,
1006
+ dresidual_out,
1007
+ has_bias,
1008
+ has_residual=ctx.residual_dtype is not None,
1009
+ )
1119
1010
  dx = dx.view(x_shape_og)
1120
- if dresidual_out is not None:
1121
- dresidual_out = dresidual_out.reshape(x_shape_og)
1122
1011
  if dresidual is not None:
1123
1012
  dresidual = dresidual.reshape(x_shape_og)
1124
1013
 
@@ -1148,7 +1037,7 @@ def rmsnorm(
1148
1037
  return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
1149
1038
 
1150
1039
 
1151
- class QuackRMSNorm(torch.nn.Module):
1040
+ class QuackRMSNorm(torch.nn.RMSNorm):
1152
1041
  """RMSNorm module that behaves like torch.nn.RMSNorm.
1153
1042
 
1154
1043
  This class provides a drop-in replacement for torch.nn.RMSNorm that uses
@@ -1163,10 +1052,10 @@ class QuackRMSNorm(torch.nn.Module):
1163
1052
  eps (float): A small constant for numerical stability
1164
1053
  """
1165
1054
 
1166
- def __init__(self, dim: int, eps: float = 1e-6):
1167
- super().__init__()
1168
- self.weight = torch.nn.Parameter(torch.ones(dim))
1169
- self.eps = eps
1055
+ def __init__(
1056
+ self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True, device=None, dtype=None
1057
+ ):
1058
+ super().__init__(dim, eps, elementwise_affine, device=device, dtype=dtype)
1170
1059
 
1171
1060
  def forward(self, x: Tensor) -> Tensor:
1172
1061
  """Apply RMSNorm to the input tensor.
@@ -1179,6 +1068,67 @@ class QuackRMSNorm(torch.nn.Module):
1179
1068
  """
1180
1069
  return rmsnorm(x, self.weight, eps=self.eps)
1181
1070
 
1182
- def reset_parameters(self):
1183
- """Reset the weight parameter to ones."""
1184
- torch.nn.init.ones_(self.weight)
1071
+
1072
+ def layernorm_fwd(
1073
+ x: Tensor,
1074
+ weight: Tensor,
1075
+ bias: Optional[Tensor] = None,
1076
+ eps: float = 1e-6,
1077
+ return_rstd: bool = False,
1078
+ return_mean: bool = False,
1079
+ ):
1080
+ """LayerNorm forward pass using the unified RMSNorm/LayerNorm kernel.
1081
+
1082
+ Args:
1083
+ x: Input tensor of shape (M, N)
1084
+ weight: Weight tensor of shape (N,). Must be float32.
1085
+ bias: Optional bias tensor of shape (N,). Must be float32.
1086
+ eps: Small value for numerical stability
1087
+ return_rstd: Whether to return the reciprocal standard deviation
1088
+ return_mean: Whether to return the mean
1089
+
1090
+ Returns:
1091
+ Normalized output tensor of same shape as x
1092
+ If return_rstd is True, also returns rstd tensor of shape (M,)
1093
+ If return_mean is True, also returns mean tensor of shape (M,)
1094
+ """
1095
+ assert x.dim() == 2, "Input must be 2D"
1096
+ assert weight.dim() == 1, "Weight must be 1D"
1097
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
1098
+ assert weight.dtype == torch.float32, "Weight must be float32"
1099
+ if bias is not None:
1100
+ assert bias.dim() == 1, "Bias must be 1D"
1101
+ assert bias.dtype == torch.float32, "Bias must be float32"
1102
+
1103
+ M, N = x.shape
1104
+ device = x.device
1105
+ out = torch.empty_like(x)
1106
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
1107
+ mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
1108
+
1109
+ _rmsnorm_fwd(x, weight, out, bias, rstd, mean, None, None, eps, True)
1110
+
1111
+ if return_rstd and return_mean:
1112
+ return out, rstd, mean
1113
+ elif return_rstd:
1114
+ return out, rstd
1115
+ elif return_mean:
1116
+ return out, mean
1117
+ return out
1118
+
1119
+
1120
+ def layernorm_ref(x: Tensor, w: Tensor, eps: float = 1e-6) -> Tensor:
1121
+ """Reference implementation for LayerNorm."""
1122
+ x_f32 = x.float()
1123
+ return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
1124
+
1125
+
1126
+ def layernorm_rstd_ref(x: torch.Tensor, eps: float = 1e-6):
1127
+ x_f32 = x.float()
1128
+ mean = x_f32.mean(dim=-1, keepdim=True)
1129
+ var = ((x_f32 - mean) ** 2).mean(dim=-1)
1130
+ return 1.0 / torch.sqrt(var + eps)
1131
+
1132
+
1133
+ def layernorm_mean_ref(x: torch.Tensor) -> torch.Tensor:
1134
+ return x.float().mean(dim=-1)