quack-kernels 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.3"
1
+ __version__ = "0.1.5"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
1
3
  import math
2
4
  import torch
3
5
  from typing import Optional, Type
@@ -77,7 +79,7 @@ class CrossEntropy(ReductionBase):
77
79
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
80
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
81
  block=[num_threads, 1, 1],
80
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
82
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
81
83
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
84
  stream=stream,
83
85
  )
@@ -93,15 +95,16 @@ class CrossEntropy(ReductionBase):
93
95
  tiler_mn: cute.Shape,
94
96
  ):
95
97
  tidx, _, _ = cute.arch.thread_idx()
96
- bidx, cluster_y, _ = cute.arch.block_idx()
98
+ bidx, _, _ = cute.arch.block_idx()
99
+ if cutlass.const_expr(self.cluster_n > 1):
100
+ cluster_y = cute.arch.block_idx()[1]
101
+ else:
102
+ cluster_y = cutlass.const_expr(0)
97
103
 
98
104
  shape: cute.Shape = mX.shape
99
105
  idX = cute.make_identity_tensor(shape)
100
106
  # slice for CTAs
101
- gX, cX = [
102
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
103
- for mT in (mX, idX)
104
- ]
107
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
105
108
 
106
109
  smem = cutlass.utils.SmemAllocator()
107
110
  sX = smem.allocate_tensor(
@@ -131,7 +134,9 @@ class CrossEntropy(ReductionBase):
131
134
 
132
135
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
133
136
  tXpX = (
134
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
137
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
138
+ if cutlass.const_expr(not is_even_N)
139
+ else None
135
140
  )
136
141
  if row < shape[0]:
137
142
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
@@ -154,7 +159,7 @@ class CrossEntropy(ReductionBase):
154
159
  cute.ReductionOp.MAX,
155
160
  threads_per_row,
156
161
  reduction_buffer[None, None, 0],
157
- mbar_ptr + 0 if self.cluster_n > 1 else None,
162
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
158
163
  init_val=-cutlass.Float32.inf,
159
164
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
165
  )
@@ -172,7 +177,7 @@ class CrossEntropy(ReductionBase):
172
177
  cute.ReductionOp.ADD,
173
178
  threads_per_row,
174
179
  reduction_buffer[None, None, 1],
175
- mbar_ptr + 1 if self.cluster_n > 1 else None,
180
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
176
181
  init_val=0.0,
177
182
  )
178
183
  else:
@@ -197,7 +202,7 @@ class CrossEntropy(ReductionBase):
197
202
  mLSE[row] = lse
198
203
 
199
204
 
