quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -1,23 +1,28 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- from typing import Optional
3
+ from typing import Optional, Tuple
4
+ from functools import partial
4
5
 
5
6
  import cuda.bindings.driver as cuda
6
7
 
7
8
  import cutlass
8
9
  import cutlass.cute as cute
9
10
  from cutlass import Float32, Int32
11
+ from cutlass import const_expr
10
12
  from cutlass.cute.runtime import from_dlpack
11
13
 
12
- import quack.utils as utils
13
14
  import torch
14
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+ from torch import Tensor
15
16
 
17
+ import quack.utils as utils
18
+ from quack.reduce import row_reduce
19
+ from quack.reduction_base import ReductionBase
20
+ from quack.cute_dsl_utils import torch2cute_dtype_map
16
21
 
17
22
  class RMSNorm(ReductionBase):
18
23
  def __init__(self, dtype: cutlass.Numeric, N: int):
19
24
  super().__init__(dtype, N, stage=1)
20
- self.reload_from = None if N <= 16384 else "smem"
25
+ self.reload_from = None if N <= 8192 else "smem"
21
26
  self.delay_w_load = False
22
27
 
23
28
  def _calculate_threads_per_row(self):
@@ -45,7 +50,7 @@ class RMSNorm(ReductionBase):
45
50
 
46
51
  # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
47
52
  # Similarly cluster_n = 8 is faster for N=128k
48
- if cutlass.const_expr(self.dtype.width == 16):
53
+ if const_expr(self.dtype.width == 16):
49
54
  # 16-bit types (fp16, bf16)
50
55
  if N <= 16 * 1024:
51
56
  cluster_n = 1
@@ -72,12 +77,27 @@ class RMSNorm(ReductionBase):
72
77
 
73
78
  self.cluster_n = cluster_n
74
79
 
