quack-kernels 0.1.7__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {quack_kernels-0.1.7/quack_kernels.egg-info → quack_kernels-0.1.9}/PKG-INFO +1 -1
  2. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/__init__.py +1 -1
  3. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/cross_entropy.py +56 -15
  4. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/rmsnorm.py +232 -106
  5. {quack_kernels-0.1.7 → quack_kernels-0.1.9/quack_kernels.egg-info}/PKG-INFO +1 -1
  6. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/top_level.txt +1 -0
  7. quack_kernels-0.1.9/tests/test_rmsnorm.py +398 -0
  8. quack_kernels-0.1.7/tests/test_rmsnorm.py +0 -183
  9. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/LICENSE +0 -0
  10. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/README.md +0 -0
  11. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/pyproject.toml +0 -0
  12. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/layernorm.py +0 -0
  13. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/reduction_base.py +0 -0
  14. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/softmax.py +0 -0
  15. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/utils.py +0 -0
  16. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/SOURCES.txt +0 -0
  17. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/dependency_links.txt +0 -0
  18. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/requires.txt +0 -0
  19. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/setup.cfg +0 -0
  20. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/setup.py +0 -0
  21. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/tests/test_cross_entropy.py +0 -0
  22. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/tests/test_layernorm.py +0 -0
  23. {quack_kernels-0.1.7 → quack_kernels-0.1.9}/tests/test_softmax.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.7
3
+ Version: 0.1.9
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.7"
1
+ __version__ = "0.1.9"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -1,16 +1,16 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- import torch
5
4
  from typing import Optional, Type
6
5
 
7
6
  import cuda.bindings.driver as cuda
8
7
 
9
8
  import cutlass
10
9
  import cutlass.cute as cute
11
- from cutlass.cute.runtime import from_dlpack
12
10
 
13
11
  import quack.utils as utils
12
+ import torch
13
+ from cutlass.cute.runtime import from_dlpack
14
14
  from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
15
 
16
16
 
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
79
79
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
80
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
81
  block=[num_threads, 1, 1],
82
- cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
82
+ cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
83
83
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
84
  stream=stream,
85
85
  )
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
111
111
 
112
112
  smem = cutlass.utils.SmemAllocator()
113
113
  sX = smem.allocate_tensor(
114
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
114
+ mX.element_type,
115
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
116
+ byte_alignment=16,
115
117
  )
116
118
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
119
 
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
166
168
  reduction_buffer[None, None, 0],
167
169
  mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
168
170
  init_val=-cutlass.Float32.inf,
169
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
171
+ hook_fn=(
172
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
+ ),
170
174
  )
171
175
  if cutlass.const_expr(self.reload_from == "smem"):
172
176
  cute.autovec_copy(tXsX, tXrX)
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
191
195
  threads_per_row,
192
196
  reduction_buffer[None, None, 0],
193
197
  mbar_ptr,
194
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
198
+ hook_fn=(
199
+ cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
+ ),
195
201
  )
196
202
 
197
203
  if (
@@ -225,7 +231,11 @@ def _cross_entropy(
225
231
  assert target.dim() == 1, "Target must be 1D"
226
232
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
227
233
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
228
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
234
+ assert x.dtype in [
235
+ torch.float16,
236
+ torch.bfloat16,
237
+ torch.float32,
238
+ ], "Unsupported input dtype"
229
239
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
230
240
  M, N = x.shape
231
241
  device = x.device
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
314
324
  num_threads = cute.size(tv_layout, mode=[0])
315
325
 
316
326
  mDLoss = cute.make_tensor(
317
- mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
327
+ mDLoss.iterator,
328
+ cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
318
329
  )
319
330
  mTarget = cute.make_tensor(
320
- mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
331
+ mTarget.iterator,
332
+ cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
321
333
  )
322
334
  mLSE = cute.make_tensor(
323
- mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
335
+ mLSE.iterator,
336
+ cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
324
337
  )
325
338
 
326
339
  smem_size = cute.size_in_bytes(
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
364
377
 
365
378
  smem = cutlass.utils.SmemAllocator()
366
379
  sX = smem.allocate_tensor(
367
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
380
+ mX.element_type,
381
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
382
+ byte_alignment=16,
368
383
  )
369
384
 
370
385
  idX = cute.make_identity_tensor(shape)
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
474
489
  assert (
475
490
  x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
491
  ), "Tensors must be on CUDA device"
477
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
492
+ assert x.dtype in [
493
+ torch.float16,
494
+ torch.bfloat16,
495
+ torch.float32,
496
+ ], "Unsupported input dtype"
478
497
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
498
 
480
499
  M, N = x.shape
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
532
551
 
533
552
 
534
553
  def cross_entropy(
535
- x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
554
+ x: torch.Tensor,
555
+ target: torch.Tensor,
556
+ inplace_backward: bool = True,
557
+ reduction: str = "none",
536
558
  ) -> torch.Tensor:
537
559
  """Cross entropy loss with automatic differentiation support.
538
560
 
539
561
  Args:
540
562
  x: Input logits tensor of shape (M, N)
541
563
  target: Target class indices tensor of shape (M,)
564
+ inplace_backward: Whether to perform backward pass in-place
565
+ reduction: Specifies the reduction to apply to the output:
566
+ 'none': no reduction will be applied (default)
567
+ 'mean': the sum of the output will be divided by the number of elements
568
+ 'sum': the output will be summed
542
569
 
543
570
  Returns:
544
- Cross entropy loss tensor of shape (M,)
571
+ Cross entropy loss tensor:
572
+ - If reduction='none': tensor of shape (M,) with per-example losses
573
+ - If reduction='mean': scalar tensor with mean loss
574
+ - If reduction='sum': scalar tensor with sum of losses
545
575
  """
546
- return CrossEntropyFunction.apply(x, target, inplace_backward)
576
+ loss = CrossEntropyFunction.apply(x, target, inplace_backward)
577
+
578
+ if reduction == "mean":
579
+ return loss.mean()
580
+ elif reduction == "sum":
581
+ return loss.sum()
582
+ elif reduction == "none":
583
+ return loss
584
+ else:
585
+ raise ValueError(
586
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
587
+ )