200
- def cross_entropy(
205
+ def _cross_entropy(
201
206
  x: torch.Tensor,
202
207
  target: torch.Tensor,
203
208
  return_lse: bool = False,
@@ -238,15 +243,300 @@ def cross_entropy(
238
243
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
239
244
 
240
245
  compile_key = (dtype, N, lse is not None)
241
- if compile_key not in cross_entropy.compile_cache:
246
+ if compile_key not in _cross_entropy.compile_cache:
242
247
  cross_entropy_op = CrossEntropy(dtype, N)
243
- cross_entropy.compile_cache[compile_key] = cute.compile(
248
+ _cross_entropy.compile_cache[compile_key] = cute.compile(
244
249
  cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
245
250
  )
246
- cross_entropy.compile_cache[compile_key](
251
+ _cross_entropy.compile_cache[compile_key](
247
252
  x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
253
  )
249
254
  return loss if not return_lse else (loss, lse)
250
255
 
251
256
 
252
- cross_entropy.compile_cache = {}
257
+ _cross_entropy.compile_cache = {}
258
+
259
+
260
+ class CrossEntropyBackward:
261
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
262
+ self.dtype = dtype
263
+ self.N = N
264
+ self.vecsize = 128 // dtype.width
265
+
266
+ def _calculate_threads_per_row(self):
267
+ N = self.N
268
+ return (
269
+ 8
270
+ if N <= 64
271
+ else (
272
+ 16
273
+ if N <= 128
274
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
275
+ )
276
+ )
277
+
278
+ def _get_tv_layout(self):
279
+ N = self.N
280
+ vecsize = self.vecsize
281
+ num_threads = 128 if N <= 16384 else 256
282
+ threads_per_row = self._calculate_threads_per_row()
283
+ cols_per_block = num_threads // threads_per_row
284
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
285
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
286
+ tv_layout = cute.make_layout(
287
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
288
+ stride=(
289
+ (vecsize * cols_per_block, 1),
290
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
291
+ ),
292
+ )
293
+ return tiler_mn, tv_layout
294
+
295
+ @cute.jit
296
+ def __call__(
297
+ self,
298
+ mX: cute.Tensor,
299
+ mTarget: cute.Tensor,
300
+ mDLoss: cute.Tensor,
301
+ mdX: cute.Tensor,
302
+ mLSE: cute.Tensor,
303
+ stream: cuda.CUstream,
304
+ ):
305
+ assert mX.element_type == self.dtype
306
+ assert mdX.element_type == self.dtype
307
+
308
+ tiler_mn, tv_layout = self._get_tv_layout()
309
+ num_threads = cute.size(tv_layout, mode=[0])
310
+
311
+ mDLoss = cute.make_tensor(
312
+ mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
313
+ )
314
+ mTarget = cute.make_tensor(
315
+ mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
316
+ )
317
+ mLSE = cute.make_tensor(
318
+ mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
319
+ )
320
+
321
+ smem_size = cute.size_in_bytes(
322
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
323
+ )
324
+
325
+ self.kernel(
326
+ mX,
327
+ mTarget,
328
+ mDLoss,
329
+ mdX,
330
+ mLSE,
331
+ mX.shape,
332
+ tv_layout,
333
+ tiler_mn,
334
+ ).launch(
335
+ grid=[
336
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
337
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
338
+ 1,
339
+ ],
340
+ block=[num_threads, 1, 1],
341
+ smem=smem_size,
342
+ stream=stream,
343
+ )
344
+
345
+ @cute.kernel
346
+ def kernel(
347
+ self,
348
+ mX: cute.Tensor, # (M, N)
349
+ mTarget: cute.Tensor, # (M,)
350
+ mDLoss: cute.Tensor, # (M,)
351
+ mdX: cute.Tensor, # (M, N)
352
+ mLSE: cute.Tensor, # (M,)
353
+ shape: cute.Shape,
354
+ tv_layout: cute.Layout,
355
+ tiler_mn: cute.Shape,
356
+ ):
357
+ tidx, _, _ = cute.arch.thread_idx()
358
+ bidx, bidy, _ = cute.arch.block_idx()
359
+
360
+ smem = cutlass.utils.SmemAllocator()
361
+ sX = smem.allocate_tensor(
362
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
363
+ )
364
+
365
+ idX = cute.make_identity_tensor(shape)
366
+
367
+ gX, gdX, cX, gTarget, gDLoss, gLse = [
368
+ cute.local_tile(mT, tiler_mn, (bidx, bidy))
369
+ for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
370
+ ]
371
+
372
+ copy_atom_load_X = cute.make_copy_atom(
373
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
374
+ )
375
+ copy_atom_load_X_async = cute.make_copy_atom(
376
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
377
+ )
378
+ copy_atom_store_O = cute.make_copy_atom(
379
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
380
+ )
381
+
382
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
383
+ thr_copy_X_async = cute.make_tiled_copy(
384
+ copy_atom_load_X_async, tv_layout, tiler_mn
385
+ ).get_slice(tidx)
386
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
387
+
388
+ #### Thread View
389
+ tXgX = thr_copy_X_async.partition_S(gX)
390
+ tXsX = thr_copy_X_async.partition_S(sX)
391
+
392
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
393
+ tXcFull = thr_copy_X.partition_S(cX) # improve
394
+
395
+ tXgO = thr_copy_O.partition_D(gdX)
396
+
397
+ # allocate fragments for gmem->rmem
398
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
399
+
400
+ is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
401
+ row = tXcX[0][0]
402
+
403
+ tXpX = (
404
+ utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
405
+ if not is_even_N
406
+ else None
407
+ )
408
+
409
+ if row < shape[0]:
410
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
411
+ cute.arch.cp_async_commit_group()
412
+ cute.arch.cp_async_wait_group(0)
413
+ if cutlass.const_expr(not is_even_N):
414
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
415
+
416
+ cute.autovec_copy(tXsX, tXrX)
417
+ x = tXrX.load().to(cute.Float32)
418
+
419
+ label = cute.Int32.zero
420
+ dloss = cute.Float32.zero
421
+ lse = cute.Float32.zero
422
+ if row < shape[0]:
423
+ label = cute.Int32(mTarget[row])
424
+ dloss = cute.Float32(mDLoss[row])
425
+ lse = cute.Float32(mLSE[row])
426
+
427
+ log2_e = math.log2(math.e)
428
+ probs = utils.exp2f((x - lse) * log2_e)
429
+ prob_shifted = probs - 1.0
430
+
431
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
432
+ for i in cutlass.range_constexpr(cute.size(tXcFull)):
433
+ mask[i] = tXcFull[i][1] == label
434
+
435
+ mask = mask.load()
436
+ grad = cute.where(mask, prob_shifted, probs)
437
+ grad = grad * dloss
438
+
439
+ tXrO.store(grad.to(tXrO.element_type))
440
+ tOpO = (
441
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
442
+ )
443
+ if row < shape[0]:
444
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
445
+
446
+
447
+ def _cross_entropy_backward(
448
+ x: torch.Tensor,
449
+ target: torch.Tensor,
450
+ dloss: torch.Tensor,
451
+ lse: torch.Tensor,
452
+ inplace_backward: bool = False,
453
+ ) -> torch.Tensor:
454
+ """Cross entropy backward pass.
455
+ Args:
456
+ x: Input logits tensor of shape (M, N)
457
+ target: Target class indices tensor of shape (M,)
458
+ dloss: Upstream gradients tensor of shape (M,)
459
+ lse: Log-sum-exp values tensor of shape (M,)
460
+ Returns:
461
+ Input gradients tensor of shape (M, N)
462
+ """
463
+ assert x.dim() == 2, "Input must be 2D"
464
+ assert target.dim() == 1, "Target must be 1D"
465
+ assert dloss.dim() == 1, "dloss must be 1D"
466
+ assert lse.dim() == 1, "lse must be 1D"
467
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
468
+ assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
469
+ assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
470
+ assert (
471
+ x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
472
+ ), "Tensors must be on CUDA device"
473
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
474
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
475
+
476
+ M, N = x.shape
477
+ dx = torch.empty_like(x) if not inplace_backward else x
478
+ dtype = torch2cute_dtype_map[x.dtype]
479
+
480
+ convert_from_dlpack = lambda tensor: (
481
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
482
+ mode=0, stride_order=(0, 1)
483
+ )
484
+ )
485
+ x_tensor = convert_from_dlpack(x)
486
+ dx_tensor = convert_from_dlpack(dx)
487
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
488
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
489
+ target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
490
+ mode=0
491
+ )
492
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
493
+
494
+ compile_key = (dtype, N)
495
+ if compile_key not in _cross_entropy_backward.compile_cache:
496
+ cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
497
+ _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
498
+ cross_entropy_backward_op,
499
+ x_tensor,
500
+ target_tensor,
501
+ dloss_tensor,
502
+ dx_tensor,
503
+ lse_tensor,
504
+ stream,
505
+ )
506
+ _cross_entropy_backward.compile_cache[compile_key](
507
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
508
+ )
509
+ return dx
510
+
511
+
512
+ _cross_entropy_backward.compile_cache = {}
513
+
514
+
515
+ class CrossEntropyFunction(torch.autograd.Function):
516
+ @staticmethod
517
+ def forward(ctx, x, target, inplace_backward=False):
518
+ loss, lse = _cross_entropy(x, target, return_lse=True)
519
+ ctx.save_for_backward(x, target, lse)
520
+ ctx.inplace_backward = inplace_backward
521
+ return loss
522
+
523
+ @staticmethod
524
+ def backward(ctx, dloss):
525
+ x, target, lse = ctx.saved_tensors
526
+ dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
527
+ return dx, None, None
528
+
529
+
530
+ def cross_entropy(
531
+ x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
532
+ ) -> torch.Tensor:
533
+ """Cross entropy loss with automatic differentiation support.
534
+
535
+ Args:
536
+ x: Input logits tensor of shape (M, N)
537
+ target: Target class indices tensor of shape (M,)
538
+
539
+ Returns:
540
+ Cross entropy loss tensor of shape (M,)
541
+ """
542
+ return CrossEntropyFunction.apply(x, target, inplace_backward)
quack/reduction_base.py CHANGED
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
- import quack.utils as utils
10
-
11
9
 
12
10
  torch2cute_dtype_map = {
13
11
  torch.float16: cutlass.Float16,
@@ -39,7 +37,6 @@ class ReductionBase:
39
37
  vecsize = copy_bits // self.dtype.width
40
38
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
39
  num_threads = self._get_num_threads()
42
- num_warps = num_threads // cute.arch.WARP_SIZE
43
40
  assert num_threads % cute.arch.WARP_SIZE == 0
44
41
 
45
42
  threads_per_row = self._calculate_threads_per_row()
@@ -64,7 +61,7 @@ class ReductionBase:
64
61
 
65
62
  def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
63
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
64
+ warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
65
  return cute.make_ordered_layout(
69
66
  (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
67
  order=(1, 0, 2),
@@ -88,10 +85,10 @@ class ReductionBase:
88
85
  def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
86
  if cutlass.const_expr(self.cluster_n > 1):
90
87
  if tidx < self.stage:
91
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
88
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
92
89
  cute.arch.mbarrier_init_fence()
93
90
  if tidx < self.stage:
94
- cute.arch.mbarrier_init_tx_bytes(
91
+ cute.arch.mbarrier_arrive_and_expect_tx(
95
92
  mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
93
  )
97
94
  # Cluster arrive after barrier init
quack/rmsnorm.py CHANGED
@@ -9,7 +9,6 @@ import cuda.bindings.driver as cuda
9
9
  import cutlass
10
10
  import cutlass.cute as cute
11
11
  from cutlass.cute.runtime import from_dlpack
12
-
13
12
  import quack.utils as utils
14
13
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
14
 
@@ -84,7 +83,7 @@ class RMSNorm(ReductionBase):
84
83
  self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
85
84
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
86
85
  block=[num_threads, 1, 1],
87
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
86
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
88
87
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
89
88
  stream=stream,
90
89
  )
@@ -103,7 +102,11 @@ class RMSNorm(ReductionBase):
103
102
  delay_w_load: cutlass.Constexpr = False,
104
103
  ):
105
104
  tidx, _, _ = cute.arch.thread_idx()
106
- bidx, cluster_y, _ = cute.arch.block_idx()
105
+ bidx, _, _ = cute.arch.block_idx()
106
+ if cutlass.const_expr(self.cluster_n > 1):
107
+ cluster_y = cute.arch.block_idx()[1]
108
+ else:
109
+ cluster_y = cutlass.const_expr(0)
107
110
 
108
111
  smem = cutlass.utils.SmemAllocator()
109
112
  sX = smem.allocate_tensor(
@@ -114,13 +117,10 @@ class RMSNorm(ReductionBase):
114
117
  shape = mX.shape
115
118
  idX = cute.make_identity_tensor(shape)
116
119
  # slice for CTAs
117
- gX, gO, cX = [
118
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
119
- for mT in (mX, mO, idX)
120
- ]
121
- gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
120
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
121
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
122
122
  gRstd = (
123
- cute.local_tile(mRstd, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
123
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
124
124
  if cutlass.const_expr(mRstd is not None)
125
125
  else None
126
126
  )
@@ -167,7 +167,7 @@ class RMSNorm(ReductionBase):
167
167
  cute.arch.cp_async_commit_group()
168
168
 
169
169
  tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
170
- if not delay_w_load:
170
+ if cutlass.const_expr(not delay_w_load):
171
171
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
172
172
 
173
173
  cute.arch.cp_async_wait_group(0)
@@ -192,12 +192,12 @@ class RMSNorm(ReductionBase):
192
192
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
193
193
  ):
194
194
  tXrRstd[0] = rstd
195
- if delay_w_load:
195
+ if cutlass.const_expr(delay_w_load):
196
196
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
197
- if reload_from == "smem":
197
+ if cutlass.const_expr(reload_from == "smem"):
198
198
  cute.autovec_copy(tXsX, tXrX)
199
199
  x = tXrX.load().to(cute.Float32)
200
- elif reload_from == "gmem":
200
+ elif cutlass.const_expr(reload_from == "gmem"):
201
201
  cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
202
202
  x = tXrX.load().to(cute.Float32)
203
203
  x_hat = x * rstd
@@ -209,20 +209,18 @@ class RMSNorm(ReductionBase):
209
209
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
210
210
 
211
211
 
212
- def rmsnorm(
212
+ def _rmsnorm_fwd(
213
213
  x: torch.Tensor,
214
214
  weight: torch.Tensor,
215
215
  eps: float = 1e-6,
216
216
  return_rstd: bool = False,
217
217
  ) -> torch.Tensor:
218
218
  """RMSNorm forward pass.
219
-
220
219
  Args:
221
220
  x: Input tensor of shape (M, N)
222
221
  weight: Weight tensor of shape (N,)
223
222
  eps: Small value for numerical stability
224
223
  return_rstd: Whether to return the reciprocal standard deviation
225
-
226
224
  Returns:
227
225
  Normalized output tensor of same shape as x
228
226
  If return_rstd is True, also returns rstd tensor of shape (M,)
@@ -258,18 +256,18 @@ def rmsnorm(
258
256
  )
259
257
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
260
258
  compile_key = (dtype, N, rstd is not None)
261
- if compile_key not in rmsnorm.compile_cache:
259
+ if compile_key not in _rmsnorm_fwd.compile_cache:
262
260
  rmsnorm_op = RMSNorm(dtype, N)
263
- rmsnorm.compile_cache[compile_key] = cute.compile(
261
+ _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
264
262
  rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
265
263
  )
266
- rmsnorm.compile_cache[compile_key](
264
+ _rmsnorm_fwd.compile_cache[compile_key](
267
265
  x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
268
266
  )
269
267
  return (out, rstd) if return_rstd else out
270
268
 
271
269
 
272
- rmsnorm.compile_cache = {}
270
+ _rmsnorm_fwd.compile_cache = {}
273
271
 
274
272
 
275
273
  def rmsnorm_ref(x, w, eps=1e-6):
@@ -282,3 +280,383 @@ def rmsnorm_ref(x, w, eps=1e-6):
282
280
  def rstd_ref(x, eps=1e-6):
283
281
  x_f32 = x.float()
284
282
  return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
283
+
284
+
285
+ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
286
+ """Reference implementation for RMSNorm backward pass."""
287
+ x_f32 = x.float()
288
+ x_hat = x_f32 * rstd.unsqueeze(1)
289
+ wdy = dout * w
290
+ c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
291
+ dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
292
+
293
+ # dL/dW
294
+ dw = (dout * x_hat).sum(dim=0)
295
+ return dx.to(x.dtype), dw.to(w.dtype)
296
+
297
+
298
+ class RMSNormBackward(ReductionBase):
299
+ def __init__(self, dtype: cutlass.Numeric, N: int):
300
+ # 1 stage for computing mean of x_hat * wdy
301
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
302
+
303
+ def _calculate_threads_per_row(self):
304
+ N = self.N
305
+ return (
306
+ 8
307
+ if N <= 64
308
+ else (
309
+ 16
310
+ if N <= 128
311
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
312
+ )
313
+ )
314
+
315
+ def _set_cluster_n(self):
316
+ N = self.N
317
+ if cutlass.const_expr(self.dtype.width == 16):
318
+ cluster_n = (
319
+ 1
320
+ if N <= 16 * 1024
321
+ else (
322
+ 2
323
+ if N <= 32 * 1024
324
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
325
+ )
326
+ )
327
+ else: # fp32
328
+ cluster_n = (
329
+ 1
330
+ if N <= 32 * 1024
331
+ else (
332
+ 2
333
+ if N <= 64 * 1024
334
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
335
+ )
336
+ )
337
+ self.cluster_n = cluster_n
338
+
339
+ @cute.jit
340
+ def __call__(
341
+ self,
342
+ mX: cute.Tensor,
343
+ mW: cute.Tensor,
344
+ mDout: cute.Tensor,
345
+ mRstd: cute.Tensor,
346
+ mDx: cute.Tensor,
347
+ mDw: cute.Tensor,
348
+ sm_count: cutlass.Constexpr,
349
+ stream: cuda.CUstream,
350
+ ):
351
+ self._set_cluster_n()
352
+ tiler_mn, tv_layout = self._get_tv_layout()
353
+ num_threads = cute.size(tv_layout, mode=[0])
354
+ num_warps = num_threads // cute.arch.WARP_SIZE
355
+
356
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
357
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
358
+
359
+ mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
360
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
361
+
362
+ num_blocks = (
363
+ sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
364
+ )
365
+
366
+ self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
367
+ grid=[num_blocks, self.cluster_n, 1],
368
+ block=[num_threads, 1, 1],
369
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
370
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
371
+ stream=stream,
372
+ )
373
+
374
+ @cute.kernel
375
+ def kernel(
376
+ self,
377
+ mX: cute.Tensor,
378
+ mW: cute.Tensor,
379
+ mDout: cute.Tensor,
380
+ mRstd: cute.Tensor,
381
+ mDx: cute.Tensor,
382
+ mDw: cute.Tensor,
383
+ sm_count: cutlass.Constexpr,
384
+ tv_layout: cute.Layout,
385
+ tiler_mn: cute.Shape,
386
+ ):
387
+ tidx, _, _ = cute.arch.thread_idx()
388
+ bidx, cluster_y, _ = cute.arch.block_idx()
389
+ gdim, _, _ = cute.arch.grid_dim()
390
+
391
+ shape = mX.shape
392
+ M, N = shape[0], shape[1]
393
+
394
+ idX = cute.make_identity_tensor(shape)
395
+
396
+ smem = cutlass.utils.SmemAllocator()
397
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
398
+
399
+ copy_atom_load_X = cute.make_copy_atom(
400
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
401
+ )
402
+
403
+ copy_atom_load_W = cute.make_copy_atom(
404
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
405
+ )
406
+
407
+ copy_atom_store_dX = cute.make_copy_atom(
408
+ cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
409
+ )
410
+
411
+ copy_atom_dw = cute.make_copy_atom(
412
+ cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
413
+ )
414
+
415
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
416
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
417
+ thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
418
+ thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
419
+
420
+ gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
421
+ tWgW = thr_copy_W.partition_S(gW)
422
+ tWrW = cute.make_fragment_like(tWgW)
423
+ tXrW = thr_copy_X.retile(tWrW)
424
+
425
+ gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
426
+
427
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
428
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
429
+ weight = tXrW.load().to(cute.Float32)
430
+
431
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
432
+
433
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
434
+
435
+ dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
436
+ tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
437
+
438
+ gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
439
+ tDwgDw = thr_copy_dw.partition_D(gDw)
440
+ tDwrDw = cute.make_fragment_like(tDwgDw)
441
+ dw_accumulator = thr_copy_X.retile(tDwrDw)
442
+ dw_accumulator.fill(0.0)
443
+
444
+ M_pad = ((M + sm_count - 1) // sm_count) * sm_count
445
+
446
+ jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
447
+
448
+ if cutlass.const_expr(self.cluster_n > 1):
449
+ cute.arch.cluster_arrive()
450
+ cute.arch.cluster_wait()
451
+
452
+ ## need to update range_dynamic since it will be deprecated soon
453
+ for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
454
+ gX = cute.local_tile(
455
+ mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
456
+ )
457
+ gDout = cute.local_tile(
458
+ mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
459
+ )
460
+ gRstd = cute.local_tile(
461
+ mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
462
+ )
463
+ gDx = cute.local_tile(
464
+ mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
465
+ )
466
+ cX = cute.local_tile(
467
+ idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
468
+ )
469
+
470
+ tXgX = thr_copy_X.partition_S(gX)
471
+ thrDout = thr_copy_X.partition_S(gDout)
472
+ tXrRstd = thr_copy_W.partition_S(gRstd)
473
+ thrDx = thr_store_dx.partition_D(gDx)
474
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
475
+
476
+ tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
477
+
478
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
479
+
480
+ if tXcX[0][0] < shape[0]:
481
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
482
+ cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
483
+
484
+ x = tXrX.load().to(cute.Float32)
485
+ dout = frgDout.load().to(cute.Float32)
486
+
487
+ rstd = tXrRstd[0]
488
+ x_hat = x * rstd
489
+ wdy = dout * weight
490
+
491
+ threads_per_row = tv_layout.shape[0][0]
492
+
493
+ row = tXcX[0][0]
494
+ if cutlass.const_expr(self.cluster_n > 1):
495
+ cute.arch.cluster_arrive()
496
+ cute.arch.cluster_wait()
497
+ else:
498
+ cute.arch.barrier()
499
+
500
+ mean_xhat_wdy = (
501
+ utils.row_reduce(
502
+ x_hat * wdy,
503
+ cute.ReductionOp.ADD,
504
+ threads_per_row,
505
+ reduction_buffer[None, None, 0],
506
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
507
+ init_val=0.0,
508
+ hook_fn=cute.arch.cluster_wait
509
+ if cutlass.const_expr(self.cluster_n > 1)
510
+ else None,
511
+ )
512
+ / shape[1]
513
+ )
514
+
515
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd
516
+ frgDx.store(dx.to(frgDout.element_type))
517
+
518
+ if row < M:
519
+ cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
520
+
521
+ if cutlass.const_expr(self.cluster_n > 1):
522
+ cute.arch.cluster_arrive()
523
+ cute.arch.cluster_wait()
524
+ else:
525
+ cute.arch.barrier()
526
+
527
+ if row < M:
528
+ dw_row = dout * x_hat
529
+ current_dw = dw_accumulator.load().to(cute.Float32)
530
+ updated_dw = current_dw + dw_row
531
+ dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
532
+
533
+ """
534
+ if cutlass.const_expr(self.cluster_n > 1):
535
+ cute.arch.cluster_arrive()
536
+ cute.arch.cluster_wait()
537
+ else:
538
+ cute.arch.barrier()
539
+ """
540
+ """
541
+ if cutlass.const_expr(self.cluster_n > 1):
542
+ cute.arch.cluster_arrive()
543
+ cute.arch.cluster_wait()
544
+ else:
545
+ cute.arch.barrier()
546
+ """
547
+
548
+ cute.autovec_copy(dw_accumulator, tDwrDw)
549
+ cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
550
+
551
+
552
+ def _rmsnorm_backward(
553
+ x: torch.Tensor,
554
+ weight: torch.Tensor,
555
+ dout: torch.Tensor,
556
+ rstd: torch.Tensor,
557
+ ) -> (torch.Tensor, torch.Tensor):
558
+ """RMSNorm backward pass.
559
+ Args:
560
+ x: Input tensor of shape (M, N)
561
+ weight: Weight tensor of shape (N,)
562
+ dout: Upstream gradients tensor of shape (M, N)
563
+ rstd: Reciprocal standard deviation tensor of shape (M,)
564
+ Returns:
565
+ Tuple of (dx, dw) where:
566
+ - dx: Input gradients tensor of same shape as x
567
+ - dw: Weight gradients tensor of same shape as weight
568
+ """
569
+ assert x.dim() == 2, "Input must be 2D"
570
+ assert weight.dim() == 1, "Weight must be 1D"
571
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
572
+ assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
573
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
574
+ assert weight.dtype == torch.float32, "Weight must be float32"
575
+
576
+ M, N = x.shape
577
+ dx = torch.empty_like(x)
578
+
579
+ device = x.device
580
+
581
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
582
+ dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
583
+
584
+ dtype = torch2cute_dtype_map[x.dtype]
585
+
586
+ convert_from_dlpack = lambda tensor: (
587
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
588
+ mode=0, stride_order=(0, 1)
589
+ )
590
+ )
591
+
592
+ x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
593
+
594
+ weight_tensor = utils.convert_from_dlpack(
595
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
596
+ )
597
+
598
+ dw_partial_tensor = convert_from_dlpack(dw_partial)
599
+ rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
600
+
601
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
602
+
603
+ compile_key = (dtype, N)
604
+ if compile_key not in _rmsnorm_backward.compile_cache:
605
+ rmsnorm_backward_op = RMSNormBackward(dtype, N)
606
+ _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
607
+ rmsnorm_backward_op,
608
+ x_tensor,
609
+ weight_tensor,
610
+ dout_tensor,
611
+ rstd_tensor,
612
+ dx_tensor,
613
+ dw_partial_tensor,
614
+ sm_count,
615
+ current_stream,
616
+ )
617
+
618
+ _rmsnorm_backward.compile_cache[compile_key](
619
+ x_tensor,
620
+ weight_tensor,
621
+ dout_tensor,
622
+ rstd_tensor,
623
+ dx_tensor,
624
+ dw_partial_tensor,
625
+ current_stream,
626
+ )
627
+
628
+ dw = dw_partial.sum(dim=0).to(weight.dtype)
629
+ return dx, dw
630
+
631
+
632
+ _rmsnorm_backward.compile_cache = {}
633
+
634
+
635
+ class RMSNormFunction(torch.autograd.Function):
636
+ @staticmethod
637
+ def forward(ctx, x, weight, eps):
638
+ out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
639
+ ctx.save_for_backward(x, weight, rstd)
640
+ ctx.eps = eps
641
+ return out
642
+
643
+ @staticmethod
644
+ def backward(ctx, dout):
645
+ x, weight, rstd = ctx.saved_tensors
646
+ dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
647
+ # dw is returned for weight gradient, None for eps gradient
648
+ return dx, dw, None
649
+
650
+
651
+ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
652
+ """RMSNorm forward pass with automatic differentiation support.
653
+
654
+ Args:
655
+ x: Input tensor of shape (M, N)
656
+ weight: Weight tensor of shape (N,)
657
+ eps: Small value for numerical stability
658
+
659
+ Returns:
660
+ Normalized output tensor of same shape as x
661
+ """
662
+ return RMSNormFunction.apply(x, weight, eps)
quack/softmax.py CHANGED
@@ -75,7 +75,7 @@ class Softmax(ReductionBase):
75
75
  self.kernel(mX, mO, tv_layout, tiler_mn).launch(
76
76
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
77
77
  block=[num_threads, 1, 1],
78
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
78
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
79
79
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
80
80
  stream=stream,
81
81
  )
@@ -89,15 +89,16 @@ class Softmax(ReductionBase):
89
89
  tiler_mn: cute.Shape,
90
90
  ):
91
91
  tidx, _, _ = cute.arch.thread_idx()
92
- bidx, cluster_y, _ = cute.arch.block_idx()
92
+ bidx, _, _ = cute.arch.block_idx()
93
+ if cutlass.const_expr(self.cluster_n > 1):
94
+ cluster_y = cute.arch.block_idx()[1]
95
+ else:
96
+ cluster_y = cutlass.const_expr(0)
93
97
 
94
98
  shape = mX.shape
95
99
  idX = cute.make_identity_tensor(shape)
96
100
  # slice for CTAs
97
- gX, gO, cX = [
98
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
99
- for mT in (mX, mO, idX)
100
- ]
101
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
101
102
 
102
103
  smem = cutlass.utils.SmemAllocator()
103
104
  sX = smem.allocate_tensor(
@@ -129,7 +130,9 @@ class Softmax(ReductionBase):
129
130
 
130
131
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
131
132
  tXpX = (
132
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
133
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
134
+ if cutlass.const_expr(not is_even_N)
135
+ else None
133
136
  )
134
137
  if tXcX[0][0] < shape[0]:
135
138
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
@@ -148,7 +151,7 @@ class Softmax(ReductionBase):
148
151
  cute.ReductionOp.MAX,
149
152
  threads_per_row,
150
153
  reduction_buffer[None, None, 0],
151
- mbar_ptr + 0 if self.cluster_n > 1 else None,
154
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
152
155
  init_val=-cutlass.Float32.inf,
153
156
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
154
157
  )
@@ -159,7 +162,7 @@ class Softmax(ReductionBase):
159
162
  cute.ReductionOp.ADD,
160
163
  threads_per_row,
161
164
  reduction_buffer[None, None, 1],
162
- mbar_ptr + 1 if self.cluster_n > 1 else None,
165
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
163
166
  init_val=0.0,
164
167
  )
165
168
  else:
@@ -174,7 +177,9 @@ class Softmax(ReductionBase):
174
177
  y = exp_x * (1.0 / denom)
175
178
  tXrO.store(y.to(tXrO.element_type))
176
179
  tOpO = (
177
- utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
180
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
181
+ if cutlass.const_expr(not is_even_N)
182
+ else None
178
183
  )
179
184
  if tXcX[0][0] < shape[0]:
180
185
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
@@ -283,7 +288,7 @@ class SoftmaxBackward(ReductionBase):
283
288
  self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
284
289
  grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
285
290
  block=[num_threads, 1, 1],
286
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
291
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
287
292
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
288
293
  stream=stream,
289
294
  )
@@ -298,14 +303,17 @@ class SoftmaxBackward(ReductionBase):
298
303
  tiler_mn: cute.Shape,
299
304
  ):
300
305
  tidx, _, _ = cute.arch.thread_idx()
301
- bidx, cluster_y, _ = cute.arch.block_idx()
306
+ bidx, _, _ = cute.arch.block_idx()
307
+ if cutlass.const_expr(self.cluster_n > 1):
308
+ cluster_y = cute.arch.block_idx()[1]
309
+ else:
310
+ cluster_y = cutlass.const_expr(0)
302
311
 
303
312
  shape = mdY.shape
304
313
  idX = cute.make_identity_tensor(shape)
305
314
  # slice for CTAs
306
315
  gdY, gY, gdX, cX = [
307
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
308
- for mT in (mdY, mY, mdX, idX)
316
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
309
317
  ]
310
318
 
311
319
  smem = cutlass.utils.SmemAllocator()
@@ -344,7 +352,7 @@ class SoftmaxBackward(ReductionBase):
344
352
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
345
353
  tdYpdY = (
346
354
  utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
347
- if not is_even_N
355
+ if cutlass.const_expr(not is_even_N)
348
356
  else None
349
357
  )
350
358
 
@@ -366,7 +374,7 @@ class SoftmaxBackward(ReductionBase):
366
374
  cute.ReductionOp.ADD,
367
375
  threads_per_row,
368
376
  reduction_buffer[None, None, 0],
369
- mbar_ptr if self.cluster_n > 1 else None,
377
+ mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
370
378
  init_val=0.0,
371
379
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
372
380
  )
@@ -376,7 +384,7 @@ class SoftmaxBackward(ReductionBase):
376
384
  tdXrdX.store(dx.to(tdXrdX.element_type))
377
385
  tdXpdX = (
378
386
  utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
379
- if not is_even_N
387
+ if cutlass.const_expr(not is_even_N)
380
388
  else None
381
389
  )
382
390
  if tXcX[0][0] < shape[0]:
quack/utils.py CHANGED
@@ -24,32 +24,19 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
24
24
 
25
25
 
26
26
  @cute.jit
27
- def max_constexpr(
28
- a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
29
- ) -> cutlass.Constexpr[cute.Numeric]:
30
- return a if a > b else b
31
-
32
-
33
- @cute.jit
34
- def min_constexpr(
35
- a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
36
- ) -> cutlass.Constexpr[cute.Numeric]:
37
- return a if a < b else b
38
-
39
-
40
27
  def warp_reduce(
41
28
  val: cute.TensorSSA | cute.Numeric,
42
29
  op: Callable,
43
30
  width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
44
31
  ) -> cute.TensorSSA | cute.Numeric:
45
- if isinstance(val, cute.TensorSSA):
32
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
46
33
  res = cute.make_fragment(val.shape, val.dtype)
47
34
  res.store(val)
48
- for i in range(cute.size(val.shape)):
35
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
49
36
  res[i] = warp_reduce(res[i], op, width)
50
37
  return res.load()
51
38
  else:
52
- for i in range(int(math.log2(width))):
39
+ for i in cutlass.range_constexpr(int(math.log2(width))):
53
40
  val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
54
41
  return val
55
42
 
@@ -111,15 +98,15 @@ def store_shared_remote(
111
98
  remote_mbar_ptr_i32 = set_block_rank(
112
99
  mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
113
100
  ).ir_value()
114
- if isinstance(val, float):
101
+ if cutlass.const_expr(isinstance(val, float)):
115
102
  val = Float32(val)
116
103
  assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
117
- suffix = "f32" if isinstance(val, Float32) else "s64"
104
+ suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
118
105
  llvm.inline_asm(
119
106
  None,
120
107
  [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
121
108
  f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
122
- f"r,{'f' if isinstance(val, Float32) else 'l'},r",
109
+ f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
123
110
  has_side_effects=True,
124
111
  is_align_stack=False,
125
112
  asm_dialect=llvm.AsmDialect.AD_ATT,
@@ -195,7 +182,7 @@ def row_reduce(
195
182
  val = warp_reduce(
196
183
  val,
197
184
  warp_op,
198
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
185
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
199
186
  )
200
187
  if cutlass.const_expr(hook_fn is not None):
201
188
  hook_fn()
@@ -225,7 +212,7 @@ def online_softmax_reduce(
225
212
  max_x = warp_reduce(
226
213
  x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
227
214
  cute.arch.fmax,
228
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
215
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
229
216
  )
230
217
  log2_e = math.log2(math.e)
231
218
  exp_x = exp2f(x * log2_e - (max_x * log2_e))
@@ -233,7 +220,7 @@ def online_softmax_reduce(
233
220
  sum_exp_x = warp_reduce(
234
221
  exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
235
222
  operator.add,
236
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
223
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
237
224
  )
238
225
  if cutlass.const_expr(hook_fn is not None):
239
226
  hook_fn()
@@ -299,18 +286,18 @@ def online_softmax_reduce(
299
286
  return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
300
287
 
301
288
 
289
+ @cute.jit
302
290
  def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
303
291
  """exp2f calculation for both vector and scalar.
304
-
305
292
  :param x: input value
306
293
  :type x: cute.TensorSSA or Float32
307
294
  :return: exp2 value
308
295
  :rtype: cute.TensorSSA or Float32
309
296
  """
310
- if isinstance(x, cute.TensorSSA):
297
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
311
298
  res = cute.make_fragment(x.shape, Float32)
312
299
  res.store(x)
313
- for i in range(cute.size(x.shape)):
300
+ for i in cutlass.range_constexpr(cute.size(x.shape)):
314
301
  res[i] = cute.arch.exp2(res[i])
315
302
  return res.load()
316
303
  else:
@@ -347,6 +334,7 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
347
334
  )
348
335
 
349
336
 
337
+ @cute.jit
350
338
  def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
351
339
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
352
340
  tApA = cute.make_fragment(
@@ -356,8 +344,8 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
356
344
  ),
357
345
  cutlass.Boolean,
358
346
  )
359
- for rest_v in range(tApA.shape[0]):
360
- for rest_k in range(tApA.shape[2]):
347
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
348
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
361
349
  tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
362
350
  return tApA
363
351
 
@@ -373,8 +361,8 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
373
361
  """
374
362
  tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
375
363
  tXrX_fill.fill(fill_value)
376
- for rest_v in range(tXpX.shape[0]):
377
- for rest_k in range(tXpX.shape[2]):
364
+ for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
365
+ for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
378
366
  if not tXpX[rest_v, 0, rest_k]:
379
367
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
380
368
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.0.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -0,0 +1,11 @@
1
+ quack/__init__.py,sha256=GPoImcynY5-OkMep5RhQhXrnZyxgqZG3RoHhsYQFSL4,203
2
+ quack/cross_entropy.py,sha256=WkngPY8uk4RCjCFtHtB7h9GF_8xt4NnyvDzvw73gIL4,19320
3
+ quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
4
+ quack/rmsnorm.py,sha256=N9NavrR85ws4cZgkfpeRLjYkVSq2yfyzJQWvfKf98pY,23935
5
+ quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
6
+ quack/utils.py,sha256=6EyWgf0z3wcbhGUivHmWB8hVBnEzMyOhmAuZ2Te82k0,15226
7
+ quack_kernels-0.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ quack_kernels-0.1.5.dist-info/METADATA,sha256=WI-2CP1mRH05V9Fjdx7HsErNOkrc6fUhheoH4ynlo-U,289
9
+ quack_kernels-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ quack_kernels-0.1.5.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
+ quack_kernels-0.1.5.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- quack/__init__.py,sha256=aUR7drzgaqmbzw9H_eoFselMUVQVF3BHc9VOzZg5d-Q,203
2
- quack/cross_entropy.py,sha256=_Xlyifd_YS8LaYxYlZEsuBfsi8zTH4At3i9DDggGCf8,9319
3
- quack/reduction_base.py,sha256=nrRsXwTpLVQkPp2Gr_FgHRPnifqkMHRodve5ciHzx58,3667
4
- quack/rmsnorm.py,sha256=YqGTTKHHXYzw3xnnjBRfaN9TDlhG8D_fSI9CHKAU40A,10548
5
- quack/softmax.py,sha256=mWaUfaY6PBtO1ioYxXxS-yodQmcBNGasWVMUg9G066Y,15938
6
- quack/utils.py,sha256=1-HMcFTEvGdAtqC3ucQGZ3DLa_PoJQsqwYlKd9bcXO8,15347
7
- quack_kernels-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- quack_kernels-0.1.3.dist-info/METADATA,sha256=DDuEKHLjFx9dFTQV5YtXsnKVFZVoueO7NwhcwOtpw6g,284
9
- quack_kernels-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- quack_kernels-0.1.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
- quack_kernels-0.1.3.dist-info/RECORD,,