80
+ def _smem_size_in_bytes(self, tiler_mn, num_warps, dtype_res=None):
81
+ return (
82
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
83
+ + (
84
+ cute.size_in_bytes(dtype_res, cute.make_layout(tiler_mn))
85
+ if dtype_res is not None
86
+ else 0
87
+ )
88
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
89
+ + self.stage * (cutlass.Int64.width // 8)
90
+ )
91
+
75
92
  @cute.jit
76
93
  def __call__(
77
94
  self,
78
95
  mX: cute.Tensor,
79
96
  mW: cute.Tensor,
97
+ mB: Optional[cute.Tensor],
98
+ mRes: Optional[cute.Tensor],
80
99
  mO: cute.Tensor,
100
+ mResO: Optional[cute.Tensor],
81
101
  mRstd: Optional[cute.Tensor],
82
102
  stream: cuda.CUstream,
83
103
  eps: Float32 = 1e-6,
@@ -87,28 +107,47 @@ class RMSNorm(ReductionBase):
87
107
  cute.assume(t.stride[0], divby=128 // t.element_type.width),
88
108
  t.stride[1],
89
109
  )
90
- mX, mO = [
110
+ mX, mRes, mO, mResO = [
91
111
  cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
92
- for t in (mX, mO)
112
+ if const_expr(t is not None)
113
+ else None
114
+ for t in (mX, mRes, mO, mResO)
93
115
  ]
94
116
  assert mX.element_type == self.dtype
95
117
  assert mO.element_type == self.dtype
96
118
  self._set_cluster_n()
97
- tiler_mn, tv_layout = self._get_tv_layout()
119
+ largest_dtype_width = const_expr(
120
+ max(
121
+ mX.element_type.width,
122
+ mRes.element_type.width if mRes is not None else 0,
123
+ mO.element_type.width,
124
+ mResO.element_type.width if mResO is not None else 0,
125
+ )
126
+ )
127
+ tiler_mn, tv_layout = self._get_tv_layout(
128
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
129
+ )
98
130
  num_threads = cute.size(tv_layout, mode=[0])
99
131
  num_warps = num_threads // cute.arch.WARP_SIZE
100
132
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
101
133
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
102
- if cutlass.const_expr(mRstd is not None):
134
+ if const_expr(mB is not None):
135
+ mB_expanded_layout = cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
136
+ mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
137
+ if const_expr(mRstd is not None):
103
138
  mRstd_expanded_layout = cute.append(
104
139
  mRstd.layout, cute.make_layout((self.N,), stride=(0,))
105
140
  )
106
141
  mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
107
- self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
142
+ self.kernel(
143
+ mX, mW, mB, mRes, mO, mResO, mRstd, eps, tv_layout, tiler_mn, self.reload_from
144
+ ).launch(
108
145
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
109
146
  block=[num_threads, 1, 1],
110
- cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
111
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
147
+ cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
148
+ smem=self._smem_size_in_bytes(
149
+ tiler_mn, num_warps, dtype_res=mRes.element_type if mRes is not None else None
150
+ ),
112
151
  stream=stream,
113
152
  )
114
153
 
@@ -117,7 +156,10 @@ class RMSNorm(ReductionBase):
117
156
  self,
118
157
  mX: cute.Tensor,
119
158
  mW: cute.Tensor,
159
+ mB: Optional[cute.Tensor],
160
+ mRes: Optional[cute.Tensor],
120
161
  mO: cute.Tensor,
162
+ mResO: Optional[cute.Tensor],
121
163
  mRstd: Optional[cute.Tensor],
122
164
  eps: cute.Float32,
123
165
  tv_layout: cute.Layout,
@@ -127,10 +169,10 @@ class RMSNorm(ReductionBase):
127
169
  ):
128
170
  tidx, _, _ = cute.arch.thread_idx()
129
171
  bidx, _, _ = cute.arch.block_idx()
130
- if cutlass.const_expr(self.cluster_n > 1):
172
+ if const_expr(self.cluster_n > 1):
131
173
  cluster_y = cute.arch.block_idx()[1]
132
174
  else:
133
- cluster_y = cutlass.const_expr(0)
175
+ cluster_y = const_expr(0)
134
176
 
135
177
  smem = cutlass.utils.SmemAllocator()
136
178
  sX = smem.allocate_tensor(
@@ -138,86 +180,145 @@ class RMSNorm(ReductionBase):
138
180
  cute.make_ordered_layout(tiler_mn, order=(1, 0)),
139
181
  byte_alignment=16,
140
182
  )
183
+ if const_expr(mRes is not None):
184
+ sRes = smem.allocate_tensor(
185
+ mRes.element_type,
186
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
187
+ byte_alignment=16,
188
+ )
141
189
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
142
190
 
143
191
  shape = mX.shape
144
192
  idX = cute.make_identity_tensor(shape)
145
193
  # slice for CTAs
146
194
  # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
147
- mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
148
- gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
195
+ mX, mRes, mO, mResO = [
196
+ utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) if mT is not None else None
197
+ for mT in (mX, mRes, mO, mResO)
198
+ ]
199
+ gX, gRes, gO, gResO = [
200
+ cute.local_tile(mT, tiler_mn, (0, cluster_y)) if mT is not None else None
201
+ for mT in (mX, mRes, mO, mResO)
202
+ ]
149
203
  cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
150
204
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
205
+ gB = (
206
+ cute.local_tile(mB, tiler_mn, (0, cluster_y))
207
+ if const_expr(mB is not None)
208
+ else None
209
+ )
151
210
  gRstd = (
152
211
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
153
- if cutlass.const_expr(mRstd is not None)
212
+ if const_expr(mRstd is not None)
154
213
  else None
155
214
  )
156
215
 
157
216
  # declare the atoms which will be used later for memory copy
217
+ num_copy_elems_X = tv_layout.shape[1][0]
218
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
158
219
  copy_atom_load_X = cute.make_copy_atom(
159
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
220
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
160
221
  )
161
222
  copy_atom_load_X_async = cute.make_copy_atom(
162
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
163
- )
164
- num_bits_per_copy_W = cutlass.const_expr(
165
- min(128, 128 // mX.element_type.width * mW.element_type.width)
223
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
166
224
  )
225
+ num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
167
226
  copy_atom_load_W = cute.make_copy_atom(
168
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
169
- )
170
- num_bits_per_copy_O = cutlass.const_expr(
171
- min(128, 128 // mX.element_type.width * mO.element_type.width)
227
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
172
228
  )
229
+ num_bits_per_copy_B = cutlass.const_expr(
230
+ min(128, num_copy_elems_X * mB.element_type.width)
231
+ ) if const_expr(mB is not None) else 0
232
+ copy_atom_load_B = cute.make_copy_atom(
233
+ cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
234
+ ) if const_expr(mB is not None) else None
235
+ if const_expr(mRes is not None):
236
+ num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
237
+ copy_atom_load_Res_async = cute.make_copy_atom(
238
+ cute.nvgpu.cpasync.CopyG2SOp(),
239
+ mRes.element_type,
240
+ num_bits_per_copy=num_copy_bits_Res,
241
+ )
242
+ num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
173
243
  copy_atom_store_O = cute.make_copy_atom(
174
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_bits_per_copy_O
244
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
175
245
  )
246
+ if const_expr(mResO is not None):
247
+ num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
248
+ copy_atom_store_ResO = cute.make_copy_atom(
249
+ cute.nvgpu.CopyUniversalOp(),
250
+ mResO.element_type,
251
+ num_bits_per_copy=num_copy_bits_ResO,
252
+ )
176
253
 
177
254
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
178
255
  tidx
179
256
  )
180
257
 
181
258
  tXgW = thr_copy_X.partition_S(gW)
259
+ tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
182
260
  tXgX = thr_copy_X.partition_S(gX)
183
261
  tXsX = thr_copy_X.partition_D(sX)
262
+ if const_expr(mRes is not None):
263
+ tXgRes = thr_copy_X.partition_S(gRes)
264
+ tXsRes = thr_copy_X.partition_D(sRes)
184
265
  tXgO = thr_copy_X.partition_D(gO)
185
- tXrRstd = thr_copy_X.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
266
+ if const_expr(mResO is not None):
267
+ tXgResO = thr_copy_X.partition_D(gResO)
268
+ tXrRstd = thr_copy_X.partition_D(gRstd) if const_expr(mRstd is not None) else None
186
269
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
187
270
 
188
271
  # allocate fragments for gmem->rmem
189
272
  tXrW = cute.make_fragment_like(tXgW)
190
273
  tXrW.fill(0.0)
191
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
274
+ tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
275
+ tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
276
+ if const_expr(mRes is not None):
277
+ tXrRes = cute.make_fragment_like(tXgRes)
192
278
 
193
279
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
194
280
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
195
281
 
196
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
282
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
283
+ tXpX = (
284
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
285
+ )
197
286
  row = tXcX[0][0]
198
287
  if row < shape[0]:
199
288
  cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
289
+ if const_expr(mRes is not None):
290
+ cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
200
291
  cute.arch.cp_async_commit_group()
201
292
 
202
- tXpW = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
203
- if cutlass.const_expr(not delay_w_load):
204
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
293
+ if const_expr(not delay_w_load):
294
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
295
+ if const_expr(mB is not None):
296
+ cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
205
297
 
206
298
  cute.arch.cp_async_wait_group(0)
207
299
  cute.autovec_copy(tXsX, tXrX)
208
300
  x = tXrX.load().to(cute.Float32)
301
+ if const_expr(mRes is not None):
302
+ cute.autovec_copy(tXsRes, tXrRes)
303
+ x += tXrRes.load().to(cute.Float32)
304
+ if const_expr(mResO is not None):
305
+ tXrResO = cute.make_fragment_like(tXgResO)
306
+ tXrResO.store(x.to(tXrResO.element_type))
307
+ if row < shape[0]:
308
+ cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
309
+
209
310
  threads_per_row = tv_layout.shape[0][0]
210
- sum_sq_x = utils.row_reduce(
311
+ sum_sq_x = row_reduce(
211
312
  x * x,
212
313
  cute.ReductionOp.ADD,
213
314
  threads_per_row,
214
315
  reduction_buffer[None, None, 0],
215
316
  mbar_ptr,
216
317
  init_val=0.0,
217
- hook_fn=(cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None),
318
+ hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
218
319
  )
219
320
  rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
220
- if cutlass.const_expr(mRstd is not None):
321
+ if const_expr(mRstd is not None):
221
322
  # Only the thread corresponding to column 0 writes out the rstd to gmem
222
323
  if (
223
324
  tXcX[0][1] == 0
@@ -225,59 +326,76 @@ class RMSNorm(ReductionBase):
225
326
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
226
327
  ):
227
328
  tXrRstd[0] = rstd
228
- if cutlass.const_expr(delay_w_load):
229
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
230
- if cutlass.const_expr(reload_from == "smem"):
231
- cute.autovec_copy(tXsX, tXrX)
232
- x = tXrX.load().to(cute.Float32)
233
- elif cutlass.const_expr(reload_from == "gmem"):
234
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
329
+ if const_expr(delay_w_load):
330
+ cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
331
+ if const_expr(mB is not None):
332
+ cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
333
+ if const_expr(reload_from == "smem" or reload_from == "gmem"):
334
+ if const_expr(reload_from == "smem"):
335
+ cute.autovec_copy(tXsX, tXrX)
336
+ else:
337
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
235
338
  x = tXrX.load().to(cute.Float32)
339
+ if const_expr(mRes is not None):
340
+ cute.autovec_copy(tXsRes, tXrRes)
341
+ x += tXrRes.load().to(cute.Float32)
236
342
  x_hat = x * rstd
237
343
  w = tXrW.load().to(cute.Float32)
238
344
  y = x_hat * w
345
+ if const_expr(mB is not None):
346
+ b = tXrB.load().to(cute.Float32)
347
+ y = y + b
239
348
  tXrO.store(y.to(tXrO.element_type))
240
- tXpO = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
241
349
  if row < shape[0]:
242
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpO)
350
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
243
351
 
244
352
 
353
+ @torch.library.custom_op(
354
+ "quack::_rmsnorm_fwd",
355
+ mutates_args=("out", "rstd", "residual_out"),
356
+ device_types="cuda",
357
+ # We need to specify the schema manually since we're mutating an optional tensor
358
+ schema="(Tensor x, Tensor weight, Tensor(a!) out, Tensor? bias, Tensor(a!)? rstd, Tensor? residual, Tensor(a!)? residual_out, float eps=1e-6) -> ()",
359
+ )
245
360
  def _rmsnorm_fwd(
246
- x: torch.Tensor,
247
- weight: torch.Tensor,
361
+ x: Tensor,
362
+ weight: Tensor,
363
+ out: Tensor,
364
+ bias: Optional[Tensor] = None,
365
+ rstd: Optional[Tensor] = None,
366
+ residual: Optional[Tensor] = None,
367
+ residual_out: Optional[Tensor] = None,
248
368
  eps: float = 1e-6,
249
- return_rstd: bool = False,
250
- ) -> torch.Tensor:
369
+ ) -> None:
251
370
  """RMSNorm forward pass.
252
371
  Args:
253
372
  x: Input tensor of shape (M, N)
254
373
  weight: Weight tensor of shape (N,)
255
374
  eps: Small value for numerical stability
256
- return_rstd: Whether to return the reciprocal standard deviation
257
375
  Returns:
258
376
  Normalized output tensor of same shape as x
259
- If return_rstd is True, also returns rstd tensor of shape (M,)
260
377
  """
261
378
  assert x.dim() == 2, "Input must be 2D"
262
379
  assert weight.dim() == 1, "Weight must be 1D"
263
380
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
264
381
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
265
- assert x.dtype in [
266
- torch.float16,
267
- torch.bfloat16,
268
- torch.float32,
269
- ], "Unsupported dtype"
270
-
382
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
271
383
  assert weight.dtype in [
272
384
  torch.float32,
273
385
  torch.bfloat16,
274
386
  torch.float16,
275
387
  ], "Weight must be float32, float16 or bfloat16"
276
-
277
- M, N = x.shape
388
+ if residual is not None:
389
+ assert residual.shape == x.shape
390
+ assert residual.is_cuda
391
+ assert residual.dtype in [
392
+ torch.float16,
393
+ torch.bfloat16,
394
+ torch.float32,
395
+ ], "Residual must be float16, bfloat16, or float32"
396
+
397
+ _, N = x.shape
278
398
  device = x.device
279
- out = torch.empty_like(x)
280
- rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
281
399
  dtype = torch2cute_dtype_map[x.dtype]
282
400
  # convert_from_dlpack = lambda x: (
283
401
  # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -287,44 +405,109 @@ def _rmsnorm_fwd(
287
405
  convert_from_dlpack = lambda x: (
288
406
  from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
289
407
  )
290
- x_tensor, out_tensor = [convert_from_dlpack(t) for t in (x, out)]
408
+ x_tensor, res_tensor, out_tensor, res_out_tensor = [
409
+ convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
410
+ ]
291
411
  # handle weight divisibility based on weight dtype
292
412
  weight_dtype = torch2cute_dtype_map[weight.dtype]
293
413
  weight_tensor = utils.convert_from_dlpack(
294
414
  weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
295
415
  )
416
+ if bias is not None:
417
+ bias_dtype = torch2cute_dtype_map[bias.dtype]
418
+ bias_tensor = utils.convert_from_dlpack(
419
+ bias.detach(), leading_dim=0, divisibility=128 // bias_dtype.width
420
+ )
421
+ else:
422
+ bias_tensor = None
296
423
  rstd_tensor = (
297
424
  from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
298
425
  if rstd is not None
299
426
  else None
300
427
  )
301
428
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
302
- compile_key = (dtype, N, rstd is not None, weight.dtype)
429
+ compile_key = (
430
+ N,
431
+ dtype,
432
+ res_tensor.element_type if residual is not None else None,
433
+ weight_tensor.element_type,
434
+ bias_tensor.element_type if bias is not None else None,
435
+ res_out_tensor.element_type if residual_out is not None else None,
436
+ rstd is not None,
437
+ )
303
438
  if compile_key not in _rmsnorm_fwd.compile_cache:
304
439
  rmsnorm_op = RMSNorm(dtype, N)
305
440
  _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
306
- rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
441
+ rmsnorm_op,
442
+ x_tensor,
443
+ weight_tensor,
444
+ bias_tensor,
445
+ res_tensor,
446
+ out_tensor,
447
+ res_out_tensor,
448
+ rstd_tensor,
449
+ current_stream,
450
+ eps,
307
451
  )
