quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -1,23 +1,29 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- from typing import Optional
3
+ from typing import Optional, Tuple
4
+ from functools import partial
4
5
 
5
6
  import cuda.bindings.driver as cuda
6
7
 
7
8
  import cutlass
8
9
  import cutlass.cute as cute
9
10
  from cutlass import Float32, Int32
11
+ from cutlass import const_expr
10
12
  from cutlass.cute.runtime import from_dlpack
11
13
 
12
- import quack.utils as utils
13
14
  import torch
14
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+ from torch import Tensor
16
+
17
+ import quack.utils as utils
18
+ from quack.reduce import row_reduce
19
+ from quack.reduction_base import ReductionBase
20
+ from quack.cute_dsl_utils import torch2cute_dtype_map
15
21
 
16
22
 
17
23
  class RMSNorm(ReductionBase):
18
24
  def __init__(self, dtype: cutlass.Numeric, N: int):
19
25
  super().__init__(dtype, N, stage=1)
20
- self.reload_from = None if N <= 16384 else "smem"
26
+ self.reload_from = None if N <= 8192 else "smem"
21
27
  self.delay_w_load = False
22
28
 
23
29
  def _calculate_threads_per_row(self):
@@ -45,7 +51,7 @@ class RMSNorm(ReductionBase):
45
51
 
46
52
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
47
53
  # Similarly cluster_n = 8 is faster for N=128k
48
- if cutlass.const_expr(self.dtype.width == 16):
54
+ if const_expr(self.dtype.width == 16):
49
55
  # 16-bit types (fp16, bf16)
50
56
  if N <= 16 * 1024:
51
57
  cluster_n = 1
@@ -72,12 +78,27 @@ class RMSNorm(ReductionBase):
72
78
 
73
79
  self.cluster_n = cluster_n
74
80
 
