quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -19,6 +19,7 @@ from quack.reduce import row_reduce
19
19
  from quack.reduction_base import ReductionBase
20
20
  from quack.cute_dsl_utils import torch2cute_dtype_map
21
21
 
22
+
22
23
  class RMSNorm(ReductionBase):
23
24
  def __init__(self, dtype: cutlass.Numeric, N: int):
24
25
  super().__init__(dtype, N, stage=1)
@@ -93,7 +94,7 @@ class RMSNorm(ReductionBase):
93
94
  def __call__(
94
95
  self,
95
96
  mX: cute.Tensor,
96
- mW: cute.Tensor,
97
+ mW: Optional[cute.Tensor],
97
98
  mB: Optional[cute.Tensor],
98
99
  mRes: Optional[cute.Tensor],
99
100
  mO: cute.Tensor,
@@ -129,10 +130,15 @@ class RMSNorm(ReductionBase):
129
130
  )
130
131
  num_threads = cute.size(tv_layout, mode=[0])
131
132
  num_warps = num_threads // cute.arch.WARP_SIZE
132
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
133
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
133
+ if const_expr(mW is not None):
134
+ mW_expanded_layout = cute.prepend(
135
+ mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
136
+ )
137
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
134
138
  if const_expr(mB is not None):
135
- mB_expanded_layout = cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
139
+ mB_expanded_layout = cute.prepend(
140
+ mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
141
+ )
136
142
  mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
137
143
  if const_expr(mRstd is not None):
