quack-kernels 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.1.7/quack_kernels.egg-info → quack_kernels-0.1.9}/PKG-INFO +1 -1
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/__init__.py +1 -1
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/cross_entropy.py +56 -15
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/rmsnorm.py +232 -106
- {quack_kernels-0.1.7 → quack_kernels-0.1.9/quack_kernels.egg-info}/PKG-INFO +1 -1
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/top_level.txt +1 -0
- quack_kernels-0.1.9/tests/test_rmsnorm.py +398 -0
- quack_kernels-0.1.7/tests/test_rmsnorm.py +0 -183
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/LICENSE +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/README.md +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/pyproject.toml +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/layernorm.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/reduction_base.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/softmax.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack/utils.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/SOURCES.txt +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/quack_kernels.egg-info/requires.txt +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/setup.cfg +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/setup.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/tests/test_cross_entropy.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/tests/test_layernorm.py +0 -0
- {quack_kernels-0.1.7 → quack_kernels-0.1.9}/tests/test_softmax.py +0 -0
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
import torch
|
|
5
4
|
from typing import Optional, Type
|
|
6
5
|
|
|
7
6
|
import cuda.bindings.driver as cuda
|
|
8
7
|
|
|
9
8
|
import cutlass
|
|
10
9
|
import cutlass.cute as cute
|
|
11
|
-
from cutlass.cute.runtime import from_dlpack
|
|
12
10
|
|
|
13
11
|
import quack.utils as utils
|
|
12
|
+
import torch
|
|
13
|
+
from cutlass.cute.runtime import from_dlpack
|
|
14
14
|
from quack.reduction_base import ReductionBase, torch2cute_dtype_map
|
|
15
15
|
|
|
16
16
|
|
|
@@ -79,7 +79,7 @@ class CrossEntropy(ReductionBase):
|
|
|
79
79
|
self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
|
|
80
80
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
81
|
block=[num_threads, 1, 1],
|
|
82
|
-
cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
82
|
+
cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
|
|
83
83
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
84
|
stream=stream,
|
|
85
85
|
)
|
|
@@ -111,7 +111,9 @@ class CrossEntropy(ReductionBase):
|
|
|
111
111
|
|
|
112
112
|
smem = cutlass.utils.SmemAllocator()
|
|
113
113
|
sX = smem.allocate_tensor(
|
|
114
|
-
mX.element_type,
|
|
114
|
+
mX.element_type,
|
|
115
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
116
|
+
byte_alignment=16,
|
|
115
117
|
)
|
|
116
118
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
117
119
|
|
|
@@ -166,7 +168,9 @@ class CrossEntropy(ReductionBase):
|
|
|
166
168
|
reduction_buffer[None, None, 0],
|
|
167
169
|
mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
168
170
|
init_val=-cutlass.Float32.inf,
|
|
169
|
-
hook_fn=
|
|
171
|
+
hook_fn=(
|
|
172
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
173
|
+
),
|
|
170
174
|
)
|
|
171
175
|
if cutlass.const_expr(self.reload_from == "smem"):
|
|
172
176
|
cute.autovec_copy(tXsX, tXrX)
|
|
@@ -191,7 +195,9 @@ class CrossEntropy(ReductionBase):
|
|
|
191
195
|
threads_per_row,
|
|
192
196
|
reduction_buffer[None, None, 0],
|
|
193
197
|
mbar_ptr,
|
|
194
|
-
hook_fn=
|
|
198
|
+
hook_fn=(
|
|
199
|
+
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
200
|
+
),
|
|
195
201
|
)
|
|
196
202
|
|
|
197
203
|
if (
|
|
@@ -225,7 +231,11 @@ def _cross_entropy(
|
|
|
225
231
|
assert target.dim() == 1, "Target must be 1D"
|
|
226
232
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
227
233
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
228
|
-
assert x.dtype in [
|
|
234
|
+
assert x.dtype in [
|
|
235
|
+
torch.float16,
|
|
236
|
+
torch.bfloat16,
|
|
237
|
+
torch.float32,
|
|
238
|
+
], "Unsupported input dtype"
|
|
229
239
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
230
240
|
M, N = x.shape
|
|
231
241
|
device = x.device
|
|
@@ -314,13 +324,16 @@ class CrossEntropyBackward:
|
|
|
314
324
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
315
325
|
|
|
316
326
|
mDLoss = cute.make_tensor(
|
|
317
|
-
mDLoss.iterator,
|
|
327
|
+
mDLoss.iterator,
|
|
328
|
+
cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
318
329
|
)
|
|
319
330
|
mTarget = cute.make_tensor(
|
|
320
|
-
mTarget.iterator,
|
|
331
|
+
mTarget.iterator,
|
|
332
|
+
cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
321
333
|
)
|
|
322
334
|
mLSE = cute.make_tensor(
|
|
323
|
-
mLSE.iterator,
|
|
335
|
+
mLSE.iterator,
|
|
336
|
+
cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
324
337
|
)
|
|
325
338
|
|
|
326
339
|
smem_size = cute.size_in_bytes(
|
|
@@ -364,7 +377,9 @@ class CrossEntropyBackward:
|
|
|
364
377
|
|
|
365
378
|
smem = cutlass.utils.SmemAllocator()
|
|
366
379
|
sX = smem.allocate_tensor(
|
|
367
|
-
mX.element_type,
|
|
380
|
+
mX.element_type,
|
|
381
|
+
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
|
|
382
|
+
byte_alignment=16,
|
|
368
383
|
)
|
|
369
384
|
|
|
370
385
|
idX = cute.make_identity_tensor(shape)
|
|
@@ -474,7 +489,11 @@ def _cross_entropy_backward(
|
|
|
474
489
|
assert (
|
|
475
490
|
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
476
491
|
), "Tensors must be on CUDA device"
|
|
477
|
-
assert x.dtype in [
|
|
492
|
+
assert x.dtype in [
|
|
493
|
+
torch.float16,
|
|
494
|
+
torch.bfloat16,
|
|
495
|
+
torch.float32,
|
|
496
|
+
], "Unsupported input dtype"
|
|
478
497
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
479
498
|
|
|
480
499
|
M, N = x.shape
|
|
@@ -532,15 +551,37 @@ class CrossEntropyFunction(torch.autograd.Function):
|
|
|
532
551
|
|
|
533
552
|
|
|
534
553
|
def cross_entropy(
|
|
535
|
-
x: torch.Tensor,
|
|
554
|
+
x: torch.Tensor,
|
|
555
|
+
target: torch.Tensor,
|
|
556
|
+
inplace_backward: bool = True,
|
|
557
|
+
reduction: str = "none",
|
|
536
558
|
) -> torch.Tensor:
|
|
537
559
|
"""Cross entropy loss with automatic differentiation support.
|
|
538
560
|
|
|
539
561
|
Args:
|
|
540
562
|
x: Input logits tensor of shape (M, N)
|
|
541
563
|
target: Target class indices tensor of shape (M,)
|
|
564
|
+
inplace_backward: Whether to perform backward pass in-place
|
|
565
|
+
reduction: Specifies the reduction to apply to the output:
|
|
566
|
+
'none': no reduction will be applied (default)
|
|
567
|
+
'mean': the sum of the output will be divided by the number of elements
|
|
568
|
+
'sum': the output will be summed
|
|
542
569
|
|
|
543
570
|
Returns:
|
|
544
|
-
Cross entropy loss tensor
|
|
571
|
+
Cross entropy loss tensor:
|
|
572
|
+
- If reduction='none': tensor of shape (M,) with per-example losses
|
|
573
|
+
- If reduction='mean': scalar tensor with mean loss
|
|
574
|
+
- If reduction='sum': scalar tensor with sum of losses
|
|
545
575
|
"""
|
|
546
|
-
|
|
576
|
+
loss = CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
577
|
+
|
|
578
|
+
if reduction == "mean":
|
|
579
|
+
return loss.mean()
|
|
580
|
+
elif reduction == "sum":
|
|
581
|
+
return loss.sum()
|
|
582
|
+
elif reduction == "none":
|
|
583
|
+
return loss
|
|
584
|
+
else:
|
|
585
|
+
raise ValueError(
|
|
586
|
+
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'"
|
|
587
|
+
)
|