81
+ def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
82
+ return (
83
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
84
+ + (
85
+ cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
86
+ if dtype_res is not None
87
+ else 0
88
+ )
89
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
90
+ + self.stage * (cutlass.Int64.width // 8)
91
+ )
92
+
75
93
  @cute.jit
76
94
  def __call__(
77
95
  self,
78
96
  mX: cute.Tensor,
79
97
  mW: cute.Tensor,
98
+ mB: Optional[cute.Tensor],
99
+ mRes: Optional[cute.Tensor],
80
100
  mO: cute.Tensor,
101
+ mResO: Optional[cute.Tensor],
81
102
  mRstd: Optional[cute.Tensor],
82
103
  stream: cuda.CUstream,
83
104
  eps: Float32 = 1e-6,
@@ -87,28 +108,49 @@ class RMSNorm(ReductionBase):
87
108
  cute.assume(t.stride[0], divby=128 // t.element_type.width),
88
109
  t.stride[1],
89
110
  )
90
- mX, mO = [
111
+ mX, mRes, mO, mResO = [
91
112
  cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
92
- for t in (mX, mO)
113
+ if const_expr(t is not None)
114
+ else None
115
+ for t in (mX, mRes, mO, mResO)
93
116
  ]
94
117
  assert mX.element_type == self.dtype
95
118
  assert mO.element_type == self.dtype
96
119
  self._set_cluster_n()
97
- tiler_mn, tv_layout = self._get_tv_layout()
120
+ largest_dtype_width = const_expr(
121
+ max(
122
+ mX.element_type.width,
123
+ mRes.element_type.width if mRes is not None else 0,
124
+ mO.element_type.width,
125
+ mResO.element_type.width if mResO is not None else 0,
126
+ )
127
+ )
128
+ tiler_mn, tv_layout = self._get_tv_layout(
129
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
130
+ )
98
131
  num_threads = cute.size(tv_layout, mode=[0])
99
132
  num_warps = num_threads // cute.arch.WARP_SIZE
100
133
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
101
134
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
102
- if cutlass.const_expr(mRstd is not None):
135
+ if const_expr(mB is not None):
136
+ mB_expanded_layout = cute.prepend(
137
+ mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
138
+ )
139
+ mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
140
+ if const_expr(mRstd is not None):
103
141
  mRstd_expanded_layout = cute.append(
104
142
  mRstd.layout, cute.make_layout((self.N,), stride=(0,))
105
143
  )
106
144
  mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
107
- self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
145
+ self.kernel(
146
+ mX, mW, mB, mRes, mO, mResO, mRstd, eps, tv_layout, tiler_mn, self.reload_from
147
+ ).launch(
108
148
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
109
149
  block=[num_threads, 1, 1],
110
- cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
111
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
150
+ cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
151
+ smem=self._smem_size_in_bytes(
152
+ tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
153
+ ),
112
154
  stream=stream,
113
155
  )
114
156
 
@@ -117,7 +159,10 @@ class RMSNorm(ReductionBase):
117
159
  self,
118
160
  mX: cute.Tensor,
119
161
  mW: cute.Tensor,
162
+ mB: Optional[cute.Tensor],
163
+ mRes: Optional[cute.Tensor],
120
164
  mO: cute.Tensor,
165
+ mResO: Optional[cute.Tensor],
121
166
  mRstd: Optional[cute.Tensor],
122
167
  eps: cute.Float32,
123
168
  tv_layout: cute.Layout,
@@ -127,10 +172,10 @@ class RMSNorm(ReductionBase):
127
172
  ):
128
173
  tidx, _, _ = cute.arch.thread_idx()
129
174
  bidx, _, _ = cute.arch.block_idx()
130
- if cutlass.const_expr(self.cluster_n > 1):
175
+ if const_expr(self.cluster_n > 1):
131
176
  cluster_y = cute.arch.block_idx()[1]
132
177
  else:
133
- cluster_y = cutlass.const_expr(0)
178
+ cluster_y = const_expr(0)
134
179
 
135
180
  smem = cutlass.utils.SmemAllocator()
136
181
  sX = smem.allocate_tensor(
@@ -138,86 +183,147 @@ class RMSNorm(ReductionBase):
138
183
  cute.make_ordered_layout(tiler_mn, order=(1, 0)),
139
184
  byte_alignment=16,
140
185
  )
186
+ if const_expr(mRes is not None):
187
+ sRes = smem.allocate_tensor(
188
+ mRes.element_type,
189
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
190
+ byte_alignment=16,
191
+ )
141
192
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
142
193
 
143
194
  shape = mX.shape
144
195
  idX = cute.make_identity_tensor(shape)
145
196
  # slice for CTAs
146
197
  # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
147
- mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
148
- gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
198
+ mX, mRes, mO, mResO = [
199
+ utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) if mT is not None else None
200
+ for mT in (mX, mRes, mO, mResO)
201
+ ]
202
+ gX, gRes, gO, gResO = [
203
+ cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
204
+ for mT in (mX, mRes, mO, mResO)
205
+ ]
149
206
  cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
150
207
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
208
+ gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
151
209
  gRstd = (
152
210
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
153
- if cutlass.const_expr(mRstd is not None)
211
+ if const_expr(mRstd is not None)
154
212
  else None
155
213
  )
156
214
 
157
215
  # declare the atoms which will be used later for memory copy
216
+ num_copy_elems_X = tv_layout.shape[1][0]
217
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
158
218
  copy_atom_load_X = cute.make_copy_atom(
159
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
219
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
160
220
  )
161
221
  copy_atom_load_X_async = cute.make_copy_atom(
162
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
163
- )
164
- num_bits_per_copy_W = cutlass.const_expr(
165
- min(128, 128 // mX.element_type.width * mW.element_type.width)
222
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
166
223
  )
224
+ num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
167
225
  copy_atom_load_W = cute.make_copy_atom(
168
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
226
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
169
227
  )
170
- num_bits_per_copy_O = cutlass.const_expr(
171
- min(128, 128 // mX.element_type.width * mO.element_type.width)
228
+ num_bits_per_copy_B = (
229
+ cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
230
+ if const_expr(mB is not None)
231
+ else 0
172
232
  )
233
+ copy_atom_load_B = (
234
+ cute.make_copy_atom(
235
+ cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
236
+ )
237
+ if const_expr(mB is not None)
238
+ else None
239
+ )
240
+ if const_expr(mRes is not None):
241
+ num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
242
+ copy_atom_load_Res_async = cute.make_copy_atom(
243
+ cute.nvgpu.cpasync.CopyG2SOp(),
244
+ mRes.element_type,
245
+ num_bits_per_copy=num_copy_bits_Res,
246
+ )
247
+ num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
173
248
  copy_atom_store_O = cute.make_copy_atom(
174
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_bits_per_copy_O
249
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
175
250
  )
251
+ if const_expr(mResO is not None):
252
+ num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
253
+ copy_atom_store_ResO = cute.make_copy_atom(
254
+ cute.nvgpu.CopyUniversalOp(),
255
+ mResO.element_type,
256
+ num_bits_per_copy=num_copy_bits_ResO,
257
+ )
176
258
 
177
259
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
178
260
  tidx
179
261
  )
180
262
 
181
263
  tXgW = thr_copy_X.partition_S(gW)
264
+ tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
182
265
  tXgX = thr_copy_X.partition_S(gX)
183
266
  tXsX = thr_copy_X.partition_D(sX)
267
+ if const_expr(mRes is not None):
268
+ tXgRes = thr_copy_X.partition_S(gRes)
269
+ tXsRes = thr_copy_X.partition_D(sRes)
184
270
  tXgO = thr_copy_X.partition_D(gO)
185
- tXrRstd = thr_copy_X.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
271
+ if const_expr(mResO is not None):
272
+ tXgResO = thr_copy_X.partition_D(gResO)
273
+ tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
186
274
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
187
275
 
188
276
  # allocate fragments for gmem->rmem
189
277
  tXrW = cute.make_fragment_like(tXgW)
190
278
  tXrW.fill(0.0)
191
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
279
+ tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
280
+ tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
281
+ if const_expr(mRes is not None):
282
+ tXrRes = cute.make_fragment_like(tXgRes)
192
283
 
193
284
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
194
285
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
195
286
 
196
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
287
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
288
+ tXpX = (
289
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
290
+ )
197
291
  row = tXcX[0][0]
198
292
  if row < shape[0]:
199
293
  cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
294
+ if const_expr(mRes is not None):
295
+ cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
200
296
  cute.arch.cp_async_commit_group()
201
297
 
202
- tXpW = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
203
- if cutlass.const_expr(not delay_w_load):
204
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
298
+ if const_expr(not delay_w_load):
299
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
300
+ if const_expr(mB is not None):
301
+ cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
205
302
 
206
303
  cute.arch.cp_async_wait_group(0)
207
304
  cute.autovec_copy(tXsX, tXrX)
208
305
  x = tXrX.load().to(cute.Float32)
306
+ if const_expr(mRes is not None):
307
+ cute.autovec_copy(tXsRes, tXrRes)
308
+ x += tXrRes.load().to(cute.Float32)
309
+ if const_expr(mResO is not None):
310
+ tXrResO = cute.make_fragment_like(tXgResO)
311
+ tXrResO.store(x.to(tXrResO.element_type))
312
+ if row < shape[0]:
313
+ cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
314
+
209
315
  threads_per_row = tv_layout.shape[0][0]
210
- sum_sq_x = utils.row_reduce(
316
+ sum_sq_x = row_reduce(
211
317
  x * x,
212
318
  cute.ReductionOp.ADD,
213
319
  threads_per_row,
214
320
  reduction_buffer[None, None, 0],
215
321
  mbar_ptr,
216
322
  init_val=0.0,
217
- hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
323
+ hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
218
324
  )
219
- rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
220
- if cutlass.const_expr(mRstd is not None):
325
+ rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
326
+ if const_expr(mRstd is not None):
221
327
  # Only the thread corresponding to column 0 writes out the rstd to gmem
222
328
  if (
223
329
  tXcX[0][1] == 0
@@ -225,59 +331,76 @@ class RMSNorm(ReductionBase):
225
331
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
226
332
  ):
227
333
  tXrRstd[0] = rstd
228
- if cutlass.const_expr(delay_w_load):
229
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
230
- if cutlass.const_expr(reload_from == "smem"):
231
- cute.autovec_copy(tXsX, tXrX)
232
- x = tXrX.load().to(cute.Float32)
233
- elif cutlass.const_expr(reload_from == "gmem"):
234
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
334
+ if const_expr(delay_w_load):
335
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
336
+ if const_expr(mB is not None):
337
+ cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
338
+ if const_expr(reload_from == "smem" or reload_from == "gmem"):
339
+ if const_expr(reload_from == "smem"):
340
+ cute.autovec_copy(tXsX, tXrX)
341
+ else:
342
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
235
343
  x = tXrX.load().to(cute.Float32)
344
+ if const_expr(mRes is not None):
345
+ cute.autovec_copy(tXsRes, tXrRes)
346
+ x += tXrRes.load().to(cute.Float32)
236
347
  x_hat = x * rstd
237
348
  w = tXrW.load().to(cute.Float32)
238
349
  y = x_hat * w
350
+ if const_expr(mB is not None):
351
+ b = tXrB.load().to(cute.Float32)
352
+ y = y + b
239
353
  tXrO.store(y.to(tXrO.element_type))
240
- tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
241
354
  if row < shape[0]:
242
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpO)
355
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
243
356
 
244
357
 
358
+ @torch.library.custom_op(
359
+ "quack::_rmsnorm_fwd",
360
+ mutates_args=("out", "rstd", "residual_out"),
361
+ device_types="cuda",
362
+ # We need to specify the schema manually since we're mutating an optional tensor
363
+ schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
364
+ )
245
365
  def _rmsnorm_fwd(
246
- x: torch.Tensor,
247
- weight: torch.Tensor,
366
+ x: Tensor,
367
+ weight: Tensor,
368
+ out: Tensor,
369
+ bias: Optional[Tensor] = None,
370
+ rstd: Optional[Tensor] = None,
371
+ residual: Optional[Tensor] = None,
372
+ residual_out: Optional[Tensor] = None,
248
373
  eps: float = 1e-6,
249
- return_rstd: bool = False,
250
- ) -> torch.Tensor:
374
+ ) -> None:
251
375
  """RMSNorm forward pass.
252
376
  Args:
253
377
  x: Input tensor of shape (M, N)
254
378
  weight: Weight tensor of shape (N,)
255
379
  eps: Small value for numerical stability
256
- return_rstd: Whether to return the reciprocal standard deviation
257
380
  Returns:
258
381
  Normalized output tensor of same shape as x
259
- If return_rstd is True, also returns rstd tensor of shape (M,)
260
382
  """
261
383
  assert x.dim() == 2, "Input must be 2D"
262
384
  assert weight.dim() == 1, "Weight must be 1D"
263
385
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
264
386
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
265
- assert x.dtype in [
266
- torch.float16,
267
- torch.bfloat16,
268
- torch.float32,
269
- ], "Unsupported dtype"
270
-
387
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
271
388
  assert weight.dtype in [
272
389
  torch.float32,
273
390
  torch.bfloat16,
274
391
  torch.float16,
275
392
  ], "Weight must be float32, float16 or bfloat16"
276
-
277
- M, N = x.shape
393
+ if residual is not None:
394
+ assert residual.shape == x.shape
395
+ assert residual.is_cuda
396
+ assert residual.dtype in [
397
+ torch.float16,
398
+ torch.bfloat16,
399
+ torch.float32,
400
+ ], "Residual must be float16, bfloat16, or float32"
401
+
402
+ _, N = x.shape
278
403
  device = x.device
279
- out = torch.empty_like(x)
280
- rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
281
404
  dtype = torch2cute_dtype_map[x.dtype]
282
405
  # convert_from_dlpack = lambda x: (
283
406
  # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -287,43 +410,109 @@ def _rmsnorm_fwd(
287
410
  convert_from_dlpack = lambda x: (
288
411
  from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
289
412
  )
290
- x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
413
+ x_tensor, res_tensor, out_tensor, res_out_tensor = [
414
+ convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
415
+ ]
291
416
  # handle weight divisibility based on weight dtype
292
417
  weight_dtype = torch2cute_dtype_map[weight.dtype]
293
418
  weight_tensor = utils.convert_from_dlpack(
294
419
  weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
295
420
  )
421
+ if bias is not None:
422
+ bias_dtype = torch2cute_dtype_map[bias.dtype]
423
+ bias_tensor = utils.convert_from_dlpack(
424
+ bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
425
+ )
426
+ else:
427
+ bias_tensor = None
296
428
  rstd_tensor = (
297
429
  from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
298
430
  if rstd is not None
299
431
  else None
300
432
  )
301
433
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
302
- compile_key = (dtype, N, rstd is not None, weight.dtype)
434
+ compile_key = (
435
+ N,
436
+ dtype,
437
+ res_tensor.element_type if residual is not None else None,
438
+ weight_tensor.element_type,
439
+ bias_tensor.element_type if bias is not None else None,
440
+ res_out_tensor.element_type if residual_out is not None else None,
441
+ rstd is not None,
442
+ )
303
443
  if compile_key not in _rmsnorm_fwd.compile_cache:
304
444
  rmsnorm_op = RMSNorm(dtype, N)
305
445
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
306
- rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
446
+ rmsnorm_op,
447
+ x_tensor,
448
+ weight_tensor,
449
+ bias_tensor,
450
+ res_tensor,
451
+ out_tensor,
452
+ res_out_tensor,
453
+ rstd_tensor,
454
+ current_stream,
455
+ eps,
307
456
  )
308
457
  _rmsnorm_fwd.compile_cache[compile_key](
309
- x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
458
+ x_tensor,
459
+ weight_tensor,
460
+ bias_tensor,
461
+ res_tensor,
462
+ out_tensor,
463
+ res_out_tensor,
464
+ rstd_tensor,
465
+ current_stream,
466
+ eps,
310
467
  )
311
- return (out, rstd) if return_rstd else out
312
468
 
313
469
 
314
470
  _rmsnorm_fwd.compile_cache = {}
315
471
 
316
472
 
317
- def rmsnorm_ref(x, w, eps=1e-6):
318
- x_f32 = x.float()
319
- return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
320
- x.dtype
321
- )
473
+ def rmsnorm_fwd(
474
+ x: Tensor,
475
+ weight: Tensor,
476
+ bias: Optional[Tensor] = None,
477
+ residual: Optional[Tensor] = None,
478
+ out_dtype: Optional[torch.dtype] = None,
479
+ residual_dtype: Optional[torch.dtype] = None,
480
+ eps: float = 1e-6,
481
+ store_rstd: bool = False,
482
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
483
+ # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
484
+ # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
485
+ # so that _layer_norm_fwd_impl doesn't have to return them.
486
+ out_dtype = x.dtype if out_dtype is None else out_dtype
487
+ out = torch.empty_like(x, dtype=out_dtype)
488
+ rstd = torch.empty(x.shape[0], device=x.device, dtype=torch.float32) if store_rstd else None
489
+ if residual is not None:
490
+ residual_dtype = residual.dtype
491
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
492
+ residual_out = torch.empty_like(
493
+ x, dtype=residual_dtype if residual_dtype is not None else x.dtype
494
+ )
495
+ else:
496
+ residual_out = None
497
+ _rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps=eps)
498
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
499
+ if residual_out is None:
500
+ residual_out = x
501
+ return out, residual_out, rstd
322
502
 
