quack-kernels 0.1.4__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {quack_kernels-0.1.4/quack_kernels.egg-info → quack_kernels-0.1.5}/PKG-INFO +1 -1
  2. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/README.md +10 -3
  3. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack/__init__.py +1 -1
  4. quack_kernels-0.1.5/quack/cross_entropy.py +542 -0
  5. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack/reduction_base.py +1 -4
  6. quack_kernels-0.1.5/quack/rmsnorm.py +662 -0
  7. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack/utils.py +3 -18
  8. {quack_kernels-0.1.4 → quack_kernels-0.1.5/quack_kernels.egg-info}/PKG-INFO +1 -1
  9. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/tests/test_cross_entropy.py +54 -10
  10. quack_kernels-0.1.4/quack/cross_entropy.py +0 -255
  11. quack_kernels-0.1.4/quack/rmsnorm.py +0 -285
  12. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/LICENSE +0 -0
  13. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/pyproject.toml +0 -0
  14. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack/softmax.py +0 -0
  15. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack_kernels.egg-info/SOURCES.txt +0 -0
  16. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack_kernels.egg-info/dependency_links.txt +0 -0
  17. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack_kernels.egg-info/requires.txt +0 -0
  18. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/quack_kernels.egg-info/top_level.txt +0 -0
  19. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/setup.cfg +0 -0
  20. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/setup.py +0 -0
  21. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/tests/test_rmsnorm.py +0 -0
  22. {quack_kernels-0.1.4 → quack_kernels-0.1.5}/tests/test_softmax.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -17,11 +17,10 @@ pip install quack-kernels
17
17
  ## Kernels 🐥
18
18
 
19
19
  - 🦆 RMSNorm forward
20
- - 🦆 Softmax forward and backward
21
- - 🦆 Cross entropy forward
20
+ - 🦆 Softmax forward + backward
21
+ - 🦆 Cross entropy forward + backward
22
22
 
23
23
  Upcoming:
24
- - 🦆 Cross entropy backward
25
24
  - 🦆 RMSNorm backward
26
25
  - 🦆 Rotary forward + backward
27
26
 
@@ -31,6 +30,14 @@ Upcoming:
31
30
  from quack import rmsnorm, softmax, cross_entropy
