quack-kernels 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {quack_kernels-0.1.2/quack_kernels.egg-info → quack_kernels-0.1.3}/PKG-INFO +4 -1
  2. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/README.md +10 -1
  3. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/pyproject.toml +16 -1
  4. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/quack/__init__.py +7 -1
  5. quack_kernels-0.1.3/quack/cross_entropy.py +252 -0
  6. quack_kernels-0.1.3/quack/reduction_base.py +98 -0
  7. quack_kernels-0.1.3/quack/rmsnorm.py +284 -0
  8. quack_kernels-0.1.3/quack/softmax.py +448 -0
  9. quack_kernels-0.1.3/quack/utils.py +404 -0
  10. {quack_kernels-0.1.2 → quack_kernels-0.1.3/quack_kernels.egg-info}/PKG-INFO +4 -1
  11. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/quack_kernels.egg-info/SOURCES.txt +1 -0
  12. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/quack_kernels.egg-info/requires.txt +4 -0
  13. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/tests/test_cross_entropy.py +3 -3
  14. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/tests/test_softmax.py +75 -16
  15. quack_kernels-0.1.2/quack/cross_entropy.py +0 -221
  16. quack_kernels-0.1.2/quack/rmsnorm.py +0 -254
  17. quack_kernels-0.1.2/quack/softmax.py +0 -195
  18. quack_kernels-0.1.2/quack/utils.py +0 -246
  19. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/LICENSE +0 -0
  20. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/quack_kernels.egg-info/dependency_links.txt +0 -0
  21. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/quack_kernels.egg-info/top_level.txt +0 -0
  22. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/setup.cfg +0 -0
  23. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/setup.py +0 -0
  24. {quack_kernels-0.1.2 → quack_kernels-0.1.3}/tests/test_rmsnorm.py +0 -0
@@ -1,8 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.0.0
7
7
  Requires-Dist: torch
8
+ Provides-Extra: dev
9
+ Requires-Dist: pre-commit; extra == "dev"
10
+ Requires-Dist: ruff; extra == "dev"
8
11
  Dynamic: license-file
@@ -17,7 +17,7 @@ pip install quack-kernels
17
17
  ## Kernels 🐥
18
18
 
19
19
  - 🦆 RMSNorm forward
20
- - 🦆 Softmax forward
20
+ - 🦆 Softmax forward and backward
21
21
  - 🦆 Cross entropy forward
22
22
 
23
23
 
@@ -26,3 +26,12 @@ pip install quack-kernels
26
26
  ```
27
27
  from quack import rmsnorm, softmax, cross_entropy
28
28
  ```
29
+
30
+ ## Development
31
+
32
+ To set up the development environment:
33
+
34
+ ```bash
35
+ pip install -e .[dev]
36
+ pre-commit install
37
+ ```
@@ -11,8 +11,23 @@ dependencies = [
11
11
  "torch",
12
12
  ]
13
13
 
14
+ [project.optional-dependencies]
15
+ dev = [
16
+ "pre-commit",
17
+ "ruff",
18
+ ]
19
+
14
20
  [tool.setuptools.packages.find]
15
21
  exclude = ["tests", "benchmarks"]
16
22
 
17
23
  [tool.setuptools.dynamic]
18
- version = {attr = "quack.__version__"}
24
+ version = {attr = "quack.__version__"}
25
+
26
+ [tool.ruff]
27
+ line-length = 100
28
+
29
+ [tool.ruff.lint]
30
+ ignore = [
31
+ "E731", # do not assign a lambda expression, use a def
32
+ "F841", # local variable is assigned to but never used
33
+ ]
@@ -1,5 +1,11 @@
1
- __version__ = "0.1.2"
1
+ __version__ = "0.1.3"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
5
5
  from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