323
503
 
324
- def rstd_ref(x, eps=1e-6):
504
+ def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
325
505
  x_f32 = x.float()
326
- return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
506
+ if residual is not None:
507
+ residual_f32 = residual.float()
508
+ x_f32 += residual_f32
509
+ out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
510
+ if bias is not None:
511
+ out = out + bias.float()
512
+ if residual is None:
513
+ return out.to(x.dtype)
514
+ else:
515
+ return out.to(x.dtype), x_f32.to(residual.dtype)
327
516
 
328
517
 
329
518
  def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
@@ -372,11 +561,13 @@ class RMSNormBackward(ReductionBase):
372
561
  )
373
562
  self.cluster_n = cluster_n
374
563
 
375
- def _smem_size_in_bytes(self, tiler_mn, num_warps):
564
+ def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
565
+ if do_dtype is None:
566
+ do_dtype = self.dtype
376
567
  return (
377
- # Multiply by 2 since we need space for X and dOut,
378
- # and multiply by another 2 due to double buffering
379
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
568
+ # We need space for X and dO, and multiply by 2 due to double buffering
569
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
570
+ + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
380
571
  + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
381
572
  + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
382
573
  )
@@ -386,10 +577,13 @@ class RMSNormBackward(ReductionBase):
386
577
  self,