308
452
  _rmsnorm_fwd.compile_cache[compile_key](
309
- x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
453
+ x_tensor,
454
+ weight_tensor,
455
+ bias_tensor,
456
+ res_tensor,
457
+ out_tensor,
458
+ res_out_tensor,
459
+ rstd_tensor,
460
+ current_stream,
461
+ eps,
310
462
  )
311
- return (out, rstd) if return_rstd else out
312
463
 
313
464
 
314
465
  _rmsnorm_fwd.compile_cache = {}
315
466
 
316
467
 
317
- def rmsnorm_ref(x, w, eps=1e-6):
318
- x_f32 = x.float()
319
- return (x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w).to(
320
- x.dtype
321
- )
468
+ def rmsnorm_fwd(
469
+ x: Tensor,
470
+ weight: Tensor,
471
+ bias: Optional[Tensor] = None,
472
+ residual: Optional[Tensor] = None,
473
+ out_dtype: Optional[torch.dtype] = None,
474
+ residual_dtype: Optional[torch.dtype] = None,
475
+ eps: float = 1e-6,
476
+ store_rstd: bool = False,
477
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
478
+ # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
479
+ # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
480
+ # so that _layer_norm_fwd_impl doesn't have to return them.
481
+ out_dtype = x.dtype if out_dtype is None else out_dtype
482
+ out = torch.empty_like(x, dtype=out_dtype)
483
+ rstd = torch.empty(x.shape[0], device=x.device, dtype=torch.float32) if store_rstd else None
484
+ if residual is not None:
485
+ residual_dtype = residual.dtype
486
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
487
+ residual_out = torch.empty_like(
488
+ x, dtype=residual_dtype if residual_dtype is not None else x.dtype
489
+ )
490
+ else:
491
+ residual_out = None
492
+ _rmsnorm_fwd(x, weight, out, bias, rstd, residual, residual_out, eps=eps)
493
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
494
+ if residual_out is None:
495
+ residual_out = x
496
+ return out, residual_out, rstd
322
497
 
323
498
 
324
- def rstd_ref(x, eps=1e-6):
499
+ def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
325
500
  x_f32 = x.float()
326
- return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
327
-
501
+ if residual is not None:
502
+ residual_f32 = residual.float()
503
+ x_f32 += residual_f32
504
+ out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
505
+ if bias is not None:
506
+ out = out + bias.float()
507
+ if residual is None:
508
+ return out.to(x.dtype)
509
+ else:
510
+ return out.to(x.dtype), x_f32.to(residual.dtype)
328
511
 
329
512
  def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
330
513
  """Reference implementation for RMSNorm backward pass."""
@@ -338,7 +521,6 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
338
521
  dw = (dout * x_hat).sum(dim=0)
339
522
  return dx.to(x.dtype), dw.to(w.dtype)
340
523
 
341
-
342
524
  class RMSNormBackward(ReductionBase):
343
525
  def __init__(self, dtype: cutlass.Numeric, N: int):
344
526
  # 2 stages for double buffering when computing mean of x_hat * wdy
@@ -372,11 +554,13 @@ class RMSNormBackward(ReductionBase):
372
554
  )