@@ -0,0 +1,252 @@
1
+ import math
2
+ import torch
3
+ from typing import Optional, Type
4
+
5
+ import cuda.bindings.driver as cuda
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass.cute.runtime import from_dlpack
10
+
11
+ import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
+
14
+
15
+ class CrossEntropy(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
17
+ # 2 stages: 1 for max, 1 for sum
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
23
+ )
24
+ self.online_softmax = online_softmax
25
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
26
+
27
+ def _calculate_threads_per_row(self):
28
+ N = self.N
29
+ return (
30
+ 8
31
+ if N <= 64
32
+ else (
33
+ 16
34
+ if N <= 128
35
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
36
+ )
37
+ )
38
+
39
+ def _set_cluster_n(self):
40
+ N = self.N
41
+ if cutlass.const_expr(self.dtype.width == 16):
42
+ cluster_n = (
43
+ 1
44
+ if N <= 16 * 1024
45
+ else (
46
+ 2
47
+ if N <= 32 * 1024
48
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
49
+ )
50
+ )
51
+ else: # fp32
52
+ cluster_n = (
53
+ 1
54
+ if N <= 16 * 1024
55
+ else (
56
+ 2
57
+ if N <= 64 * 1024
58
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
59
+ )
60
+ )
61
+ self.cluster_n = cluster_n
62
+
63
+ @cute.jit
64
+ def __call__(
65
+ self,
66
+ mX: cute.Tensor,
67
+ mTarget: cute.Tensor,
68
+ mLoss: cute.Tensor,
69
+ mLSE: Optional[cute.Tensor],
70
+ stream: cuda.CUstream,
71
+ ):
72
+ assert mX.element_type == self.dtype
73
+ self._set_cluster_n()
74
+ tiler_mn, tv_layout = self._get_tv_layout()
75
+ num_threads = cute.size(tv_layout, mode=[0])
76
+ num_warps = num_threads // cute.arch.WARP_SIZE
77
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
+ block=[num_threads, 1, 1],
80
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
81
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
+ stream=stream,
83
+ )
84
+
85
+ @cute.kernel
86
+ def kernel(
87
+ self,
88
+ mX: cute.Tensor, # (M, N)
89
+ mTarget: cute.Tensor, # (M,)
90
+ mLoss: cute.Tensor, # (M,)
91
+ mLSE: Optional[cute.Tensor], # (M,)
92
+ tv_layout: cute.Layout,
93
+ tiler_mn: cute.Shape,
94
+ ):
95
+ tidx, _, _ = cute.arch.thread_idx()
96
+ bidx, cluster_y, _ = cute.arch.block_idx()
97
+
98
+ shape: cute.Shape = mX.shape
99
+ idX = cute.make_identity_tensor(shape)
100
+ # slice for CTAs
101
+ gX, cX = [
102
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
103
+ for mT in (mX, idX)
104
+ ]
105
+
106
+ smem = cutlass.utils.SmemAllocator()
107
+ sX = smem.allocate_tensor(
108
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
109
+ )
110
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
111
+
112
+ # declare the atoms which will be used later for memory copy
113
+ copy_atom_load_X = cute.make_copy_atom(
114
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
115
+ )
116
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
117
+
118
+ #### Thread View
119
+ tXgX = thr_copy_X.partition_S(gX)
120
+ tXsX = thr_copy_X.partition_D(sX)
121
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
122
+ tXrX = cute.make_fragment_like(tXgX)
123
+
124
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
125
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
126
+
127
+ row = tXcX[0][0]
128
+ target = cute.Int32.zero
129
+ if row < shape[0] and tXcX[0][1] == 0:
130
+ target = cute.Int32(mTarget[row])
131
+
132
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
133
+ tXpX = (
134
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
135
+ )
136
+ if row < shape[0]:
137
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
138
+ cute.arch.cp_async_commit_group()
139
+ cute.arch.cp_async_wait_group(0)
140
+ # Fill OOB values with -inf
141
+ if cutlass.const_expr(not is_even_N):
142
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
143
+ cute.autovec_copy(tXsX, tXrX)
144
+ x = tXrX.load().to(cute.Float32)
145
+
146
+ target_logit = cute.Float32.zero
147
+ if row < shape[0] and tXcX[0][1] == 0:
148
+ target_logit = cute.Float32(mX[row, target])
149
+
150
+ threads_per_row = tv_layout.shape[0][0]
151
+ if cutlass.const_expr(not self.online_softmax):
152
+ max_x = utils.row_reduce(
153
+ x,
154
+ cute.ReductionOp.MAX,
155
+ threads_per_row,
156
+ reduction_buffer[None, None, 0],
157
+ mbar_ptr + 0 if self.cluster_n > 1 else None,
158
+ init_val=-cutlass.Float32.inf,
159
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
+ )
161
+ if cutlass.const_expr(self.reload_from == "smem"):
162
+ cute.autovec_copy(tXsX, tXrX)
163
+ x = tXrX.load().to(cute.Float32)
164
+ log2_e = math.log2(math.e)
165
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
166
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
167
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
168
+ # This would use ffma instead of fadd then fmul
169
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
170
+ denom = utils.row_reduce(
171
+ exp_x,
172
+ cute.ReductionOp.ADD,
173
+ threads_per_row,
174
+ reduction_buffer[None, None, 1],
175
+ mbar_ptr + 1 if self.cluster_n > 1 else None,
176
+ init_val=0.0,
177
+ )
178
+ else:
179
+ max_x, denom, _ = utils.online_softmax_reduce(
180
+ x,
181
+ threads_per_row,
182
+ reduction_buffer[None, None, 0],
183
+ mbar_ptr,
184
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
185
+ )
186
+
187
+ if (
188
+ tXcX[0][1] == 0
189
+ and row < shape[0]
190
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
191
+ ):
192
+ ln_2 = math.log(2.0)
193
+ lse = max_x + utils.log2f(denom) * ln_2
194
+ loss_val = lse - target_logit
195
+ mLoss[row] = loss_val.to(mLoss.element_type)
196
+ if cutlass.const_expr(mLSE is not None):
197
+ mLSE[row] = lse
198
+
199
+
200
+ def cross_entropy(
201
+ x: torch.Tensor,
202
+ target: torch.Tensor,
203
+ return_lse: bool = False,
204
+ ) -> torch.Tensor:
205
+ """Cross entropy forward pass.
206
+
207
+ Args:
208
+ x: Input logits tensor of shape (M, N)
209
+ target: Target class indices tensor of shape (M,)
210
+
211
+ Returns:
212
+ Cross entropy loss tensor of shape (M,)
213
+ """
214
+ assert x.dim() == 2, "Input must be 2D"
215
+ assert target.dim() == 1, "Target must be 1D"
216
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
217
+ assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
218
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
219
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
220
+ M, N = x.shape
221
+ device = x.device
222
+ loss = torch.empty(M, device=device, dtype=torch.float32)
223
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
224
+ dtype = torch2cute_dtype_map[x.dtype]
225
+ convert_from_dlpack = lambda tensor: (
226
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
227
+ mode=0, stride_order=(0, 1)
228
+ )
229
+ )
230
+ x_tensor = convert_from_dlpack(x)
231
+ loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
232
+ lse_tensor = (
233
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
234
+ if lse is not None
235
+ else None
236
+ )
237
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
238
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
239
+
240
+ compile_key = (dtype, N, lse is not None)
241
+ if compile_key not in cross_entropy.compile_cache:
242
+ cross_entropy_op = CrossEntropy(dtype, N)
243
+ cross_entropy.compile_cache[compile_key] = cute.compile(
244
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
245
+ )
246
+ cross_entropy.compile_cache[compile_key](
247
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
248
+ )
249
+ return loss if not return_lse else (loss, lse)
250
+
251
+
252
+ cross_entropy.compile_cache = {}
@@ -0,0 +1,98 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import torch
4
+ from typing import Type, Tuple, Optional
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+
9
+ import quack.utils as utils
10
+
11
+
12
+ torch2cute_dtype_map = {
13
+ torch.float16: cutlass.Float16,
14
+ torch.bfloat16: cutlass.BFloat16,
15
+ torch.float32: cutlass.Float32,
16
+ }
17
+
18
+
19
+ class ReductionBase:
20
+ def __init__(
21
+ self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
22
+ ):
23
+ self.dtype = dtype
24
+ self.N = N
25
+ self.stage = stage
26
+ self.reduction_dtype = reduction_dtype
27
+
28
+ def _calculate_threads_per_row(self):
29
+ raise NotImplementedError()
30
+
31
+ def _set_cluster_n(self):
32
+ self.cluster_n = 1
33
+
34
+ def _get_num_threads(self):
35
+ return 128 if self.N <= 16384 else 256
36
+
37
+ def _get_tv_layout(self):
38
+ copy_bits = 128
39
+ vecsize = copy_bits // self.dtype.width
40
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
+ num_threads = self._get_num_threads()
42
+ num_warps = num_threads // cute.arch.WARP_SIZE
43
+ assert num_threads % cute.arch.WARP_SIZE == 0
44
+
45
+ threads_per_row = self._calculate_threads_per_row()
46
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
47
+ cols_per_block = num_threads // threads_per_row
48
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
49
+ tv_layout = cute.make_layout(
50
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
51
+ stride=(
52
+ (vecsize * cols_per_block, 1),
53
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
54
+ ),
55
+ )
56
+ return tiler_mn, tv_layout
57
+
58
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
59
+ return (
60
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
61
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
62
+ + self.stage * (cutlass.Int64.width // 8)
63
+ )
64
+
65
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
+ warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
+ return cute.make_ordered_layout(
69
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
+ order=(1, 0, 2),
71
+ )
72
+
73
+ def _allocate_reduction_buffer_and_mbar(
74
+ self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout
75
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
76
+ reduction_buffer = smem.allocate_tensor(
77
+ self.reduction_dtype,
78
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
79
+ byte_alignment=4,
80
+ )
81
+ if cutlass.const_expr(self.cluster_n > 1):
82
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
83
+ else:
84
+ mbar_ptr = None
85
+ return reduction_buffer, mbar_ptr
86
+
87
+ @cute.jit
88
+ def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
+ if cutlass.const_expr(self.cluster_n > 1):
90
+ if tidx < self.stage:
91
+ cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
92
+ cute.arch.mbarrier_init_fence()
93
+ if tidx < self.stage:
94
+ cute.arch.mbarrier_init_tx_bytes(
95
+ mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
+ )
97
+ # Cluster arrive after barrier init
98
+ cute.arch.cluster_arrive_relaxed()