387
578
  mX: cute.Tensor,
388
579
  mW: cute.Tensor,
389
- mdOut: cute.Tensor,
580
+ mdO: cute.Tensor,
581
+ mdResO: Optional[cute.Tensor],
390
582
  mRstd: cute.Tensor,
391
583
  mdX: cute.Tensor,
392
584
  mdW: cute.Tensor,
585
+ mdRes: Optional[cute.Tensor],
586
+ mdB: Optional[cute.Tensor],
393
587
  sm_count: Int32,
394
588
  stream: cuda.CUstream,
395
589
  ):
@@ -398,24 +592,36 @@ class RMSNormBackward(ReductionBase):
398
592
  cute.assume(t.stride[0], divby=128 // t.element_type.width),
399
593
  t.stride[1],
400
594
  )
401
- mX, mdOut, mdX = [
595
+ mX, mdO, mdResO, mdX, mdRes = [
402
596
  cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
403
- for t in (mX, mdOut, mdX)
597
+ if const_expr(t is not None)
598
+ else None
599
+ for t in (mX, mdO, mdResO, mdX, mdRes)
404
600
  ]
405
601
  self._set_cluster_n()
406
- tiler_mn, tv_layout = self._get_tv_layout()
602
+ largest_dtype_width = const_expr(
603
+ max(
604
+ mX.element_type.width,
605
+ mdO.element_type.width,
606
+ mdX.element_type.width,
607
+ mdResO.element_type.width if mdResO is not None else 0,
608
+ mdRes.element_type.width if mdRes is not None else 0,
609
+ )
610
+ )
611
+ tiler_mn, tv_layout = self._get_tv_layout(
612
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
613
+ )
407
614
  num_threads = cute.size(tv_layout, mode=[0])
408
615
  num_warps = num_threads // cute.arch.WARP_SIZE
409
-
410
616
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
411
617
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
412
618
 
413
619
  num_blocks = sm_count
414
- self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
620
+ self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
415
621
  grid=[num_blocks, self.cluster_n, 1],
416
622
  block=[num_threads, 1, 1],
417
623
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
418
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
624
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
419
625
  stream=stream,
420
626
  )
421
627
 
@@ -424,63 +630,85 @@ class RMSNormBackward(ReductionBase):
424
630
  self,
425
631
  mX: cute.Tensor,
426
632
  mW: cute.Tensor,
427
- mdOut: cute.Tensor,
633
+ mdO: cute.Tensor,
634
+ mdResO: Optional[cute.Tensor],
428
635
  mRstd: cute.Tensor,
429
636
  mdX: cute.Tensor,
430
637
  mdW: cute.Tensor,
638
+ mdB: Optional[cute.Tensor],
639
+ mdRes: Optional[cute.Tensor],
431
640
  tv_layout: cute.Layout,
432
641
  tiler_mn: cute.Shape,
433
642
  ):
434
643
  tidx, _, _ = cute.arch.thread_idx()
435
644
  bidx_start, _, _ = cute.arch.block_idx()
436
645
  gdim, _, _ = cute.arch.grid_dim()
437
- if cutlass.const_expr(self.cluster_n > 1):
646
+ if const_expr(self.cluster_n > 1):
438
647
  cluster_y = cute.arch.block_idx()[1]
439
648
  else:
440
- cluster_y = cutlass.const_expr(0)
649
+ cluster_y = const_expr(0)
441
650
 
442
651
  shape = mX.shape
443
652
  M, N = shape[0], shape[1]
444
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
653
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
445
654
 
446
655
  idX = cute.make_identity_tensor(shape)
447
656
 
448
657
  smem = cutlass.utils.SmemAllocator()
449
658
  smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
450
659
  sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
451
- sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
660
+ sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
452
661
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
453
662
  smem, tv_layout, is_persistent=True
454
663
  )
455
- if cutlass.const_expr(mbar_ptr is not None):
664
+ if const_expr(mbar_ptr is not None):
456
665
  mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
457
666
  else:
458
667
  mbar_full_ptr, mbar_empty_ptr = None, None
459
668
 
669
+ num_copy_elems_X = tv_layout.shape[1][0]
670
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
460
671
  copy_atom_load_X = cute.make_copy_atom(
461
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
672
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
462
673
  )
463
674
  copy_atom_load_X_async = cute.make_copy_atom(
464
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
675
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
465
676
  )
