quack-kernels 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {quack_kernels-0.1.4/quack_kernels.egg-info → quack_kernels-0.1.6}/PKG-INFO +1 -1
  2. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/README.md +21 -3
  3. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/__init__.py +1 -1
  4. quack_kernels-0.1.6/quack/cross_entropy.py +546 -0
  5. quack_kernels-0.1.4/quack/rmsnorm.py → quack_kernels-0.1.6/quack/layernorm.py +93 -27
  6. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/reduction_base.py +1 -4
  7. quack_kernels-0.1.6/quack/rmsnorm.py +665 -0
  8. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/softmax.py +8 -3
  9. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack/utils.py +19 -18
  10. {quack_kernels-0.1.4 → quack_kernels-0.1.6/quack_kernels.egg-info}/PKG-INFO +1 -1
  11. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/SOURCES.txt +2 -0
  12. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/top_level.txt +1 -0
  13. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/tests/test_cross_entropy.py +21 -13
  14. quack_kernels-0.1.6/tests/test_layernorm.py +162 -0
  15. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/tests/test_rmsnorm.py +36 -5
  16. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/tests/test_softmax.py +2 -3
  17. quack_kernels-0.1.4/quack/cross_entropy.py +0 -255
  18. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/LICENSE +0 -0
  19. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/pyproject.toml +0 -0
  20. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/dependency_links.txt +0 -0
  21. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/quack_kernels.egg-info/requires.txt +0 -0
  22. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/setup.cfg +0 -0
  23. {quack_kernels-0.1.4 → quack_kernels-0.1.6}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
@@ -17,11 +17,11 @@ pip install quack-kernels
17
17
  ## Kernels 🐥
18
18
 
19
19
  - 🦆 RMSNorm forward
20
- - 🦆 Softmax forward and backward
21
- - 🦆 Cross entropy forward
20
+ - 🦆 Softmax forward + backward
21
+ - 🦆 Cross entropy forward + backward
22
+ - 🦆 Layernorm forward
22
23
 
23
24
  Upcoming:
24
- - 🦆 Cross entropy backward
25
25
  - 🦆 RMSNorm backward
26
26
  - 🦆 Rotary forward + backward
27
27
 
@@ -31,6 +31,24 @@ Upcoming:
31
31
  from quack import rmsnorm, softmax, cross_entropy
