quack-kernels 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.4"
1
+ __version__ = "0.1.6"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
quack/cross_entropy.py CHANGED
@@ -1,3 +1,5 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
1
3
  import math
2
4
  import torch
3
5
  from typing import Optional, Type
@@ -102,7 +104,10 @@ class CrossEntropy(ReductionBase):
102
104
  shape: cute.Shape = mX.shape
103
105
  idX = cute.make_identity_tensor(shape)
104
106
  # slice for CTAs
105
- gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
107
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
108
+ mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
109
+ gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
110
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
106
111
 
107
112
  smem = cutlass.utils.SmemAllocator()
108
113
  sX = smem.allocate_tensor(
@@ -148,7 +153,9 @@ class CrossEntropy(ReductionBase):
148
153
 
149
154
  target_logit = cute.Float32.zero
150
155
  if row < shape[0] and tXcX[0][1] == 0:
151
- target_logit = cute.Float32(mX[row, target])
156
+ # Use Int64 for indexing to deal with large tensors
157
+ mX_off = utils.domain_offset_i64((row, 0), mX)
158
+ target_logit = cute.Float32(mX_off[0, target])
152
159
 
153
160
  threads_per_row = tv_layout.shape[0][0]
154
161
  if cutlass.const_expr(not self.online_softmax):
@@ -200,7 +207,7 @@ class CrossEntropy(ReductionBase):
200
207
  mLSE[row] = lse
201
208
 
202
209
 
203
- def cross_entropy(
210
+ def _cross_entropy(
204
211
  x: torch.Tensor,
205
212
  target: torch.Tensor,
206
213
  return_lse: bool = False,
@@ -241,15 +248,299 @@ def cross_entropy(
241
248
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
242
249
 
243
250
  compile_key = (dtype, N, lse is not None)
244
- if compile_key not in cross_entropy.compile_cache:
251
+ if compile_key not in _cross_entropy.compile_cache:
245
252
  cross_entropy_op = CrossEntropy(dtype, N)
246
- cross_entropy.compile_cache[compile_key] = cute.compile(
253
+ _cross_entropy.compile_cache[compile_key] = cute.compile(
247
254
  cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
255
  )
249
- cross_entropy.compile_cache[compile_key](
256
+ _cross_entropy.compile_cache[compile_key](
250
257
  x_tensor, target_tensor, loss_tensor, lse_tensor, stream
251
258
  )
252
259
  return loss if not return_lse else (loss, lse)
253
260
 
254
261
 
255
- cross_entropy.compile_cache = {}
262
+ _cross_entropy.compile_cache = {}
263
+
264
+
265
+ class CrossEntropyBackward:
266
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
267
+ self.dtype = dtype
268
+ self.N = N
269
+ self.vecsize = 128 // dtype.width
270
+
271
+ def _calculate_threads_per_row(self):
272
+ N = self.N
273
+ return (
274
+ 8
275
+ if N <= 64
276
+ else (
277
+ 16
278
+ if N <= 128
279
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
280
+ )
281
+ )
282
+
283
+ def _get_tv_layout(self):
284
+ N = self.N
285
+ vecsize = self.vecsize
286
+ num_threads = 128 if N <= 16384 else 256
287
+ threads_per_row = self._calculate_threads_per_row()
288
+ cols_per_block = num_threads // threads_per_row
289
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
290
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
291
+ tv_layout = cute.make_layout(
292
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
293
+ stride=(
294
+ (vecsize * cols_per_block, 1),
295
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
296
+ ),
297
+ )
298
+ return tiler_mn, tv_layout
299
+
300
+ @cute.jit
301
+ def __call__(
302
+ self,
303
+ mX: cute.Tensor,
304
+ mTarget: cute.Tensor,
305
+ mDLoss: cute.Tensor,
306
+ mdX: cute.Tensor,
307
+ mLSE: cute.Tensor,
308
+ stream: cuda.CUstream,
309
+ ):
310
+ assert mX.element_type == self.dtype
311
+ assert mdX.element_type == self.dtype
312
+
313
+ tiler_mn, tv_layout = self._get_tv_layout()
314
+ num_threads = cute.size(tv_layout, mode=[0])
315
+
316
+ mDLoss = cute.make_tensor(
317
+ mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
318
+ )
319
+ mTarget = cute.make_tensor(
320
+ mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
321
+ )
322
+ mLSE = cute.make_tensor(
323
+ mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
324
+ )
325
+
326
+ smem_size = cute.size_in_bytes(
327
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
328
+ )
329
+
330
+ self.kernel(
331
+ mX,
332
+ mTarget,
333
+ mDLoss,
334
+ mdX,
335
+ mLSE,
336
+ mX.shape,
337
+ tv_layout,
338
+ tiler_mn,
339
+ ).launch(
340
+ grid=[
341
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
342
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
343
+ 1,
344
+ ],
345
+ block=[num_threads, 1, 1],
346
+ smem=smem_size,
347
+ stream=stream,
348
+ )
349
+
350
+ @cute.kernel
351
+ def kernel(
352
+ self,
353
+ mX: cute.Tensor, # (M, N)
354
+ mTarget: cute.Tensor, # (M,)
355
+ mDLoss: cute.Tensor, # (M,)
356
+ mdX: cute.Tensor, # (M, N)
357
+ mLSE: cute.Tensor, # (M,)
358
+ shape: cute.Shape,
359
+ tv_layout: cute.Layout,
360
+ tiler_mn: cute.Shape,
361
+ ):
362
+ tidx, _, _ = cute.arch.thread_idx()
363
+ bidx, bidy, _ = cute.arch.block_idx()
364
+
365
+ smem = cutlass.utils.SmemAllocator()
366
+ sX = smem.allocate_tensor(
367
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
368
+ )
369
+
370
+ idX = cute.make_identity_tensor(shape)
371
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
372
+ mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
373
+ gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
374
+ cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
375
+
376
+ copy_atom_load_X = cute.make_copy_atom(
377
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
378
+ )
379
+ copy_atom_load_X_async = cute.make_copy_atom(
380
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
381
+ )
382
+ copy_atom_store_O = cute.make_copy_atom(
383
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
384
+ )
385
+
386
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
387
+ thr_copy_X_async = cute.make_tiled_copy(
388
+ copy_atom_load_X_async, tv_layout, tiler_mn
389
+ ).get_slice(tidx)
390
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
391
+
392
+ #### Thread View
393
+ tXgX = thr_copy_X_async.partition_S(gX)
394
+ tXsX = thr_copy_X_async.partition_S(sX)
395
+
396
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
397
+ tXcFull = thr_copy_X.partition_S(cX) # improve
398
+
399
+ tXgO = thr_copy_O.partition_D(gdX)
400
+
401
+ # allocate fragments for gmem->rmem
402
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
403
+
404
+ is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
405
+ row = tXcX[0][0]
406
+
407
+ tXpX = (
408
+ utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
409
+ if not is_even_N
410
+ else None
411
+ )
412
+
413
+ if row < shape[0]:
414
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
415
+ cute.arch.cp_async_commit_group()
416
+ cute.arch.cp_async_wait_group(0)
417
+ if cutlass.const_expr(not is_even_N):
418
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
419
+
420
+ cute.autovec_copy(tXsX, tXrX)
421
+ x = tXrX.load().to(cute.Float32)
422
+
423
+ label = cute.Int32.zero
424
+ dloss = cute.Float32.zero
425
+ lse = cute.Float32.zero
426
+ if row < shape[0]:
427
+ label = cute.Int32(mTarget[row])
428
+ dloss = cute.Float32(mDLoss[row])
429
+ lse = cute.Float32(mLSE[row])
430
+
431
+ log2_e = math.log2(math.e)
432
+ probs = utils.exp2f((x - lse) * log2_e)
433
+ prob_shifted = probs - 1.0
434
+
435
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
436
+ for i in cutlass.range_constexpr(cute.size(tXcFull)):
437
+ mask[i] = tXcFull[i][1] == label
438
+
439
+ mask = mask.load()
440
+ grad = cute.where(mask, prob_shifted, probs)
441
+ grad = grad * dloss
442
+
443
+ tXrO.store(grad.to(tXrO.element_type))
444
+ tOpO = (
445
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
446
+ )
447
+ if row < shape[0]:
448
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
449
+
450
+
451
+ def _cross_entropy_backward(
452
+ x: torch.Tensor,
453
+ target: torch.Tensor,
454
+ dloss: torch.Tensor,
455
+ lse: torch.Tensor,
456
+ inplace_backward: bool = False,
457
+ ) -> torch.Tensor:
458
+ """Cross entropy backward pass.
459
+ Args:
460
+ x: Input logits tensor of shape (M, N)
461
+ target: Target class indices tensor of shape (M,)
462
+ dloss: Upstream gradients tensor of shape (M,)
463
+ lse: Log-sum-exp values tensor of shape (M,)
464
+ Returns:
465
+ Input gradients tensor of shape (M, N)
466
+ """
467
+ assert x.dim() == 2, "Input must be 2D"
468
+ assert target.dim() == 1, "Target must be 1D"
469
+ assert dloss.dim() == 1, "dloss must be 1D"
470
+ assert lse.dim() == 1, "lse must be 1D"
471
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
472
+ assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
473
+ assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
474
+ assert (
475
+ x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
+ ), "Tensors must be on CUDA device"
477
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
478
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
+
480
+ M, N = x.shape
481
+ dx = torch.empty_like(x) if not inplace_backward else x
482
+ dtype = torch2cute_dtype_map[x.dtype]
483
+
484
+ convert_from_dlpack = lambda tensor: (
485
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
486
+ mode=0, stride_order=(0, 1)
487
+ )
488
+ )
489
+ x_tensor = convert_from_dlpack(x)
490
+ dx_tensor = convert_from_dlpack(dx)
491
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
492
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
493
+ target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
494
+ mode=0
495
+ )
496
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
497
+
498
+ compile_key = (dtype, N)
499
+ if compile_key not in _cross_entropy_backward.compile_cache:
500
+ cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
501
+ _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
502
+ cross_entropy_backward_op,
503
+ x_tensor,
504
+ target_tensor,
505
+ dloss_tensor,
506
+ dx_tensor,
507
+ lse_tensor,
508
+ stream,
509
+ )
510
+ _cross_entropy_backward.compile_cache[compile_key](
511
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
512
+ )
513
+ return dx
514
+
515
+
516
+ _cross_entropy_backward.compile_cache = {}
517
+
518
+
519
+ class CrossEntropyFunction(torch.autograd.Function):
520
+ @staticmethod
521
+ def forward(ctx, x, target, inplace_backward=False):
522
+ loss, lse = _cross_entropy(x, target, return_lse=True)
523
+ ctx.save_for_backward(x, target, lse)
524
+ ctx.inplace_backward = inplace_backward
525
+ return loss
526
+
527
+ @staticmethod
528
+ def backward(ctx, dloss):
529
+ x, target, lse = ctx.saved_tensors
530
+ dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
531
+ return dx, None, None
532
+
533
+
534
+ def cross_entropy(
535
+ x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
536
+ ) -> torch.Tensor:
537
+ """Cross entropy loss with automatic differentiation support.
538
+
539
+ Args:
540
+ x: Input logits tensor of shape (M, N)
541
+ target: Target class indices tensor of shape (M,)
542
+
543
+ Returns:
544
+ Cross entropy loss tensor of shape (M,)
545
+ """
546
+ return CrossEntropyFunction.apply(x, target, inplace_backward)
quack/layernorm.py ADDED
@@ -0,0 +1,351 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+
4
+ import torch
5
+ from typing import Optional
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+ import quack.utils as utils
13
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
14
+
15
+
16
+ class LayerNorm(ReductionBase):
17
+ def __init__(self, dtype: cutlass.Numeric, N: int):
18
+ super().__init__(dtype, N, stage=2) # 2 stages for mean and var
19
+ self.reload_from = None if N <= 16384 else "smem"
20
+ self.delay_w_load = False
21
+
22
+ def _calculate_threads_per_row(self):
23
+ N = self.N
24
+ return (
25
+ 8
26
+ if N <= 64
27
+ else (
28
+ 16
29
+ if N <= 128
30
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
31
+ )
32
+ )
33
+
34
+ def _set_cluster_n(self):
35
+ N = self.N
36
+ # cluster_n = 4 is faster and cluster_n = 2 for N=64k for some reason
37
+ # Similarly cluster_n = 8 is faster for N=128k
38
+ if cutlass.const_expr(self.dtype.width == 16):
39
+ cluster_n = (
40
+ 1
41
+ if N <= 16 * 1024
42
+ else (
43
+ 2
44
+ if N <= 32 * 1024
45
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
46
+ )
47
+ )
48
+ else: # fp32
49
+ cluster_n = (
50
+ 1
51
+ if N <= 32 * 1024
52
+ else (
53
+ 2
54
+ if N <= 64 * 1024
55
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
56
+ )
57
+ )
58
+ self.cluster_n = cluster_n
59
+
60
+ @cute.jit
61
+ def __call__(
62
+ self,
63
+ mX: cute.Tensor,
64
+ mW: cute.Tensor,
65
+ mO: cute.Tensor,
66
+ mRstd: Optional[cute.Tensor],
67
+ mMean: Optional[cute.Tensor],
68
+ stream: cuda.CUstream,
69
+ eps: cutlass.Float32 = 1e-6,
70
+ ):
71
+ assert mX.element_type == self.dtype
72
+ assert mO.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
78
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
79
+ if cutlass.const_expr(mRstd is not None):
80
+ mRstd_expanded_layout = cute.append(
81
+ mRstd.layout, cute.make_layout((self.N,), stride=(0,))
82
+ )
83
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
84
+ if cutlass.const_expr(mMean is not None):
85
+ mMean_expanded_layout = cute.append(
86
+ mMean.layout, cute.make_layout((self.N,), stride=(0,))
87
+ )
88
+ mMean = cute.make_tensor(mMean.iterator, mMean_expanded_layout)
89
+ self.kernel(mX, mW, mO, mRstd, mMean, eps, tv_layout, tiler_mn, self.reload_from).launch(
90
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
91
+ block=[num_threads, 1, 1],
92
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
93
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
94
+ stream=stream,
95
+ )
96
+
97
+ @cute.kernel
98
+ def kernel(
99
+ self,
100
+ mX: cute.Tensor,
101
+ mW: cute.Tensor,
102
+ mO: cute.Tensor,
103
+ mRstd: Optional[cute.Tensor],
104
+ mMean: Optional[cute.Tensor],
105
+ eps: cute.Float32,
106
+ tv_layout: cute.Layout,
107
+ tiler_mn: cute.Shape,
108
+ reload_from: cutlass.Constexpr = None,
109
+ delay_w_load: cutlass.Constexpr = False,
110
+ ):
111
+ tidx, _, _ = cute.arch.thread_idx()
112
+ bidx, _, _ = cute.arch.block_idx()
113
+ if cutlass.const_expr(self.cluster_n > 1):
114
+ cluster_y = cute.arch.block_idx()[1]
115
+ else:
116
+ cluster_y = cutlass.const_expr(0)
117
+
118
+ smem = cutlass.utils.SmemAllocator()
119
+ sX = smem.allocate_tensor(
120
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
121
+ )
122
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
123
+
124
+ shape = mX.shape
125
+ idX = cute.make_identity_tensor(shape)
126
+ # slice for CTAs
127
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
128
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
129
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
130
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
131
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
132
+ gRstd = (
133
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
134
+ if cutlass.const_expr(mRstd is not None)
135
+ else None
136
+ )
137
+ gMean = (
138
+ cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
139
+ if cutlass.const_expr(mMean is not None)
140
+ else None
141
+ )
142
+
143
+ # declare the atoms which will be used later for memory copy
144
+ copy_atom_load_X = cute.make_copy_atom(
145
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
146
+ )
147
+ copy_atom_load_X_async = cute.make_copy_atom(
148
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
149
+ )
150
+ copy_atom_load_W = cute.make_copy_atom(
151
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
152
+ )
153
+ copy_atom_store_O = cute.make_copy_atom(
154
+ cute.nvgpu.CopyUniversalOp(), mO.element_type, num_bits_per_copy=128
155
+ )
156
+
157
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(
158
+ tidx
159
+ )
160
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
161
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
162
+
163
+ tWgW = thr_copy_W.partition_S(gW)
164
+ tXgX = thr_copy_X.partition_S(gX)
165
+ tXsX = thr_copy_X.partition_D(sX)
166
+ tXgO = thr_copy_O.partition_D(gO)
167
+ tXrRstd = thr_copy_O.partition_D(gRstd) if cutlass.const_expr(mRstd is not None) else None
168
+ tXrMean = thr_copy_O.partition_D(gMean) if cutlass.const_expr(mMean is not None) else None
169
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
170
+
171
+ # allocate fragments for gmem->rmem
172
+ tWrW = cute.make_fragment_like(tWgW)
173
+ tXrW = thr_copy_X.retile(tWrW)
174
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
175
+
176
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
177
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
178
+
179
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
180
+ row = tXcX[0][0]
181
+ if row < shape[0]:
182
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
183
+ cute.arch.cp_async_commit_group()
184
+
185
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
186
+ if cutlass.const_expr(not delay_w_load):
187
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
188
+
189
+ cute.arch.cp_async_wait_group(0)
190
+ cute.autovec_copy(tXsX, tXrX)
191
+ x = tXrX.load().to(cute.Float32)
192
+ threads_per_row = tv_layout.shape[0][0]
193
+ sum_x = utils.row_reduce(
194
+ x,
195
+ cute.ReductionOp.ADD,
196
+ threads_per_row,
197
+ reduction_buffer[None, None, 0],
198
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
199
+ init_val=0.0,
200
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
201
+ )
202
+ mean = sum_x / shape[1]
203
+ if cutlass.const_expr(reload_from == "smem"):
204
+ cute.autovec_copy(tXsX, tXrX)
205
+ x = tXrX.load().to(cute.Float32)
206
+ elif cutlass.const_expr(reload_from == "gmem"):
207
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
208
+ x = tXrX.load().to(cute.Float32)
209
+
210
+ sum_sq_x_sub_mean = utils.row_reduce(
211
+ (x - mean) * (x - mean),
212
+ cute.ReductionOp.ADD,
213
+ threads_per_row,
214
+ reduction_buffer[None, None, 1],
215
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
216
+ init_val=0.0,
217
+ )
218
+ rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
219
+ if cutlass.const_expr(mRstd is not None):
220
+ # Only the thread corresponding to column 0 writes out the rstd to gmem
221
+ if (
222
+ tXcX[0][1] == 0
223
+ and row < shape[0]
224
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
225
+ ):
226
+ tXrRstd[0] = rstd
227
+ if cutlass.const_expr(mMean is not None):
228
+ # Only the thread corresponding to column 0 writes out the mean to gmem
229
+ if (
230
+ tXcX[0][1] == 0
231
+ and row < shape[0]
232
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
233
+ ):
234
+ tXrMean[0] = mean
235
+ if cutlass.const_expr(delay_w_load):
236
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
237
+ if cutlass.const_expr(reload_from == "smem"):
238
+ cute.autovec_copy(tXsX, tXrX)
239
+ x = tXrX.load().to(cute.Float32)
240
+ elif cutlass.const_expr(reload_from == "gmem"):
241
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
242
+ x = tXrX.load().to(cute.Float32)
243
+ x_hat = (x - mean) * rstd
244
+ w = tXrW.load().to(cute.Float32)
245
+ y = x_hat * w
246
+ tXrO.store(y.to(tXrO.element_type))
247
+ tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
248
+ if row < shape[0]:
249
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
250
+
251
+
252
+ def layernorm(
253
+ x: torch.Tensor,
254
+ weight: torch.Tensor,
255
+ eps: float = 1e-6,
256
+ return_rstd: bool = False,
257
+ return_mean: bool = False,
258
+ ) -> torch.Tensor:
259
+ """LayerNorm forward pass.
260
+
261
+ Args:
262
+ x: Input tensor of shape (M, N)
263
+ weight: Weight tensor of shape (N,)
264
+ eps: Small value for numerical stability
265
+ return_rstd: Whether to return the reciprocal standard deviation
266
+ return_mean: Whether to return the mean
267
+
268
+ Returns:
269
+ Normalized output tensor of same shape as x
270
+ If return_rstd is True, also returns rstd tensor of shape (M,)
271
+ If return_mean is True, also returns mean tensor of shape (M,)
272
+ """
273
+ assert x.dim() == 2, "Input must be 2D"
274
+ assert weight.dim() == 1, "Weight must be 1D"
275
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
276
+ assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
277
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
278
+ assert weight.dtype == torch.float32, "Weight must be float32"
279
+ M, N = x.shape
280
+ device = x.device
281
+ out = torch.empty_like(x)
282
+ rstd = torch.empty(M, device=device, dtype=torch.float32) if return_rstd else None
283
+ mean = torch.empty(M, device=device, dtype=torch.float32) if return_mean else None
284
+ dtype = torch2cute_dtype_map[x.dtype]
285
+ convert_from_dlpack = lambda x: (
286
+ from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
287
+ mode=0, stride_order=(0, 1)
288
+ )
289
+ )
290
+ x_tensor, out_tensor = [
291
+ # utils.convert_from_dlpack(t, leading_dim=t.ndim - 1, divisibility=128 // dtype.width)
292
+ convert_from_dlpack(t)
293
+ for t in (x, out)
294
+ ]
295
+ weight_tensor = utils.convert_from_dlpack(
296
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
297
+ )
298
+ rstd_tensor = (
299
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
300
+ if rstd is not None
301
+ else None
302
+ )
303
+ mean_tensor = (
304
+ from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
305
+ if mean is not None
306
+ else None
307
+ )
308
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
309
+ compile_key = (dtype, N, rstd is not None, mean is not None)
310
+ if compile_key not in layernorm.compile_cache:
311
+ rmsnorm_op = LayerNorm(dtype, N)
312
+ layernorm.compile_cache[compile_key] = cute.compile(
313
+ rmsnorm_op,
314
+ x_tensor,
315
+ weight_tensor,
316
+ out_tensor,
317
+ rstd_tensor,
318
+ mean_tensor,
319
+ current_stream,
320
+ )
321
+ layernorm.compile_cache[compile_key](
322
+ x_tensor, weight_tensor, out_tensor, rstd_tensor, mean_tensor, current_stream, eps
323
+ )
324
+ return (
325
+ (out, rstd, mean)
326
+ if return_mean and return_rstd
327
+ else (
328
+ (out, rstd)
329
+ if return_rstd and not return_mean
330
+ else ((out, mean) if return_mean and not return_rstd else (out))
331
+ )
332
+ )
333
+
334
+
335
+ layernorm.compile_cache = {}
336
+
337
+
338
+ def layernorm_ref(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
339
+ x_f32 = x.float()
340
+ return torch.nn.functional.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype)
341
+
342
+
343
+ def rstd_ref(x: torch.Tensor, eps: float = 1e-6):
344
+ x_f32 = x.float()
345
+ mean = x_f32.mean(dim=-1, keepdim=True)
346
+ var = ((x_f32 - mean) ** 2).mean(dim=-1)
347
+ return 1.0 / torch.sqrt(var + eps)
348
+
349
+
350
+ def mean_ref(x: torch.Tensor) -> torch.Tensor:
351
+ return x.float().mean(dim=-1)
quack/reduction_base.py CHANGED
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
- import quack.utils as utils
10
-
11
9
 
12
10
  torch2cute_dtype_map = {
13
11
  torch.float16: cutlass.Float16,
@@ -39,7 +37,6 @@ class ReductionBase:
39
37
  vecsize = copy_bits // self.dtype.width
40
38
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
39
  num_threads = self._get_num_threads()
42
- num_warps = num_threads // cute.arch.WARP_SIZE
43
40
  assert num_threads % cute.arch.WARP_SIZE == 0
44
41
 
45
42
  threads_per_row = self._calculate_threads_per_row()
@@ -64,7 +61,7 @@ class ReductionBase:
64
61
 
65
62
  def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
63
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
64
+ warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
65
  return cute.make_ordered_layout(
69
66
  (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
67
  order=(1, 0, 2),
quack/rmsnorm.py CHANGED
@@ -9,7 +9,6 @@ import cuda.bindings.driver as cuda
9
9
  import cutlass
10
10
  import cutlass.cute as cute
11
11
  from cutlass.cute.runtime import from_dlpack
12
-
13
12
  import quack.utils as utils
14
13
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
14
 
@@ -118,7 +117,10 @@ class RMSNorm(ReductionBase):
118
117
  shape = mX.shape
119
118
  idX = cute.make_identity_tensor(shape)
120
119
  # slice for CTAs
121
- gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
120
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
121
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
122
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
123
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
122
124
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
123
125
  gRstd = (
124
126
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
@@ -210,20 +212,18 @@ class RMSNorm(ReductionBase):
210
212
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
211
213
 
212
214
 
213
- def rmsnorm(
215
+ def _rmsnorm_fwd(
214
216
  x: torch.Tensor,
215
217
  weight: torch.Tensor,
216
218
  eps: float = 1e-6,
217
219
  return_rstd: bool = False,
218
220
  ) -> torch.Tensor:
219
221
  """RMSNorm forward pass.
220
-
221
222
  Args:
222
223
  x: Input tensor of shape (M, N)
223
224
  weight: Weight tensor of shape (N,)
224
225
  eps: Small value for numerical stability
225
226
  return_rstd: Whether to return the reciprocal standard deviation
226
-
227
227
  Returns:
228
228
  Normalized output tensor of same shape as x
229
229
  If return_rstd is True, also returns rstd tensor of shape (M,)
@@ -259,18 +259,18 @@ def rmsnorm(
259
259
  )
260
260
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
261
261
  compile_key = (dtype, N, rstd is not None)
262
- if compile_key not in rmsnorm.compile_cache:
262
+ if compile_key not in _rmsnorm_fwd.compile_cache:
263
263
  rmsnorm_op = RMSNorm(dtype, N)
264
- rmsnorm.compile_cache[compile_key] = cute.compile(
264
+ _rmsnorm_fwd.compile_cache[compile_key] = cute.compile(
265
265
  rmsnorm_op, x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream
266
266
  )
267
- rmsnorm.compile_cache[compile_key](
267
+ _rmsnorm_fwd.compile_cache[compile_key](
268
268
  x_tensor, weight_tensor, out_tensor, rstd_tensor, current_stream, eps
269
269
  )
270
270
  return (out, rstd) if return_rstd else out
271
271
 
272
272
 
273
- rmsnorm.compile_cache = {}
273
+ _rmsnorm_fwd.compile_cache = {}
274
274
 
275
275
 
276
276
  def rmsnorm_ref(x, w, eps=1e-6):
@@ -283,3 +283,383 @@ def rmsnorm_ref(x, w, eps=1e-6):
283
283
  def rstd_ref(x, eps=1e-6):
284
284
  x_f32 = x.float()
285
285
  return 1.0 / torch.sqrt(torch.mean(x_f32 * x_f32, dim=-1) + eps)
286
+
287
+
288
+ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
289
+ """Reference implementation for RMSNorm backward pass."""
290
+ x_f32 = x.float()
291
+ x_hat = x_f32 * rstd.unsqueeze(1)
292
+ wdy = dout * w
293
+ c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
294
+ dx = (wdy - x_hat * c1) * rstd.unsqueeze(1)
295
+
296
+ # dL/dW
297
+ dw = (dout * x_hat).sum(dim=0)
298
+ return dx.to(x.dtype), dw.to(w.dtype)
299
+
300
+
301
+ class RMSNormBackward(ReductionBase):
302
+ def __init__(self, dtype: cutlass.Numeric, N: int):
303
+ # 1 stage for computing mean of x_hat * wdy
304
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
305
+
306
+ def _calculate_threads_per_row(self):
307
+ N = self.N
308
+ return (
309
+ 8
310
+ if N <= 64
311
+ else (
312
+ 16
313
+ if N <= 128
314
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
315
+ )
316
+ )
317
+
318
+ def _set_cluster_n(self):
319
+ N = self.N
320
+ if cutlass.const_expr(self.dtype.width == 16):
321
+ cluster_n = (
322
+ 1
323
+ if N <= 16 * 1024
324
+ else (
325
+ 2
326
+ if N <= 32 * 1024
327
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
328
+ )
329
+ )
330
+ else: # fp32
331
+ cluster_n = (
332
+ 1
333
+ if N <= 32 * 1024
334
+ else (
335
+ 2
336
+ if N <= 64 * 1024
337
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
338
+ )
339
+ )
340
+ self.cluster_n = cluster_n
341
+
342
+ @cute.jit
343
+ def __call__(
344
+ self,
345
+ mX: cute.Tensor,
346
+ mW: cute.Tensor,
347
+ mDout: cute.Tensor,
348
+ mRstd: cute.Tensor,
349
+ mDx: cute.Tensor,
350
+ mDw: cute.Tensor,
351
+ sm_count: cutlass.Constexpr,
352
+ stream: cuda.CUstream,
353
+ ):
354
+ self._set_cluster_n()
355
+ tiler_mn, tv_layout = self._get_tv_layout()
356
+ num_threads = cute.size(tv_layout, mode=[0])
357
+ num_warps = num_threads // cute.arch.WARP_SIZE
358
+
359
+ mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
360
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
361
+
362
+ mRstd_expanded_layout = cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
363
+ mRstd = cute.make_tensor(mRstd.iterator, mRstd_expanded_layout)
364
+
365
+ num_blocks = (
366
+ sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
367
+ )
368
+
369
+ self.kernel(mX, mW, mDout, mRstd, mDx, mDw, sm_count, tv_layout, tiler_mn).launch(
370
+ grid=[num_blocks, self.cluster_n, 1],
371
+ block=[num_threads, 1, 1],
372
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
373
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
374
+ stream=stream,
375
+ )
376
+
377
+ @cute.kernel
378
+ def kernel(
379
+ self,
380
+ mX: cute.Tensor,
381
+ mW: cute.Tensor,
382
+ mDout: cute.Tensor,
383
+ mRstd: cute.Tensor,
384
+ mDx: cute.Tensor,
385
+ mDw: cute.Tensor,
386
+ sm_count: cutlass.Constexpr,
387
+ tv_layout: cute.Layout,
388
+ tiler_mn: cute.Shape,
389
+ ):
390
+ tidx, _, _ = cute.arch.thread_idx()
391
+ bidx, cluster_y, _ = cute.arch.block_idx()
392
+ gdim, _, _ = cute.arch.grid_dim()
393
+
394
+ shape = mX.shape
395
+ M, N = shape[0], shape[1]
396
+
397
+ idX = cute.make_identity_tensor(shape)
398
+
399
+ smem = cutlass.utils.SmemAllocator()
400
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
401
+
402
+ copy_atom_load_X = cute.make_copy_atom(
403
+ cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128
404
+ )
405
+
406
+ copy_atom_load_W = cute.make_copy_atom(
407
+ cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=128
408
+ )
409
+
410
+ copy_atom_store_dX = cute.make_copy_atom(
411
+ cute.nvgpu.CopyUniversalOp(), mDx.element_type, num_bits_per_copy=128
412
+ )
413
+
414
+ copy_atom_dw = cute.make_copy_atom(
415
+ cute.nvgpu.CopyUniversalOp(), mDw.element_type, num_bits_per_copy=128
416
+ )
417
+
418
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
419
+ thr_copy_W = cute.make_tiled_copy(copy_atom_load_W, tv_layout, tiler_mn).get_slice(tidx)
420
+ thr_copy_dw = cute.make_tiled_copy(copy_atom_dw, tv_layout, tiler_mn).get_slice(tidx)
421
+ thr_store_dx = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
422
+
423
+ gW = cute.local_tile(mW, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
424
+ tWgW = thr_copy_W.partition_S(gW)
425
+ tWrW = cute.make_fragment_like(tWgW)
426
+ tXrW = thr_copy_X.retile(tWrW)
427
+
428
+ gW_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
429
+
430
+ tWpW = utils.predicate_k(thr_copy_W.partition_S(gW_coord), limit=shape[1])
431
+ cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
432
+ weight = tXrW.load().to(cute.Float32)
433
+
434
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
435
+
436
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
437
+
438
+ dw_coord = cute.local_tile(idX, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
439
+ tDwpDw = utils.predicate_k(thr_copy_dw.partition_S(dw_coord), limit=shape[1])
440
+
441
+ gDw = cute.local_tile(mDw, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
442
+ tDwgDw = thr_copy_dw.partition_D(gDw)
443
+ tDwrDw = cute.make_fragment_like(tDwgDw)
444
+ dw_accumulator = thr_copy_X.retile(tDwrDw)
445
+ dw_accumulator.fill(0.0)
446
+
447
+ M_pad = ((M + sm_count - 1) // sm_count) * sm_count
448
+
449
+ jump = sm_count if tiler_mn[0] == 1 else min(sm_count, cute.ceil_div(1024, tiler_mn[0]))
450
+
451
+ if cutlass.const_expr(self.cluster_n > 1):
452
+ cute.arch.cluster_arrive()
453
+ cute.arch.cluster_wait()
454
+
455
+ ## need to update range_dynamic since it will be deprecated soon
456
+ for row_offset in cutlass.range_dynamic(bidx, M_pad, jump):
457
+ gX = cute.local_tile(
458
+ mX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
459
+ )
460
+ gDout = cute.local_tile(
461
+ mDout, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
462
+ )
463
+ gRstd = cute.local_tile(
464
+ mRstd, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
465
+ )
466
+ gDx = cute.local_tile(
467
+ mDx, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
468
+ )
469
+ cX = cute.local_tile(
470
+ idX, tiler_mn, (row_offset, 0 if self.cluster_n == 1 else cluster_y)
471
+ )
472
+
473
+ tXgX = thr_copy_X.partition_S(gX)
474
+ thrDout = thr_copy_X.partition_S(gDout)
475
+ tXrRstd = thr_copy_W.partition_S(gRstd)
476
+ thrDx = thr_store_dx.partition_D(gDx)
477
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
478
+
479
+ tXrX, frgDout, frgDx = [cute.make_fragment_like(thr) for thr in (tXgX, thrDout, thrDx)]
480
+
481
+ tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
482
+
483
+ if tXcX[0][0] < shape[0]:
484
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
485
+ cute.copy(copy_atom_load_X, thrDout, frgDout, pred=tXpX)
486
+
487
+ x = tXrX.load().to(cute.Float32)
488
+ dout = frgDout.load().to(cute.Float32)
489
+
490
+ rstd = tXrRstd[0]
491
+ x_hat = x * rstd
492
+ wdy = dout * weight
493
+
494
+ threads_per_row = tv_layout.shape[0][0]
495
+
496
+ row = tXcX[0][0]
497
+ if cutlass.const_expr(self.cluster_n > 1):
498
+ cute.arch.cluster_arrive()
499
+ cute.arch.cluster_wait()
500
+ else:
501
+ cute.arch.barrier()
502
+
503
+ mean_xhat_wdy = (
504
+ utils.row_reduce(
505
+ x_hat * wdy,
506
+ cute.ReductionOp.ADD,
507
+ threads_per_row,
508
+ reduction_buffer[None, None, 0],
509
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
510
+ init_val=0.0,
511
+ hook_fn=cute.arch.cluster_wait
512
+ if cutlass.const_expr(self.cluster_n > 1)
513
+ else None,
514
+ )
515
+ / shape[1]
516
+ )
517
+
518
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd
519
+ frgDx.store(dx.to(frgDout.element_type))
520
+
521
+ if row < M:
522
+ cute.copy(copy_atom_store_dX, frgDx, thrDx, pred=tXpX)
523
+
524
+ if cutlass.const_expr(self.cluster_n > 1):
525
+ cute.arch.cluster_arrive()
526
+ cute.arch.cluster_wait()
527
+ else:
528
+ cute.arch.barrier()
529
+
530
+ if row < M:
531
+ dw_row = dout * x_hat
532
+ current_dw = dw_accumulator.load().to(cute.Float32)
533
+ updated_dw = current_dw + dw_row
534
+ dw_accumulator.store(updated_dw.to(dw_accumulator.element_type))
535
+
536
+ """
537
+ if cutlass.const_expr(self.cluster_n > 1):
538
+ cute.arch.cluster_arrive()
539
+ cute.arch.cluster_wait()
540
+ else:
541
+ cute.arch.barrier()
542
+ """
543
+ """
544
+ if cutlass.const_expr(self.cluster_n > 1):
545
+ cute.arch.cluster_arrive()
546
+ cute.arch.cluster_wait()
547
+ else:
548
+ cute.arch.barrier()
549
+ """
550
+
551
+ cute.autovec_copy(dw_accumulator, tDwrDw)
552
+ cute.copy(copy_atom_dw, tDwrDw, tDwgDw, pred=tDwpDw)
553
+
554
+
555
+ def _rmsnorm_backward(
556
+ x: torch.Tensor,
557
+ weight: torch.Tensor,
558
+ dout: torch.Tensor,
559
+ rstd: torch.Tensor,
560
+ ) -> (torch.Tensor, torch.Tensor):
561
+ """RMSNorm backward pass.
562
+ Args:
563
+ x: Input tensor of shape (M, N)
564
+ weight: Weight tensor of shape (N,)
565
+ dout: Upstream gradients tensor of shape (M, N)
566
+ rstd: Reciprocal standard deviation tensor of shape (M,)
567
+ Returns:
568
+ Tuple of (dx, dw) where:
569
+ - dx: Input gradients tensor of same shape as x
570
+ - dw: Weight gradients tensor of same shape as weight
571
+ """
572
+ assert x.dim() == 2, "Input must be 2D"
573
+ assert weight.dim() == 1, "Weight must be 1D"
574
+ assert x.shape[-1] == weight.shape[0], "Last dimension of input must match weight dimension"
575
+ assert x.is_cuda and weight.is_cuda, "Tensors must be on CUDA device"
576
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
577
+ assert weight.dtype == torch.float32, "Weight must be float32"
578
+
579
+ M, N = x.shape
580
+ dx = torch.empty_like(x)
581
+
582
+ device = x.device
583
+
584
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count * 8
585
+ dw_partial = torch.zeros((sm_count, N), device=device, dtype=weight.dtype)
586
+
587
+ dtype = torch2cute_dtype_map[x.dtype]
588
+
589
+ convert_from_dlpack = lambda tensor: (
590
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
591
+ mode=0, stride_order=(0, 1)
592
+ )
593
+ )
594
+
595
+ x_tensor, dout_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (x, dout, dx)]
596
+
597
+ weight_tensor = utils.convert_from_dlpack(
598
+ weight.detach(), leading_dim=0, divisibility=128 // cutlass.Float32.width
599
+ )
600
+
601
+ dw_partial_tensor = convert_from_dlpack(dw_partial)
602
+ rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
603
+
604
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
605
+
606
+ compile_key = (dtype, N)
607
+ if compile_key not in _rmsnorm_backward.compile_cache:
608
+ rmsnorm_backward_op = RMSNormBackward(dtype, N)
609
+ _rmsnorm_backward.compile_cache[compile_key] = cute.compile(
610
+ rmsnorm_backward_op,
611
+ x_tensor,
612
+ weight_tensor,
613
+ dout_tensor,
614
+ rstd_tensor,
615
+ dx_tensor,
616
+ dw_partial_tensor,
617
+ sm_count,
618
+ current_stream,
619
+ )
620
+
621
+ _rmsnorm_backward.compile_cache[compile_key](
622
+ x_tensor,
623
+ weight_tensor,
624
+ dout_tensor,
625
+ rstd_tensor,
626
+ dx_tensor,
627
+ dw_partial_tensor,
628
+ current_stream,
629
+ )
630
+
631
+ dw = dw_partial.sum(dim=0).to(weight.dtype)
632
+ return dx, dw
633
+
634
+
635
+ _rmsnorm_backward.compile_cache = {}
636
+
637
+
638
+ class RMSNormFunction(torch.autograd.Function):
639
+ @staticmethod
640
+ def forward(ctx, x, weight, eps):
641
+ out, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
642
+ ctx.save_for_backward(x, weight, rstd)
643
+ ctx.eps = eps
644
+ return out
645
+
646
+ @staticmethod
647
+ def backward(ctx, dout):
648
+ x, weight, rstd = ctx.saved_tensors
649
+ dx, dw = _rmsnorm_backward(x, weight, dout, rstd)
650
+ # dw is returned for weight gradient, None for eps gradient
651
+ return dx, dw, None
652
+
653
+
654
+ def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
655
+ """RMSNorm forward pass with automatic differentiation support.
656
+
657
+ Args:
658
+ x: Input tensor of shape (M, N)
659
+ weight: Weight tensor of shape (N,)
660
+ eps: Small value for numerical stability
661
+
662
+ Returns:
663
+ Normalized output tensor of same shape as x
664
+ """
665
+ return RMSNormFunction.apply(x, weight, eps)
quack/softmax.py CHANGED
@@ -98,7 +98,10 @@ class Softmax(ReductionBase):
98
98
  shape = mX.shape
99
99
  idX = cute.make_identity_tensor(shape)
100
100
  # slice for CTAs
101
- gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
101
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
102
+ mX, mO = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
103
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
104
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
102
105
 
103
106
  smem = cutlass.utils.SmemAllocator()
104
107
  sX = smem.allocate_tensor(
@@ -312,9 +315,11 @@ class SoftmaxBackward(ReductionBase):
312
315
  shape = mdY.shape
313
316
  idX = cute.make_identity_tensor(shape)
314
317
  # slice for CTAs
315
- gdY, gY, gdX, cX = [
316
- cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
318
+ mdY, mY, mdX = [
319
+ utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
317
320
  ]
321
+ gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
322
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
318
323
 
319
324
  smem = cutlass.utils.SmemAllocator()
320
325
  sdY = smem.allocate_tensor(
quack/utils.py CHANGED
@@ -23,20 +23,6 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
23
23
  )
24
24
 
25
25
 
26
- @cute.jit
27
- def max_constexpr(
28
- a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
29
- ) -> cutlass.Constexpr[cute.Numeric]:
30
- return a if a > b else b
31
-
32
-
33
- @cute.jit
34
- def min_constexpr(
35
- a: cutlass.Constexpr[cute.Numeric], b: cutlass.Constexpr[cute.Numeric]
36
- ) -> cutlass.Constexpr[cute.Numeric]:
37
- return a if a < b else b
38
-
39
-
40
26
  @cute.jit
41
27
  def warp_reduce(
42
28
  val: cute.TensorSSA | cute.Numeric,
@@ -196,7 +182,7 @@ def row_reduce(
196
182
  val = warp_reduce(
197
183
  val,
198
184
  warp_op,
199
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
185
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
200
186
  )
201
187
  if cutlass.const_expr(hook_fn is not None):
202
188
  hook_fn()
@@ -226,7 +212,7 @@ def online_softmax_reduce(
226
212
  max_x = warp_reduce(
227
213
  x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
228
214
  cute.arch.fmax,
229
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
215
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
230
216
  )
231
217
  log2_e = math.log2(math.e)
232
218
  exp_x = exp2f(x * log2_e - (max_x * log2_e))
@@ -234,7 +220,7 @@ def online_softmax_reduce(
234
220
  sum_exp_x = warp_reduce(
235
221
  exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
236
222
  operator.add,
237
- width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
223
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
238
224
  )
239
225
  if cutlass.const_expr(hook_fn is not None):
240
226
  hook_fn()
@@ -303,7 +289,6 @@ def online_softmax_reduce(
303
289
  @cute.jit
304
290
  def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
305
291
  """exp2f calculation for both vector and scalar.
306
-
307
292
  :param x: input value
308
293
  :type x: cute.TensorSSA or Float32
309
294
  :return: exp2 value
@@ -405,3 +390,19 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
405
390
  vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
406
391
  )
407
392
  return res0, res1
393
+
394
+
395
+ @dsl_user_op
396
+ def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
397
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
398
+ flat_stride = cute.flatten_to_tuple(tensor.stride)
399
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
400
+ assert isinstance(tensor.iterator, cute.Pointer)
401
+ # HACK: we assume that applying the offset does not change the pointer alignment
402
+ new_ptr = cute.make_ptr(
403
+ tensor.element_type,
404
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
405
+ tensor.memspace,
406
+ assumed_align=tensor.iterator.max_alignment,
407
+ )
408
+ return cute.make_tensor(new_ptr, tensor.layout)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -0,0 +1,12 @@
1
+ quack/__init__.py,sha256=EV_43VfcxQUsFu_Nfq0944ImloZ2T9X94-8T9IQM0I8,203
2
+ quack/cross_entropy.py,sha256=bg66wECki5I71SMPIRUa-6-oFJ93aIKpK1jqT__SCBM,19775
3
+ quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
4
+ quack/reduction_base.py,sha256=fFuGXPR3lDq2yw_m86ujmkni6R51jzNAzy_r9R6C8tA,3563
5
+ quack/rmsnorm.py,sha256=rbRP_O-EJYvrvxYPFwfy0yCbG7Qs_DgbzgacxTLvih4,24159
6
+ quack/softmax.py,sha256=b-QEQiGjOEPidolE2--K21kwEaLJ-3wQoGIz_BsEcSI,16742
7
+ quack/utils.py,sha256=laz_lqeggiIOY_xCs3c3VvCPP8bJmSKBfjFpFR1Al80,15946
8
+ quack_kernels-0.1.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
9
+ quack_kernels-0.1.6.dist-info/METADATA,sha256=NKhSoudW9lNdYryNiMkyjKXYl5n37cuFMbyLv3zr5js,289
10
+ quack_kernels-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ quack_kernels-0.1.6.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
12
+ quack_kernels-0.1.6.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- quack/__init__.py,sha256=cFLxO6nA_faFqHf4N-Fy7G0j8ykuYPB1uOt9uoJ2dkQ,203
2
- quack/cross_entropy.py,sha256=HnF2OErEzb10SWxY6HoYE42lnvlw2DsWCks7mylPwnI,9511
3
- quack/reduction_base.py,sha256=Rsj9ZeSHcKAXGn1p7mY1vrrBqxevi4feLjY0JJhKnmY,3663
4
- quack/rmsnorm.py,sha256=TkOZsXJwcsoZMLnmEWQ-pEF0r-iiZhGrCNLSFCXfv6s,10676
5
- quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
6
- quack/utils.py,sha256=zVc9U-5No19trE585KqDdXx9chAruXPRIPMZdO7mkRg,15603
7
- quack_kernels-0.1.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
- quack_kernels-0.1.4.dist-info/METADATA,sha256=xl62C5WFgiUbnOICAzjldsljJ9j1Fb_JxZVksHLCI8I,289
9
- quack_kernels-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- quack_kernels-0.1.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
- quack_kernels-0.1.4.dist-info/RECORD,,