466
- num_bits_per_copy_W = cutlass.const_expr(
467
- min(128, 128 // mX.element_type.width * mW.element_type.width)
677
+ num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
678
+ copy_atom_load_dO_async = cute.make_copy_atom(
679
+ cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
468
680
  )
681
+ num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
469
682
  copy_atom_load_W = cute.make_copy_atom(
470
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
471
- )
472
- num_bits_per_copy_dX = cutlass.const_expr(
473
- min(128, 128 // mX.element_type.width * mdX.element_type.width)
683
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
474
684
  )
685
+ if const_expr(mdResO is not None):
686
+ num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
687
+ copy_atom_load_dResO = cute.make_copy_atom(
688
+ cute.nvgpu.CopyUniversalOp(),
689
+ mdResO.element_type,
690
+ num_bits_per_copy=num_copy_bits_dResO,
691
+ )
692
+ num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
475
693
  copy_atom_store_dX = cute.make_copy_atom(
476
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_bits_per_copy_dX
477
- )
478
- num_bits_per_copy_dW = cutlass.const_expr(
479
- min(128, 128 // mX.element_type.width * mdW.element_type.width)
694
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
480
695
  )
696
+ num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
481
697
  copy_atom_store_dW = cute.make_copy_atom(
482
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_bits_per_copy_dW
698
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
483
699
  )
700
+ if const_expr(mdB is not None):
701
+ num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
702
+ copy_atom_store_dB = cute.make_copy_atom(
703
+ cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
704
+ )
705
+ if const_expr(mdRes is not None):
706
+ num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
707
+ copy_atom_load_dRes = cute.make_copy_atom(
708
+ cute.nvgpu.CopyUniversalOp(),
709
+ mdRes.element_type,
710
+ num_bits_per_copy=num_copy_bits_dRes,
711
+ )
484
712
 
485
713
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
486
714
 
@@ -510,21 +738,40 @@ class RMSNormBackward(ReductionBase):
510
738
  if not is_even_N
511
739
  else None
512
740
  )
741
+ if const_expr(mdB is not None):
742
+ db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
743
+ tXpdB = (
744
+ utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
745
+ if not is_even_N
746
+ else None
747
+ )
513
748
 
514
749
  gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
515
750
  tXgdW = thr_copy_X.partition_S(gdW)
516
751
  # Always compute partial weight gradients in fp32
517
752
  tXrdW = cute.make_fragment_like(tXgdW, Float32)
518
753
 
519
- gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
520
- gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
521
- gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
522
- cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
754
+ gdB = (
755
+ cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
756
+ if const_expr(mdB is not None)
757
+ else None
758
+ )
759
+ tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
760
+ tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
761
+
762
+ gX, gdO, gdResO, gdX, gdRes, cX = [
763
+ cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
764
+ for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
765
+ ]
523
766
  tXgX = thr_copy_X.partition_S(gX)
524
767
  tXsX = thr_copy_X.partition_D(sX)
525
- tXgdOut = thr_copy_X.partition_S(gdOut)
526
- tXsdOut = thr_copy_X.partition_D(sdOut)
768
+ tXgdO = thr_copy_X.partition_S(gdO)
769
+ tXsdO = thr_copy_X.partition_D(sdO)
527
770
  tXgdX = thr_copy_X.partition_D(gdX)
771
+ if const_expr(mdResO is not None):
772
+ tXgdResO = thr_copy_X.partition_S(gdResO)
773
+ if const_expr(mdRes is not None):
774
+ tXgdRes = thr_copy_X.partition_D(gdRes)
528
775
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
529
776
  # This doesn't change across iterations
530
777
  tXpX = (
@@ -533,62 +780,50 @@ class RMSNormBackward(ReductionBase):
533
780
  else None
534
781
  )
535
782
 
536
- tXrX, tXrdOut, tXrdX = [
537
- cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
783
+ tXrX, tXrdO, tXrdX = [
784
+ cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
538
785
  ]
786
+ tXrdResO = None
787
+ if const_expr(mdResO is not None):
788
+ tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
789
+ tXrdRes = None
790
+ if const_expr(mdRes is not None):
791
+ tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
792
+
793
+ copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
794
+ copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
539
795
 
540
796
  # Prefetch the first batch
541
797
  row = tXcX[None, None, None, bidx_start][0][0]
542
798
  if row < M:
543
799
  tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
544
- tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
545
- cute.copy(
546
- copy_atom_load_X_async,
547
- tXgX_cur,
548
- tXsX[None, None, None, 0],
549
- pred=tXpX,
550
- )
551
- cute.copy(
552
- copy_atom_load_X_async,
553
- tXgdOut_cur,
554
- tXsdOut[None, None, None, 0],
555
- pred=tXpX,
556
- )
800
+ tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
801
+ copy_X(tXgX_cur, tXsX[None, None, None, 0])
802
+ copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
557
803
  elif tiler_mn[0] > 1:
558
804
  # Fill with zero, otherwise smem will be uninitialized, and we could read this back
559
805
  # later into registers, causing wrong dW.
560
806
  utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
561
- utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
807
+ utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
562
808
  cute.arch.cp_async_commit_group()
563
809
 
564
- if cutlass.const_expr(self.cluster_n > 1):
810
+ if const_expr(self.cluster_n > 1):
565
811
  cute.arch.cluster_wait()
566
812
 
567
813
  threads_per_row = tv_layout.shape[0][0]
568
814
  tXrdW.fill(0.0)
815
+ if const_expr(mdB is not None):
816
+ tXrdB.fill(0.0)
569
817
  stage = Int32(0)
570
818
  producer_phase = Int32(1)
571
819
  consumer_phase = Int32(0)
572
820
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
573
821
  row = tXcX[None, None, None, bidx][0][0]
574
- rstd = cutlass.Float.zero
575
822
  if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
576
823
  tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
577
- tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
578
- None, None, None, 0
579
- ]
580
- cute.copy(
581
- copy_atom_load_X_async,
582
- tXgX_cur,
583
- tXsX[None, None, None, stage ^ 1],
584
- pred=tXpX,
585
- )
586
- cute.copy(
587
- copy_atom_load_X_async,
588
- tXgdOut_cur,
589
- tXsdOut[None, None, None, stage ^ 1],
590
- pred=tXpX,
591
- )
824
+ tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
825
+ copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
826
+ copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
592
827
  elif tiler_mn[0] > 1:
593
828
  utils.fill_oob(
594
829
  tXsX[None, None, None, stage ^ 1],
@@ -596,36 +831,45 @@ class RMSNormBackward(ReductionBase):
596
831
  fill_value=mX.element_type.zero,
597
832
  )
598
833
  utils.fill_oob(
599
- tXsdOut[None, None, None, stage ^ 1],
834
+ tXsdO[None, None, None, stage ^ 1],
600
835
  None,
601
- fill_value=mdOut.element_type.zero,
836
+ fill_value=mdO.element_type.zero,
602
837
  )
603
838
  cute.arch.cp_async_commit_group()
839
+ rstd = cutlass.Float.zero
604
840
  if row < M or tiler_mn[0] == 1:
605
841
  rstd = mRstd[row]
842
+ if const_expr(mdResO is not None):
843
+ tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
844
+ if row < M or tiler_mn[0] == 1:
845
+ cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
846
+ elif tiler_mn[0] > 1:
847
+ tXrdResO.fill(0.0)
606
848
  cute.arch.cp_async_wait_group(1)
607
849
  cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
608
850
  x = tXrX.load().to(cute.Float32)
609
- cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
610
- dout = tXrdOut.load().to(cute.Float32)
851
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
852
+ dout = tXrdO.load().to(cute.Float32)
853
+ if const_expr(mdResO is not None):
854
+ dout += tXrdResO.load().to(cute.Float32)
611
855
  x_hat = x * rstd
612
856
  wdy = dout * weight
613
- if cutlass.const_expr(self.cluster_n > 1):
857
+ if const_expr(self.cluster_n > 1):
614
858
  cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
615
859
  mean_xhat_wdy = (
616
- utils.row_reduce(
860
+ row_reduce(
617
861
  x_hat * wdy,
618
862
  cute.ReductionOp.ADD,
619
863
  threads_per_row,
620
864
  reduction_buffer[None, None, stage],
621
- (mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
865
+ (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None),
622
866
  phase=consumer_phase,
623
867
  init_val=0.0,
624
868
  )
625
869
  / shape[1]
626
870
  )
627
871
 
628
- if cutlass.const_expr(self.cluster_n > 1):
872
+ if const_expr(self.cluster_n > 1):
629
873
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
630
874
  # Requires adjusting the thread_count when initializing the mbar
631
875
  cute.arch.sync_warp()
@@ -635,28 +879,37 @@ class RMSNormBackward(ReductionBase):
635
879
  mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
636
880
  )
637
881
 
638
- if cutlass.const_expr(self.reload_wdy == "smem"):
639
- cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
640
- dout = tXrdOut.load().to(cute.Float32)
882
+ if const_expr(self.reload_wdy == "smem"):
883
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
884
+ dout = tXrdO.load().to(cute.Float32)
885
+ if const_expr(mdResO is not None):
886
+ dout += tXrdResO.load().to(cute.Float32)
641
887
  wdy = dout * weight
642
888
 
643
889
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
644
- tXrdX.store(dx.to(tXrdOut.element_type))
890
+ tXrdX.store(dx.to(tXrdX.element_type))
645
891
  if row < M or tiler_mn[0] == 1:
646
892
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
647
893
  cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
894
+ if const_expr(mdRes is not None):
895
+ tXrdRes.store(dx.to(tXrdRes.element_type))
896
+ tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
897
+ if row < M or tiler_mn[0] == 1:
898
+ cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
648
899
  # Accumulate weight gradients in fp32
649
900
  tXrdW.store(tXrdW.load() + dout * x_hat)
901
+ if const_expr(mdB is not None):
902
+ tXrdB.store(tXrdB.load() + dout)
650
903
 
651
904
  stage ^= 1
652
905
  if stage == 0:
653
906
  consumer_phase ^= 1
654
907
  producer_phase ^= 1
655
908
 
656
- if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
909
+ if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
657
910
  cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
658
911
 
659
- if cutlass.const_expr(tiler_mn[0] > 1):
912
+ if const_expr(tiler_mn[0] > 1):
660
913
  # reduction of dw_partial within the same threadblock
661
914
  sdW = cute.make_tensor(
662
915
  cute.recast_ptr(sX.iterator, dtype=cute.Float32),
@@ -669,23 +922,75 @@ class RMSNormBackward(ReductionBase):
669
922
  cute.autovec_copy(tXrdW, tXsdW)
670
923
  cute.arch.barrier()
671
924
  if row == 0:
672
- for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
925
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
673
926
  tXrdW_other = cute.make_fragment_like(tXrdW)
674
927
  tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
675
928
  cute.autovec_copy(tXsdW_other, tXrdW_other)
676
929
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
677
930
  cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
931
+ cute.arch.barrier()
932
+ if const_expr(mdB is not None):
933
+ sdB = cute.make_tensor(
934
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
935
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
936
+ )
937
+ tXsdB = thr_copy_X.partition_D(sdB)
938
+ cute.arch.barrier()
939
+ row = tXcX[None, None, None, 0][0][0]
940
+ if row > 0:
941
+ cute.autovec_copy(tXrdB, tXsdB)
942
+ cute.arch.barrier()
943
+ if row == 0:
944
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
945
+ tXrdB_other = cute.make_fragment_like(tXrdB)
946
+ tXsdB_other = cute.make_tensor(
947
+ tXsdB.iterator + i * sdB.stride[0], tXsdB.layout
948
+ )
949
+ cute.autovec_copy(tXsdB_other, tXrdB_other)
950
+ tXrdB.store(tXrdB.load() + tXrdB_other.load())
951
+ cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
678
952
  else:
679
953
  # dw is already in fp32, so we can directly copy to global memory
680
954
  cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
955
+ if const_expr(mdB is not None):
956
+ cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
681
957
 
682
958
 
683
- def _rmsnorm_backward(
684
- x: torch.Tensor,
685
- weight: torch.Tensor,
686
- dout: torch.Tensor,
687
- rstd: torch.Tensor,
688
- ) -> (torch.Tensor, torch.Tensor):
959
+ def _get_sm_count(N: int, device: torch.device) -> int:
960
+ # This should be tuned on how many CTAs can be launched on each SM
961
+ sm_count_multiple = (
962
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
963
+ )
964
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count
965
+ # By right, if we're using cluster, this should be cluster_count not sm_count.
966
+ # But for cluster >= 4, due to quantization we would need to query active max cluster.
967
+ # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
968
+ # avoid wave quantization.
969
+ sm_count = (
970
+ sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
971
+ )
972
+
973
+ return sm_count
974
+
975
+
976
+ @torch.library.custom_op(
977
+ "quack::_rmsnorm_bwd",
978
+ mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
979
+ device_types="cuda",
980
+ # We need to specify the schema manually since we're mutating an optional tensor
981
+ schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
982
+ )
983
+ def _rmsnorm_bwd(
984
+ x: Tensor,
985
+ weight: Tensor,
986
+ dout: Tensor,
987
+ rstd: Tensor,
988
+ dx: Tensor,
989
+ dw_partial: Tensor,
990
+ db_partial: Optional[Tensor] = None,
991
+ dresidual_out: Optional[Tensor] = None,
992
+ dresidual: Optional[Tensor] = None,
993
+ ) -> None:
689
994
  """RMSNorm backward pass.
690
995
  Args:
691
996
  x: Input tensor of shape (M, N)
@@ -701,46 +1006,39 @@ def _rmsnorm_backward(
701
1006
  assert weight.dim() == 1, "Weight must be 1D"
702
1007
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
703
1008
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
704
- assert x.dtype in [
705
- torch.float16,
706
- torch.bfloat16,
707
- torch.float32,
708
- ], "Unsupported dtype"
709
-
1009
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
710
1010
  assert weight.dtype in [
711
1011
  torch.float32,
712
1012
  torch.bfloat16,
713
1013
  torch.float16,
714
1014
  ], "Weight must be float32, float16 or bfloat16"
715
-
716
- M, N = x.shape
717
- dx = torch.empty_like(x)
718
-
1015
+ if dresidual_out is not None:
1016
+ assert dresidual_out.shape == x.shape
1017
+ assert dresidual_out.is_cuda
1018
+ assert dresidual_out.dtype in [
1019
+ torch.float16,
1020
+ torch.bfloat16,
1021
+ torch.float32,
1022
+ ], "Residual must be float16, bfloat16, or float32"
1023
+ if dresidual is not None:
1024
+ assert dresidual.shape == x.shape
1025
+ assert dresidual.is_cuda
1026
+ assert dresidual.dtype in [
1027
+ torch.float16,
1028
+ torch.bfloat16,
1029
+ torch.float32,
1030
+ ], "Residual must be float16, bfloat16, or float32"
1031
+
1032
+ N = x.size(1)
719
1033
  device = x.device
720
-
721
- # This should be tuned on how many CTAs can be launched on each SM
722
- sm_count_multiple = (
723
- 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
724
- )
725
- sm_count = torch.cuda.get_device_properties(device).multi_processor_count
726
- # By right, if we're using cluster, this should be cluster_count not sm_count.
727
- # But for cluster >= 4, due to quantization we would need to query active max cluster.
728
- # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
729
- # avoid wave quantization.
730
- sm_count = (
731
- sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
732
- )
733
-
734
- # Always store partial gradients in fp32 for numerical accuracy
735
- dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
736
-
737
- dtype = torch2cute_dtype_map[x.dtype]
738
-
1034
+ sm_count = dw_partial.shape[0]
739
1035
  convert_from_dlpack = lambda x: (
740
1036
  from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
741
1037
  )
742
- x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
743
-
1038
+ x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
1039
+ convert_from_dlpack(t) if t is not None else None
1040
+ for t in (x, dout, dresidual_out, dx, dresidual)
1041
+ ]
744
1042
  # Handle weight div based on weight dtype
745
1043
  weight_dtype = torch2cute_dtype_map[weight.dtype]
746
1044
  weight_tensor = utils.convert_from_dlpack(
@@ -748,74 +1046,162 @@ def _rmsnorm_backward(
748
1046
  )
749
1047
 
750
1048
  dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1049
+ db_partial_tensor = (
1050
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1051
+ if db_partial is not None
1052
+ else None
1053
+ )
751
1054
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
752
1055
 
753
1056
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
754
1057
 
755
- compile_key = (dtype, N, weight.dtype)
756
- if compile_key not in _rmsnorm_backward.compile_cache:
757
- rmsnorm_backward_op = RMSNormBackward(dtype, N)
758
- _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
1058
+ compile_key = (
1059
+ N,
1060
+ x_tensor.element_type,
1061
+ weight_tensor.element_type,
1062
+ db_partial.dtype if db_partial is not None else None,
1063
+ dresidual.dtype if dresidual is not None else None,
1064
+ dresidual_out.dtype if dresidual_out is not None else None,
1065
+ )
1066
+ if compile_key not in _rmsnorm_bwd.compile_cache:
1067
+ rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
1068
+ _rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
759
1069
  rmsnorm_backward_op,
760
1070
  x_tensor,
761
1071
  weight_tensor,
762
1072
  dout_tensor,
1073
+ dres_out_tensor,
763
1074
  rstd_tensor,
764
1075
  dx_tensor,
765
1076
  dw_partial_tensor,
1077
+ dres_tensor,
1078
+ db_partial_tensor,
766
1079
  sm_count,
767
1080
  current_stream,
768
1081
  )
769
1082
 
770
- _rmsnorm_backward.compile_cache[compile_key](
1083
+ _rmsnorm_bwd.compile_cache[compile_key](
771
1084
  x_tensor,
772
1085
  weight_tensor,
773
1086
  dout_tensor,
1087
+ dres_out_tensor,
774
1088
  rstd_tensor,
775
1089
  dx_tensor,
776
1090
  dw_partial_tensor,
1091
+ dres_tensor,
1092
+ db_partial_tensor,
777
1093
  sm_count,
778
1094
  current_stream,
779
1095
  )
780
- # we have summed the partial gradients in fp32, now we convert back to the weight dtype
781
- dw = dw_partial.sum(dim=0).to(weight.dtype)
782
- return dx, dw
783
1096
 
784
1097
 
785
- _rmsnorm_backward.compile_cache = {}
1098
+ _rmsnorm_bwd.compile_cache = {}
1099
+
1100
+
1101
+ def rmsnorm_bwd(
1102
+ x: Tensor,
1103
+ weight: Tensor,
1104
+ dout: Tensor,
1105
+ rstd: Tensor,
1106
+ dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
1107
+ has_bias: bool = False,
1108
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
1109
+ device = x.device
1110
+ N = x.size(1)
1111
+ sm_count = _get_sm_count(N, device)
1112
+ dx = torch.empty_like(x)
1113
+
1114
+ if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
1115
+ dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
1116
+ else:
1117
+ dresidual = None
1118
+ # Always store partial gradients in fp32 for numerical accuracy
1119
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
1120
+ db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
1121
+ _rmsnorm_bwd(x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual)
1122
+ # we have summed the partial gradients in fp32, now we convert back to the weight dtype
1123
+ dw = dw_partial.sum(dim=0).to(weight.dtype)
1124
+ db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
1125
+ # dresidual is the same as dx in this case
1126
+ if dresidual_out is not None and dresidual_out.dtype == dx.dtype:
1127
+ dresidual = dx
1128
+ return dx, dw, db, dresidual
786
1129
 
787
1130
 
788
1131
  class RMSNormFunction(torch.autograd.Function):
789
1132
  @staticmethod
790
- def forward(ctx, x, weight, eps):
791
- x_shape_start = x.shape
792
-
1133
+ def forward(
1134
+ ctx,
1135
+ x,
1136
+ weight,
1137
+ bias=None,
1138
+ residual=None,
1139
+ out_dtype=None,
1140
+ residual_dtype=None,
1141
+ eps=1e-6,
1142
+ prenorm=False,
1143
+ ):
1144
+ x_shape_og = x.shape
793
1145
  # Flatten input
794
- x = x.view(-1, x.shape[-1])
795
-
796
- out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
797
- ctx.save_for_backward(x, weight, rstd)
1146
+ x = x.reshape(-1, x.shape[-1])
1147
+ if residual is not None:
1148
+ residual = residual.reshape(-1, residual.shape[-1])
1149
+ need_grad = any(ctx.needs_input_grad[:3])
1150
+ out, residual_out, rstd = rmsnorm_fwd(
1151
+ x,
1152
+ weight,
1153
+ bias=bias,
1154
+ residual=residual,
1155
+ out_dtype=out_dtype,
1156
+ residual_dtype=residual_dtype,
1157
+ eps=eps,
1158
+ store_rstd=need_grad,
1159
+ )
1160
+ ctx.save_for_backward(x if residual is None else residual_out, weight, rstd)
1161
+ ctx.has_bias = bias is not None
798
1162
  ctx.eps = eps
799
- ctx.x_shape_start = x_shape_start
800
-
801
- return out.reshape(x_shape_start)
1163
+ ctx.x_shape_og = x_shape_og
1164
+ ctx.residual_dtype = residual.dtype if residual is not None else None
1165
+ ctx.prenorm = prenorm
1166
+ if residual_out is None or not prenorm:
1167
+ return out.reshape(x_shape_og)
1168
+ else:
1169
+ return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
802
1170
 
803
1171
  @staticmethod
804
- def backward(ctx, dout):
1172
+ def backward(ctx, dout, *args):
805
1173
  x, weight, rstd = ctx.saved_tensors
806
- x_shape_start = ctx.x_shape_start
1174
+ has_bias = ctx.has_bias
1175
+ if ctx.prenorm and ctx.residual_dtype is not None:
1176
+ dresidual_out = args[0]
1177
+ dresidual_out = dresidual_out.reshape(-1, dresidual_out.shape[-1])
1178
+ else:
1179
+ dresidual_out = None
1180
+ x_shape_og = ctx.x_shape_og
807
1181
  # Reshape dout to match the flattened shape used in forward
808
1182
  dout = dout.view(-1, dout.shape[-1])
809
- dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
810
- dx = dx.view(x_shape_start)
811
- # dx is returned for input gradient,
812
- # dw is returned for weight gradient,
813
- # None for eps gradient
814
- return dx, dw, None
815
1183
 
1184
+ dx, dw, db, dresidual = rmsnorm_bwd(x, weight, dout, rstd, dresidual_out, has_bias)
1185
+ dx = dx.view(x_shape_og)
1186
+ if dresidual_out is not None:
1187
+ dresidual_out = dresidual_out.reshape(x_shape_og)
1188
+ if dresidual is not None:
1189
+ dresidual = dresidual.reshape(x_shape_og)
1190
+
1191
+ return dx, dw, db, dresidual, *([None] * 4)
816
1192
 
817
- def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
818
- """RMSNorm forward pass with automatic differentiation support.
1193
+
1194
+ def rmsnorm(
1195
+ x: Tensor,
1196
+ weight: Tensor,
1197
+ bias: Optional[Tensor] = None,
1198
+ residual: Optional[Tensor] = None,
1199
+ out_dtype: Optional[torch.dtype] = None,
1200
+ residual_dtype: Optional[torch.dtype] = None,
1201
+ eps: float = 1e-6,
1202
+ prenorm: bool = False,
1203
+ ) -> Tensor:
1204
+ """RMSNorm with automatic differentiation support.
819
1205
 
820
1206
  Args:
821
1207
  x: Input tensor of shape (M, N)
@@ -825,7 +1211,7 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
825
1211
  Returns:
826
1212
  Normalized output tensor of same shape as x
827
1213
  """
828
- return RMSNormFunction.apply(x, weight, eps)
1214
+ return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
829
1215
 
830
1216
 
831
1217
  class QuackRMSNorm(torch.nn.Module):
@@ -848,16 +1234,16 @@ class QuackRMSNorm(torch.nn.Module):
848
1234
  self.weight = torch.nn.Parameter(torch.ones(dim))
849
1235
  self.eps = eps
850
1236
 
851
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1237
+ def forward(self, x: Tensor) -> Tensor:
852
1238
  """Apply RMSNorm to the input tensor.
853
1239
 
854
1240
  Args:
855
- x (torch.Tensor): Input tensor
1241
+ x (Tensor): Input tensor
856
1242
 
857
1243
  Returns:
858
- torch.Tensor: Normalized tensor
1244
+ Tensor: Normalized tensor
859
1245
  """
860
- return rmsnorm(x, self.weight, self.eps)
1246
+ return rmsnorm(x, self.weight, eps=self.eps)
861
1247
 
862
1248
  def reset_parameters(self):
863
1249
  """Reset the weight parameter to ones."""