32
32
  ```
33
33
 
34
+ ## Documentations
35
+
36
+ [2025-07-10] We have a comprehensive
37
+ [blogpost](media/2025-07-10-membound-sol.md) on how to get memory-bound kernels
38
+ to speed-of-light, right in the comfort of Python thanks to the [CuTe-DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html).
39
+
40
+ ## Performance
41
+
42
+ <div align="center">
43
+ <figure>
44
+ <img
45
+ src="media/bf16_kernel_benchmarks_single_row.svg"
46
+ >
47
+ </figure>
48
+ </div>
49
+
50
+ See our [blogpost](media/2025-07-10-membound-sol.md) for the details.
51
+
34
52
  ## Development
35
53
 
36
54
  To set up the development environment:
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.4"
1
+ __version__ = "0.1.6"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -0,0 +1,546 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ import math
4
+ import torch
5
+ from typing import Optional, Type
6
+
7
+ import cuda.bindings.driver as cuda
8
+
9
+ import cutlass
10
+ import cutlass.cute as cute
11
+ from cutlass.cute.runtime import from_dlpack
12
+
13
+ import quack.utils as utils
14
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
15
+
16
+
17
+ class CrossEntropy(ReductionBase):
18
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
19
+ # 2 stages: 1 for max, 1 for sum
20
+ super().__init__(
21
+ dtype,
22
+ N,
23
+ stage=2 if not online_softmax else 1,
24
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
25
+ )
26
+ self.online_softmax = online_softmax
27
+ self.reload_from = None if N <= 16384 or online_softmax else "smem"
28
+
29
+ def _calculate_threads_per_row(self):
30
+ N = self.N
31
+ return (
32
+ 8
33
+ if N <= 64
34
+ else (
35
+ 16
36
+ if N <= 128
37
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
38
+ )
39
+ )
40
+
41
+ def _set_cluster_n(self):
42
+ N = self.N
43
+ if cutlass.const_expr(self.dtype.width == 16):
44
+ cluster_n = (
45
+ 1
46
+ if N <= 16 * 1024
47
+ else (
48
+ 2
49
+ if N <= 32 * 1024
50
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
51
+ )
52
+ )
53
+ else: # fp32
54
+ cluster_n = (
55
+ 1
56
+ if N <= 16 * 1024
57
+ else (
58
+ 2
59
+ if N <= 64 * 1024
60
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
61
+ )
62
+ )
63
+ self.cluster_n = cluster_n
64
+
65
+ @cute.jit
66
+ def __call__(
67
+ self,
68
+ mX: cute.Tensor,
69
+ mTarget: cute.Tensor,
70
+ mLoss: cute.Tensor,
71
+ mLSE: Optional[cute.Tensor],
72
+ stream: cuda.CUstream,
73
+ ):
74
+ assert mX.element_type == self.dtype
75
+ self._set_cluster_n()
76
+ tiler_mn, tv_layout = self._get_tv_layout()
77
+ num_threads = cute.size(tv_layout, mode=[0])
78
+ num_warps = num_threads // cute.arch.WARP_SIZE
79
+ self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
80
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
+ block=[num_threads, 1, 1],
82
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
83
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
+ stream=stream,
85
+ )
86
+
87
+ @cute.kernel
88
+ def kernel(
89
+ self,
90
+ mX: cute.Tensor, # (M, N)
91
+ mTarget: cute.Tensor, # (M,)
92
+ mLoss: cute.Tensor, # (M,)
93
+ mLSE: Optional[cute.Tensor], # (M,)
94
+ tv_layout: cute.Layout,
95
+ tiler_mn: cute.Shape,
96
+ ):
97
+ tidx, _, _ = cute.arch.thread_idx()
98
+ bidx, _, _ = cute.arch.block_idx()
99
+ if cutlass.const_expr(self.cluster_n > 1):
100
+ cluster_y = cute.arch.block_idx()[1]
101
+ else:
102
+ cluster_y = cutlass.const_expr(0)
103
+
104
+ shape: cute.Shape = mX.shape
105
+ idX = cute.make_identity_tensor(shape)
106
+ # slice for CTAs
107
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
108
+ mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
109
+ gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
110
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
111
+
112
+ smem = cutlass.utils.SmemAllocator()
113
+ sX = smem.allocate_tensor(
114
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
115
+ )
116
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
117
+
118
+ # declare the atoms which will be used later for memory copy
119
+ copy_atom_load_X = cute.make_copy_atom(
120
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
121
+ )
122
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
123
+
124
+ #### Thread View
125
+ tXgX = thr_copy_X.partition_S(gX)
126
+ tXsX = thr_copy_X.partition_D(sX)
127
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
128
+ tXrX = cute.make_fragment_like(tXgX)
129
+
130
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
131
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
132
+
133
+ row = tXcX[0][0]
134
+ target = cute.Int32.zero
135
+ if row < shape[0] and tXcX[0][1] == 0:
136
+ target = cute.Int32(mTarget[row])
137
+
138
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
139
+ tXpX = (
140
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
141
+ if cutlass.const_expr(not is_even_N)
142
+ else None
143
+ )
144
+ if row < shape[0]:
145
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
146
+ cute.arch.cp_async_commit_group()
147
+ cute.arch.cp_async_wait_group(0)
148
+ # Fill OOB values with -inf
149
+ if cutlass.const_expr(not is_even_N):
150
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
151
+ cute.autovec_copy(tXsX, tXrX)
152
+ x = tXrX.load().to(cute.Float32)
153
+
154
+ target_logit = cute.Float32.zero
155
+ if row < shape[0] and tXcX[0][1] == 0:
156
+ # Use Int64 for indexing to deal with large tensors
157
+ mX_off = utils.domain_offset_i64((row, 0), mX)
158
+ target_logit = cute.Float32(mX_off[0, target])
159
+
160
+ threads_per_row = tv_layout.shape[0][0]
161
+ if cutlass.const_expr(not self.online_softmax):
162
+ max_x = utils.row_reduce(
163
+ x,
164
+ cute.ReductionOp.MAX,
165
+ threads_per_row,
166
+ reduction_buffer[None, None, 0],
167
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
168
+ init_val=-cutlass.Float32.inf,
169
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
170
+ )
171
+ if cutlass.const_expr(self.reload_from == "smem"):
172
+ cute.autovec_copy(tXsX, tXrX)
173
+ x = tXrX.load().to(cute.Float32)
174
+ log2_e = math.log2(math.e)
175
+ # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
176
+ # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
177
+ # exp_x = utils.exp2f((x - max_x) * log2_e)
178
+ # This would use ffma instead of fadd then fmul
179
+ exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
180
+ denom = utils.row_reduce(
181
+ exp_x,
182
+ cute.ReductionOp.ADD,
183
+ threads_per_row,
184
+ reduction_buffer[None, None, 1],
185
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
186
+ init_val=0.0,
187
+ )
188
+ else:
189
+ max_x, denom, _ = utils.online_softmax_reduce(
190
+ x,
191
+ threads_per_row,
192
+ reduction_buffer[None, None, 0],
193
+ mbar_ptr,
194
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
195
+ )
196
+
197
+ if (
198
+ tXcX[0][1] == 0
199
+ and row < shape[0]
200
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
201
+ ):
202
+ ln_2 = math.log(2.0)
203
+ lse = max_x + utils.log2f(denom) * ln_2
204
+ loss_val = lse - target_logit
205
+ mLoss[row] = loss_val.to(mLoss.element_type)
206
+ if cutlass.const_expr(mLSE is not None):
207
+ mLSE[row] = lse
208
+
209
+
210
+ def _cross_entropy(
211
+ x: torch.Tensor,
212
+ target: torch.Tensor,
213
+ return_lse: bool = False,
214
+ ) -> torch.Tensor:
215
+ """Cross entropy forward pass.
216
+
217
+ Args:
218
+ x: Input logits tensor of shape (M, N)
219
+ target: Target class indices tensor of shape (M,)
220
+
221
+ Returns:
222
+ Cross entropy loss tensor of shape (M,)
223
+ """
224
+ assert x.dim() == 2, "Input must be 2D"
225
+ assert target.dim() == 1, "Target must be 1D"
226
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
227
+ assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
228
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
229
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
230
+ M, N = x.shape
231
+ device = x.device
232
+ loss = torch.empty(M, device=device, dtype=torch.float32)
233
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
234
+ dtype = torch2cute_dtype_map[x.dtype]
235
+ convert_from_dlpack = lambda tensor: (
236
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
237
+ mode=0, stride_order=(0, 1)
238
+ )
239
+ )
240
+ x_tensor = convert_from_dlpack(x)
241
+ loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
242
+ lse_tensor = (
243
+ from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
244
+ if lse is not None
245
+ else None
246
+ )
247
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
248
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
249
+
250
+ compile_key = (dtype, N, lse is not None)
251
+ if compile_key not in _cross_entropy.compile_cache:
252
+ cross_entropy_op = CrossEntropy(dtype, N)
253
+ _cross_entropy.compile_cache[compile_key] = cute.compile(
254
+ cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
255
+ )
256
+ _cross_entropy.compile_cache[compile_key](
257
+ x_tensor, target_tensor, loss_tensor, lse_tensor, stream
258
+ )
259
+ return loss if not return_lse else (loss, lse)
260
+
261
+
262
+ _cross_entropy.compile_cache = {}
263
+
264
+
265
+ class CrossEntropyBackward:
266
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
267
+ self.dtype = dtype
268
+ self.N = N
269
+ self.vecsize = 128 // dtype.width
270
+
271
+ def _calculate_threads_per_row(self):
272
+ N = self.N
273
+ return (
274
+ 8
275
+ if N <= 64
276
+ else (
277
+ 16
278
+ if N <= 128
279
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
280
+ )
281
+ )
282
+
283
+ def _get_tv_layout(self):
284
+ N = self.N
285
+ vecsize = self.vecsize
286
+ num_threads = 128 if N <= 16384 else 256
287
+ threads_per_row = self._calculate_threads_per_row()
288
+ cols_per_block = num_threads // threads_per_row
289
+ num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
290
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
291
+ tv_layout = cute.make_layout(
292
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
293
+ stride=(
294
+ (vecsize * cols_per_block, 1),
295
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
296
+ ),
297
+ )
298
+ return tiler_mn, tv_layout
299
+
300
+ @cute.jit
301
+ def __call__(
302
+ self,
303
+ mX: cute.Tensor,
304
+ mTarget: cute.Tensor,
305
+ mDLoss: cute.Tensor,
306
+ mdX: cute.Tensor,
307
+ mLSE: cute.Tensor,
308
+ stream: cuda.CUstream,
309
+ ):
310
+ assert mX.element_type == self.dtype
311
+ assert mdX.element_type == self.dtype
312
+
313
+ tiler_mn, tv_layout = self._get_tv_layout()
314
+ num_threads = cute.size(tv_layout, mode=[0])
315
+
316
+ mDLoss = cute.make_tensor(
317
+ mDLoss.iterator, cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,)))
318
+ )
319
+ mTarget = cute.make_tensor(
320
+ mTarget.iterator, cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,)))
321
+ )
322
+ mLSE = cute.make_tensor(
323
+ mLSE.iterator, cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,)))
324
+ )
325
+
326
+ smem_size = cute.size_in_bytes(
327
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
328
+ )
329
+
330
+ self.kernel(
331
+ mX,
332
+ mTarget,
333
+ mDLoss,
334
+ mdX,
335
+ mLSE,
336
+ mX.shape,
337
+ tv_layout,
338
+ tiler_mn,
339
+ ).launch(
340
+ grid=[
341
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
342
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
343
+ 1,
344
+ ],
345
+ block=[num_threads, 1, 1],
346
+ smem=smem_size,
347
+ stream=stream,
348
+ )
349
+
350
+ @cute.kernel
351
+ def kernel(
352
+ self,
353
+ mX: cute.Tensor, # (M, N)
354
+ mTarget: cute.Tensor, # (M,)
355
+ mDLoss: cute.Tensor, # (M,)
356
+ mdX: cute.Tensor, # (M, N)
357
+ mLSE: cute.Tensor, # (M,)
358
+ shape: cute.Shape,
359
+ tv_layout: cute.Layout,
360
+ tiler_mn: cute.Shape,
361
+ ):
362
+ tidx, _, _ = cute.arch.thread_idx()
363
+ bidx, bidy, _ = cute.arch.block_idx()
364
+
365
+ smem = cutlass.utils.SmemAllocator()
366
+ sX = smem.allocate_tensor(
367
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
368
+ )
369
+
370
+ idX = cute.make_identity_tensor(shape)
371
+ # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
372
+ mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
373
+ gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
374
+ cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
375
+
376
+ copy_atom_load_X = cute.make_copy_atom(
377
+ cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
378
+ )
379
+ copy_atom_load_X_async = cute.make_copy_atom(
380
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
381
+ )
382
+ copy_atom_store_O = cute.make_copy_atom(
383
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
384
+ )
385
+
386
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
387
+ thr_copy_X_async = cute.make_tiled_copy(
388
+ copy_atom_load_X_async, tv_layout, tiler_mn
389
+ ).get_slice(tidx)
390
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
391
+
392
+ #### Thread View
393
+ tXgX = thr_copy_X_async.partition_S(gX)
394
+ tXsX = thr_copy_X_async.partition_S(sX)
395
+
396
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
397
+ tXcFull = thr_copy_X.partition_S(cX) # improve
398
+
399
+ tXgO = thr_copy_O.partition_D(gdX)
400
+
401
+ # allocate fragments for gmem->rmem
402
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
403
+
404
+ is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
405
+ row = tXcX[0][0]
406
+
407
+ tXpX = (
408
+ utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
409
+ if not is_even_N
410
+ else None
411
+ )
412
+
413
+ if row < shape[0]:
414
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
415
+ cute.arch.cp_async_commit_group()
416
+ cute.arch.cp_async_wait_group(0)
417
+ if cutlass.const_expr(not is_even_N):
418
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
419
+
420
+ cute.autovec_copy(tXsX, tXrX)
421
+ x = tXrX.load().to(cute.Float32)
422
+
423
+ label = cute.Int32.zero
424
+ dloss = cute.Float32.zero
425
+ lse = cute.Float32.zero
426
+ if row < shape[0]:
427
+ label = cute.Int32(mTarget[row])
428
+ dloss = cute.Float32(mDLoss[row])
429
+ lse = cute.Float32(mLSE[row])
430
+
431
+ log2_e = math.log2(math.e)
432
+ probs = utils.exp2f((x - lse) * log2_e)
433
+ prob_shifted = probs - 1.0
434
+
435
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
436
+ for i in cutlass.range_constexpr(cute.size(tXcFull)):
437
+ mask[i] = tXcFull[i][1] == label
438
+
439
+ mask = mask.load()
440
+ grad = cute.where(mask, prob_shifted, probs)
441
+ grad = grad * dloss
442
+
443
+ tXrO.store(grad.to(tXrO.element_type))
444
+ tOpO = (
445
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
446
+ )
447
+ if row < shape[0]:
448
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
449
+
450
+
451
+ def _cross_entropy_backward(
452
+ x: torch.Tensor,
453
+ target: torch.Tensor,
454
+ dloss: torch.Tensor,
455
+ lse: torch.Tensor,
456
+ inplace_backward: bool = False,
457
+ ) -> torch.Tensor:
458
+ """Cross entropy backward pass.
459
+ Args:
460
+ x: Input logits tensor of shape (M, N)
461
+ target: Target class indices tensor of shape (M,)
462
+ dloss: Upstream gradients tensor of shape (M,)
463
+ lse: Log-sum-exp values tensor of shape (M,)
464
+ Returns:
465
+ Input gradients tensor of shape (M, N)
466
+ """
467
+ assert x.dim() == 2, "Input must be 2D"
468
+ assert target.dim() == 1, "Target must be 1D"
469
+ assert dloss.dim() == 1, "dloss must be 1D"
470
+ assert lse.dim() == 1, "lse must be 1D"
471
+ assert x.shape[0] == target.shape[0], "Batch dimensions must match"
472
+ assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
473
+ assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
474
+ assert (
475
+ x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
476
+ ), "Tensors must be on CUDA device"
477
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
478
+ assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
479
+
480
+ M, N = x.shape
481
+ dx = torch.empty_like(x) if not inplace_backward else x
482
+ dtype = torch2cute_dtype_map[x.dtype]
483
+
484
+ convert_from_dlpack = lambda tensor: (
485
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
486
+ mode=0, stride_order=(0, 1)
487
+ )
488
+ )
489
+ x_tensor = convert_from_dlpack(x)
490
+ dx_tensor = convert_from_dlpack(dx)
491
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
492
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
493
+ target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
494
+ mode=0
495
+ )
496
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
497
+
498
+ compile_key = (dtype, N)
499
+ if compile_key not in _cross_entropy_backward.compile_cache:
500
+ cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
501
+ _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
502
+ cross_entropy_backward_op,
503
+ x_tensor,
504
+ target_tensor,
505
+ dloss_tensor,
506
+ dx_tensor,
507
+ lse_tensor,
508
+ stream,
509
+ )
510
+ _cross_entropy_backward.compile_cache[compile_key](
511
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
512
+ )
513
+ return dx
514
+
515
+
516
+ _cross_entropy_backward.compile_cache = {}
517
+
518
+
519
+ class CrossEntropyFunction(torch.autograd.Function):
520
+ @staticmethod
521
+ def forward(ctx, x, target, inplace_backward=False):
522
+ loss, lse = _cross_entropy(x, target, return_lse=True)
523
+ ctx.save_for_backward(x, target, lse)
524
+ ctx.inplace_backward = inplace_backward
525
+ return loss
526
+
527
+ @staticmethod
528
+ def backward(ctx, dloss):
529
+ x, target, lse = ctx.saved_tensors
530
+ dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
531
+ return dx, None, None
532
+
533
+
534
+ def cross_entropy(
535
+ x: torch.Tensor, target: torch.Tensor, inplace_backward: bool = False
536
+ ) -> torch.Tensor:
537
+ """Cross entropy loss with automatic differentiation support.
538
+
539
+ Args:
540
+ x: Input logits tensor of shape (M, N)
541
+ target: Target class indices tensor of shape (M,)
542
+
543
+ Returns:
544
+ Cross entropy loss tensor of shape (M,)
545
+ """
546
+ return CrossEntropyFunction.apply(x, target, inplace_backward)