quack-kernels 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/rmsnorm.py CHANGED
@@ -94,7 +94,7 @@ class RMSNorm(ReductionBase):
94
94
  def __call__(
95
95
  self,
96
96
  mX: cute.Tensor,
97
- mW: cute.Tensor,
97
+ mW: Optional[cute.Tensor],
98
98
  mB: Optional[cute.Tensor],
99
99
  mRes: Optional[cute.Tensor],
100
100
  mO: cute.Tensor,
@@ -130,8 +130,11 @@ class RMSNorm(ReductionBase):
130
130
  )
131
131
  num_threads = cute.size(tv_layout, mode=[0])
132
132
  num_warps = num_threads // cute.arch.WARP_SIZE
133
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
134
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
133
+ if const_expr(mW is not None):
134
+ mW_expanded_layout = cute.prepend(
135
+ mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
136
+ )
137
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
135
138
  if const_expr(mB is not None):
136
139
  mB_expanded_layout = cute.prepend(
137
140
  mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
@@ -158,7 +161,7 @@ class RMSNorm(ReductionBase):
158
161
  def kernel(
159
162
  self,
160
163
  mX: cute.Tensor,
161
- mW: cute.Tensor,
164
+ mW: Optional[cute.Tensor],
162
165
  mB: Optional[cute.Tensor],
163
166
  mRes: Optional[cute.Tensor],
164
167
  mO: cute.Tensor,
@@ -204,8 +207,10 @@ class RMSNorm(ReductionBase):
204
207
  for mT in (mX, mRes, mO, mResO)
205
208
  ]
206
209
  cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
207
- gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
208
- gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
210
+ gW, gB = [
211
+ cute.local_tile(mT, tiler_mn, (0, cluster_y)) if const_expr(mT is not None) else None
212
+ for mT in (mW, mB)
213
+ ]
209
214
  gRstd = (
210
215
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
211
216
  if const_expr(mRstd is not None)
@@ -214,53 +219,14 @@ class RMSNorm(ReductionBase):
214
219
 
215
220
  # declare the atoms which will be used later for memory copy
216
221
  num_copy_elems_X = tv_layout.shape[1][0]
217
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
218
- copy_atom_load_X = cute.make_copy_atom(
219
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
220
- )
221
- copy_atom_load_X_async = cute.make_copy_atom(
222
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
223
- )
224
- num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
225
- copy_atom_load_W = cute.make_copy_atom(
226
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
227
- )
228
- num_bits_per_copy_B = (
229
- cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
230
- if const_expr(mB is not None)
231
- else 0
232
- )
233
- copy_atom_load_B = (
234
- cute.make_copy_atom(
235
- cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
236
- )
237
- if const_expr(mB is not None)
238
- else None
239
- )
240
- if const_expr(mRes is not None):
241
- num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
242
- copy_atom_load_Res_async = cute.make_copy_atom(
243
- cute.nvgpu.cpasync.CopyG2SOp(),
244
- mRes.element_type,
245
- num_bits_per_copy=num_copy_bits_Res,
246
- )
247
- num_copy_bits_O = const_expr(min(128, num_copy_elems_X * mO.element_type.width))
248
- copy_atom_store_O = cute.make_copy_atom(
249
- cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=num_copy_bits_O
222
+ copy_atom_load_X_async = utils.get_copy_atom(
223
+ mX.element_type, num_copy_elems_X, is_async=True
250
224
  )
251
- if const_expr(mResO is not None):
252
- num_copy_bits_ResO = const_expr(min(128, num_copy_elems_X * mResO.element_type.width))
253
- copy_atom_store_ResO = cute.make_copy_atom(
254
- cute.nvgpu.CopyUniversalOp(),
255
- mResO.element_type,
256
- num_bits_per_copy=num_copy_bits_ResO,
257
- )
258
-
259
225
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
260
226
  tidx
261
227
  )
262
228
 
263
- tXgW = thr_copy_X.partition_S(gW)
229
+ tXgW = thr_copy_X.partition_S(gW) if const_expr(mW is not None) else None
264
230
  tXgB = thr_copy_X.partition_S(gB) if const_expr(mB is not None) else None
265
231
  tXgX = thr_copy_X.partition_S(gX)
266
232
  tXsX = thr_copy_X.partition_D(sX)
@@ -274,8 +240,9 @@ class RMSNorm(ReductionBase):
274
240
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
275
241
 
276
242
  # allocate fragments for gmem->rmem
277
- tXrW = cute.make_fragment_like(tXgW)
278
- tXrW.fill(0.0)
243
+ tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
244
+ if const_expr(mW is not None):
245
+ tXrW.fill(0.0)
279
246
  tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
280
247
  tXrX, tXrO = [cute.make_fragment_like(t) for t in (tXgX, tXgO)]
281
248
  if const_expr(mRes is not None):
@@ -288,17 +255,21 @@ class RMSNorm(ReductionBase):
288
255
  tXpX = (
289
256
  utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
290
257
  )
258
+ # Each copy will use the same number of elements as X and same predicate
259
+ copy = partial(utils.copy, pred=tXpX, num_copy_elems=num_copy_elems_X)
260
+
291
261
  row = tXcX[0][0]
292
262
  if row < shape[0]:
293
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
263
+ copy(tXgX, tXsX, is_async=True)
294
264
  if const_expr(mRes is not None):
295
- cute.copy(copy_atom_load_Res_async, tXgRes, tXsRes, pred=tXpX)
265
+ copy(tXgRes, tXsRes, is_async=True)
296
266
  cute.arch.cp_async_commit_group()
297
267
 
298
268
  if const_expr(not delay_w_load):
299
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
269
+ if const_expr(mW is not None):
270
+ copy(tXgW, tXrW)
300
271
  if const_expr(mB is not None):
301
- cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
272
+ copy(tXgB, tXrB)
302
273
 
303
274
  cute.arch.cp_async_wait_group(0)
304
275
  cute.autovec_copy(tXsX, tXrX)
@@ -310,7 +281,7 @@ class RMSNorm(ReductionBase):
310
281
  tXrResO = cute.make_fragment_like(tXgResO)
311
282
  tXrResO.store(x.to(tXrResO.element_type))
312
283
  if row < shape[0]:
313
- cute.copy(copy_atom_store_ResO, tXrResO, tXgResO, pred=tXpX)
284
+ copy(tXrResO, tXgResO)
314
285
 
315
286
  threads_per_row = tv_layout.shape[0][0]
316
287
  sum_sq_x = row_reduce(
@@ -332,27 +303,28 @@ class RMSNorm(ReductionBase):
332
303
  ):
333
304
  tXrRstd[0] = rstd
334
305
  if const_expr(delay_w_load):
335
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpX)
306
+ if const_expr(mW is not None):
307
+ copy(tXgW, tXrW)
336
308
  if const_expr(mB is not None):
337
- cute.copy(copy_atom_load_B, tXgB, tXrB, pred=tXpX)
309
+ copy(tXgB, tXrB)
338
310
  if const_expr(reload_from == "smem" or reload_from == "gmem"):
339
311
  if const_expr(reload_from == "smem"):
340
312
  cute.autovec_copy(tXsX, tXrX)
341
313
  else:
342
- cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
314
+ copy(tXgX, tXrX)
343
315
  x = tXrX.load().to(cute.Float32)
344
316
  if const_expr(mRes is not None):
345
317
  cute.autovec_copy(tXsRes, tXrRes)
346
318
  x += tXrRes.load().to(cute.Float32)
347
319
  x_hat = x * rstd
348
- w = tXrW.load().to(cute.Float32)
349
- y = x_hat * w
320
+ y = x_hat
321
+ if const_expr(mW is not None):
322
+ y *= tXrW.load().to(cute.Float32)
350
323
  if const_expr(mB is not None):
351
- b = tXrB.load().to(cute.Float32)
352
- y = y + b
324
+ y += tXrB.load().to(cute.Float32)
353
325
  tXrO.store(y.to(tXrO.element_type))
354
326
  if row < shape[0]:
355
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX)
327
+ copy(tXrO, tXgO)
356
328
 