32
31
  ```
33
32
 
33
+ ## Caveats 🦆⚠️
34
+
35
+ **Tensor Size Limitation**: We currently only support tensors ≤ 4GB due to CuTe-DSL using int32 for indexing.
36
+
37
+ 🦆 **Workaround**: For larger tensors, split your input tensors into chunks of
38
+ size ≤ 4GB each. We will implement this automatic chunking in the pytorch part
39
+ of the code in the near future, but if you need it in the meantime, we welcome contributions!
40
+
34
41
  ## Development
35
42
 
36
43
  To set up the development environment:
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.4"
1
+ __version__ = "0.1.5"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -0,0 +1,542 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import math
4
+ import torch
5
+ from typing import Optional, Type
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+
13
+ import quack.utils as utils
14
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+
16
+
17
+ class CrossEntropy(ReductionBase):
18
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
19
+ # 2 stages: 1 for max, 1 for sum
20
+ super().__init__(
21
+ dtype,
22
+ N,
23
+ stage=2 if not online_softmax else 1,
24
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
25
+ )
26
+ self.online_softmax = online_softmax
27
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
28
+
29
+ def _calculate_threads_per_row(self):
30
+ N = self.N
31
+ return (
32
+ 8
33
+ if N <= 64
34
+ else (
35
+ 16
36
+ if N <= 128
37
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
38
+ )
39
+ )
40
+
41
+ def _set_cluster_n(self):
42
+ N = self.N
43
+ if cutlass.const_expr(self.dtype.width == 16):
44
+ cluster_n = (
45
+ 1
46
+ if N <= 16 * 1024
47
+ else (
48
+ 2
49
+ if N <= 32 * 1024
50
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
51
+ )
52
+ )
53
+ else: # fp32
54
+ cluster_n = (
55
+ 1
56
+ if N <= 16 * 1024
57
+ else (
58
+ 2
59
+ if N <= 64 * 1024
60
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
61
+ )
62
+ )
63
+ self.cluster_n = cluster_n
64
+
65
+ @cute.jit
66
+ def __call__(
67
+ self,
68
+ mX: cute.Tensor,
69
+ mTarget: cute.Tensor,
70
+ mLoss: cute.Tensor,
71
+ mLSE: Optional[cute.Tensor],
72
+ stream: cuda.CUstream,
73
+ ):
74
+ assert mX.element_type == self.dtype
75
+ self._set_cluster_n()
76
+ tiler_mn, tv_layout = self._get_tv_layout()
77
+ num_threads = cute.size(tv_layout, mode=[0])
78
+ num_warps = num_threads // cute.arch.WARP_SIZE
79
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
+ block=[num_threads, 1, 1],
82
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
83
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
+ stream=stream,
85
+ )
86
+
87
+ @cute.kernel
88
+ def kernel(
89
+ self,
90
+ mX: cute.Tensor, # (M, N)
91
+ mTarget: cute.Tensor, # (M,)
92
+ mLoss: cute.Tensor, # (M,)
93
+ mLSE: Optional[cute.Tensor], # (M,)
94
+ tv_layout: cute.Layout,
95
+ tiler_mn: cute.Shape,
96
+ ):
97
+ tidx, _, _ = cute.arch.thread_idx()
98
+ bidx, _, _ = cute.arch.block_idx()
99
+ if cutlass.const_expr(self.cluster_n > 1):
100
+ cluster_y = cute.arch.block_idx()[1]
101
+ else:
102
+ cluster_y = cutlass.const_expr(0)
103
+
104
+ shape: cute.Shape = mX.shape
105
+ idX = cute.make_identity_tensor(shape)
106
+ # slice for CTAs
107
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
108
+
109
+ smem = cutlass.utils.SmemAllocator()
110
+ sX = smem.allocate_tensor(
111
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
112
+ )
113
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
114
+
115
+ # declare the atoms which will be used later for memory copy
116
+ copy_atom_load_X = cute.make_copy_atom(
117
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
118
+ )
119
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
120
+
121
+ #### Thread View
122
+ tXgX = thr_copy_X.partition_S(gX)
123
+ tXsX = thr_copy_X.partition_D(sX)
124
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
125
+ tXrX = cute.make_fragment_like(tXgX)
126
+
127
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
128
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
129
+
130
+ row = tXcX[0][0]
131
+ target = cute.Int32.zero
132
+ if row < shape[0] and tXcX[0][1] == 0:
133
+ target = cute.Int32(mTarget[row])
134
+
135
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
136
+ tXpX = (
137
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
138
+ if cutlass.const_expr(not is_even_N)
139
+ else None
140
+ )
141
+ if row < shape[0]:
142
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
143
+ cute.arch.cp_async_commit_group()
144
+ cute.arch.cp_async_wait_group(0)
145
+ # Fill OOB values with -inf
146
+ if cutlass.const_expr(not is_even_N):
147
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
148
+ cute.autovec_copy(tXsX, tXrX)
149
+ x = tXrX.load().to(cute.Float32)
150
+
151
+ target_logit = cute.Float32.zero
152
+ if row < shape[0] and tXcX[0][1] == 0:
153
+ target_logit = cute.Float32(mX[row, target])
154
+
155
+ threads_per_row = tv_layout.shape[0][0]
156
+ if cutlass.const_expr(not self.online_softmax):
157
+ max_x = utils.row_reduce(
158
+ x,
159
+ cute.ReductionOp.MAX,
160
+ threads_per_row,
161
+ reduction_buffer[None, None, 0],
162
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
163
+ init_val=-cutlass.Float32.inf,
164
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
165
+ )
166
+ if cutlass.const_expr(self.reload_from == "smem"):
167
+ cute.autovec_copy(tXsX, tXrX)
168
+ x = tXrX.load().to(cute.Float32)
169
+ log2_e = math.log2(math.e)
170
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
171
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
172
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
173
+ # This would use ffma instead of fadd then fmul
174
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
175
+ denom = utils.row_reduce(
176
+ exp_x,
177
+ cute.ReductionOp.ADD,
178
+ threads_per_row,
179
+ reduction_buffer[None, None, 1],
180
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
181
+ init_val=0.0,
182
+ )
183
+ else:
184
+ max_x, denom, _ = utils.online_softmax_reduce(
185
+ x,
186
+ threads_per_row,
187
+ reduction_buffer[None, None, 0],
188
+ mbar_ptr,
189
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
190
+ )
191
+
192
+ if (
193
+ tXcX[0][1] == 0
194
+ and row < shape[0]
195
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
196
+ ):
197
+ ln_2 = math.log(2.0)
198
+ lse = max_x + utils.log2f(denom) * ln_2
199
+ loss_val = lse - target_logit
200
+ mLoss[row] = loss_val.to(mLoss.element_type)
201
+ if cutlass.const_expr(mLSE is not None):
202
+ mLSE[row] = lse
203
+
204
+
205
+ def _cross_entropy(
206
+ x: torch.Tensor,
207
+ target: torch.Tensor,
208
+ return_lse: bool = False,
209
+ ) -> torch.Tensor:
210
+ """Cross entropy forward pass.
211
+
212
+ Args:
213
+ x: Input logits tensor of shape (M, N)
214
+ target: Target class indices tensor of shape (M,)
215
+
216
+ Returns:
217
+ Cross entropy loss tensor of shape (M,)
218
+ """
219
+ assert x.dim() == 2, "Input must be 2D"
220
+ assert target.dim() == 1, "Target must be 1D"
221
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
222
+ assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
223
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
224
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
225
+ M, N = x.shape
226
+ device = x.device
227
+ loss = torch.empty(M, device=device, dtype=torch.float32)
228
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
229
+ dtype = torch2cute_dtype_map[x.dtype]
230
+ convert_from_dlpack = lambda tensor: (
231
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
232
+ mode=0, stride_order=(0, 1)
233
+ )
234
+ )
235
+ x_tensor = convert_from_dlpack(x)
236
+ loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
237
+ lse_tensor = (
238
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
239
+ if lse is not None
240
+ else None
241
+ )
242
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
243
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
244
+
245
+ compile_key = (dtype, N, lse is not None)
246
+ if compile_key not in _cross_entropy.compile_cache:
247
+ cross_entropy_op = CrossEntropy(dtype, N)
248
+ _cross_entropy.compile_cache[compile_key] = cute.compile(
249
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
250
+ )
251
+ _cross_entropy.compile_cache[compile_key](
252
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
253
+ )
254
+ return loss if not return_lse else (loss, lse)
255
+
256
+
257
+ _cross_entropy.compile_cache = {}
258
+
259
+
260
+ class CrossEntropyBackward:
261
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
262
+ self.dtype = dtype
263
+ self.N = N
264
+ self.vecsize = 128 // dtype.width
265
+
266
+ def _calculate_threads_per_row(self):
267
+ N = self.N
268
+ return (
269
+ 8
270
+ if N <= 64
271
+ else (
272
+ 16
273
+ if N <= 128
274
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
275
+ )
276
+ )
277
+
278
+ def _get_tv_layout(self):
279
+ N = self.N
280
+ vecsize = self.vecsize
281
+ num_threads = 128 if N <= 16384 else 256
282
+ threads_per_row = self._calculate_threads_per_row()
283
+ cols_per_block = num_threads // threads_per_row
284
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
285
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
286
+ tv_layout = cute.make_layout(
287
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
288
+ stride=(
289
+ (vecsize * cols_per_block, 1),
290
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
291
+ ),
292
+ )
293
+ return tiler_mn, tv_layout
294
+
295
+ @cute.jit
296
+ def __call__(
297
+ self,
298
+ mX: cute.Tensor,
299
+ mTarget: cute.Tensor,
300
+ mDLoss: cute.Tensor,
301
+ mdX: cute.Tensor,
302
+ mLSE: cute.Tensor,
303
+ stream: cuda.CUstream,
304
+ ):
305
+ assert mX.element_type == self.dtype
306
+ assert mdX.element_type == self.dtype
307
+
308
+ tiler_mn, tv_layout = self._get_tv_layout()
309
+ num_threads = cute.size(tv_layout, mode=[0])
310
+
311
+ mDLoss = cute.make_tensor(
312
+ mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
313
+ )
314
+ mTarget = cute.make_tensor(
315
+ mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
316
+ )
317
+ mLSE = cute.make_tensor(
318
+ mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
319
+ )
320
+
321
+ smem_size = cute.size_in_bytes(
322
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
323
+ )
324
+
325
+ self.kernel(
326
+ mX,
327
+ mTarget,
328
+ mDLoss,
329
+ mdX,
330
+ mLSE,
331
+ mX.shape,
332
+ tv_layout,
333
+ tiler_mn,
334
+ ).launch(
335
+ grid=[
336
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
337
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
338
+ 1,
339
+ ],
340
+ block=[num_threads, 1, 1],
341
+ smem=smem_size,
342
+ stream=stream,
343
+ )
344
+
345
+ @cute.kernel
346
+ def kernel(
347
+ self,
348
+ mX: cute.Tensor, # (M, N)
349
+ mTarget: cute.Tensor, # (M,)
350
+ mDLoss: cute.Tensor, # (M,)
351
+ mdX: cute.Tensor, # (M, N)
352
+ mLSE: cute.Tensor, # (M,)
353
+ shape: cute.Shape,
354
+ tv_layout: cute.Layout,
355
+ tiler_mn: cute.Shape,
356
+ ):
357
+ tidx, _, _ = cute.arch.thread_idx()
358
+ bidx, bidy, _ = cute.arch.block_idx()
359
+
360
+ smem = cutlass.utils.SmemAllocator()
361
+ sX = smem.allocate_tensor(
362
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
363
+ )
364
+
365
+ idX = cute.make_identity_tensor(shape)
366
+
367
+ gX, gdX, cX, gTarget, gDLoss, gLse = [
368
+ cute.local_tile(mT, tiler_mn, (bidx, bidy))
369
+ for mT in (mX, mdX, idX, mTarget, mDLoss, mLSE)
370
+ ]
371
+
372
+ copy_atom_load_X = cute.make_copy_atom(
373
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
374
+ )
375
+ copy_atom_load_X_async = cute.make_copy_atom(
376
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
377
+ )
378
+ copy_atom_store_O = cute.make_copy_atom(
379
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
380
+ )
381
+
382
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
383
+ thr_copy_X_async = cute.make_tiled_copy(
384
+ copy_atom_load_X_async, tv_layout, tiler_mn
385
+ ).get_slice(tidx)
386
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
387
+
388
+ #### Thread View
389
+ tXgX = thr_copy_X_async.partition_S(gX)
390
+ tXsX = thr_copy_X_async.partition_S(sX)
391
+
392
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
393
+ tXcFull = thr_copy_X.partition_S(cX) # improve
394
+
395
+ tXgO = thr_copy_O.partition_D(gdX)
396
+
397
+ # allocate fragments for gmem->rmem
398
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
399
+
400
+ is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
401
+ row = tXcX[0][0]
402
+
403
+ tXpX = (
404
+ utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
405
+ if not is_even_N
406
+ else None
407
+ )
408
+
409
+ if row < shape[0]:
410
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
411
+ cute.arch.cp_async_commit_group()
412
+ cute.arch.cp_async_wait_group(0)
413
+ if cutlass.const_expr(not is_even_N):
414
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
415
+
416
+ cute.autovec_copy(tXsX, tXrX)
417
+ x = tXrX.load().to(cute.Float32)
418
+
419
+ label = cute.Int32.zero
420
+ dloss = cute.Float32.zero
421
+ lse = cute.Float32.zero
422
+ if row < shape[0]:
423
+ label = cute.Int32(mTarget[row])
424
+ dloss = cute.Float32(mDLoss[row])
425
+ lse = cute.Float32(mLSE[row])
426
+
427
+ log2_e = math.log2(math.e)
428
+ probs = utils.exp2f((x - lse) * log2_e)
429
+ prob_shifted = probs - 1.0
430
+
431
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
432
+ for i in cutlass.range_constexpr(cute.size(tXcFull)):
433
+ mask[i] = tXcFull[i][1] == label
434
+
435
+ mask = mask.load()
436
+ grad = cute.where(mask, prob_shifted, probs)
437
+ grad = grad * dloss
438
+
439
+ tXrO.store(grad.to(tXrO.element_type))
440
+ tOpO = (
441
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
442
+ )
443
+ if row < shape[0]:
444
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
445
+
446
+
447
+ def _cross_entropy_backward(
448
+ x: torch.Tensor,
449
+ target: torch.Tensor,
450
+ dloss: torch.Tensor,
451
+ lse: torch.Tensor,
452
+ inplace_backward: bool = False,
453
+ ) -> torch.Tensor:
454
+ """Cross entropy backward pass.
455
+ Args:
456
+ x: Input logits tensor of shape (M, N)
457
+ target: Target class indices tensor of shape (M,)
458
+ dloss: Upstream gradients tensor of shape (M,)
459
+ lse: Log-sum-exp values tensor of shape (M,)
460
+ Returns:
461
+ Input gradients tensor of shape (M, N)
462
+ """
463
+ assert x.dim() == 2, "Input must be 2D"
464
+ assert target.dim() == 1, "Target must be 1D"
465
+ assert dloss.dim() == 1, "dloss must be 1D"
466
+ assert lse.dim() == 1, "lse must be 1D"
467
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
468
+ assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
469
+ assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
470
+ assert (
471
+ x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
472
+ ), "Tensors must be on CUDA device"
473
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
474
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
475
+
476
+ M, N = x.shape
477
+ dx = torch.empty_like(x) if not inplace_backward else x
478
+ dtype = torch2cute_dtype_map[x.dtype]
479
+
480
+ convert_from_dlpack = lambda tensor: (
481
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
482
+ mode=0, stride_order=(0, 1)
483
+ )
484
+ )
485
+ x_tensor = convert_from_dlpack(x)
486
+ dx_tensor = convert_from_dlpack(dx)
487
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
488
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
489
+ target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
490
+ mode=0
491
+ )
492
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
493
+
494
+ compile_key = (dtype, N)
495
+ if compile_key not in _cross_entropy_backward.compile_cache:
496
+ cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
497
+ _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
498
+ cross_entropy_backward_op,
499
+ x_tensor,
500
+ target_tensor,
501
+ dloss_tensor,
502
+ dx_tensor,
503
+ lse_tensor,
504
+ stream,
505
+ )
506
+ _cross_entropy_backward.compile_cache[compile_key](
507
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
508
+ )
509
+ return dx
510
+
511
+
512
+ _cross_entropy_backward.compile_cache = {}
513
+
514
+
515
+ class CrossEntropyFunction(torch.autograd.Function):
516
+ @staticmethod
517
+ def forward(ctx, x, target, inplace_backward=False):
518
+ loss, lse = _cross_entropy(x, target, return_lse=True)
519
+ ctx.save_for_backward(x, target, lse)
520
+ ctx.inplace_backward = inplace_backward
521
+ return loss
522
+
523
+ @staticmethod
524
+ def backward(ctx, dloss):
525
+ x, target, lse = ctx.saved_tensors
526
+ dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
527
+ return dx, None, None
528
+
529
+
530
+ def cross_entropy(
531
+ x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
532
+ ) -> torch.Tensor:
533
+ """Cross entropy loss with automatic differentiation support.
534
+
535
+ Args:
536
+ x: Input logits tensor of shape (M, N)
537
+ target: Target class indices tensor of shape (M,)
538
+
539
+ Returns:
540
+ Cross entropy loss tensor of shape (M,)
541
+ """
542
+ return CrossEntropyFunction.apply(x, target, inplace_backward)
@@ -6,8 +6,6 @@ from typing import Type, Tuple, Optional
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
- import quack.utils as utils
10
-
11
9
 
12
10
  torch2cute_dtype_map = {
13
11
  torch.float16: cutlass.Float16,
@@ -39,7 +37,6 @@ class ReductionBase:
39
37
  vecsize = copy_bits // self.dtype.width
40
38
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
41
39
  num_threads = self._get_num_threads()
42
- num_warps = num_threads // cute.arch.WARP_SIZE
43
40
  assert num_threads % cute.arch.WARP_SIZE == 0
44
41
 
45
42
  threads_per_row = self._calculate_threads_per_row()
@@ -64,7 +61,7 @@ class ReductionBase:
64
61
 
65
62
  def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int):
66
63
  num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
67
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
64
+ warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
68
65
  return cute.make_ordered_layout(
69
66
  (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
70
67
  order=(1, 0, 2),