373
555
  self.cluster_n = cluster_n
374
556
 
375
- def _smem_size_in_bytes(self, tiler_mn, num_warps):
557
+ def _smem_size_in_bytes(self, tiler_mn, num_warps, do_dtype=None):
558
+ if do_dtype is None:
559
+ do_dtype = self.dtype
376
560
  return (
377
- # Multiply by 2 since we need space for X and dOut,
378
- # and multiply by another 2 due to double buffering
379
- cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 * 2
561
+ # We need space for X and dO, and multiply by 2 due to double buffering
562
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
563
+ + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
380
564
  + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
381
565
  + self.stage * (cutlass.Int64.width // 8) * 2 # mult 2 as we need 2 mbar per stage
382
566
  )
@@ -386,10 +570,13 @@ class RMSNormBackward(ReductionBase):
386
570
  self,
387
571
  mX: cute.Tensor,
388
572
  mW: cute.Tensor,
389
- mdOut: cute.Tensor,
573
+ mdO: cute.Tensor,
574
+ mdResO: Optional[cute.Tensor],
390
575
  mRstd: cute.Tensor,
391
576
  mdX: cute.Tensor,
392
577
  mdW: cute.Tensor,
578
+ mdRes: Optional[cute.Tensor],
579
+ mdB: Optional[cute.Tensor],
393
580
  sm_count: Int32,
394
581
  stream: cuda.CUstream,
395
582
  ):
@@ -398,24 +585,36 @@ class RMSNormBackward(ReductionBase):
398
585
  cute.assume(t.stride[0], divby=128 // t.element_type.width),
399
586
  t.stride[1],
400
587
  )
401
- mX, mdOut, mdX = [
588
+ mX, mdO, mdResO, mdX, mdRes = [
402
589
  cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
403
- for t in (mX, mdOut, mdX)
590
+ if const_expr(t is not None)
591
+ else None
592
+ for t in (mX, mdO, mdResO, mdX, mdRes)
404
593
  ]
405
594
  self._set_cluster_n()
406
- tiler_mn, tv_layout = self._get_tv_layout()
595
+ largest_dtype_width = const_expr(
596
+ max(
597
+ mX.element_type.width,
598
+ mdO.element_type.width,
599
+ mdX.element_type.width,
600
+ mdResO.element_type.width if mdResO is not None else 0,
601
+ mdRes.element_type.width if mdRes is not None else 0,
602
+ )
603
+ )
604
+ tiler_mn, tv_layout = self._get_tv_layout(
605
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
606
+ )
407
607
  num_threads = cute.size(tv_layout, mode=[0])
408
608
  num_warps = num_threads // cute.arch.WARP_SIZE
409
-
410
609
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
411
610
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
412
611
 
413
612
  num_blocks = sm_count
414
- self.kernel(mX, mW, mdOut, mRstd, mdX, mdW, tv_layout, tiler_mn).launch(
613
+ self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
415
614
  grid=[num_blocks, self.cluster_n, 1],
416
615
  block=[num_threads, 1, 1],
417
616
  cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
418
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
617
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
419
618
  stream=stream,
420
619
  )
421
620
 
@@ -424,63 +623,85 @@ class RMSNormBackward(ReductionBase):
424
623
  self,
425
624
  mX: cute.Tensor,
426
625
  mW: cute.Tensor,
427
- mdOut: cute.Tensor,
626
+ mdO: cute.Tensor,
627
+ mdResO: Optional[cute.Tensor],
428
628
  mRstd: cute.Tensor,
429
629
  mdX: cute.Tensor,
430
630
  mdW: cute.Tensor,
631
+ mdB: Optional[cute.Tensor],
632
+ mdRes: Optional[cute.Tensor],
431
633
  tv_layout: cute.Layout,
432
634
  tiler_mn: cute.Shape,
433
635
  ):
434
636
  tidx, _, _ = cute.arch.thread_idx()
435
637
  bidx_start, _, _ = cute.arch.block_idx()
436
638
  gdim, _, _ = cute.arch.grid_dim()
437
- if cutlass.const_expr(self.cluster_n > 1):
639
+ if const_expr(self.cluster_n > 1):
438
640
  cluster_y = cute.arch.block_idx()[1]
439
641
  else:
440
- cluster_y = cutlass.const_expr(0)
642
+ cluster_y = const_expr(0)
441
643
 
442
644
  shape = mX.shape
443
645
  M, N = shape[0], shape[1]
444
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
646
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
445
647
 
446
648
  idX = cute.make_identity_tensor(shape)
447
649
 
448
650
  smem = cutlass.utils.SmemAllocator()
449
651
  smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
450
652
  sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
451
- sdOut = smem.allocate_tensor(mdOut.element_type, smem_layout, byte_alignment=16)
653
+ sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
452
654
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
453
655
  smem, tv_layout, is_persistent=True
454
656
  )
455
- if cutlass.const_expr(mbar_ptr is not None):
657
+ if const_expr(mbar_ptr is not None):
456
658
  mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
457
659
  else:
458
660
  mbar_full_ptr, mbar_empty_ptr = None, None
459
661
 
662
+ num_copy_elems_X = tv_layout.shape[1][0]
663
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
460
664
  copy_atom_load_X = cute.make_copy_atom(
461
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
665
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
462
666
  )