138
144
  mRstd_expanded_layout = cute.append(
@@ -155,7 +161,7 @@ class RMSNorm(ReductionBase):
155
161
  def kernel(
156
162
  self,
157
163
  mX: cute.Tensor,
158
- mW: cute.Tensor,
164
+ mW: Optional[cute.Tensor],
159
165
  mB: Optional[cute.Tensor],
160
166
  mRes: Optional[cute.Tensor],
161
167
  mO: cute.Tensor,
@@ -201,12 +207,10 @@ class RMSNorm(ReductionBase):
201
207
  for mT in (mX, mRes, mO, mResO)
202
208
  ]
203
209
  cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
204
- gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
205
- gB = (
206
- cute.local_tile(mB, tiler_mn, (0, cluster_y))
207
- if const_expr(mB is not None)
208
- else None
209
- )
210
+ gW, gB = [
211
+ cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
212
+ for mT in (mW, mB)
213
+ ]
210
214
  gRstd = (
211
215
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
212
216
  if const_expr(mRstd is not None)
@@ -215,47 +219,14 @@ class RMSNorm(ReductionBase):
215
219
 
216
220
  # declare the atoms which will be used later for memory copy
217
221
  num_copy_elems_X = tv_layout.shape[1][0]
218
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
219
- copy_atom_load_X = cute.make_copy_atom(
220
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
222
+ copy_atom_load_X_async = utils.get_copy_atom(
223
+ mX.element_type, num_copy_elems_X, is_async=True
221
224
  )
222
- copy_atom_load_X_async = cute.make_copy_atom(
223
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
224
- )
225
- num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
226
- copy_atom_load_W = cute.make_copy_atom(
227
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
228
- )
229
- num_bits_per_copy_B = cutlass.const_expr(
230
- min(128, num_copy_elems_X * mB.element_type.width)
231
- ) if const_expr(mB is not None) else 0
232
- copy_atom_load_B = cute.make_copy_atom(
233
- cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
234
- ) if const_expr(mB is not None) else None
235
- if const_expr(mRes is not None):
236
- num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
237
- copy_atom_load_Res_async = cute.make_copy_atom(
238
- cute.nvgpu.cpasync.CopyG2SOp(),
239
- mRes.element_type,
240
- num_bits_per_copy=num_copy_bits_Res,
241
- )
242
- num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
243
- copy_atom_store_O = cute.make_copy_atom(
244
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
245
- )
246
- if const_expr(mResO is not None):
247
- num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
248
- copy_atom_store_ResO = cute.make_copy_atom(
249
- cute.nvgpu.CopyUniversalOp(),
250
- mResO.element_type,
251
- num_bits_per_copy=num_copy_bits_ResO,
252
- )
253
-
254
225
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
255
226
  tidx
256
227
  )
257
228
 
258
- tXgW = thr_copy_X.partition_S(gW)
229
+ tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
259
230
  tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
260
231
  tXgX = thr_copy_X.partition_S(gX)
261
232
  tXsX = thr_copy_X.partition_D(sX)
@@ -269,8 +240,9 @@ class RMSNorm(ReductionBase):
269
240
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
270
241
 
271
242
  # allocate fragments for gmem->rmem
272
- tXrW = cute.make_fragment_like(tXgW)
273
- tXrW.fill(0.0)
243
+ tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
244
+ if const_expr(mW is not None):
245
+ tXrW.fill(0.0)
274
246
  tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
275
247
  tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
276
248
  if const_expr(mRes is not None):
@@ -283,17 +255,21 @@ class RMSNorm(ReductionBase):
283
255
  tXpX = (
284
256
  utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
285
257
  )
258
+ # Each copy will use the same number of elements as X and same predicate
259
+ copy = partial(utils.copy, pred=tXpX, num_copy_elems=num_copy_elems_X)
260
+
286
261
  row = tXcX[0][0]
287
262
  if row < shape[0]:
288
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
263
+ copy(tXgX, tXsX, is_async=True)
289
264
  if const_expr(mRes is not None):
290
- cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
265
+ copy(tXgRes, tXsRes, is_async=True)
291
266
  cute.arch.cp_async_commit_group()
292
267
 
293
268
  if const_expr(not delay_w_load):
294
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
269
+ if const_expr(mW is not None):
270
+ copy(tXgW, tXrW)
295
271
  if const_expr(mB is not None):
296
- cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
272
+ copy(tXgB, tXrB)
297
273
 
298
274
  cute.arch.cp_async_wait_group(0)
299
275
  cute.autovec_copy(tXsX, tXrX)
@@ -305,7 +281,7 @@ class RMSNorm(ReductionBase):
305
281
  tXrResO = cute.make_fragment_like(tXgResO)
306
282
  tXrResO.store(x.to(tXrResO.element_type))
307
283
  if row < shape[0]:
308
- cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
284
+ copy(tXrResO, tXgResO)
309
285
 
310
286
  threads_per_row = tv_layout.shape[0][0]
311
287
  sum_sq_x = row_reduce(
@@ -317,7 +293,7 @@ class RMSNorm(ReductionBase):
317
293
  init_val=0.0,
318
294
  hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
319
295
  )
320
- rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
296
+ rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
321
297
  if const_expr(mRstd is not None):
322
298
  # Only the thread corresponding to column 0 writes out the rstd to gmem
323
299
  if (
@@ -327,27 +303,28 @@ class RMSNorm(ReductionBase):
327
303
  ):
328
304
  tXrRstd[0] = rstd
329
305
  if const_expr(delay_w_load):
330
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
306
+ if const_expr(mW is not None):
307
+ copy(tXgW, tXrW)
331
308
  if const_expr(mB is not None):
332
- cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
309
+ copy(tXgB, tXrB)
333
310
  if const_expr(reload_from == "smem" or reload_from == "gmem"):
334
311
  if const_expr(reload_from == "smem"):
335
312
  cute.autovec_copy(tXsX, tXrX)
336
313
  else:
337
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
314
+ copy(tXgX, tXrX)
338
315
  x = tXrX.load().to(cute.Float32)
339
316
  if const_expr(mRes is not None):
340
317
  cute.autovec_copy(tXsRes, tXrRes)
341
318
  x += tXrRes.load().to(cute.Float32)
342
319
  x_hat = x * rstd
343
- w = tXrW.load().to(cute.Float32)
344
- y = x_hat * w
320
+ y = x_hat
321
+ if const_expr(mW is not None):
322
+ y *= tXrW.load().to(cute.Float32)
345
323
  if const_expr(mB is not None):
346
- b = tXrB.load().to(cute.Float32)
347
- y = y + b
324
+ y += tXrB.load().to(cute.Float32)
348
325
  tXrO.store(y.to(tXrO.element_type))
349
326
  if row < shape[0]:
350
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
327
+ copy(tXrO, tXgO)
351
328
 
352
329
 
353
330
  @torch.library.custom_op(
@@ -355,11 +332,11 @@ class RMSNorm(ReductionBase):
355
332
  mutates_args=("out", "rstd", "residual_out"),
356
333
  device_types="cuda",
357
334
  # We need to specify the schema manually since we're mutating an optional tensor
358
- schema="(Tensor x, Tensor weight, Tensor(a!) out, Tensor? bias, Tensor(a!)? rstd, Tensor? residual, Tensor(a!)? residual_out, float eps=1e-6) -> ()",
335
+ schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
359
336
  )
360
337
  def _rmsnorm_fwd(
361
338
  x: Tensor,
362
- weight: Tensor,
339
+ weight: Optional[Tensor],
363
340
  out: Tensor,
364
341
  bias: Optional[Tensor] = None,
365
342
  rstd: Optional[Tensor] = None,
@@ -370,21 +347,23 @@ def _rmsnorm_fwd(
370
347
  """RMSNorm forward pass.
371
348
  Args:
372
349
  x: Input tensor of shape (M, N)
373
- weight: Weight tensor of shape (N,)
350
+ weight: Optional weight tensor of shape (N,)
374
351
  eps: Small value for numerical stability
375
352
  Returns:
376
353
  Normalized output tensor of same shape as x
377
354
  """
378
355
  assert x.dim() == 2, "Input must be 2D"
379
- assert weight.dim() == 1, "Weight must be 1D"
380
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
381
- assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
356
+ assert x.is_cuda, "Input tensor must be on CUDA device"
382
357
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
383
- assert weight.dtype in [
384
- torch.float32,
385
- torch.bfloat16,
386
- torch.float16,
387
- ], "Weight must be float32, float16 or bfloat16"
358
+ if weight is not None:
359
+ assert weight.dim() == 1, "Weight must be 1D"
360
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
361
+ assert weight.is_cuda, "Weight tensor must be on CUDA device"
362
+ assert weight.dtype in [
363
+ torch.float32,
364
+ torch.bfloat16,
365
+ torch.float16,
366
+ ], "Weight must be float32, float16 or bfloat16"
388
367
  if residual is not None:
389
368
  assert residual.shape == x.shape
390
369
  assert residual.is_cuda
@@ -397,11 +376,6 @@ def _rmsnorm_fwd(
397
376
  _, N = x.shape
398
377
  device = x.device
399
378
  dtype = torch2cute_dtype_map[x.dtype]
400
- # convert_from_dlpack = lambda x: (
401
- # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
402
- # mode=0, divisibility=128 // dtype.width
403
- # )
404
- # )
405
379
  convert_from_dlpack = lambda x: (
406
380
  from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
407
381
  )
@@ -409,10 +383,13 @@ def _rmsnorm_fwd(
409
383
  convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
410
384
  ]
411
385
  # handle weight divisibility based on weight dtype
412
- weight_dtype = torch2cute_dtype_map[weight.dtype]
413
- weight_tensor = utils.convert_from_dlpack(
414
- weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
415
- )
386
+ if weight is not None:
387
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
388
+ weight_tensor = utils.convert_from_dlpack(
389
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
390
+ )
391
+ else:
392
+ weight_tensor = None
416
393
  if bias is not None:
417
394
  bias_dtype = torch2cute_dtype_map[bias.dtype]
418
395
  bias_tensor = utils.convert_from_dlpack(
@@ -430,7 +407,7 @@ def _rmsnorm_fwd(
430
407
  N,
431
408
  dtype,
432
409
  res_tensor.element_type if residual is not None else None,
433
- weight_tensor.element_type,
410
+ weight_tensor.element_type if weight is not None else None,
434
411
  bias_tensor.element_type if bias is not None else None,
435
412
  res_out_tensor.element_type if residual_out is not None else None,
436
413
  rstd is not None,
@@ -467,7 +444,7 @@ _rmsnorm_fwd.compile_cache = {}
467
444
 
468
445
  def rmsnorm_fwd(
469
446
  x: Tensor,
470
- weight: Tensor,
447
+ weight: Optional[Tensor] = None,
471
448
  bias: Optional[Tensor] = None,
472
449
  residual: Optional[Tensor] = None,
473
450
  out_dtype: Optional[torch.dtype] = None,
@@ -496,12 +473,13 @@ def rmsnorm_fwd(
496
473
  return out, residual_out, rstd
497
474
 
498
475
 
499
- def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
476
+ def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
500
477
  x_f32 = x.float()
501
478
  if residual is not None:
502
479
  residual_f32 = residual.float()
503
480
  x_f32 += residual_f32
504
- out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
481
+ x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
482
+ out = x_norm * w if w is not None else x_norm
505
483
  if bias is not None:
506
484
  out = out + bias.float()
507
485
  if residual is None:
@@ -509,6 +487,7 @@ def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
509
487
  else:
510
488
  return out.to(x.dtype), x_f32.to(residual.dtype)
511
489
 
490
+
512
491
  def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
513
492
  """Reference implementation for RMSNorm backward pass."""
514
493
  x_f32 = x.float()
@@ -521,6 +500,7 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
521
500
  dw = (dout * x_hat).sum(dim=0)
522
501
  return dx.to(x.dtype), dw.to(w.dtype)
523
502
 
503
+
524
504
  class RMSNormBackward(ReductionBase):
525
505
  def __init__(self, dtype: cutlass.Numeric, N: int):
526
506
  # 2 stages for double buffering when computing mean of x_hat * wdy
@@ -606,8 +586,11 @@ class RMSNormBackward(ReductionBase):
606
586
  )
607
587
  num_threads = cute.size(tv_layout, mode=[0])
608
588
  num_warps = num_threads // cute.arch.WARP_SIZE
609
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
610
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
589
+ if const_expr(mW is not None):
590
+ mW_expanded_layout = cute.prepend(
591
+ mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
592
+ )
593
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
611
594
 
612
595
  num_blocks = sm_count
613
596
  self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
@@ -660,50 +643,10 @@ class RMSNormBackward(ReductionBase):
660
643
  mbar_full_ptr, mbar_empty_ptr = None, None
661
644
 
662
645
  num_copy_elems_X = tv_layout.shape[1][0]
663
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
664
- copy_atom_load_X = cute.make_copy_atom(
665
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
666
- )
667
- copy_atom_load_X_async = cute.make_copy_atom(
668
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
669
- )
670
- num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
671
- copy_atom_load_dO_async = cute.make_copy_atom(
672
- cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
673
- )
674
- num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
675
- copy_atom_load_W = cute.make_copy_atom(
676
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
677
- )
678
- if const_expr(mdResO is not None):
679
- num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
680
- copy_atom_load_dResO = cute.make_copy_atom(
681
- cute.nvgpu.CopyUniversalOp(),
682
- mdResO.element_type,
683
- num_bits_per_copy=num_copy_bits_dResO,
684
- )
685
- num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
686
- copy_atom_store_dX = cute.make_copy_atom(
687
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
688
- )
689
- num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
690
- copy_atom_store_dW = cute.make_copy_atom(
691
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
692
- )
693
- if const_expr(mdB is not None):
694
- num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
695
- copy_atom_store_dB = cute.make_copy_atom(
696
- cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
697
- )
698
- if const_expr(mdRes is not None):
699
- num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
700
- copy_atom_load_dRes = cute.make_copy_atom(
701
- cute.nvgpu.CopyUniversalOp(),
702
- mdRes.element_type,
703
- num_bits_per_copy=num_copy_bits_dRes,
704
- )
705
-
646
+ copy_atom_load_X = utils.get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
706
647
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
648
+ # Each copy will use the same number of elements as X
649
+ copy = partial(utils.copy, num_copy_elems=num_copy_elems_X)
707
650
 
708
651
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
709
652
  tXgW = thr_copy_X.partition_S(gW)
@@ -718,7 +661,7 @@ class RMSNormBackward(ReductionBase):
718
661
  if not is_even_N
719
662
  else None
720
663
  )
721
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
664
+ copy(tXgW, tXrW, pred=tXpW)
722
665
  weight = tXrW.load().to(cute.Float32)
723
666
 
724
667
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -744,7 +687,11 @@ class RMSNormBackward(ReductionBase):
744
687
  # Always compute partial weight gradients in fp32
745
688
  tXrdW = cute.make_fragment_like(tXgdW, Float32)
746
689
 
747
- gdB = cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y)) if const_expr(mdB is not None) else None
690
+ gdB = (
691
+ cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
692
+ if const_expr(mdB is not None)
693
+ else None
694
+ )
748
695
  tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
749
696
  tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
750
697
 
@@ -772,21 +719,20 @@ class RMSNormBackward(ReductionBase):
772
719
  tXrX, tXrdO, tXrdX = [
773
720
  cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
774
721
  ]
722
+ tXrdResO = None
775
723
  if const_expr(mdResO is not None):
776
724
  tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
725
+ tXrdRes = None
777
726
  if const_expr(mdRes is not None):
778
727
  tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
779
728
 
780
- copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
781
- copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
782
-
783
729
  # Prefetch the first batch
784
730
  row = tXcX[None, None, None, bidx_start][0][0]
785
731
  if row < M:
786
732
  tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
787
733
  tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
788
- copy_X(tXgX_cur, tXsX[None, None, None, 0])
789
- copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
734
+ copy(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
735
+ copy(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
790
736
  elif tiler_mn[0] > 1:
791
737
  # Fill with zero, otherwise smem will be uninitialized, and we could read this back
792
738
  # later into registers, causing wrong dW.
@@ -809,8 +755,8 @@ class RMSNormBackward(ReductionBase):
809
755
  if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
810
756
  tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
811
757
  tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
812
- copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
813
- copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
758
+ copy(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
759
+ copy(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
814
760
  elif tiler_mn[0] > 1:
815
761
  utils.fill_oob(
816
762
  tXsX[None, None, None, stage ^ 1],
@@ -829,7 +775,7 @@ class RMSNormBackward(ReductionBase):
829
775
  if const_expr(mdResO is not None):
830
776
  tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
831
777
  if row < M or tiler_mn[0] == 1:
832
- cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
778
+ copy(tXgdResO_cur, tXrdResO, pred=tXpX)
833
779
  elif tiler_mn[0] > 1:
834
780
  tXrdResO.fill(0.0)
835
781
  cute.arch.cp_async_wait_group(1)
@@ -877,12 +823,12 @@ class RMSNormBackward(ReductionBase):
877
823
  tXrdX.store(dx.to(tXrdX.element_type))
878
824
  if row < M or tiler_mn[0] == 1:
879
825
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
880
- cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
826
+ copy(tXrdX, tXgdX_cur, pred=tXpX)
881
827
  if const_expr(mdRes is not None):
882
828
  tXrdRes.store(dx.to(tXrdRes.element_type))
883
829
  tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
884
830
  if row < M or tiler_mn[0] == 1:
885
- cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
831
+ copy(tXrdRes, tXgdRes_cur, pred=tXpX)
886
832
  # Accumulate weight gradients in fp32
887
833
  tXrdW.store(tXrdW.load() + dout * x_hat)
888
834
  if const_expr(mdB is not None):
@@ -914,7 +860,7 @@ class RMSNormBackward(ReductionBase):
914
860
  tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
915
861
  cute.autovec_copy(tXsdW_other, tXrdW_other)
916
862
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
917
- cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
863
+ copy(tXrdW, tXgdW, pred=tXpdW)
918
864
  cute.arch.barrier()
919
865
  if const_expr(mdB is not None):
920
866
  sdB = cute.make_tensor(
@@ -930,15 +876,17 @@ class RMSNormBackward(ReductionBase):
930
876
  if row == 0:
931
877
  for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
932
878
  tXrdB_other = cute.make_fragment_like(tXrdB)
933
- tXsdB_other = cute.make_tensor(tXsdB.iterator + i * sdB.stride[0], tXsdB.layout)
879
+ tXsdB_other = cute.make_tensor(
880
+ tXsdB.iterator + i * sdB.stride[0], tXsdB.layout
881
+ )
934
882
  cute.autovec_copy(tXsdB_other, tXrdB_other)
935
883
  tXrdB.store(tXrdB.load() + tXrdB_other.load())
936
- cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
884
+ copy(tXrdB, tXgdB, pred=tXpdB)
937
885
  else:
938
886
  # dw is already in fp32, so we can directly copy to global memory
939
- cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
887
+ copy(tXrdW, tXgdW, pred=tXpdW)
940
888
  if const_expr(mdB is not None):
941
- cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
889
+ copy(tXrdB, tXgdB, pred=tXpdB)
942
890
 
943
891
 
944
892
  def _get_sm_count(N: int, device: torch.device) -> int:
@@ -963,7 +911,7 @@ def _get_sm_count(N: int, device: torch.device) -> int:
963
911
  mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
964
912
  device_types="cuda",
965
913
  # We need to specify the schema manually since we're mutating an optional tensor
966
- schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a!) dx, Tensor(a!) dw_partial, Tensor(a!)? db_partial, Tensor? dresidual_out, Tensor(a!)? dresidual) -> ()",
914
+ schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
967
915
  )
968
916
  def _rmsnorm_bwd(
969
917
  x: Tensor,
@@ -1031,14 +979,23 @@ def _rmsnorm_bwd(
1031
979
  )
1032
980
 
1033
981
  dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1034
- db_partial_tensor = from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) if db_partial is not None else None
982
+ db_partial_tensor = (
983
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
984
+ if db_partial is not None
985
+ else None
986
+ )
1035
987
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
1036
988
 
1037
989
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1038
990
 
1039
- compile_key = (N, x_tensor.element_type, weight_tensor.element_type, db_partial.dtype if db_partial is not None else None,
991
+ compile_key = (
992
+ N,
993
+ x_tensor.element_type,
994
+ weight_tensor.element_type,
995
+ db_partial.dtype if db_partial is not None else None,
1040
996
  dresidual.dtype if dresidual is not None else None,
1041
- dresidual_out.dtype if dresidual_out is not None else None)
997
+ dresidual_out.dtype if dresidual_out is not None else None,
998
+ )
1042
999
  if compile_key not in _rmsnorm_bwd.compile_cache:
1043
1000
  rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
1044
1001
  _rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
@@ -1106,7 +1063,17 @@ def rmsnorm_bwd(
1106
1063
 
1107
1064
  class RMSNormFunction(torch.autograd.Function):
1108
1065
  @staticmethod
1109
- def forward(ctx, x, weight, bias=None, residual=None, out_dtype=None, residual_dtype=None, eps=1e-6, prenorm=False):
1066
+ def forward(
1067
+ ctx,
1068
+ x,
1069
+ weight,
1070
+ bias=None,
1071
+ residual=None,
1072
+ out_dtype=None,
1073
+ residual_dtype=None,
1074
+ eps=1e-6,
1075
+ prenorm=False,
1076
+ ):
1110
1077
  x_shape_og = x.shape
1111
1078
  # Flatten input
1112
1079
  x = x.reshape(-1, x.shape[-1])
@@ -1129,7 +1096,7 @@ class RMSNormFunction(torch.autograd.Function):
1129
1096
  ctx.x_shape_og = x_shape_og
1130
1097
  ctx.residual_dtype = residual.dtype if residual is not None else None
1131
1098
  ctx.prenorm = prenorm
1132
- if residual_out is None or prenorm == False:
1099
+ if residual_out is None or not prenorm:
1133
1100
  return out.reshape(x_shape_og)
1134
1101
  else:
1135
1102
  return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
@@ -1137,6 +1104,7 @@ class RMSNormFunction(torch.autograd.Function):
1137
1104
  @staticmethod
1138
1105
  def backward(ctx, dout, *args):
1139
1106
  x, weight, rstd = ctx.saved_tensors
1107
+ assert weight is not None, "RMSNorm backward doesn't support weight=None yet"
1140
1108
  has_bias = ctx.has_bias
1141
1109
  if ctx.prenorm and ctx.residual_dtype is not None:
1142
1110
  dresidual_out = args[0]
@@ -1159,7 +1127,7 @@ class RMSNormFunction(torch.autograd.Function):
1159
1127
 
1160
1128
  def rmsnorm(
1161
1129
  x: Tensor,
1162
- weight: Tensor,
1130
+ weight: Optional[Tensor] = None,
1163
1131
  bias: Optional[Tensor] = None,
1164
1132
  residual: Optional[Tensor] = None,
1165
1133
  out_dtype: Optional[torch.dtype] = None,
@@ -1171,7 +1139,7 @@ def rmsnorm(
1171
1139
 
1172
1140
  Args:
1173
1141
  x: Input tensor of shape (M, N)
1174
- weight: Weight tensor of shape (N,)
1142
+ weight: Optional weight tensor of shape (N,)
1175
1143
  eps: Small value for numerical stability
1176
1144
 
1177
1145
  Returns:
@@ -1213,4 +1181,4 @@ class QuackRMSNorm(torch.nn.Module):
1213
1181
 
1214
1182
  def reset_parameters(self):
1215
1183
  """Reset the weight parameter to ones."""
1216
- torch.nn.init.ones_(self.weight)
1184
+ torch.nn.init.ones_(self.weight)
quack/softmax.py CHANGED
@@ -159,7 +159,7 @@ class Softmax(ReductionBase):
159
159
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
160
  )
161
161
  log2_e = math.log2(math.e)
162
- exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
162
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
163
163
  denom = row_reduce(
164
164
  exp_x,
165
165
  cute.ReductionOp.ADD,