357
329
 
358
330
  @torch.library.custom_op(
@@ -360,11 +332,11 @@ class RMSNorm(ReductionBase):
360
332
  mutates_args=("out", "rstd", "residual_out"),
361
333
  device_types="cuda",
362
334
  # We need to specify the schema manually since we're mutating an optional tensor
363
- schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
335
+ schema="(Tensor x, Tensor? weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
364
336
  )
365
337
  def _rmsnorm_fwd(
366
338
  x: Tensor,
367
- weight: Tensor,
339
+ weight: Optional[Tensor],
368
340
  out: Tensor,
369
341
  bias: Optional[Tensor] = None,
370
342
  rstd: Optional[Tensor] = None,
@@ -375,21 +347,23 @@ def _rmsnorm_fwd(
375
347
  """RMSNorm forward pass.
376
348
  Args:
377
349
  x: Input tensor of shape (M, N)
378
- weight: Weight tensor of shape (N,)
350
+ weight: Optional weight tensor of shape (N,)
379
351
  eps: Small value for numerical stability
380
352
  Returns:
381
353
  Normalized output tensor of same shape as x
382
354
  """
383
355
  assert x.dim() == 2, "Input must be 2D"
384
- assert weight.dim() == 1, "Weight must be 1D"
385
- assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
386
- assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
356
+ assert x.is_cuda, "Input tensor must be on CUDA device"
387
357
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
388
- assert weight.dtype in [
389
- torch.float32,
390
- torch.bfloat16,
391
- torch.float16,
392
- ], "Weight must be float32, float16 or bfloat16"
358
+ if weight is not None:
359
+ assert weight.dim() == 1, "Weight must be 1D"
360
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
361
+ assert weight.is_cuda, "Weight tensor must be on CUDA device"
362
+ assert weight.dtype in [
363
+ torch.float32,
364
+ torch.bfloat16,
365
+ torch.float16,
366
+ ], "Weight must be float32, float16 or bfloat16"
393
367
  if residual is not None:
394
368
  assert residual.shape == x.shape
395
369
  assert residual.is_cuda
@@ -402,11 +376,6 @@ def _rmsnorm_fwd(
402
376
  _, N = x.shape
403
377
  device = x.device
404
378
  dtype = torch2cute_dtype_map[x.dtype]
405
- # convert_from_dlpack = lambda x: (
406
- # from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
407
- # mode=0, divisibility=128 // dtype.width
408
- # )
409
- # )
410
379
  convert_from_dlpack = lambda x: (
411
380
  from_dlpack(x.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=1)
412
381
  )
@@ -414,10 +383,13 @@ def _rmsnorm_fwd(
414
383
  convert_from_dlpack(t) if t is not None else None for t in (x, residual, out, residual_out)
415
384
  ]
416
385
  # handle weight divisibility based on weight dtype
417
- weight_dtype = torch2cute_dtype_map[weight.dtype]
418
- weight_tensor = utils.convert_from_dlpack(
419
- weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
420
- )
386
+ if weight is not None:
387
+ weight_dtype = torch2cute_dtype_map[weight.dtype]
388
+ weight_tensor = utils.convert_from_dlpack(
389
+ weight.detach(), leading_dim=0, divisibility=128 // weight_dtype.width
390
+ )
391
+ else:
392
+ weight_tensor = None
421
393
  if bias is not None:
422
394
  bias_dtype = torch2cute_dtype_map[bias.dtype]
423
395
  bias_tensor = utils.convert_from_dlpack(
@@ -435,7 +407,7 @@ def _rmsnorm_fwd(
435
407
  N,
436
408
  dtype,
437
409
  res_tensor.element_type if residual is not None else None,
438
- weight_tensor.element_type,
410
+ weight_tensor.element_type if weight is not None else None,
439
411
  bias_tensor.element_type if bias is not None else None,
440
412
  res_out_tensor.element_type if residual_out is not None else None,
441
413
  rstd is not None,
@@ -472,7 +444,7 @@ _rmsnorm_fwd.compile_cache = {}
472
444
 
473
445
  def rmsnorm_fwd(
474
446
  x: Tensor,
475
- weight: Tensor,
447
+ weight: Optional[Tensor] = None,
476
448
  bias: Optional[Tensor] = None,
477
449
  residual: Optional[Tensor] = None,
478
450
  out_dtype: Optional[torch.dtype] = None,
@@ -501,12 +473,13 @@ def rmsnorm_fwd(
501
473
  return out, residual_out, rstd
502
474
 
503
475
 
504
- def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
476
+ def rmsnorm_ref(x, w=None, bias=None, residual=None, eps=1e-6):
505
477
  x_f32 = x.float()
506
478
  if residual is not None:
507
479
  residual_f32 = residual.float()
508
480
  x_f32 += residual_f32
509
- out = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps)) * w
481
+ x_norm = x_f32 / (torch.sqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + eps))
482
+ out = x_norm * w if w is not None else x_norm
510
483
  if bias is not None:
511
484
  out = out + bias.float()
512
485
  if residual is None:
@@ -613,8 +586,11 @@ class RMSNormBackward(ReductionBase):
613
586
  )
614
587
  num_threads = cute.size(tv_layout, mode=[0])
615
588
  num_warps = num_threads // cute.arch.WARP_SIZE
616
- mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
617
- mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
589
+ if const_expr(mW is not None):
590
+ mW_expanded_layout = cute.prepend(
591
+ mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
592
+ )
593
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
618
594
 
619
595
  num_blocks = sm_count
620
596
  self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn).launch(
@@ -667,50 +643,10 @@ class RMSNormBackward(ReductionBase):
667
643
  mbar_full_ptr, mbar_empty_ptr = None, None
668
644
 
669
645
  num_copy_elems_X = tv_layout.shape[1][0]
670
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
671
- copy_atom_load_X = cute.make_copy_atom(
672
- cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
673
- )
674
- copy_atom_load_X_async = cute.make_copy_atom(
675
- cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=num_copy_bits_X
676
- )
677
- num_copy_bits_dO = const_expr(min(128, num_copy_elems_X * mdO.element_type.width))
678
- copy_atom_load_dO_async = cute.make_copy_atom(
679
- cute.nvgpu.cpasync.CopyG2SOp(), mdO.element_type, num_bits_per_copy=num_copy_bits_dO
680
- )
681
- num_copy_bits_W = const_expr(min(128, num_copy_elems_X * mW.element_type.width))
682
- copy_atom_load_W = cute.make_copy_atom(
683
- cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
684
- )
685
- if const_expr(mdResO is not None):
686
- num_copy_bits_dResO = const_expr(min(128, num_copy_elems_X * mdResO.element_type.width))
687
- copy_atom_load_dResO = cute.make_copy_atom(
688
- cute.nvgpu.CopyUniversalOp(),
689
- mdResO.element_type,
690
- num_bits_per_copy=num_copy_bits_dResO,
691
- )
692
- num_copy_bits_dX = const_expr(min(128, num_copy_elems_X * mdX.element_type.width))
693
- copy_atom_store_dX = cute.make_copy_atom(
694
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_dX
695
- )
696
- num_copy_bits_dW = const_expr(min(128, num_copy_elems_X * mdW.element_type.width))
697
- copy_atom_store_dW = cute.make_copy_atom(
698
- cute.nvgpu.CopyUniversalOp(), mdW.element_type, num_bits_per_copy=num_copy_bits_dW
699
- )
700
- if const_expr(mdB is not None):
701
- num_copy_bits_dB = const_expr(min(128, num_copy_elems_X * mdB.element_type.width))
702
- copy_atom_store_dB = cute.make_copy_atom(
703
- cute.nvgpu.CopyUniversalOp(), mdB.element_type, num_bits_per_copy=num_copy_bits_dB
704
- )
705
- if const_expr(mdRes is not None):
706
- num_copy_bits_dRes = const_expr(min(128, num_copy_elems_X * mdRes.element_type.width))
707
- copy_atom_load_dRes = cute.make_copy_atom(
708
- cute.nvgpu.CopyUniversalOp(),
709
- mdRes.element_type,
710
- num_bits_per_copy=num_copy_bits_dRes,
711
- )
712
-
646
+ copy_atom_load_X = utils.get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
713
647
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
648
+ # Each copy will use the same number of elements as X
649
+ copy = partial(utils.copy, num_copy_elems=num_copy_elems_X)
714
650
 
715
651
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
716
652
  tXgW = thr_copy_X.partition_S(gW)
@@ -725,7 +661,7 @@ class RMSNormBackward(ReductionBase):
725
661
  if not is_even_N
726
662
  else None
727
663
  )
728
- cute.copy(copy_atom_load_W, tXgW, tXrW, pred=tXpW)
664
+ copy(tXgW, tXrW, pred=tXpW)
729
665
  weight = tXrW.load().to(cute.Float32)
730
666
 
731
667
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -790,16 +726,13 @@ class RMSNormBackward(ReductionBase):
790
726
  if const_expr(mdRes is not None):
791
727
  tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
792
728
 
793
- copy_X = partial(cute.copy, copy_atom_load_X_async, pred=tXpX)
794
- copy_dO = partial(cute.copy, copy_atom_load_dO_async, pred=tXpX)
795
-
796
729
  # Prefetch the first batch
797
730
  row = tXcX[None, None, None, bidx_start][0][0]
798
731
  if row < M:
799
732
  tXgX_cur = utils.coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
800
733
  tXgdO_cur = utils.coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
801
- copy_X(tXgX_cur, tXsX[None, None, None, 0])
802
- copy_dO(tXgdO_cur, tXsdO[None, None, None, 0])
734
+ copy(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
735
+ copy(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
803
736
  elif tiler_mn[0] > 1:
804
737
  # Fill with zero, otherwise smem will be uninitialized, and we could read this back
805
738
  # later into registers, causing wrong dW.
@@ -822,8 +755,8 @@ class RMSNormBackward(ReductionBase):
822
755
  if row + gdim * tiler_mn[0] < M: # Prefetch the next batch
823
756
  tXgX_cur = utils.coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
824
757
  tXgdO_cur = utils.coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
825
- copy_X(tXgX_cur, tXsX[None, None, None, stage ^ 1])
826
- copy_dO(tXgdO_cur, tXsdO[None, None, None, stage ^ 1])
758
+ copy(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
759
+ copy(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
827
760
  elif tiler_mn[0] > 1:
828
761
  utils.fill_oob(
829
762
  tXsX[None, None, None, stage ^ 1],
@@ -842,7 +775,7 @@ class RMSNormBackward(ReductionBase):
842
775
  if const_expr(mdResO is not None):
843
776
  tXgdResO_cur = utils.coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
844
777
  if row < M or tiler_mn[0] == 1:
845
- cute.copy(copy_atom_load_dResO, tXgdResO_cur, tXrdResO, pred=tXpX)
778
+ copy(tXgdResO_cur, tXrdResO, pred=tXpX)
846
779
  elif tiler_mn[0] > 1:
847
780
  tXrdResO.fill(0.0)
848
781
  cute.arch.cp_async_wait_group(1)
@@ -890,12 +823,12 @@ class RMSNormBackward(ReductionBase):
890
823
  tXrdX.store(dx.to(tXrdX.element_type))
891
824
  if row < M or tiler_mn[0] == 1:
892
825
  tXgdX_cur = utils.coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
893
- cute.copy(copy_atom_store_dX, tXrdX, tXgdX_cur, pred=tXpX)
826
+ copy(tXrdX, tXgdX_cur, pred=tXpX)
894
827
  if const_expr(mdRes is not None):
895
828
  tXrdRes.store(dx.to(tXrdRes.element_type))
896
829
  tXgdRes_cur = utils.coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
897
830
  if row < M or tiler_mn[0] == 1:
898
- cute.copy(copy_atom_load_dRes, tXrdRes, tXgdRes_cur, pred=tXpX)
831
+ copy(tXrdRes, tXgdRes_cur, pred=tXpX)
899
832
  # Accumulate weight gradients in fp32
900
833
  tXrdW.store(tXrdW.load() + dout * x_hat)
901
834
  if const_expr(mdB is not None):
@@ -927,7 +860,7 @@ class RMSNormBackward(ReductionBase):
927
860
  tXsdW_other = cute.make_tensor(tXsdW.iterator + i * sdW.stride[0], tXsdW.layout)
928
861
  cute.autovec_copy(tXsdW_other, tXrdW_other)
929
862
  tXrdW.store(tXrdW.load() + tXrdW_other.load())
930
- cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
863
+ copy(tXrdW, tXgdW, pred=tXpdW)
931
864
  cute.arch.barrier()
932
865
  if const_expr(mdB is not None):
933
866
  sdB = cute.make_tensor(
@@ -948,12 +881,12 @@ class RMSNormBackward(ReductionBase):
948
881
  )
949
882
  cute.autovec_copy(tXsdB_other, tXrdB_other)
950
883
  tXrdB.store(tXrdB.load() + tXrdB_other.load())
951
- cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
884
+ copy(tXrdB, tXgdB, pred=tXpdB)
952
885
  else:
953
886
  # dw is already in fp32, so we can directly copy to global memory
954
- cute.copy(copy_atom_store_dW, tXrdW, tXgdW, pred=tXpdW)
887
+ copy(tXrdW, tXgdW, pred=tXpdW)
955
888
  if const_expr(mdB is not None):
956
- cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
889
+ copy(tXrdB, tXgdB, pred=tXpdB)
957
890
 
958
891
 
959
892
  def _get_sm_count(N: int, device: torch.device) -> int:
@@ -1171,6 +1104,7 @@ class RMSNormFunction(torch.autograd.Function):
1171
1104
  @staticmethod
1172
1105
  def backward(ctx, dout, *args):
1173
1106
  x, weight, rstd = ctx.saved_tensors
1107
+ assert weight is not None, "RMSNorm backward doesn't support weight=None yet"
1174
1108
  has_bias = ctx.has_bias
1175
1109
  if ctx.prenorm and ctx.residual_dtype is not None:
1176
1110
  dresidual_out = args[0]
@@ -1193,7 +1127,7 @@ class RMSNormFunction(torch.autograd.Function):
1193
1127
 
1194
1128
  def rmsnorm(
1195
1129
  x: Tensor,
1196
- weight: Tensor,
1130
+ weight: Optional[Tensor] = None,
1197
1131
  bias: Optional[Tensor] = None,
1198
1132
  residual: Optional[Tensor] = None,
1199
1133
  out_dtype: Optional[torch.dtype] = None,
@@ -1205,7 +1139,7 @@ def rmsnorm(
1205
1139
 
1206
1140
  Args:
1207
1141
  x: Input tensor of shape (M, N)
1208
- weight: Weight tensor of shape (N,)
1142
+ weight: Optional weight tensor of shape (N,)
1209
1143
  eps: Small value for numerical stability
1210
1144
 
1211
1145
  Returns:
quack/tile_scheduler.py CHANGED
@@ -135,7 +135,7 @@ class TileScheduler:
135
135
  ip=None,
136
136
  ):
137
137
  self._current_work_linear_idx = current_work_linear_idx
138
- self._num_tiles_executed = num_tiles_executed
138
+ self.num_tiles_executed = num_tiles_executed
139
139
  self._tile_count = tile_count
140
140
  self._scheduler_pipeline = scheduler_pipeline
141
141
  self._pipeline_state = pipeline_state
@@ -251,7 +251,7 @@ class TileScheduler:
251
251
  )
252
252
  tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
253
253
  if const_expr(not params.is_persistent):
254
- is_valid = self._num_tiles_executed == 0
254
+ is_valid = self.num_tiles_executed == 0
255
255
  else:
256
256
  is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl)
257
257
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
@@ -276,38 +276,6 @@ class TileScheduler:
276
276
  current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0)
277
277
  self._current_work_linear_idx = current_work_linear_idx
278
278
 
279
- # We have to split broadcast_next_work and advance_to_next_work into two functions
280
- # due to a bug in cute-dsl 4.2: https://github.com/NVIDIA/cutlass/issues/2647
281
- @cute.jit
282
- def broadcast_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None):
283
- """is_scheduler_warp should only be true for one warp in the whole cluster"""
284
- params = self.params
285
- if const_expr(params.is_persistent and params.tile_count_semaphore is not None):
286
- current_work_linear_idx = self._current_work_linear_idx
287
- if is_scheduler_warp:
288
- self._scheduler_pipeline.producer_acquire(self._pipeline_state)
289
- lane_idx = cute.arch.lane_idx()
290
- if lane_idx < cute.size(params.cluster_shape_mn):
291
- # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
292
- if const_expr(cute.size(params.cluster_shape_mn) == 1):
293
- self._tile_count[self._pipeline_state.index] = current_work_linear_idx
294
- self._scheduler_pipeline.producer_commit(self._pipeline_state)
295
- else:
296
- peer_cta_rank_in_cluster = lane_idx
297
- mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
298
- self._pipeline_state
299
- )
300
- cute.arch.mbarrier_arrive_and_expect_tx(
301
- mbar_ptr, 4, peer_cta_rank_in_cluster
302
- )
303
- utils.store_shared_remote(
304
- val=current_work_linear_idx,
305
- smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
306
- mbar_ptr=mbar_ptr,
307
- peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
308
- )
309
- # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
310
-
311
279
  @cute.jit
312
280
  def advance_to_next_work(
313
281
  self,
@@ -328,7 +296,30 @@ class TileScheduler:
328
296
  if const_expr(advance_count > 1):
329
297
  self._pipeline_state.advance_iters(advance_count - 1)
330
298
  current_work_linear_idx = self._current_work_linear_idx
331
- if not is_scheduler_warp:
299
+ if is_scheduler_warp:
300
+ self._scheduler_pipeline.producer_acquire(self._pipeline_state)
301
+ lane_idx = cute.arch.lane_idx()
302
+ if lane_idx < cute.size(params.cluster_shape_mn):
303
+ # cute.printf("Producer bidx = {}, tidx = {}, after empty wait, idx = {}", bidx, tidx, current_work_linear_idx)
304
+ if const_expr(cute.size(params.cluster_shape_mn) == 1):
305
+ self._tile_count[self._pipeline_state.index] = current_work_linear_idx
306
+ self._scheduler_pipeline.producer_commit(self._pipeline_state)
307
+ else:
308
+ peer_cta_rank_in_cluster = lane_idx
309
+ mbar_ptr = self._scheduler_pipeline.producer_get_barrier(
310
+ self._pipeline_state
311
+ )
312
+ cute.arch.mbarrier_arrive_and_expect_tx(
313
+ mbar_ptr, 4, peer_cta_rank_in_cluster
314
+ )
315
+ utils.store_shared_remote(
316
+ val=current_work_linear_idx,
317
+ smem_ptr=self._tile_count.iterator + self._pipeline_state.index,
318
+ mbar_ptr=mbar_ptr,
319
+ peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,
320
+ )
321
+ # cute.printf("Producer bidx = {}, tidx = {}, after full arrive", bidx, tidx)
322
+ else:
332
323
  # if tidx % 64 == 0: cute.printf("bidx = {},tidx = {}, before full wait, idx = {}", bidx, tidx, current_work_linear_idx)
333
324
  self._scheduler_pipeline.consumer_wait(self._pipeline_state)
334
325
  # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after full wait, idx = {}", bidx, tidx, current_work_linear_idx)
@@ -341,21 +332,17 @@ class TileScheduler:
341
332
  # if tidx % 64 == 0: cute.printf("bidx = {}, tidx = {}, after empty arrive", bidx, tidx)
342
333
  self._current_work_linear_idx = current_work_linear_idx
343
334
  self._pipeline_state.advance()
344
- self._num_tiles_executed += Int32(advance_count)
335
+ self.num_tiles_executed += Int32(advance_count)
345
336
 
346
337
  def producer_tail(self):
347
338
  if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None):
348
339
  self._scheduler_pipeline.producer_tail(self._pipeline_state)
349
340
 
350
- @property
351
- def num_tiles_executed(self) -> Int32:
352
- return self._num_tiles_executed
353
-
354
341
  def __extract_mlir_values__(self):
355
342
  values, self._values_pos = [], []
356
343
  for obj in [
357
344
  self._current_work_linear_idx,
358
- self._num_tiles_executed,
345
+ self.num_tiles_executed,
359
346
  self._tile_count,
360
347
  self._scheduler_pipeline,
361
348
  self._pipeline_state,
@@ -371,7 +358,7 @@ class TileScheduler:
371
358
  for obj, n_items in zip(
372
359
  [
373
360
  self._current_work_linear_idx,
374
- self._num_tiles_executed,
361
+ self.num_tiles_executed,
375
362
  self._tile_count,
376
363
  self._scheduler_pipeline,
377
364
  self._pipeline_state,
@@ -562,7 +549,7 @@ class TriangularTileScheduler(TileScheduler):
562
549
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
563
550
  tile_coord_mnkl = (pid_m, pid_n, None, bidz)
564
551
  if const_expr(not params.is_persistent):
565
- is_valid = self._num_tiles_executed == 0
552
+ is_valid = self.num_tiles_executed == 0
566
553
  else:
567
554
  is_valid = (
568
555
  self._current_work_linear_idx
@@ -681,7 +668,7 @@ class VarlenMTileScheduler(TileScheduler):
681
668
  ip=None,
682
669
  ):
683
670
  self._current_work_linear_idx = current_work_linear_idx
684
- self._num_tiles_executed = num_tiles_executed
671
+ self.num_tiles_executed = num_tiles_executed
685
672
  self._current_batch_idx = current_batch_idx
686
673
  self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch
687
674
  self._tile_count = tile_count
@@ -878,7 +865,7 @@ class VarlenMTileScheduler(TileScheduler):
878
865
  pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1]
879
866
  tile_coord_mnkl = (pid_m, pid_n, None, batch_idx)
880
867
  if const_expr(not params.is_persistent):
881
- is_valid = self._num_tiles_executed == 0 and batch_idx < num_batch
868
+ is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch
882
869
  else:
883
870
  is_valid = batch_idx < num_batch
884
871
  return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid)
@@ -905,7 +892,7 @@ class VarlenMTileScheduler(TileScheduler):
905
892
  values, self._values_pos = [], []
906
893
  for obj in [
907
894
  self._current_work_linear_idx,
908
- self._num_tiles_executed,
895
+ self.num_tiles_executed,
909
896
  self._current_batch_idx,
910
897
  self._num_work_idx_before_cur_batch,
911
898
  self._tile_count,
@@ -923,7 +910,7 @@ class VarlenMTileScheduler(TileScheduler):
923
910
  for obj, n_items in zip(
924
911
  [
925
912
  self._current_work_linear_idx,
926
- self._num_tiles_executed,
913
+ self.num_tiles_executed,
927
914
  self._current_batch_idx,
928
915
  self._num_work_idx_before_cur_batch,
929
916
  self._tile_count,