463
667
  copy_atom_load_X_async = cute.make_copy_atom(
464
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
668
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
465
669
  )
466
- num_bits_per_copy_W = cutlass.const_expr(
467
- min(128, 128 // mX.element_type.width * mW.element_type.width)
670
+ num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
671
+ copy_atom_load_dO_async = cute.make_copy_atom(
672
+ cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
468
673
  )
674
+ num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
469
675
  copy_atom_load_W = cute.make_copy_atom(
470
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_bits_per_copy_W
471
- )
472
- num_bits_per_copy_dX = cutlass.const_expr(
473
- min(128, 128 // mX.element_type.width * mdX.element_type.width)
676
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
474
677
  )
678
+ if const_expr(mdResO is not None):
679
+ num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
680
+ copy_atom_load_dResO = cute.make_copy_atom(
681
+ cute.nvgpu.CopyUniversalOp(),
682
+ mdResO.element_type,
683
+ num_bits_per_copy=num_copy_bits_dResO,
684
+ )
685
+ num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
475
686
  copy_atom_store_dX = cute.make_copy_atom(
476
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_bits_per_copy_dX
477
- )
478
- num_bits_per_copy_dW = cutlass.const_expr(
479
- min(128, 128 // mX.element_type.width * mdW.element_type.width)
687
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
480
688
  )
689
+ num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
481
690
  copy_atom_store_dW = cute.make_copy_atom(
482
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_bits_per_copy_dW
691
+ cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
483
692
  )
693
+ if const_expr(mdB is not None):
694
+ num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
695
+ copy_atom_store_dB = cute.make_copy_atom(
696
+ cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
697
+ )
698
+ if const_expr(mdRes is not None):
699
+ num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
700
+ copy_atom_load_dRes = cute.make_copy_atom(
701
+ cute.nvgpu.CopyUniversalOp(),
702
+ mdRes.element_type,
703
+ num_bits_per_copy=num_copy_bits_dRes,
704
+ )
484
705
 
485
706
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
486
707
 
@@ -510,21 +731,36 @@ class RMSNormBackward(ReductionBase):
510
731
  if not is_even_N
511
732
  else None
512
733
  )
734
+ if const_expr(mdB is not None):
735
+ db_coord = cute.local_tile(idX, tiler_mn, (0, cluster_y))
736
+ tXpdB = (
737
+ utils.predicate_k(thr_copy_X.partition_S(db_coord), limit=shape[1])
738
+ if not is_even_N
739
+ else None
740
+ )
513
741
 
514
742
  gdW = cute.local_tile(mdW, (1, tiler_mn[1]), (bidx_start, cluster_y))
515
743
  tXgdW = thr_copy_X.partition_S(gdW)
516
744
  # Always compute partial weight gradients in fp32
517
745
  tXrdW = cute.make_fragment_like(tXgdW, Float32)
518
746
 
519
- gX = cute.local_tile(mX, tiler_mn, (None, cluster_y))
520
- gdOut = cute.local_tile(mdOut, tiler_mn, (None, cluster_y))
521
- gdX = cute.local_tile(mdX, tiler_mn, (None, cluster_y))
522
- cX = cute.local_tile(idX, tiler_mn, (None, cluster_y))
747
+ gdB = cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y)) if const_expr(mdB is not None) else None
748
+ tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
749
+ tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
750
+
751
+ gX, gdO, gdResO, gdX, gdRes, cX = [
752
+ cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
753
+ for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
754
+ ]
523
755
  tXgX = thr_copy_X.partition_S(gX)
524
756
  tXsX = thr_copy_X.partition_D(sX)
525
- tXgdOut = thr_copy_X.partition_S(gdOut)
526
- tXsdOut = thr_copy_X.partition_D(sdOut)
757
+ tXgdO = thr_copy_X.partition_S(gdO)
758
+ tXsdO = thr_copy_X.partition_D(sdO)
527
759
  tXgdX = thr_copy_X.partition_D(gdX)
760
+ if const_expr(mdResO is not None):
761
+ tXgdResO = thr_copy_X.partition_S(gdResO)
762
+ if const_expr(mdRes is not None):
763
+ tXgdRes = thr_copy_X.partition_D(gdRes)
528
764
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
529
765
  # This doesn't change across iterations
530
766
  tXpX = (
@@ -533,62 +769,48 @@ class RMSNormBackward(ReductionBase):
533
769
  else None
534
770
  )
535
771
 
536
- tXrX, tXrdOut, tXrdX = [
537
- cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdOut, tXgdX)
772
+ tXrX, tXrdO, tXrdX = [
773
+ cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
538
774
  ]
775
+ if const_expr(mdResO is not None):
776
+ tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
777
+ if const_expr(mdRes is not None):
778
+ tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
779
+
780
+ copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
781
+ copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
539
782
 
540
783
  # Prefetch the first batch
541
784
  row = tXcX[None, None, None, bidx_start][0][0]
542
785
  if row < M:
543
786
  tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
544
- tXgdOut_cur = utils.coord_offset_i64(bidx_start, tXgdOut, dim=3)[None, None, None, 0]
545
- cute.copy(
546
- copy_atom_load_X_async,
547
- tXgX_cur,
548
- tXsX[None, None, None, 0],
549
- pred=tXpX,
550
- )
551
- cute.copy(
552
- copy_atom_load_X_async,
553
- tXgdOut_cur,
554
- tXsdOut[None, None, None, 0],
555
- pred=tXpX,
556
- )
787
+ tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
788
+ copy_X(tXgX_cur, tXsX[None, None, None, 0])
789
+ copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
557
790
  elif tiler_mn[0] > 1:
558
791
  # Fill with zero, otherwise smem will be uninitialized, and we could read this back
559
792
  # later into registers, causing wrong dW.
560
793
  utils.fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
561
- utils.fill_oob(tXsdOut[None, None, None, 0], None, fill_value=mdOut.element_type.zero)
794
+ utils.fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
562
795
  cute.arch.cp_async_commit_group()
563
796
 
564
- if cutlass.const_expr(self.cluster_n > 1):
797
+ if const_expr(self.cluster_n > 1):
565
798
  cute.arch.cluster_wait()
566
799
 
567
800
  threads_per_row = tv_layout.shape[0][0]
568
801
  tXrdW.fill(0.0)
802
+ if const_expr(mdB is not None):
803
+ tXrdB.fill(0.0)
569
804
  stage = Int32(0)
570
805
  producer_phase = Int32(1)
571
806
  consumer_phase = Int32(0)
572
807
  for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
573
808
  row = tXcX[None, None, None, bidx][0][0]
574
- rstd = cutlass.Float.zero
575
809
  if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
576
810
  tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
577
- tXgdOut_cur = utils.coord_offset_i64(bidx + gdim, tXgdOut, dim=3)[
578
- None, None, None, 0
579
- ]
580
- cute.copy(
581
- copy_atom_load_X_async,
582
- tXgX_cur,
583
- tXsX[None, None, None, stage ^ 1],
584
- pred=tXpX,
585
- )
586
- cute.copy(
587
- copy_atom_load_X_async,
588
- tXgdOut_cur,
589
- tXsdOut[None, None, None, stage ^ 1],
590
- pred=tXpX,
591
- )
811
+ tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
812
+ copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
813
+ copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
592
814
  elif tiler_mn[0] > 1:
593
815
  utils.fill_oob(
594
816
  tXsX[None, None, None, stage ^ 1],
@@ -596,36 +818,45 @@ class RMSNormBackward(ReductionBase):
596
818
  fill_value=mX.element_type.zero,
597
819
  )
598
820
  utils.fill_oob(
599
- tXsdOut[None, None, None, stage ^ 1],
821
+ tXsdO[None, None, None, stage ^ 1],
600
822
  None,
601
- fill_value=mdOut.element_type.zero,
823
+ fill_value=mdO.element_type.zero,
602
824
  )
603
825
  cute.arch.cp_async_commit_group()
826
+ rstd = cutlass.Float.zero
604
827
  if row < M or tiler_mn[0] == 1:
605
828
  rstd = mRstd[row]
829
+ if const_expr(mdResO is not None):
830
+ tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
831
+ if row < M or tiler_mn[0] == 1:
832
+ cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
833
+ elif tiler_mn[0] > 1:
834
+ tXrdResO.fill(0.0)
606
835
  cute.arch.cp_async_wait_group(1)
607
836
  cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
608
837
  x = tXrX.load().to(cute.Float32)
609
- cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
610
- dout = tXrdOut.load().to(cute.Float32)
838
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
839
+ dout = tXrdO.load().to(cute.Float32)
840
+ if const_expr(mdResO is not None):
841
+ dout += tXrdResO.load().to(cute.Float32)
611
842
  x_hat = x * rstd
612
843
  wdy = dout * weight
613
- if cutlass.const_expr(self.cluster_n > 1):
844
+ if const_expr(self.cluster_n > 1):
614
845
  cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
615
846
  mean_xhat_wdy = (
616
- utils.row_reduce(
847
+ row_reduce(
617
848
  x_hat * wdy,
618
849
  cute.ReductionOp.ADD,
619
850
  threads_per_row,
620
851
  reduction_buffer[None, None, stage],
621
- (mbar_full_ptr + stage if cutlass.const_expr(self.cluster_n > 1) else None),
852
+ (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None),
622
853
  phase=consumer_phase,
623
854
  init_val=0.0,
624
855
  )
625
856
  / shape[1]
626
857
  )
627
858
 
628
- if cutlass.const_expr(self.cluster_n > 1):
859
+ if const_expr(self.cluster_n > 1):
629
860
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
630
861
  # Requires adjusting the thread_count when initializing the mbar
631
862
  cute.arch.sync_warp()
@@ -635,28 +866,37 @@ class RMSNormBackward(ReductionBase):
635
866
  mbar_empty_ptr + stage, peer_cta_rank_in_cluster=lane_idx
636
867
  )
637
868
 
638
- if cutlass.const_expr(self.reload_wdy == "smem"):
639
- cute.autovec_copy(tXsdOut[None, None, None, stage], tXrdOut)
640
- dout = tXrdOut.load().to(cute.Float32)
869
+ if const_expr(self.reload_wdy == "smem"):
870
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
871
+ dout = tXrdO.load().to(cute.Float32)
872
+ if const_expr(mdResO is not None):
873
+ dout += tXrdResO.load().to(cute.Float32)
641
874
  wdy = dout * weight
642
875
 
643
876
  dx = (wdy - x_hat * mean_xhat_wdy) * rstd
644
- tXrdX.store(dx.to(tXrdOut.element_type))
877
+ tXrdX.store(dx.to(tXrdX.element_type))
645
878
  if row < M or tiler_mn[0] == 1:
646
879
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
647
880
  cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
881
+ if const_expr(mdRes is not None):
882
+ tXrdRes.store(dx.to(tXrdRes.element_type))
883
+ tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
884
+ if row < M or tiler_mn[0] == 1:
885
+ cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
648
886
  # Accumulate weight gradients in fp32
649
887
  tXrdW.store(tXrdW.load() + dout * x_hat)
888
+ if const_expr(mdB is not None):
889
+ tXrdB.store(tXrdB.load() + dout)
650
890
 
651
891
  stage ^= 1
652
892
  if stage == 0:
653
893
  consumer_phase ^= 1
654
894
  producer_phase ^= 1
655
895
 
656
- if cutlass.const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
896
+ if const_expr(self.cluster_n > 1): # Prevent cluster from exiting early
657
897
  cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
658
898
 
659
- if cutlass.const_expr(tiler_mn[0] > 1):
899
+ if const_expr(tiler_mn[0] > 1):
660
900
  # reduction of dw_partial within the same threadblock
661
901
  sdW = cute.make_tensor(
662
902
  cute.recast_ptr(sX.iterator, dtype=cute.Float32),
@@ -669,23 +909,73 @@ class RMSNormBackward(ReductionBase):
669
909
  cute.autovec_copy(tXrdW, tXsdW)
670
910
  cute.arch.barrier()
671
911
  if row == 0:
672
- for i in cutlass.range_constexpr(1, cutlass.const_expr(tiler_mn[0])):
912
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
673
913
  tXrdW_other = cute.make_fragment_like(tXrdW)
674
914
  tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
675
915
  cute.autovec_copy(tXsdW_other, tXrdW_other)
676
916
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
677
917
  cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
918
+ cute.arch.barrier()
919
+ if const_expr(mdB is not None):
920
+ sdB = cute.make_tensor(
921
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
922
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
923
+ )
924
+ tXsdB = thr_copy_X.partition_D(sdB)
925
+ cute.arch.barrier()
926
+ row = tXcX[None, None, None, 0][0][0]
927
+ if row > 0:
928
+ cute.autovec_copy(tXrdB, tXsdB)
929
+ cute.arch.barrier()
930
+ if row == 0:
931
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
932
+ tXrdB_other = cute.make_fragment_like(tXrdB)
933
+ tXsdB_other = cute.make_tensor(tXsdB.iterator + i * sdB.stride[0], tXsdB.layout)
934
+ cute.autovec_copy(tXsdB_other, tXrdB_other)
935
+ tXrdB.store(tXrdB.load() + tXrdB_other.load())
936
+ cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
678
937
  else:
679
938
  # dw is already in fp32, so we can directly copy to global memory
680
939
  cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
940
+ if const_expr(mdB is not None):
941
+ cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
681
942
 
682
943
 
683
- def _rmsnorm_backward(
684
- x: torch.Tensor,
685
- weight: torch.Tensor,
686
- dout: torch.Tensor,
687
- rstd: torch.Tensor,
688
- ) -> (torch.Tensor, torch.Tensor):
944
+ def _get_sm_count(N: int, device: torch.device) -> int:
945
+ # This should be tuned on how many CTAs can be launched on each SM
946
+ sm_count_multiple = (
947
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
948
+ )
949
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count
950
+ # By right, if we're using cluster, this should be cluster_count not sm_count.
951
+ # But for cluster >= 4, due to quantization we would need to query active max cluster.
952
+ # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
953
+ # avoid wave quantization.
954
+ sm_count = (
955
+ sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
956
+ )
957
+
958
+ return sm_count
959
+
960
+
961
+ @torch.library.custom_op(
962
+ "quack::_rmsnorm_bwd",
963
+ mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
964
+ device_types="cuda",
965
+ # We need to specify the schema manually since we're mutating an optional tensor
966
+ schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a!) dx, Tensor(a!) dw_partial, Tensor(a!)? db_partial, Tensor? dresidual_out, Tensor(a!)? dresidual) -> ()",
967
+ )
968
+ def _rmsnorm_bwd(
969
+ x: Tensor,
970
+ weight: Tensor,
971
+ dout: Tensor,
972
+ rstd: Tensor,
973
+ dx: Tensor,
974
+ dw_partial: Tensor,
975
+ db_partial: Optional[Tensor] = None,
976
+ dresidual_out: Optional[Tensor] = None,
977
+ dresidual: Optional[Tensor] = None,
978
+ ) -> None:
689
979
  """RMSNorm backward pass.
690
980
  Args:
691
981
  x: Input tensor of shape (M, N)
@@ -701,46 +991,39 @@ def _rmsnorm_backward(
701
991
  assert weight.dim() == 1, "Weight must be 1D"
702
992
  assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
703
993
  assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
704
- assert x.dtype in [
705
- torch.float16,
706
- torch.bfloat16,
707
- torch.float32,
708
- ], "Unsupported dtype"
709
-
994
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
710
995
  assert weight.dtype in [
711
996
  torch.float32,
712
997
  torch.bfloat16,
713
998
  torch.float16,
714
999
  ], "Weight must be float32, float16 or bfloat16"
715
-
716
- M, N = x.shape
717
- dx = torch.empty_like(x)
718
-
1000
+ if dresidual_out is not None:
1001
+ assert dresidual_out.shape == x.shape
1002
+ assert dresidual_out.is_cuda
1003
+ assert dresidual_out.dtype in [
1004
+ torch.float16,
1005
+ torch.bfloat16,
1006
+ torch.float32,
1007
+ ], "Residual must be float16, bfloat16, or float32"
1008
+ if dresidual is not None:
1009
+ assert dresidual.shape == x.shape
1010
+ assert dresidual.is_cuda
1011
+ assert dresidual.dtype in [
1012
+ torch.float16,
1013
+ torch.bfloat16,
1014
+ torch.float32,
1015
+ ], "Residual must be float16, bfloat16, or float32"
1016
+
1017
+ N = x.size(1)
719
1018
  device = x.device
720
-
721
- # This should be tuned on how many CTAs can be launched on each SM
722
- sm_count_multiple = (
723
- 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
724
- )
725
- sm_count = torch.cuda.get_device_properties(device).multi_processor_count
726
- # By right, if we're using cluster, this should be cluster_count not sm_count.
727
- # But for cluster >= 4, due to quantization we would need to query active max cluster.
728
- # Instead we just do sm_count * 2, which is reasonably larger than active_cluster_count to
729
- # avoid wave quantization.
730
- sm_count = (
731
- sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
732
- )
733
-
734
- # Always store partial gradients in fp32 for numerical accuracy
735
- dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
736
-
737
- dtype = torch2cute_dtype_map[x.dtype]
738
-
1019
+ sm_count = dw_partial.shape[0]
739
1020
  convert_from_dlpack = lambda x: (
740
1021
  from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
741
1022
  )
742
- x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
743
-
1023
+ x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
1024
+ convert_from_dlpack(t) if t is not None else None
1025
+ for t in (x, dout, dresidual_out, dx, dresidual)
1026
+ ]
744
1027
  # Handle weight div based on weight dtype
745
1028
  weight_dtype = torch2cute_dtype_map[weight.dtype]
746
1029
  weight_tensor = utils.convert_from_dlpack(
@@ -748,74 +1031,143 @@ def _rmsnorm_backward(
748
1031
  )
749
1032
 
750
1033
  dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1034
+ db_partial_tensor = from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) if db_partial is not None else None
751
1035
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
752
1036
 
753
1037
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
754
1038
 
755
- compile_key = (dtype, N, weight.dtype)
756
- if compile_key not in _rmsnorm_backward.compile_cache:
757
- rmsnorm_backward_op = RMSNormBackward(dtype, N)
758
- _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
1039
+ compile_key = (N, x_tensor.element_type, weight_tensor.element_type, db_partial.dtype if db_partial is not None else None,
1040
+ dresidual.dtype if dresidual is not None else None,
1041
+ dresidual_out.dtype if dresidual_out is not None else None)
1042
+ if compile_key not in _rmsnorm_bwd.compile_cache:
1043
+ rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
1044
+ _rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
759
1045
  rmsnorm_backward_op,
760
1046
  x_tensor,
761
1047
  weight_tensor,
762
1048
  dout_tensor,
1049
+ dres_out_tensor,
763
1050
  rstd_tensor,
764
1051
  dx_tensor,
765
1052
  dw_partial_tensor,
1053
+ dres_tensor,
1054
+ db_partial_tensor,
766
1055
  sm_count,
767
1056
  current_stream,
768
1057
  )
769
1058
 
770
- _rmsnorm_backward.compile_cache[compile_key](
1059
+ _rmsnorm_bwd.compile_cache[compile_key](
771
1060
  x_tensor,
772
1061
  weight_tensor,
773
1062
  dout_tensor,
1063
+ dres_out_tensor,
774
1064
  rstd_tensor,
775
1065
  dx_tensor,
776
1066
  dw_partial_tensor,
1067
+ dres_tensor,
1068
+ db_partial_tensor,
777
1069
  sm_count,
778
1070
  current_stream,
779
1071
  )
780
- # we have summed the partial gradients in fp32, now we convert back to the weight dtype
781
- dw = dw_partial.sum(dim=0).to(weight.dtype)
782
- return dx, dw
783
1072
 
784
1073
 
785
- _rmsnorm_backward.compile_cache = {}
1074
+ _rmsnorm_bwd.compile_cache = {}
1075
+
1076
+
1077
+ def rmsnorm_bwd(
1078
+ x: Tensor,
1079
+ weight: Tensor,
1080
+ dout: Tensor,
1081
+ rstd: Tensor,
1082
+ dresidual_out: Optional[Tensor] = None, # grad wrt residual_out
1083
+ has_bias: bool = False,
1084
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
1085
+ device = x.device
1086
+ N = x.size(1)
1087
+ sm_count = _get_sm_count(N, device)
1088
+ dx = torch.empty_like(x)
1089
+
1090
+ if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
1091
+ dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
1092
+ else:
1093
+ dresidual = None
1094
+ # Always store partial gradients in fp32 for numerical accuracy
1095
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
1096
+ db_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
1097
+ _rmsnorm_bwd(x, weight, dout, rstd, dx, dw_partial, db_partial, dresidual_out, dresidual)
1098
+ # we have summed the partial gradients in fp32, now we convert back to the weight dtype
1099
+ dw = dw_partial.sum(dim=0).to(weight.dtype)
1100
+ db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
1101
+ # dresidual is the same as dx in this case
1102
+ if dresidual_out is not None and dresidual_out.dtype == dx.dtype:
1103
+ dresidual = dx
1104
+ return dx, dw, db, dresidual
786
1105
 
787
1106
 
788
1107
  class RMSNormFunction(torch.autograd.Function):
789
1108
  @staticmethod
790
- def forward(ctx, x, weight, eps):
791
- x_shape_start = x.shape
792
-
1109
+ def forward(ctx, x, weight, bias=None, residual=None, out_dtype=None, residual_dtype=None, eps=1e-6, prenorm=False):
1110
+ x_shape_og = x.shape
793
1111
  # Flatten input
794
- x = x.view(-1, x.shape[-1])
795
-
796
- out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
797
- ctx.save_for_backward(x, weight, rstd)
1112
+ x = x.reshape(-1, x.shape[-1])
1113
+ if residual is not None:
1114
+ residual = residual.reshape(-1, residual.shape[-1])
1115
+ need_grad = any(ctx.needs_input_grad[:3])
1116
+ out, residual_out, rstd = rmsnorm_fwd(
1117
+ x,
1118
+ weight,
1119
+ bias=bias,
1120
+ residual=residual,
1121
+ out_dtype=out_dtype,
1122
+ residual_dtype=residual_dtype,
1123
+ eps=eps,
1124
+ store_rstd=need_grad,
1125
+ )
1126
+ ctx.save_for_backward(x if residual is None else residual_out, weight, rstd)
1127
+ ctx.has_bias = bias is not None
798
1128
  ctx.eps = eps
799
- ctx.x_shape_start = x_shape_start
800
-
801
- return out.reshape(x_shape_start)
1129
+ ctx.x_shape_og = x_shape_og
1130
+ ctx.residual_dtype = residual.dtype if residual is not None else None
1131
+ ctx.prenorm = prenorm
1132
+ if residual_out is None or prenorm == False:
1133
+ return out.reshape(x_shape_og)
1134
+ else:
1135
+ return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
802
1136
 
803
1137
  @staticmethod
804
- def backward(ctx, dout):
1138
+ def backward(ctx, dout, *args):
805
1139
  x, weight, rstd = ctx.saved_tensors
806
- x_shape_start = ctx.x_shape_start
1140
+ has_bias = ctx.has_bias
1141
+ if ctx.prenorm and ctx.residual_dtype is not None:
1142
+ dresidual_out = args[0]
1143
+ dresidual_out = dresidual_out.reshape(-1, dresidual_out.shape[-1])
1144
+ else:
1145
+ dresidual_out = None
1146
+ x_shape_og = ctx.x_shape_og
807
1147
  # Reshape dout to match the flattened shape used in forward
808
1148
  dout = dout.view(-1, dout.shape[-1])
809
- dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
810
- dx = dx.view(x_shape_start)
811
- # dx is returned for input gradient,
812
- # dw is returned for weight gradient,
813
- # None for eps gradient
814
- return dx, dw, None
815
1149
 
1150
+ dx, dw, db, dresidual = rmsnorm_bwd(x, weight, dout, rstd, dresidual_out, has_bias)
1151
+ dx = dx.view(x_shape_og)
1152
+ if dresidual_out is not None:
1153
+ dresidual_out = dresidual_out.reshape(x_shape_og)
1154
+ if dresidual is not None:
1155
+ dresidual = dresidual.reshape(x_shape_og)
1156
+
1157
+ return dx, dw, db, dresidual, *([None] * 4)
816
1158
 
817
- def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
818
- """RMSNorm forward pass with automatic differentiation support.
1159
+
1160
+ def rmsnorm(
1161
+ x: Tensor,
1162
+ weight: Tensor,
1163
+ bias: Optional[Tensor] = None,
1164
+ residual: Optional[Tensor] = None,
1165
+ out_dtype: Optional[torch.dtype] = None,
1166
+ residual_dtype: Optional[torch.dtype] = None,
1167
+ eps: float = 1e-6,
1168
+ prenorm: bool = False,
1169
+ ) -> Tensor:
1170
+ """RMSNorm with automatic differentiation support.
819
1171
 
820
1172
  Args:
821
1173
  x: Input tensor of shape (M, N)
@@ -825,7 +1177,7 @@ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.T
825
1177
  Returns:
826
1178
  Normalized output tensor of same shape as x
827
1179
  """
