quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/rmsnorm.py CHANGED
@@ -1,156 +1,91 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- from typing import Optional, Tuple
3
+ import math
4
+ from typing import Optional, Tuple, Type
4
5
  from functools import partial
5
6
 
6
7
  import cuda.bindings.driver as cuda
7
8
 
8
9
  import cutlass
9
10
  import cutlass.cute as cute
10
- from cutlass import Float32, Int32
11
- from cutlass import const_expr
12
- from cutlass.cute.runtime import from_dlpack
11
+ from cutlass import Float32, Int32, const_expr
13
12
 
14
13
  import torch
15
14
  from torch import Tensor
16
15
 
17
16
  import quack.utils as utils
17
+ import quack.copy_utils as copy_utils
18
+ import quack.layout_utils as layout_utils
19
+ from quack.compile_utils import make_fake_tensor as fake_tensor
18
20
  from quack.reduce import row_reduce
19
21
  from quack.reduction_base import ReductionBase
20
22
  from quack.cute_dsl_utils import torch2cute_dtype_map
21
23
 
22
24
 
23
25
  class RMSNorm(ReductionBase):
24
- def __init__(self, dtype: cutlass.Numeric, N: int):
25
- super().__init__(dtype, N, stage=1)
26
- self.reload_from = None if N <= 8192 else "smem"
26
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, is_layernorm: bool = False):
27
+ super().__init__(dtype, N, stage=2 if is_layernorm else 1)
28
+ self.is_layernorm = is_layernorm
29
+ self.reload_from = None if N <= (16384 if is_layernorm else 8192) else "smem"
27
30
  self.delay_w_load = False
28
31
 
29
- def _calculate_threads_per_row(self):
30
- """Calculate the number of threads per row for the RMSNorm kernel."""
32
+ def _threads_per_row(self):
31
33
  N = self.N
32
- if N <= 64:
33
- return 8
34
- elif N <= 128:
35
- return 16
36
- elif N <= 3072:
37
- return 32
38
- elif N <= 6144:
39
- return 64
40
- elif N <= 16384:
41
- return 128
42
- else:
43
- return 256
34
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
35
+ if N <= limit:
36
+ return threads
37
+ return 256
44
38
 
45
39
  def _set_cluster_n(self):
46
- """
47
- Set the number of clusters for the RMSNorm kernel.
48
- Stored in self.cluster_n.
49
- """
50
40
  N = self.N
51
-
52
41
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
53
42
  # Similarly cluster_n = 8 is faster for N=128k
54
43
  if const_expr(self.dtype.width == 16):
55
- # 16-bit types (fp16, bf16)
56
- if N <= 16 * 1024:
57
- cluster_n = 1
58
- elif N <= 32 * 1024:
59
- cluster_n = 2
60
- elif N <= 64 * 1024:
61
- cluster_n = 4
62
- elif N <= 128 * 1024:
63
- cluster_n = 8
64
- else:
65
- cluster_n = 16
44
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
66
45
  else:
67
- # 32-bit types (fp32)
68
- if N <= 32 * 1024:
69
- cluster_n = 1
70
- elif N <= 64 * 1024:
71
- cluster_n = 2
72
- elif N <= 128 * 1024:
73
- cluster_n = 4
74
- elif N <= 256 * 1024:
75
- cluster_n = 8
76
- else:
77
- cluster_n = 16
78
-
79
- self.cluster_n = cluster_n
80
-
81
- def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
82
- return (
83
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
84
- + (
85
- cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
86
- if dtype_res is not None
87
- else 0
88
- )
89
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
90
- + self.stage * (cutlass.Int64.width // 8)
91
- )
46
+ thresholds = [(32 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
47
+ for limit, cluster in thresholds:
48
+ if N <= limit:
49
+ self.cluster_n = cluster
50
+ return
51
+ self.cluster_n = 16
92
52
 
93
53
  @cute.jit
94
54
  def __call__(
95
55
  self,
96
56
  mX: cute.Tensor,
97
- mW: cute.Tensor,
57
+ mW: Optional[cute.Tensor],
98
58
  mB: Optional[cute.Tensor],
99
59
  mRes: Optional[cute.Tensor],
100
60
  mO: cute.Tensor,
101
61
  mResO: Optional[cute.Tensor],
102
62
  mRstd: Optional[cute.Tensor],
63
+ mMean: Optional[cute.Tensor],
64
+ eps: Float32,
103
65
  stream: cuda.CUstream,
104
- eps: Float32 = 1e-6,
105
66
  ):
106
- semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
107
- new_stride = lambda t: (
108
- cute.assume(t.stride[0], divby=128 // t.element_type.width),
109
- t.stride[1],
110
- )
111
- mX, mRes, mO, mResO = [
112
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
113
- if const_expr(t is not None)
114
- else None
115
- for t in (mX, mRes, mO, mResO)
116
- ]
117
67
  assert mX.element_type == self.dtype
118
- assert mO.element_type == self.dtype
119
68
  self._set_cluster_n()
120
69
  largest_dtype_width = const_expr(
121
- max(
122
- mX.element_type.width,
123
- mRes.element_type.width if mRes is not None else 0,
124
- mO.element_type.width,
125
- mResO.element_type.width if mResO is not None else 0,
126
- )
127
- )
128
- tiler_mn, tv_layout = self._get_tv_layout(
129
- num_copy_bits=128 // largest_dtype_width * mX.element_type.width
70
+ max(*(t.element_type.width for t in [mX, mRes, mW, mB, mO, mResO] if t is not None))
130
71
  )
131
- num_threads = cute.size(tv_layout, mode=[0])
132
- num_warps = num_threads // cute.arch.WARP_SIZE
133
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
134
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
135
- if const_expr(mB is not None):
136
- mB_expanded_layout = cute.prepend(
137
- mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
138
- )
139
- mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
140
- if const_expr(mRstd is not None):
141
- mRstd_expanded_layout = cute.append(
142
- mRstd.layout, cute.make_layout((self.N,), stride=(0,))
143
- )
144
- mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
72
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
73
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
74
+ num_threads = tiled_copy.size
75
+ mW, mB = [
76
+ layout_utils.expand(mT, dim=0, size=tiler_mn[0]) if const_expr(mT is not None) else None
77
+ for mT in (mW, mB)
78
+ ]
79
+ mRstd, mMean = [
80
+ layout_utils.expand(mT, dim=1, size=self.N) if const_expr(mT is not None) else None
81
+ for mT in (mRstd, mMean)
82
+ ]
145
83
  self.kernel(
146
- mX, mW, mB, mRes, mO, mResO, mRstd, eps, tv_layout, tiler_mn, self.reload_from
84
+ mX, mW, mB, mRes, mO, mResO, mRstd, mMean, eps, tiler_mn, tiled_copy, threads_per_row
147
85
  ).launch(
148
86
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
149
87
  block=[num_threads, 1, 1],
150
- cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
151
- smem=self._smem_size_in_bytes(
152
- tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
153
- ),
88
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
154
89
  stream=stream,
155
90
  )
156
91
 
@@ -158,30 +93,26 @@ class RMSNorm(ReductionBase):
158
93
  def kernel(
159
94
  self,
160
95
  mX: cute.Tensor,
161
- mW: cute.Tensor,
96
+ mW: Optional[cute.Tensor],
162
97
  mB: Optional[cute.Tensor],
163
98
  mRes: Optional[cute.Tensor],
164
99
  mO: cute.Tensor,
165
100
  mResO: Optional[cute.Tensor],
166
101
  mRstd: Optional[cute.Tensor],
167
- eps: cute.Float32,
168
- tv_layout: cute.Layout,
102
+ mMean: Optional[cute.Tensor],
103
+ eps: Float32,
169
104
  tiler_mn: cute.Shape,
170
- reload_from: cutlass.Constexpr = None,
171
- delay_w_load: cutlass.Constexpr = False,
105
+ tiled_copy: cute.TiledCopy,
106
+ threads_per_row: cutlass.Constexpr[int],
172
107
  ):
173
108
  tidx, _, _ = cute.arch.thread_idx()
174
109
  bidx, _, _ = cute.arch.block_idx()
175
- if const_expr(self.cluster_n > 1):
176
- cluster_y = cute.arch.block_idx()[1]
177
- else:
178
- cluster_y = const_expr(0)
110
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
111
+ tv_layout = tiled_copy.layout_tv_tiled
179
112
 
180
113
  smem = cutlass.utils.SmemAllocator()
181
114
  sX = smem.allocate_tensor(
182
- mX.element_type,
183
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
184
- byte_alignment=16,
115
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
185
116
  )
186
117
  if const_expr(mRes is not None):
187
118
  sRes = smem.allocate_tensor(
@@ -194,73 +125,18 @@ class RMSNorm(ReductionBase):
194
125
  shape = mX.shape
195
126
  idX = cute.make_identity_tensor(shape)
196
127
  # slice for CTAs
197
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
198
- mX, mRes, mO, mResO = [
199
- utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) if mT is not None else None
200
- for mT in (mX, mRes, mO, mResO)
128
+ gX, gRes, gO, gResO, gRstd, gMean, cX = [
129
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) if mT is not None else None
130
+ for mT in (mX, mRes, mO, mResO, mRstd, mMean, idX)
201
131
  ]
202
- gX, gRes, gO, gResO = [
203
- cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
204
- for mT in (mX, mRes, mO, mResO)
132
+ gW, gB = [
133
+ cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
134
+ for mT in (mW, mB)
205
135
  ]
206
- cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
207
- gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
208
- gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
209
- gRstd = (
210
- cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
211
- if const_expr(mRstd is not None)
212
- else None
213
- )
214
136
 
215
- # declare the atoms which will be used later for memory copy
216
- num_copy_elems_X = tv_layout.shape[1][0]
217
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
218
- copy_atom_load_X = cute.make_copy_atom(
219
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
220
- )
221
- copy_atom_load_X_async = cute.make_copy_atom(
222
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
223
- )
224
- num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
225
- copy_atom_load_W = cute.make_copy_atom(
226
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
227
- )
228
- num_bits_per_copy_B = (
229
- cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
230
- if const_expr(mB is not None)
231
- else 0
232
- )
233
- copy_atom_load_B = (
234
- cute.make_copy_atom(
235
- cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
236
- )
237
- if const_expr(mB is not None)
238
- else None
239
- )
240
- if const_expr(mRes is not None):
241
- num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
242
- copy_atom_load_Res_async = cute.make_copy_atom(
243
- cute.nvgpu.cpasync.CopyG2SOp(),
244
- mRes.element_type,
245
- num_bits_per_copy=num_copy_bits_Res,
246
- )
247
- num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
248
- copy_atom_store_O = cute.make_copy_atom(
249
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
250
- )
251
- if const_expr(mResO is not None):
252
- num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
253
- copy_atom_store_ResO = cute.make_copy_atom(
254
- cute.nvgpu.CopyUniversalOp(),
255
- mResO.element_type,
256
- num_bits_per_copy=num_copy_bits_ResO,
257
- )
137
+ thr_copy_X = tiled_copy.get_slice(tidx)
258
138
 
259
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
260
- tidx
261
- )
262
-
263
- tXgW = thr_copy_X.partition_S(gW)
139
+ tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
264
140
  tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
265
141
  tXgX = thr_copy_X.partition_S(gX)
266
142
  tXsX = thr_copy_X.partition_D(sX)
@@ -271,34 +147,40 @@ class RMSNorm(ReductionBase):
271
147
  if const_expr(mResO is not None):
272
148
  tXgResO = thr_copy_X.partition_D(gResO)
273
149
  tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
150
+ tXrMean = thr_copy_X.partition_D(gMean) if const_expr(mMean is not None) else None
274
151
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
275
152
 
276
153
  # allocate fragments for gmem->rmem
277
- tXrW = cute.make_fragment_like(tXgW)
278
- tXrW.fill(0.0)
154
+ tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
279
155
  tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
280
156
  tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
281
157
  if const_expr(mRes is not None):
282
158
  tXrRes = cute.make_fragment_like(tXgRes)
283
159
 
284
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
160
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
285
161
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
286
162
 
287
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
163
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
288
164
  tXpX = (
289
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
165
+ copy_utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
166
+ if not is_even_N
167
+ else None
290
168
  )
169
+ # Each copy will use the same predicate
170
+ copy = partial(copy_utils.copy, pred=tXpX)
171
+
291
172
  row = tXcX[0][0]
292
173
  if row < shape[0]:
293
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
174
+ copy(tXgX, tXsX, is_async=True)
294
175
  if const_expr(mRes is not None):
295
- cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
176
+ copy(tXgRes, tXsRes, is_async=True)
296
177
  cute.arch.cp_async_commit_group()
297
178
 
298
- if const_expr(not delay_w_load):
299
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
179
+ if const_expr(not self.delay_w_load):
180
+ if const_expr(mW is not None):
181
+ copy(tXgW, tXrW)
300
182
  if const_expr(mB is not None):
301
- cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
183
+ copy(tXgB, tXrB)
302
184
 
303
185
  cute.arch.cp_async_wait_group(0)
304
186
  cute.autovec_copy(tXsX, tXrX)
@@ -310,19 +192,63 @@ class RMSNorm(ReductionBase):
310
192
  tXrResO = cute.make_fragment_like(tXgResO)
311
193
  tXrResO.store(x.to(tXrResO.element_type))
312
194
  if row < shape[0]:
313
- cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
314
-
315
- threads_per_row = tv_layout.shape[0][0]
316
- sum_sq_x = row_reduce(
317
- x * x,
318
- cute.ReductionOp.ADD,
319
- threads_per_row,
320
- reduction_buffer[None, None, 0],
321
- mbar_ptr,
322
- init_val=0.0,
323
- hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
324
- )
325
- rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
195
+ copy(tXrResO, tXgResO)
196
+
197
+ mean, rstd = None, None
198
+ if const_expr(self.is_layernorm):
199
+ # LayerNorm: compute mean first, then variance
200
+ sum_x = row_reduce(
201
+ x,
202
+ cute.ReductionOp.ADD,
203
+ threads_per_row,
204
+ reduction_buffer[None, None, 0],
205
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
206
+ init_val=0.0,
207
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
208
+ )
209
+ mean = sum_x / shape[1]
210
+ if const_expr(mMean is not None):
211
+ # Only the thread corresponding to column 0 writes out the mean to gmem
212
+ if (
213
+ tXcX[0][1] == 0
214
+ and row < shape[0]
215
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
216
+ ):
217
+ tXrMean[0] = mean
218
+ if const_expr(self.reload_from == "smem"):
219
+ cute.autovec_copy(tXsX, tXrX)
220
+ x = tXrX.load().to(cute.Float32)
221
+ if const_expr(mRes is not None):
222
+ cute.autovec_copy(tXsRes, tXrRes)
223
+ x += tXrRes.load().to(cute.Float32)
224
+ elif const_expr(self.reload_from == "gmem"):
225
+ copy(tXgX, tXrX)
226
+ x = tXrX.load().to(cute.Float32)
227
+ if const_expr(mRes is not None):
228
+ copy(tXgRes, tXrRes)
229
+ x += tXrRes.load().to(cute.Float32)
230
+ sum_sq_x_sub_mean = row_reduce(
231
+ (x - mean) * (x - mean),
232
+ cute.ReductionOp.ADD,
233
+ threads_per_row,
234
+ reduction_buffer[None, None, 1],
235
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
236
+ init_val=0.0,
237
+ )
238
+ rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
239
+ else:
240
+ # RMSNorm: compute sum of squares directly
241
+ mean = const_expr(0.0)
242
+ sum_sq_x = row_reduce(
243
+ x * x,
244
+ cute.ReductionOp.ADD,
245
+ threads_per_row,
246
+ reduction_buffer[None, None, 0],
247
+ mbar_ptr,
248
+ init_val=0.0,
249
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
250
+ )
251
+ rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
326
252
  if const_expr(mRstd is not None):
327
253
  # Only the thread corresponding to column 0 writes out the rstd to gmem
328
254
  if (
@@ -331,139 +257,114 @@ class RMSNorm(ReductionBase):
331
257
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
332
258
  ):
333
259
  tXrRstd[0] = rstd
334
- if const_expr(delay_w_load):
335
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
260
+ if const_expr(self.delay_w_load):
261
+ if const_expr(mW is not None):
262
+ copy(tXgW, tXrW)
336
263
  if const_expr(mB is not None):
337
- cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
338
- if const_expr(reload_from == "smem" or reload_from == "gmem"):
339
- if const_expr(reload_from == "smem"):
264
+ copy(tXgB, tXrB)
265
+ if const_expr(self.reload_from == "smem" or self.reload_from == "gmem"):
266
+ if const_expr(self.reload_from == "smem"):
340
267
  cute.autovec_copy(tXsX, tXrX)
268
+ if const_expr(mRes is not None):
269
+ cute.autovec_copy(tXsRes, tXrRes)
341
270
  else:
342
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
271
+ copy(tXgX, tXrX)
272
+ if const_expr(mRes is not None):
273
+ copy(tXgRes, tXrRes)
343
274
  x = tXrX.load().to(cute.Float32)
344
275
  if const_expr(mRes is not None):
345
- cute.autovec_copy(tXsRes, tXrRes)
346
276
  x += tXrRes.load().to(cute.Float32)
347
- x_hat = x * rstd
348
- w = tXrW.load().to(cute.Float32)
349
- y = x_hat * w
277
+ x_hat = (x - mean) * rstd if const_expr(self.is_layernorm) else x * rstd
278
+ y = x_hat
279
+ if const_expr(mW is not None):
280
+ y *= tXrW.load().to(cute.Float32)
350
281
  if const_expr(mB is not None):
351
- b = tXrB.load().to(cute.Float32)
352
- y = y + b
282
+ y += tXrB.load().to(cute.Float32)
353
283
  tXrO.store(y.to(tXrO.element_type))
354
284
  if row < shape[0]:
355
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
285
+ copy(tXrO, tXgO)
356
286
 
357
287
 
358
288
  @torch.library.custom_op(
359
289
  "quack::_rmsnorm_fwd",
360
- mutates_args=("out", "rstd", "residual_out"),
290
+ mutates_args=("out", "rstd", "mean", "residual_out"),
361
291
  device_types="cuda",
362
292
  # We need to specify the schema manually since we're mutating an optional tensor
363
- schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
293
+ schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor(a5!)? mean, Tensor? residual, Tensor(a7!)? residual_out, float eps=1e-6, bool is_layernorm=False) -> ()",
364
294
  )
365
295
  def _rmsnorm_fwd(
366
296
  x: Tensor,
367
- weight: Tensor,
297
+ weight: Optional[Tensor],
368
298
  out: Tensor,
369
299
  bias: Optional[Tensor] = None,
370
300
  rstd: Optional[Tensor] = None,
301
+ mean: Optional[Tensor] = None,
371
302
  residual: Optional[Tensor] = None,
372
303
  residual_out: Optional[Tensor] = None,
373
304
  eps: float = 1e-6,
305
+ is_layernorm: bool = False,
374
306
  ) -> None:
375
- """RMSNorm forward pass.
307
+ """RMSNorm/LayerNorm forward pass.
376
308
  Args:
377
309
  x: Input tensor of shape (M, N)
378
- weight: Weight tensor of shape (N,)
310
+ weight: Optional weight tensor of shape (N,)
379
311
  eps: Small value for numerical stability
312
+ is_layernorm: If True, compute LayerNorm instead of RMSNorm
380
313
  Returns:
381
314
  Normalized output tensor of same shape as x
382
315
  """
383
- assert x.dim() == 2, "Input must be 2D"
384
- assert weight.dim() == 1, "Weight must be 1D"
385
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
386
- assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
387
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
388
- assert weight.dtype in [
389
- torch.float32,
390
- torch.bfloat16,
391
- torch.float16,
392
- ], "Weight must be float32, float16 or bfloat16"
316
+ # Don't need to check is_cuda since torch.library ensures that
317
+ supported_types = {torch.float16, torch.bfloat16, torch.float32}
318
+ assert x.dtype in supported_types, "Unsupported dtype"
319
+ if weight is not None:
320
+ assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
393
321
  if residual is not None:
394
- assert residual.shape == x.shape
395
- assert residual.is_cuda
396
- assert residual.dtype in [
397
- torch.float16,
398
- torch.bfloat16,
399
- torch.float32,
400
- ], "Residual must be float16, bfloat16, or float32"
322
+ assert residual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
401
323
 
402
324
  _, N = x.shape
403
- device = x.device
404
- dtype = torch2cute_dtype_map[x.dtype]
405
- # convert_from_dlpack = lambda x: (
406
- # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
407
- # mode=0, divisibility=128 // dtype.width
408
- # )
409
- # )
410
- convert_from_dlpack = lambda x: (
411
- from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
412
- )
413
- x_tensor, res_tensor, out_tensor, res_out_tensor = [
414
- convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
325
+ dtype, out_dtype, weight_dtype, bias_dtype, res_dtype, res_out_dtype = [
326
+ torch2cute_dtype_map[t.dtype] if t is not None else None
327
+ for t in [x, out, weight, bias, residual, residual_out]
415
328
  ]
416
- # handle weight divisibility based on weight dtype
417
- weight_dtype = torch2cute_dtype_map[weight.dtype]
418
- weight_tensor = utils.convert_from_dlpack(
419
- weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
420
- )
421
- if bias is not None:
422
- bias_dtype = torch2cute_dtype_map[bias.dtype]
423
- bias_tensor = utils.convert_from_dlpack(
424
- bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
425
- )
426
- else:
427
- bias_tensor = None
428
- rstd_tensor = (
429
- from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
430
- if rstd is not None
431
- else None
432
- )
433
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
434
329
  compile_key = (
435
- N,
436
330
  dtype,
437
- res_tensor.element_type if residual is not None else None,
438
- weight_tensor.element_type,
439
- bias_tensor.element_type if bias is not None else None,
440
- res_out_tensor.element_type if residual_out is not None else None,
331
+ out_dtype,
332
+ res_dtype,
333
+ weight_dtype,
334
+ bias_dtype,
335
+ res_out_dtype,
336
+ N,
441
337
  rstd is not None,
338
+ mean is not None,
339
+ is_layernorm,
442
340
  )
443
341
  if compile_key not in _rmsnorm_fwd.compile_cache:
444
- rmsnorm_op = RMSNorm(dtype, N)
342
+ batch_sym = cute.sym_int()
343
+ all_dtypes = [dtype, out_dtype, res_dtype, weight_dtype, bias_dtype, res_out_dtype]
344
+ div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
345
+ x_cute, out_cute, res_cute, res_out_cute = [
346
+ fake_tensor(dt, (batch_sym, N), div)
347
+ for dt in [dtype, out_dtype, res_dtype, res_out_dtype]
348
+ ]
349
+ weight_cute, bias_cute = [fake_tensor(dt, (N,), div) for dt in [weight_dtype, bias_dtype]]
350
+ rstd_cute = fake_tensor(Float32, (batch_sym,)) if rstd is not None else None
351
+ mean_cute = fake_tensor(Float32, (batch_sym,)) if mean is not None else None
445
352
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
446
- rmsnorm_op,
447
- x_tensor,
448
- weight_tensor,
449
- bias_tensor,
450
- res_tensor,
451
- out_tensor,
452
- res_out_tensor,
453
- rstd_tensor,
454
- current_stream,
455
- eps,
353
+ RMSNorm(dtype, N, is_layernorm=is_layernorm),
354
+ x_cute,
355
+ weight_cute,
356
+ bias_cute,
357
+ res_cute,
358
+ out_cute,
359
+ res_out_cute,
360
+ rstd_cute,
361
+ mean_cute,
362
+ Float32(0), # eps, just for compilation
363
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
364
+ options="--enable-tvm-ffi",
456
365
  )
457
366
  _rmsnorm_fwd.compile_cache[compile_key](
458
- x_tensor,
459
- weight_tensor,
460
- bias_tensor,
461
- res_tensor,
462
- out_tensor,
463
- res_out_tensor,
464
- rstd_tensor,
465
- current_stream,
466
- eps,
367
+ x, weight, bias, residual, out, residual_out, rstd, mean, eps
467
368
  )
468
369
 
469
370
 
@@ -472,7 +373,7 @@ _rmsnorm_fwd.compile_cache = {}
472
373
 
473
374
  def rmsnorm_fwd(
474
375
  x: Tensor,
475
- weight: Tensor,
376
+ weight: Optional[Tensor] = None,
476
377
  bias: Optional[Tensor] = None,
477
378
  residual: Optional[Tensor] = None,
478
379
  out_dtype: Optional[torch.dtype] = None,
@@ -494,19 +395,20 @@ def rmsnorm_fwd(
494
395
  )
495
396
  else:
496
397
  residual_out = None
497
- _rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps=eps)
398
+ _rmsnorm_fwd(x, weight, out, bias, rstd, None, residual, residual_out, eps, False)
498
399
  # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
499
400
  if residual_out is None:
500
401
  residual_out = x
501
402
  return out, residual_out, rstd
502
403
 
503
404
 
504
- def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
405
+ def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
505
406
  x_f32 = x.float()
506
407
  if residual is not None:
507
408
  residual_f32 = residual.float()
508
409
  x_f32 += residual_f32
509
- out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
410
+ x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
411
+ out = x_norm * w if w is not None else x_norm
510
412
  if bias is not None:
511
413
  out = out + bias.float()
512
414
  if residual is None:
@@ -519,13 +421,19 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
519
421
  """Reference implementation for RMSNorm backward pass."""
520
422
  x_f32 = x.float()
521
423
  x_hat = x_f32 * rstd.unsqueeze(1)
522
- wdy = dout * w
424
+ if w is not None:
425
+ wdy = dout * w
426
+ else:
427
+ wdy = dout
523
428
  c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
524
429
  dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
525
430
 
526
431
  # dL/dW
527
- dw = (dout * x_hat).sum(dim=0)
528
- return dx.to(x.dtype), dw.to(w.dtype)
432
+ if w is not None:
433
+ dw = (dout * x_hat).sum(dim=0)
434
+ return dx.to(x.dtype), dw.to(w.dtype)
435
+ else:
436
+ return dx.to(x.dtype), None
529
437
 
530
438
 
531
439
  class RMSNormBackward(ReductionBase):
@@ -537,91 +445,57 @@ class RMSNormBackward(ReductionBase):
537
445
  # Not enough smem
538
446
  raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
539
447
 
540
- def _get_num_threads(self):
448
+ def _num_threads(self):
541
449
  return 128 if self.N <= 4096 else 256
542
450
 
543
- def _calculate_threads_per_row(self):
451
+ def _threads_per_row(self):
544
452
  N = self.N
545
- return (
546
- 8
547
- if N <= 64
548
- else (
549
- 16
550
- if N <= 128
551
- else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
552
- )
553
- )
453
+ for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]:
454
+ if N <= limit:
455
+ return threads
456
+ return 256
554
457
 
555
458
  def _set_cluster_n(self):
556
459
  N = self.N
557
- cluster_n = (
558
- 1
559
- if N <= 8 * 1024
560
- else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
561
- )
562
- self.cluster_n = cluster_n
563
-
564
- def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
565
- if do_dtype is None:
566
- do_dtype = self.dtype
567
- return (
568
- # We need space for X and dO, and multiply by 2 due to double buffering
569
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
570
- + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
571
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
572
- + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
573
- )
460
+ for limit, cluster in [(8 * 1024, 1), (16 * 1024, 2), (32 * 1024, 4), (64 * 1024, 8)]:
461
+ if N <= limit:
462
+ self.cluster_n = cluster
463
+ return
464
+ self.cluster_n = 16
574
465
 
575
466
  @cute.jit
576
467
  def __call__(
577
468
  self,
578
469
  mX: cute.Tensor,
579
- mW: cute.Tensor,
470
+ mW: Optional[cute.Tensor],
580
471
  mdO: cute.Tensor,
581
472
  mdResO: Optional[cute.Tensor],
582
473
  mRstd: cute.Tensor,
583
474
  mdX: cute.Tensor,
584
- mdW: cute.Tensor,
475
+ mdW: Optional[cute.Tensor],
585
476
  mdRes: Optional[cute.Tensor],
586
477
  mdB: Optional[cute.Tensor],
587
478
  sm_count: Int32,
588
479
  stream: cuda.CUstream,
589
480
  ):
590
- semistatic_shape = (*mX.shape[:-1], self.N) # Set last dimension to be statically N
591
- new_stride = lambda t: (
592
- cute.assume(t.stride[0], divby=128 // t.element_type.width),
593
- t.stride[1],
594
- )
595
- mX, mdO, mdResO, mdX, mdRes = [
596
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
597
- if const_expr(t is not None)
598
- else None
599
- for t in (mX, mdO, mdResO, mdX, mdRes)
600
- ]
481
+ assert mX.element_type == self.dtype
601
482
  self._set_cluster_n()
602
483
  largest_dtype_width = const_expr(
603
- max(
604
- mX.element_type.width,
605
- mdO.element_type.width,
606
- mdX.element_type.width,
607
- mdResO.element_type.width if mdResO is not None else 0,
608
- mdRes.element_type.width if mdRes is not None else 0,
609
- )
484
+ max(*(t.element_type.width for t in [mX, mW, mdO, mdResO, mdX, mdRes] if t is not None))
610
485
  )
611
- tiler_mn, tv_layout = self._get_tv_layout(
612
- num_copy_bits=128 // largest_dtype_width * mX.element_type.width
486
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
487
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
488
+ num_threads = tiled_copy.size
489
+ mW = (
490
+ layout_utils.expand(mW, dim=0, size=tiler_mn[0]) if const_expr(mW is not None) else None
613
491
  )
614
- num_threads = cute.size(tv_layout, mode=[0])
615
- num_warps = num_threads // cute.arch.WARP_SIZE
616
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
617
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
618
-
619
492
  num_blocks = sm_count
620
- self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
493
+ self.kernel(
494
+ mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tiler_mn, tiled_copy, threads_per_row
495
+ ).launch(
621
496
  grid=[num_blocks, self.cluster_n, 1],
622
497
  block=[num_threads, 1, 1],
623
498
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
624
- smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
625
499
  stream=stream,
626
500
  )
627
501
 
@@ -629,24 +503,23 @@ class RMSNormBackward(ReductionBase):
629
503
  def kernel(
630
504
  self,
631
505
  mX: cute.Tensor,
632
- mW: cute.Tensor,
506
+ mW: Optional[cute.Tensor],
633
507
  mdO: cute.Tensor,
634
508
  mdResO: Optional[cute.Tensor],
635
509
  mRstd: cute.Tensor,
636
510
  mdX: cute.Tensor,
637
- mdW: cute.Tensor,
511
+ mdW: Optional[cute.Tensor],
638
512
  mdB: Optional[cute.Tensor],
639
513
  mdRes: Optional[cute.Tensor],
640
- tv_layout: cute.Layout,
641
514
  tiler_mn: cute.Shape,
515
+ tiled_copy: cute.TiledCopy,
516
+ threads_per_row: cutlass.Constexpr[int],
642
517
  ):
643
518
  tidx, _, _ = cute.arch.thread_idx()
644
519
  bidx_start, _, _ = cute.arch.block_idx()
645
520
  gdim, _, _ = cute.arch.grid_dim()
646
- if const_expr(self.cluster_n > 1):
647
- cluster_y = cute.arch.block_idx()[1]
648
- else:
649
- cluster_y = const_expr(0)
521
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
522
+ tv_layout = tiled_copy.layout_tv_tiled
650
523
 
651
524
  shape = mX.shape
652
525
  M, N = shape[0], shape[1]
@@ -666,103 +539,20 @@ class RMSNormBackward(ReductionBase):
666
539
  else:
667
540
  mbar_full_ptr, mbar_empty_ptr = None, None
668
541
 
669
- num_copy_elems_X = tv_layout.shape[1][0]
670
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
671
- copy_atom_load_X = cute.make_copy_atom(
672
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
673
- )
674
- copy_atom_load_X_async = cute.make_copy_atom(
675
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
676
- )
677
- num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
678
- copy_atom_load_dO_async = cute.make_copy_atom(
679
- cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
680
- )
681
- num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
682
- copy_atom_load_W = cute.make_copy_atom(
683
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
684
- )
685
- if const_expr(mdResO is not None):
686
- num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
687
- copy_atom_load_dResO = cute.make_copy_atom(
688
- cute.nvgpu.CopyUniversalOp(),
689
- mdResO.element_type,
690
- num_bits_per_copy=num_copy_bits_dResO,
691
- )
692
- num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
693
- copy_atom_store_dX = cute.make_copy_atom(
694
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
695
- )
696
- num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
697
- copy_atom_store_dW = cute.make_copy_atom(
698
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
699
- )
700
- if const_expr(mdB is not None):
701
- num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
702
- copy_atom_store_dB = cute.make_copy_atom(
703
- cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
704
- )
705
- if const_expr(mdRes is not None):
706
- num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
707
- copy_atom_load_dRes = cute.make_copy_atom(
708
- cute.nvgpu.CopyUniversalOp(),
709
- mdRes.element_type,
710
- num_bits_per_copy=num_copy_bits_dRes,
711
- )
712
-
713
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
714
-
715
- gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
716
- tXgW = thr_copy_X.partition_S(gW)
717
- tXrW = cute.make_fragment_like(tXgW)
718
- # Need this, otherwise rW can have arbitrary values that changes the reduction
719
- if not is_even_N:
720
- tXrW.fill(0.0)
721
-
722
- gW_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
723
- tXpW = (
724
- utils.predicate_k(thr_copy_X.partition_S(gW_coord), limit=shape[1])
725
- if not is_even_N
726
- else None
727
- )
728
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
729
- weight = tXrW.load().to(cute.Float32)
730
-
731
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
732
-
733
- self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
734
-
735
- dw_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
736
- tXpdW = (
737
- utils.predicate_k(thr_copy_X.partition_S(dw_coord), limit=shape[1])
738
- if not is_even_N
739
- else None
740
- )
741
- if const_expr(mdB is not None):
742
- db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
743
- tXpdB = (
744
- utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
745
- if not is_even_N
746
- else None
747
- )
748
-
749
- gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
750
- tXgdW = thr_copy_X.partition_S(gdW)
751
- # Always compute partial weight gradients in fp32
752
- tXrdW = cute.make_fragment_like(tXgdW, Float32)
753
-
754
- gdB = (
755
- cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
756
- if const_expr(mdB is not None)
757
- else None
758
- )
759
- tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
760
- tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
542
+ thr_copy_X = tiled_copy.get_slice(tidx)
761
543
 
762
544
  gX, gdO, gdResO, gdX, gdRes, cX = [
763
545
  cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
764
546
  for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
765
547
  ]
548
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None
549
+ gdW, gdB = [
550
+ cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y))
551
+ if const_expr(mT is not None)
552
+ else None
553
+ for mT in (mdW, mdB)
554
+ ]
555
+
766
556
  tXgX = thr_copy_X.partition_S(gX)
767
557
  tXsX = thr_copy_X.partition_D(sX)
768
558
  tXgdO = thr_copy_X.partition_S(gdO)
@@ -773,12 +563,6 @@ class RMSNormBackward(ReductionBase):
773
563
  if const_expr(mdRes is not None):
774
564
  tXgdRes = thr_copy_X.partition_D(gdRes)
775
565
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
776
- # This doesn't change across iterations
777
- tXpX = (
778
- utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
779
- if not is_even_N
780
- else None
781
- )
782
566
 
783
567
  tXrX, tXrdO, tXrdX = [
784
568
  cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
@@ -790,28 +574,57 @@ class RMSNormBackward(ReductionBase):
790
574
  if const_expr(mdRes is not None):
791
575
  tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
792
576
 
793
- copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
794
- copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
577
+ # This doesn't change across iterations
578
+ tXpX = (
579
+ None
580
+ if is_even_N
581
+ else copy_utils.predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
582
+ )
583
+ # Each copy will use the same number of elements as X
584
+ copy = partial(copy_utils.copy, pred=tXpX)
585
+
586
+ tXgdW, tXrdW = None, None
587
+ tXgdB, tXrdB = None, None
588
+ if const_expr(mdW is not None):
589
+ tXgdW = thr_copy_X.partition_S(gdW)
590
+ # Always compute partial weight gradients in fp32
591
+ tXrdW = cute.make_fragment_like(tXgdW, Float32)
592
+ if const_expr(mdB is not None):
593
+ tXgdB = thr_copy_X.partition_S(gdB)
594
+ # Always compute partial bias gradients in fp32
595
+ tXrdB = cute.make_fragment_like(tXgdB, Float32)
596
+
597
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
598
+
599
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
600
+
601
+ tXrW = None
602
+ if const_expr(mW is not None):
603
+ tXgW = thr_copy_X.partition_S(gW)
604
+ tXrW = cute.make_fragment_like(tXgW)
605
+ # Need this, otherwise rW can have arbitrary values that changes the reduction
606
+ if const_expr(not is_even_N):
607
+ tXrW.fill(0.0)
608
+ copy(tXgW, tXrW)
795
609
 
796
610
  # Prefetch the first batch
797
611
  row = tXcX[None, None, None, bidx_start][0][0]
798
612
  if row < M:
799
- tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
800
- tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
801
- copy_X(tXgX_cur, tXsX[None, None, None, 0])
802
- copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
803
- elif tiler_mn[0] > 1:
804
- # Fill with zero, otherwise smem will be uninitialized, and we could read this back
805
- # later into registers, causing wrong dW.
806
- utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
807
- utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
613
+ copy(tXgX[None, None, None, bidx_start], tXsX[None, None, None, 0], is_async=True)
614
+ copy(tXgdO[None, None, None, bidx_start], tXsdO[None, None, None, 0], is_async=True)
615
+ else:
616
+ if const_expr(tiler_mn[0] > 1):
617
+ # Fill with zero, otherwise smem will be uninitialized, and we could read this back
618
+ # later into registers, causing wrong dW.
619
+ utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
620
+ utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
808
621
  cute.arch.cp_async_commit_group()
809
622
 
810
623
  if const_expr(self.cluster_n > 1):
811
624
  cute.arch.cluster_wait()
812
625
 
813
- threads_per_row = tv_layout.shape[0][0]
814
- tXrdW.fill(0.0)
626
+ if const_expr(mdW is not None):
627
+ tXrdW.fill(0.0)
815
628
  if const_expr(mdB is not None):
816
629
  tXrdB.fill(0.0)
817
630
  stage = Int32(0)
@@ -820,29 +633,31 @@ class RMSNormBackward(ReductionBase):
820
633
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
821
634
  row = tXcX[None, None, None, bidx][0][0]
822
635
  if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
823
- tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
824
- tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
825
- copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
826
- copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
827
- elif tiler_mn[0] > 1:
828
- utils.fill_oob(
636
+ copy(
637
+ tXgX[None, None, None, bidx + gdim],
829
638
  tXsX[None, None, None, stage ^ 1],
830
- None,
831
- fill_value=mX.element_type.zero,
639
+ is_async=True,
832
640
  )
833
- utils.fill_oob(
641
+ copy(
642
+ tXgdO[None, None, None, bidx + gdim],
834
643
  tXsdO[None, None, None, stage ^ 1],
835
- None,
836
- fill_value=mdO.element_type.zero,
644
+ is_async=True,
837
645
  )
646
+ else:
647
+ if const_expr(tiler_mn[0] > 1):
648
+ utils.fill_oob(
649
+ tXsX[None, None, None, stage ^ 1], None, fill_value=mX.element_type.zero
650
+ )
651
+ utils.fill_oob(
652
+ tXsdO[None, None, None, stage ^ 1], None, fill_value=mdO.element_type.zero
653
+ )
838
654
  cute.arch.cp_async_commit_group()
839
655
  rstd = cutlass.Float.zero
840
656
  if row < M or tiler_mn[0] == 1:
841
657
  rstd = mRstd[row]
842
658
  if const_expr(mdResO is not None):
843
- tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
844
659
  if row < M or tiler_mn[0] == 1:
845
- cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
660
+ copy(tXgdResO[None, None, None, bidx], tXrdResO)
846
661
  elif tiler_mn[0] > 1:
847
662
  tXrdResO.fill(0.0)
848
663
  cute.arch.cp_async_wait_group(1)
@@ -850,10 +665,10 @@ class RMSNormBackward(ReductionBase):
850
665
  x = tXrX.load().to(cute.Float32)
851
666
  cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
852
667
  dout = tXrdO.load().to(cute.Float32)
853
- if const_expr(mdResO is not None):
854
- dout += tXrdResO.load().to(cute.Float32)
855
668
  x_hat = x * rstd
856
- wdy = dout * weight
669
+ wdy = dout
670
+ if const_expr(mW is not None):
671
+ wdy *= tXrW.load().to(Float32)
857
672
  if const_expr(self.cluster_n > 1):
858
673
  cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
859
674
  mean_xhat_wdy = (
@@ -870,6 +685,10 @@ class RMSNormBackward(ReductionBase):
870
685
  )
871
686
 
872
687
  if const_expr(self.cluster_n > 1):
688
+ # Need this fence since the STAS from the producer is using the async proxy.
689
+ cute.arch.fence_proxy(
690
+ cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
691
+ )
873
692
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
874
693
  # Requires adjusting the thread_count when initializing the mbar
875
694
  cute.arch.sync_warp()
@@ -882,22 +701,22 @@ class RMSNormBackward(ReductionBase):
882
701
  if const_expr(self.reload_wdy == "smem"):
883
702
  cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
884
703
  dout = tXrdO.load().to(cute.Float32)
885
- if const_expr(mdResO is not None):
886
- dout += tXrdResO.load().to(cute.Float32)
887
- wdy = dout * weight
704
+ wdy = dout
705
+ if const_expr(mW is not None):
706
+ wdy *= tXrW.load().to(Float32)
888
707
 
889
708
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
709
+ if const_expr(mdResO is not None):
710
+ dx += tXrdResO.load().to(cute.Float32)
890
711
  tXrdX.store(dx.to(tXrdX.element_type))
891
712
  if row < M or tiler_mn[0] == 1:
892
- tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
893
- cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
713
+ copy(tXrdX, tXgdX[None, None, None, bidx])
894
714
  if const_expr(mdRes is not None):
895
715
  tXrdRes.store(dx.to(tXrdRes.element_type))
896
- tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
897
716
  if row < M or tiler_mn[0] == 1:
898
- cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
899
- # Accumulate weight gradients in fp32
900
- tXrdW.store(tXrdW.load() + dout * x_hat)
717
+ copy(tXrdRes, tXgdRes[None, None, None, bidx])
718
+ if const_expr(mdW is not None):
719
+ tXrdW.store(tXrdW.load() + dout * x_hat)
901
720
  if const_expr(mdB is not None):
902
721
  tXrdB.store(tXrdB.load() + dout)
903
722
 
@@ -906,29 +725,29 @@ class RMSNormBackward(ReductionBase):
906
725
  consumer_phase ^= 1
907
726
  producer_phase ^= 1
908
727
 
909
- if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
910
- cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
911
-
912
728
  if const_expr(tiler_mn[0] > 1):
913
- # reduction of dw_partial within the same threadblock
914
- sdW = cute.make_tensor(
915
- cute.recast_ptr(sX.iterator, dtype=cute.Float32),
916
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
917
- )
918
- tXsdW = thr_copy_X.partition_D(sdW)
919
- cute.arch.barrier()
920
- row = tXcX[None, None, None, 0][0][0]
921
- if row > 0:
922
- cute.autovec_copy(tXrdW, tXsdW)
923
- cute.arch.barrier()
924
- if row == 0:
925
- for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
926
- tXrdW_other = cute.make_fragment_like(tXrdW)
927
- tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
928
- cute.autovec_copy(tXsdW_other, tXrdW_other)
929
- tXrdW.store(tXrdW.load() + tXrdW_other.load())
930
- cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
931
- cute.arch.barrier()
729
+ if const_expr(mdW is not None):
730
+ # reduction of dw_partial within the same threadblock
731
+ sdW = cute.make_tensor(
732
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
733
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
734
+ )
735
+ tXsdW = thr_copy_X.partition_D(sdW)
736
+ cute.arch.barrier()
737
+ row = tXcX[None, None, None, 0][0][0]
738
+ if row > 0:
739
+ cute.autovec_copy(tXrdW, tXsdW)
740
+ cute.arch.barrier()
741
+ if row == 0:
742
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
743
+ tXrdW_other = cute.make_fragment_like(tXrdW)
744
+ tXsdW_other = cute.make_tensor(
745
+ tXsdW.iterator + i * sdW.stride[0], tXsdW.layout
746
+ )
747
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
748
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
749
+ copy(tXrdW, tXgdW)
750
+ cute.arch.barrier()
932
751
  if const_expr(mdB is not None):
933
752
  sdB = cute.make_tensor(
934
753
  cute.recast_ptr(sX.iterator, dtype=cute.Float32),
@@ -948,12 +767,21 @@ class RMSNormBackward(ReductionBase):
948
767
  )
949
768
  cute.autovec_copy(tXsdB_other, tXrdB_other)
950
769
  tXrdB.store(tXrdB.load() + tXrdB_other.load())
951
- cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
770
+ copy(tXrdB, tXgdB)
952
771
  else:
953
772
  # dw is already in fp32, so we can directly copy to global memory
954
- cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
773
+ if const_expr(mdW is not None):
774
+ copy(tXrdW, tXgdW)
955
775
  if const_expr(mdB is not None):
956
- cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
776
+ copy(tXrdB, tXgdB)
777
+
778
+ if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
779
+ # Assume state contains that next useful buffer
780
+ # So we only need to advance to num_stages - 1 times to last used buffer
781
+ stage ^= 1
782
+ if stage == 0:
783
+ producer_phase ^= 1
784
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
957
785
 
958
786
 
959
787
  def _get_sm_count(N: int, device: torch.device) -> int:
@@ -978,120 +806,103 @@ def _get_sm_count(N: int, device: torch.device) -> int:
978
806
  mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
979
807
  device_types="cuda",
980
808
  # We need to specify the schema manually since we're mutating an optional tensor
981
- schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
809
+ schema="(Tensor x, Tensor? weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!)? dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual, int? sm_count) -> ()",
982
810
  )
983
811
  def _rmsnorm_bwd(
984
812
  x: Tensor,
985
- weight: Tensor,
813
+ weight: Optional[Tensor],
986
814
  dout: Tensor,
987
815
  rstd: Tensor,
988
816
  dx: Tensor,
989
- dw_partial: Tensor,
817
+ dw_partial: Optional[Tensor],
990
818
  db_partial: Optional[Tensor] = None,
991
819
  dresidual_out: Optional[Tensor] = None,
992
820
  dresidual: Optional[Tensor] = None,
821
+ sm_count: Optional[int] = None,
993
822
  ) -> None:
994
823
  """RMSNorm backward pass.
995
824
  Args:
996
825
  x: Input tensor of shape (M, N)
997
- weight: Weight tensor of shape (N,)
826
+ weight: Optional weight tensor of shape (N,)
998
827
  dout: Upstream gradients tensor of shape (M, N)
999
828
  rstd: Reciprocal standard deviation tensor of shape (M,)
1000
829
  Returns:
1001
830
  Tuple of (dx, dw) where:
1002
831
  - dx: Input gradients tensor of same shape as x
1003
- - dw: Weight gradients tensor of same shape as weight
832
+ - dw: Weight gradients tensor of same shape as weight (or None if weight is None)
1004
833
  """
1005
834
  assert x.dim() == 2, "Input must be 2D"
1006
- assert weight.dim() == 1, "Weight must be 1D"
1007
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
1008
- assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
1009
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
1010
- assert weight.dtype in [
1011
- torch.float32,
1012
- torch.bfloat16,
1013
- torch.float16,
1014
- ], "Weight must be float32, float16 or bfloat16"
835
+ assert x.is_cuda, "Input tensor must be on CUDA device"
836
+ supported_types = {torch.float16, torch.bfloat16, torch.float32}
837
+ assert x.dtype in supported_types, "Unsupported dtype"
838
+ if weight is not None:
839
+ assert weight.dim() == 1, "Weight must be 1D"
840
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
841
+ assert weight.is_cuda, "Weight tensor must be on CUDA device"
842
+ assert weight.dtype in supported_types, "Weight must be float32, float16 or bfloat16"
1015
843
  if dresidual_out is not None:
1016
844
  assert dresidual_out.shape == x.shape
1017
845
  assert dresidual_out.is_cuda
1018
- assert dresidual_out.dtype in [
1019
- torch.float16,
1020
- torch.bfloat16,
1021
- torch.float32,
1022
- ], "Residual must be float16, bfloat16, or float32"
846
+ assert dresidual_out.dtype in supported_types, (
847
+ "Residual must be float16, bfloat16, or float32"
848
+ )
1023
849
  if dresidual is not None:
1024
850
  assert dresidual.shape == x.shape
1025
851
  assert dresidual.is_cuda
1026
- assert dresidual.dtype in [
1027
- torch.float16,
1028
- torch.bfloat16,
1029
- torch.float32,
1030
- ], "Residual must be float16, bfloat16, or float32"
852
+ assert dresidual.dtype in supported_types, "Residual must be float16, bfloat16, or float32"
1031
853
 
1032
854
  N = x.size(1)
1033
- device = x.device
1034
- sm_count = dw_partial.shape[0]
1035
- convert_from_dlpack = lambda x: (
1036
- from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
1037
- )
1038
- x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
1039
- convert_from_dlpack(t) if t is not None else None
1040
- for t in (x, dout, dresidual_out, dx, dresidual)
855
+ if dw_partial is None and db_partial is None:
856
+ assert sm_count is not None
857
+ else:
858
+ sm_count = dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0]
859
+ dtype, dout_dtype, dx_dtype, weight_dtype, dres_dtype, dres_out_dtype = [
860
+ torch2cute_dtype_map[t.dtype] if t is not None else None
861
+ for t in [x, dout, dx, weight, dresidual, dresidual_out]
1041
862
  ]
1042
- # Handle weight div based on weight dtype
1043
- weight_dtype = torch2cute_dtype_map[weight.dtype]
1044
- weight_tensor = utils.convert_from_dlpack(
1045
- weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
1046
- )
1047
-
1048
- dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1049
- db_partial_tensor = (
1050
- from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1051
- if db_partial is not None
1052
- else None
1053
- )
1054
- rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
1055
-
1056
- current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1057
-
1058
863
  compile_key = (
1059
864
  N,
1060
- x_tensor.element_type,
1061
- weight_tensor.element_type,
1062
- db_partial.dtype if db_partial is not None else None,
1063
- dresidual.dtype if dresidual is not None else None,
1064
- dresidual_out.dtype if dresidual_out is not None else None,
865
+ dtype,
866
+ dout_dtype,
867
+ dx_dtype,
868
+ weight_dtype,
869
+ db_partial is not None,
870
+ dres_dtype,
871
+ dres_out_dtype,
1065
872
  )
1066
873
  if compile_key not in _rmsnorm_bwd.compile_cache:
1067
- rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
874
+ batch_sym, batch_partial_sym = cute.sym_int(), cute.sym_int()
875
+ all_dtypes = [dtype, dout_dtype, dx_dtype, dres_dtype, dres_out_dtype]
876
+ div = math.gcd(N, *(128 // dt.width for dt in all_dtypes if dt is not None))
877
+ x_cute, dout_cute, dx_cute, dres_out_cute, dres_cute = [
878
+ fake_tensor(dt, (batch_sym, N), div)
879
+ for dt in [dtype, dout_dtype, dx_dtype, dres_out_dtype, dres_dtype]
880
+ ]
881
+ weight_cute = fake_tensor(weight_dtype, (N,), div)
882
+ rstd_cute = fake_tensor(Float32, (batch_sym,))
883
+ dw_partial_cute = (
884
+ fake_tensor(Float32, (batch_partial_sym, N), div) if dw_partial is not None else None
885
+ )
886
+ db_partial_cute = (
887
+ fake_tensor(Float32, (batch_partial_sym, N), div) if db_partial is not None else None
888
+ )
1068
889
  _rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
1069
- rmsnorm_backward_op,
1070
- x_tensor,
1071
- weight_tensor,
1072
- dout_tensor,
1073
- dres_out_tensor,
1074
- rstd_tensor,
1075
- dx_tensor,
1076
- dw_partial_tensor,
1077
- dres_tensor,
1078
- db_partial_tensor,
890
+ RMSNormBackward(dtype, N),
891
+ x_cute,
892
+ weight_cute,
893
+ dout_cute,
894
+ dres_out_cute,
895
+ rstd_cute,
896
+ dx_cute,
897
+ dw_partial_cute,
898
+ dres_cute,
899
+ db_partial_cute,
1079
900
  sm_count,
1080
- current_stream,
901
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
902
+ options="--enable-tvm-ffi",
1081
903
  )
1082
-
1083
904
  _rmsnorm_bwd.compile_cache[compile_key](
1084
- x_tensor,
1085
- weight_tensor,
1086
- dout_tensor,
1087
- dres_out_tensor,
1088
- rstd_tensor,
1089
- dx_tensor,
1090
- dw_partial_tensor,
1091
- dres_tensor,
1092
- db_partial_tensor,
1093
- sm_count,
1094
- current_stream,
905
+ x, weight, dout, dresidual_out, rstd, dx, dw_partial, dresidual, db_partial, sm_count
1095
906
  )
1096
907
 
1097
908
 
@@ -1100,30 +911,37 @@ _rmsnorm_bwd.compile_cache = {}
1100
911
 
1101
912
  def rmsnorm_bwd(
1102
913
  x: Tensor,
1103
- weight: Tensor,
914
+ weight: Optional[Tensor],
1104
915
  dout: Tensor,
1105
916
  rstd: Tensor,
1106
917
  dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
1107
918
  has_bias: bool = False,
1108
- ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
919
+ has_residual: bool = False,
920
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
1109
921
  device = x.device
1110
922
  N = x.size(1)
1111
- sm_count = _get_sm_count(N, device)
1112
923
  dx = torch.empty_like(x)
1113
-
1114
924
  if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
1115
925
  dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
1116
926
  else:
1117
927
  dresidual = None
1118
- # Always store partial gradients in fp32 for numerical accuracy
1119
- dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
928
+ sm_count = _get_sm_count(N, device)
929
+ if weight is not None:
930
+ # Always store partial gradients in fp32 for numerical accuracy
931
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
932
+ else:
933
+ dw_partial = None
1120
934
  db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
1121
- _rmsnorm_bwd(x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual)
935
+
936
+ _rmsnorm_bwd(
937
+ x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual, sm_count
938
+ )
939
+
1122
940
  # we have summed the partial gradients in fp32, now we convert back to the weight dtype
1123
- dw = dw_partial.sum(dim=0).to(weight.dtype)
941
+ dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None
1124
942
  db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
1125
943
  # dresidual is the same as dx in this case
1126
- if dresidual_out is not None and dresidual_out.dtype == dx.dtype:
944
+ if has_residual and dresidual is None:
1127
945
  dresidual = dx
1128
946
  return dx, dw, db, dresidual
1129
947
 
@@ -1180,11 +998,16 @@ class RMSNormFunction(torch.autograd.Function):
1180
998
  x_shape_og = ctx.x_shape_og
1181
999
  # Reshape dout to match the flattened shape used in forward
1182
1000
  dout = dout.view(-1, dout.shape[-1])
1183
-
1184
- dx, dw, db, dresidual = rmsnorm_bwd(x, weight, dout, rstd, dresidual_out, has_bias)
1001
+ dx, dw, db, dresidual = rmsnorm_bwd(
1002
+ x,
1003
+ weight,
1004
+ dout,
1005
+ rstd,
1006
+ dresidual_out,
1007
+ has_bias,
1008
+ has_residual=ctx.residual_dtype is not None,
1009
+ )
1185
1010
  dx = dx.view(x_shape_og)
1186
- if dresidual_out is not None:
1187
- dresidual_out = dresidual_out.reshape(x_shape_og)
1188
1011
  if dresidual is not None:
1189
1012
  dresidual = dresidual.reshape(x_shape_og)
1190
1013
 
@@ -1193,7 +1016,7 @@ class RMSNormFunction(torch.autograd.Function):
1193
1016
 
1194
1017
  def rmsnorm(
1195
1018
  x: Tensor,
1196
- weight: Tensor,
1019
+ weight: Optional[Tensor] = None,
1197
1020
  bias: Optional[Tensor] = None,
1198
1021
  residual: Optional[Tensor] = None,
1199
1022
  out_dtype: Optional[torch.dtype] = None,
@@ -1205,7 +1028,7 @@ def rmsnorm(
1205
1028
 
1206
1029
  Args:
1207
1030
  x: Input tensor of shape (M, N)
1208
- weight: Weight tensor of shape (N,)
1031
+ weight: Optional weight tensor of shape (N,)
1209
1032
  eps: Small value for numerical stability
1210
1033
 
1211
1034
  Returns:
@@ -1214,7 +1037,7 @@ def rmsnorm(
1214
1037
  return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
1215
1038
 
1216
1039
 
1217
- class QuackRMSNorm(torch.nn.Module):
1040
+ class QuackRMSNorm(torch.nn.RMSNorm):
1218
1041
  """RMSNorm module that behaves like torch.nn.RMSNorm.
1219
1042
 
1220
1043
  This class provides a drop-in replacement for torch.nn.RMSNorm that uses
@@ -1229,10 +1052,10 @@ class QuackRMSNorm(torch.nn.Module):
1229
1052
  eps (float): A small constant for numerical stability
1230
1053
  """
1231
1054
 
1232
- def __init__(self, dim: int, eps: float = 1e-6):
1233
- super().__init__()
1234
- self.weight = torch.nn.Parameter(torch.ones(dim))
1235
- self.eps = eps
1055
+ def __init__(
1056
+ self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True, device=None, dtype=None
1057
+ ):
1058
+ super().__init__(dim, eps, elementwise_affine, device=device, dtype=dtype)
1236
1059
 
1237
1060
  def forward(self, x: Tensor) -> Tensor:
1238
1061
  """Apply RMSNorm to the input tensor.
@@ -1245,6 +1068,67 @@ class QuackRMSNorm(torch.nn.Module):
1245
1068
  """
1246
1069
  return rmsnorm(x, self.weight, eps=self.eps)
1247
1070
 
1248
- def reset_parameters(self):
1249
- """Reset the weight parameter to ones."""
1250
- torch.nn.init.ones_(self.weight)
1071
+
1072
+ def layernorm_fwd(
1073
+ x: Tensor,
1074
+ weight: Tensor,
1075
+ bias: Optional[Tensor] = None,
1076
+ eps: float = 1e-6,
1077
+ return_rstd: bool = False,
1078
+ return_mean: bool = False,
1079
+ ):
1080
+ """LayerNorm forward pass using the unified RMSNorm/LayerNorm kernel.
1081
+
1082
+ Args:
1083
+ x: Input tensor of shape (M, N)
1084
+ weight: Weight tensor of shape (N,). Must be float32.
1085
+ bias: Optional bias tensor of shape (N,). Must be float32.
1086
+ eps: Small value for numerical stability
1087
+ return_rstd: Whether to return the reciprocal standard deviation
1088
+ return_mean: Whether to return the mean
1089
+
1090
+ Returns:
1091
+ Normalized output tensor of same shape as x
1092
+ If return_rstd is True, also returns rstd tensor of shape (M,)
1093
+ If return_mean is True, also returns mean tensor of shape (M,)
1094
+ """
1095
+ assert x.dim() == 2, "Input must be 2D"
1096
+ assert weight.dim() == 1, "Weight must be 1D"
1097
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
1098
+ assert weight.dtype == torch.float32, "Weight must be float32"
1099
+ if bias is not None:
1100
+ assert bias.dim() == 1, "Bias must be 1D"
1101
+ assert bias.dtype == torch.float32, "Bias must be float32"
1102
+
1103
+ M, N = x.shape
1104
+ device = x.device
1105
+ out = torch.empty_like(x)
1106
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
1107
+ mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
1108
+
1109
+ _rmsnorm_fwd(x, weight, out, bias, rstd, mean, None, None, eps, True)
1110
+
1111
+ if return_rstd and return_mean:
1112
+ return out, rstd, mean
1113
+ elif return_rstd:
1114
+ return out, rstd
1115
+ elif return_mean:
1116
+ return out, mean
1117
+ return out
1118
+
1119
+
1120
+ def layernorm_ref(x: Tensor, w: Tensor, eps: float = 1e-6) -> Tensor:
1121
+ """Reference implementation for LayerNorm."""
1122
+ x_f32 = x.float()
1123
+ return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
1124
+
1125
+
1126
+ def layernorm_rstd_ref(x: torch.Tensor, eps: float = 1e-6):
1127
+ x_f32 = x.float()
1128
+ mean = x_f32.mean(dim=-1, keepdim=True)
1129
+ var = ((x_f32 - mean) ** 2).mean(dim=-1)
1130
+ return 1.0 / torch.sqrt(var + eps)
1131
+
1132
+
1133
+ def layernorm_mean_ref(x: torch.Tensor) -> torch.Tensor:
1134
+ return x.float().mean(dim=-1)