828
- return RMSNormFunction.apply(x, weight, eps)
1180
+ return RMSNormFunction.apply(x, weight, bias, residual, out_dtype, residual_dtype, eps, prenorm)
829
1181
 
830
1182
 
831
1183
  class QuackRMSNorm(torch.nn.Module):
@@ -848,17 +1200,17 @@ class QuackRMSNorm(torch.nn.Module):
848
1200
  self.weight = torch.nn.Parameter(torch.ones(dim))
849
1201
  self.eps = eps
850
1202
 
851
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1203
+ def forward(self, x: Tensor) -> Tensor:
852
1204
  """Apply RMSNorm to the input tensor.
853
1205
 
854
1206
  Args:
855
- x (torch.Tensor): Input tensor
1207
+ x (Tensor): Input tensor
856
1208
 
857
1209
  Returns:
858
- torch.Tensor: Normalized tensor
1210
+ Tensor: Normalized tensor
859
1211
  """
860
- return rmsnorm(x, self.weight, self.eps)
1212
+ return rmsnorm(x, self.weight, eps=self.eps)
861
1213
 
862
1214
  def reset_parameters(self):
863
1215
  """Reset the weight parameter to ones."""
864
- torch.nn.init.ones_(self.weight)
1216
+ torch.nn.init.ones_(self.weight)