quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/cross_entropy.py CHANGED
@@ -1,17 +1,22 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- from typing import Optional, Type
4
+ from typing import Optional, Type, Literal
5
+
6
+ import torch
7
+ from torch import Tensor
5
8
 
6
9
  import cuda.bindings.driver as cuda
7
10
 
8
11
  import cutlass
9
12
  import cutlass.cute as cute
13
+ from cutlass import Int32, Float32, Boolean, const_expr
14
+ from cutlass.cute.runtime import from_dlpack
10
15
 
11
16
  import quack.utils as utils
12
- import torch
13
- from cutlass.cute.runtime import from_dlpack
14
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
17
+ from quack.reduce import row_reduce, online_softmax_reduce
18
+ from quack.reduction_base import ReductionBase
19
+ from quack.cute_dsl_utils import torch2cute_dtype_map
15
20
 
16
21
 
17
22
  class CrossEntropy(ReductionBase):
@@ -21,7 +26,7 @@ class CrossEntropy(ReductionBase):
21
26
  dtype,
22
27
  N,
23
28
  stage=2 if not online_softmax else 1,
24
- reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
29
+ reduction_dtype=Float32 if not online_softmax else cutlass.Int64,
25
30
  )
26
31
  self.online_softmax = online_softmax
27
32
  self.reload_from = None if N <= 16384 or online_softmax else "smem"
@@ -40,7 +45,7 @@ class CrossEntropy(ReductionBase):
40
45
 
41
46
  def _set_cluster_n(self):
42
47
  N = self.N
43
- if cutlass.const_expr(self.dtype.width == 16):
48
+ if const_expr(self.dtype.width == 16):
44
49
  cluster_n = (
45
50
  1
46
51
  if N <= 16 * 1024
@@ -65,21 +70,30 @@ class CrossEntropy(ReductionBase):
65
70
  @cute.jit
66
71
  def __call__(
67
72
  self,
68
- mX: cute.Tensor,
69
- mTarget: cute.Tensor,
70
- mLoss: cute.Tensor,
71
- mLSE: Optional[cute.Tensor],
73
+ mX: cute.Tensor, # (M, N)
74
+ mTarget: cute.Tensor, # (M,)
75
+ mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX
76
+ mLoss: cute.Tensor, # (M,)
77
+ mLSE: Optional[cute.Tensor], # (M,)
78
+ mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
79
+ ignore_index: Int32, # Index to ignore in loss computation
72
80
  stream: cuda.CUstream,
73
81
  ):
74
82
  assert mX.element_type == self.dtype
83
+ if const_expr(mTargetLogit is None):
84
+ mTargetLogit = mX
75
85
  self._set_cluster_n()
76
- tiler_mn, tv_layout = self._get_tv_layout()
86
+ # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
87
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
88
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
77
89
  num_threads = cute.size(tv_layout, mode=[0])
78
90
  num_warps = num_threads // cute.arch.WARP_SIZE
79
- self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
91
+ self.kernel(
92
+ mX, mTarget, mTargetLogit, mLoss, mLSE, mdX, ignore_index, tv_layout, tiler_mn
93
+ ).launch(
80
94
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
95
  block=[num_threads, 1, 1],
82
- cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
96
+ cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
83
97
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
98
  stream=stream,
85
99
  )
@@ -89,17 +103,20 @@ class CrossEntropy(ReductionBase):
89
103
  self,
90
104
  mX: cute.Tensor, # (M, N)
91
105
  mTarget: cute.Tensor, # (M,)
106
+ mTargetLogit: cute.Tensor, # (M, K) or (M,)
92
107
  mLoss: cute.Tensor, # (M,)
93
108
  mLSE: Optional[cute.Tensor], # (M,)
109
+ mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
110
+ ignore_index: Int32, # Index to ignore in loss computation
94
111
  tv_layout: cute.Layout,
95
112
  tiler_mn: cute.Shape,
96
113
  ):
97
114
  tidx, _, _ = cute.arch.thread_idx()
98
115
  bidx, _, _ = cute.arch.block_idx()
99
- if cutlass.const_expr(self.cluster_n > 1):
116
+ if const_expr(self.cluster_n > 1):
100
117
  cluster_y = cute.arch.block_idx()[1]
101
118
  else:
102
- cluster_y = cutlass.const_expr(0)
119
+ cluster_y = const_expr(0)
103
120
 
104
121
  shape: cute.Shape = mX.shape
105
122
  idX = cute.make_identity_tensor(shape)
@@ -118,12 +135,14 @@ class CrossEntropy(ReductionBase):
118
135
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
119
136
 
120
137
  # declare the atoms which will be used later for memory copy
138
+ num_copy_elems_X = tv_layout.shape[1][0]
139
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
121
140
  copy_atom_load_X = cute.make_copy_atom(
122
- cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
141
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
123
142
  )
124
143
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
125
144
 
126
- #### Thread View
145
+ #### Partition to get thread view
127
146
  tXgX = thr_copy_X.partition_S(gX)
128
147
  tXsX = thr_copy_X.partition_D(sX)
129
148
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
@@ -133,14 +152,14 @@ class CrossEntropy(ReductionBase):
133
152
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
134
153
 
135
154
  row = tXcX[0][0]
136
- target = cute.Int32.zero
137
- if row < shape[0] and tXcX[0][1] == 0:
138
- target = cute.Int32(mTarget[row])
155
+ target = Int32.zero
156
+ if row < shape[0]:
157
+ target = Int32(mTarget[row])
139
158
 
140
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
159
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
141
160
  tXpX = (
142
161
  utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
143
- if cutlass.const_expr(not is_even_N)
162
+ if const_expr(not is_even_N)
144
163
  else None
145
164
  )
146
165
  if row < shape[0]:
@@ -148,58 +167,62 @@ class CrossEntropy(ReductionBase):
148
167
  cute.arch.cp_async_commit_group()
149
168
  cute.arch.cp_async_wait_group(0)
150
169
  # Fill OOB values with -inf
151
- if cutlass.const_expr(not is_even_N):
170
+ if const_expr(not is_even_N):
152
171
  utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
153
172
  cute.autovec_copy(tXsX, tXrX)
154
- x = tXrX.load().to(cute.Float32)
155
-
156
- target_logit = cute.Float32.zero
157
- if row < shape[0] and tXcX[0][1] == 0:
158
- # Use Int64 for indexing to deal with large tensors
159
- mX_off = utils.domain_offset_i64((row, 0), mX)
160
- target_logit = cute.Float32(mX_off[0, target])
173
+ x = tXrX.load().to(Float32)
174
+
175
+ target_logit = Float32.zero
176
+ should_ignore = Boolean(target == ignore_index)
177
+ if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
178
+ # Only load target logit if not ignoring this index
179
+ if const_expr(cute.rank(mTargetLogit.shape) == 2):
180
+ # Use Int64 for indexing to deal with large tensors
181
+ mTargetLogit_off = utils.domain_offset_i64((row, 0), mTargetLogit)
182
+ target_logit = Float32(mTargetLogit_off[0, target])
183
+ else:
184
+ assert cute.rank(mTargetLogit.shape) == 1
185
+ target_logit = Float32(mTargetLogit[row])
161
186
 
162
187
  threads_per_row = tv_layout.shape[0][0]
163
- if cutlass.const_expr(not self.online_softmax):
164
- max_x = utils.row_reduce(
188
+ if const_expr(not self.online_softmax):
189
+ max_x = row_reduce(
165
190
  x,
166
191
  cute.ReductionOp.MAX,
167
192
  threads_per_row,
168
193
  reduction_buffer[None, None, 0],
169
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
170
- init_val=-cutlass.Float32.inf,
171
- hook_fn=(
172
- cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
- ),
194
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
195
+ init_val=-Float32.inf,
196
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
174
197
  )
175
- if cutlass.const_expr(self.reload_from == "smem"):
198
+ if const_expr(self.reload_from == "smem"):
176
199
  cute.autovec_copy(tXsX, tXrX)
177
- x = tXrX.load().to(cute.Float32)
200
+ x = tXrX.load().to(Float32)
178
201
  log2_e = math.log2(math.e)
179
202
  # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
180
203
  # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
181
204
  # exp_x = utils.exp2f((x - max_x) * log2_e)
182
205
  # This would use ffma instead of fadd then fmul
183
206
  exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
184
- denom = utils.row_reduce(
207
+ denom = row_reduce(
185
208
  exp_x,
186
209
  cute.ReductionOp.ADD,
187
210
  threads_per_row,
188
211
  reduction_buffer[None, None, 1],
189
- mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
212
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
190
213
  init_val=0.0,
191
214
  )
192
215
  else:
193
- max_x, denom, _ = utils.online_softmax_reduce(
216
+ max_x, denom, exp_x = online_softmax_reduce(
194
217
  x,
195
218
  threads_per_row,
196
219
  reduction_buffer[None, None, 0],
197
220
  mbar_ptr,
198
- hook_fn=(
199
- cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
- ),
221
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
222
+ return_exp_x=const_expr(mdX is not None),
201
223
  )
202
224
 
225
+ # Write loss and lse to gmem
203
226
  if (
204
227
  tXcX[0][1] == 0
205
228
  and row < shape[0]
@@ -207,40 +230,89 @@ class CrossEntropy(ReductionBase):
207
230
  ):
208
231
  ln_2 = math.log(2.0)
209
232
  lse = max_x + utils.log2f(denom) * ln_2
210
- loss_val = lse - target_logit
211
- mLoss[row] = loss_val.to(mLoss.element_type)
212
- if cutlass.const_expr(mLSE is not None):
233
+ # Set loss to 0 if this index should be ignored, otherwise compute normally
234
+ loss_val = (lse - target_logit) if not should_ignore else Float32.zero
235
+ mLoss[row] = mLoss.element_type(loss_val)
236
+ if const_expr(mLSE is not None):
213
237
  mLSE[row] = lse
214
238
 
215
-
216
- def _cross_entropy(
217
- x: torch.Tensor,
218
- target: torch.Tensor,
219
- return_lse: bool = False,
220
- ) -> torch.Tensor:
239
+ # Compute gradient if mdX is provided
240
+ if const_expr(mdX is not None):
241
+ # Compute probabilities: exp(x) / sum(exp(x))
242
+ # If ignored, gradient should be zero
243
+ denom_inv = (
244
+ 1.0 / denom
245
+ if not (denom == 0.0 or denom != denom or should_ignore)
246
+ else Float32.zero
247
+ )
248
+ probs = exp_x * denom_inv
249
+ mdX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mdX)
250
+ gdX = cute.local_tile(mdX_off, tiler_mn, (0, cluster_y))
251
+ # Setup copy atom for storing gradient
252
+ copy_atom_store = cute.make_copy_atom(
253
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_X
254
+ )
255
+ thr_copy_dX = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
256
+ tXgdX = thr_copy_dX.partition_D(gdX)
257
+ tXrdX = cute.make_fragment_like(tXgdX)
258
+ tXcFull = thr_copy_X.partition_S(cX)
259
+ # Compute gradient: probs for all classes, (probs - 1) for target class
260
+ # If ignored, gradient is already zero
261
+ tXrdX_f32 = cute.make_fragment_like(tXrX, Float32)
262
+ tXrdX_f32.store(probs)
263
+ if not should_ignore:
264
+ for i in cutlass.range(cute.size(tXrX), unroll_full=True):
265
+ tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
266
+ tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
267
+ tXpdX = (
268
+ utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
269
+ if not is_even_N
270
+ else None
271
+ )
272
+ if row < shape[0]:
273
+ cute.copy(copy_atom_store, tXrdX, tXgdX, pred=tXpdX)
274
+
275
+
276
+ @torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
277
+ def cross_entropy_fwd_out(
278
+ x: Tensor,
279
+ target: Tensor,
280
+ target_logit: Optional[Tensor],
281
+ loss: Tensor,
282
+ lse: Optional[Tensor],
283
+ dx: Optional[Tensor],
284
+ ignore_index: int = -100,
285
+ ) -> None:
221
286
  """Cross entropy forward pass.
222
287
 
223
288
  Args:
224
289
  x: Input logits tensor of shape (M, N)
225
290
  target: Target class indices tensor of shape (M,)
291
+ target_logit: (M, K) or (M,).
292
+ If provided, the target logit will be read from this tensor instead of x.
293
+ loss: Output loss tensor of shape (M,)
294
+ lse: Optional output log-sum-exp tensor of shape (M,)
295
+ dx: Optional output gradient tensor of shape (M, N)
296
+ ignore_index: Index to ignore in loss computation
226
297
 
227
298
  Returns:
228
- Cross entropy loss tensor of shape (M,)
299
+ None (mutates loss, lse, and optionally dx in-place)
229
300
  """
230
301
  assert x.dim() == 2, "Input must be 2D"
231
302
  assert target.dim() == 1, "Target must be 1D"
232
303
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
233
304
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
234
- assert x.dtype in [
235
- torch.float16,
236
- torch.bfloat16,
237
- torch.float32,
238
- ], "Unsupported input dtype"
305
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
239
306
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
240
- M, N = x.shape
241
- device = x.device
242
- loss = torch.empty(M, device=device, dtype=torch.float32)
243
- lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
307
+ if target_logit is not None:
308
+ assert target_logit.shape[0] == x.shape[0]
309
+ assert target_logit.is_cuda, "Target logits must be on CUDA device"
310
+ assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
311
+ if dx is not None:
312
+ assert dx.shape == x.shape, "dx must have same shape as x"
313
+ assert dx.is_cuda, "dx must be on CUDA device"
314
+ assert dx.dtype == x.dtype, "dx must have same dtype as x"
315
+ N = x.size(1)
244
316
  dtype = torch2cute_dtype_map[x.dtype]
245
317
  convert_from_dlpack = lambda tensor: (
246
318
  from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -248,28 +320,86 @@ def _cross_entropy(
248
320
  )
249
321
  )
250
322
  x_tensor = convert_from_dlpack(x)
251
- loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
323
+ loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_layout_dynamic()
252
324
  lse_tensor = (
253
- from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
325
+ from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
254
326
  if lse is not None
255
327
  else None
256
328
  )
257
- target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
329
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
330
+ target_logit_tensor = (
331
+ from_dlpack(target_logit.detach(), assumed_align=4).mark_layout_dynamic(
332
+ leading_dim=target_logit.ndim - 1
333
+ )
334
+ if target_logit is not None
335
+ else None
336
+ )
337
+ dx_tensor = convert_from_dlpack(dx) if dx is not None else None
258
338
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
259
339
 
260
- compile_key = (dtype, N, lse is not None)
261
- if compile_key not in _cross_entropy.compile_cache:
262
- cross_entropy_op = CrossEntropy(dtype, N)
263
- _cross_entropy.compile_cache[compile_key] = cute.compile(
264
- cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
340
+ compile_key = (
341
+ dtype,
342
+ N,
343
+ target_logit.dtype if target_logit is not None else None,
344
+ lse.dtype if lse is not None else None,
345
+ dx is not None,
346
+ loss.stride(),
347
+ lse.stride() if lse is not None else None,
348
+ target.stride(),
349
+ target_logit.stride(-1) if target_logit is not None else None,
350
+ )
351
+ if compile_key not in cross_entropy_fwd_out.compile_cache:
352
+ # If there's dx, it's faster to not use online softmax since we want the exp(x - max)
353
+ cross_entropy_op = CrossEntropy(dtype, N, online_softmax=dx is None)
354
+ cross_entropy_fwd_out.compile_cache[compile_key] = cute.compile(
355
+ cross_entropy_op,
356
+ x_tensor,
357
+ target_tensor,
358
+ target_logit_tensor,
359
+ loss_tensor,
360
+ lse_tensor,
361
+ dx_tensor,
362
+ Int32(ignore_index),
363
+ stream,
265
364
  )
266
- _cross_entropy.compile_cache[compile_key](
267
- x_tensor, target_tensor, loss_tensor, lse_tensor, stream
365
+ cross_entropy_fwd_out.compile_cache[compile_key](
366
+ x_tensor,
367
+ target_tensor,
368
+ target_logit_tensor,
369
+ loss_tensor,
370
+ lse_tensor,
371
+ dx_tensor,
372
+ Int32(ignore_index),
373
+ stream,
268
374
  )
269
- return loss if not return_lse else (loss, lse)
270
375
 
271
376
 
272
- _cross_entropy.compile_cache = {}
377
+ cross_entropy_fwd_out.compile_cache = {}
378
+
379
+
380
+ def cross_entropy_fwd(
381
+ x: torch.Tensor,
382
+ target: torch.Tensor,
383
+ target_logit: Optional[torch.Tensor] = None,
384
+ ignore_index: int = -100,
385
+ return_lse: bool = False,
386
+ return_dx: bool = False,
387
+ inplace_backward: bool = False,
388
+ ) -> torch.Tensor | tuple[torch.Tensor]:
389
+ M = x.size(0)
390
+ device = x.device
391
+ loss = torch.empty(M, device=device, dtype=torch.float32)
392
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
393
+ dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None
394
+ cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index)
395
+ if return_lse and return_dx:
396
+ return loss, lse, dx
397
+ elif return_lse:
398
+ return loss, lse
399
+ elif return_dx:
400
+ return loss, dx
401
+ else:
402
+ return loss
273
403
 
274
404
 
275
405
  class CrossEntropyBackward:
@@ -279,7 +409,7 @@ class CrossEntropyBackward:
279
409
  self.vecsize = 128 // dtype.width
280
410
 
281
411
  def _calculate_threads_per_row(self):
282
- N = self.N
412
+ N = min(self.N, 16384) # We split by blocks of 16k
283
413
  return (
284
414
  8
285
415
  if N <= 64
@@ -290,13 +420,14 @@ class CrossEntropyBackward:
290
420
  )
291
421
  )
292
422
 
293
- def _get_tv_layout(self):
294
- N = self.N
295
- vecsize = self.vecsize
423
+ def _get_tv_layout(self, num_copy_bits=128):
424
+ vecsize = num_copy_bits // self.dtype.width
425
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
426
+ N = min(self.N, 16384)
296
427
  num_threads = 128 if N <= 16384 else 256
297
428
  threads_per_row = self._calculate_threads_per_row()
298
429
  cols_per_block = num_threads // threads_per_row
299
- num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
430
+ num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
300
431
  tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
301
432
  tv_layout = cute.make_layout(
302
433
  ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
@@ -315,40 +446,27 @@ class CrossEntropyBackward:
315
446
  mDLoss: cute.Tensor,
316
447
  mdX: cute.Tensor,
317
448
  mLSE: cute.Tensor,
449
+ ignore_index: Int32, # Index to ignore in gradient computation
318
450
  stream: cuda.CUstream,
319
451
  ):
320
452
  assert mX.element_type == self.dtype
321
453
  assert mdX.element_type == self.dtype
322
-
323
- tiler_mn, tv_layout = self._get_tv_layout()
454
+ # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
455
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
456
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
324
457
  num_threads = cute.size(tv_layout, mode=[0])
325
-
326
- mDLoss = cute.make_tensor(
327
- mDLoss.iterator,
328
- cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
329
- )
330
- mTarget = cute.make_tensor(
331
- mTarget.iterator,
332
- cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
333
- )
334
- mLSE = cute.make_tensor(
335
- mLSE.iterator,
336
- cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
337
- )
338
-
458
+ # (M,) -> (M, N) with stride 0 in the N dimension
459
+ mDLoss, mTarget, mLSE = [
460
+ cute.make_tensor(
461
+ X.iterator, cute.append(X.layout, cute.make_layout((self.N,), stride=(0,)))
462
+ )
463
+ for X in (mDLoss, mTarget, mLSE)
464
+ ]
339
465
  smem_size = cute.size_in_bytes(
340
466
  mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
341
467
  )
342
-
343
468
  self.kernel(
344
- mX,
345
- mTarget,
346
- mDLoss,
347
- mdX,
348
- mLSE,
349
- mX.shape,
350
- tv_layout,
351
- tiler_mn,
469
+ mX, mTarget, mDLoss, mdX, mLSE, ignore_index, mX.shape, tv_layout, tiler_mn
352
470
  ).launch(
353
471
  grid=[
354
472
  cute.ceil_div(mX.shape[0], tiler_mn[0]),
@@ -368,6 +486,7 @@ class CrossEntropyBackward:
368
486
  mDLoss: cute.Tensor, # (M,)
369
487
  mdX: cute.Tensor, # (M, N)
370
488
  mLSE: cute.Tensor, # (M,)
489
+ ignore_index: Int32, # Index to ignore in gradient computation
371
490
  shape: cute.Shape,
372
491
  tv_layout: cute.Layout,
373
492
  tiler_mn: cute.Shape,
@@ -388,76 +507,67 @@ class CrossEntropyBackward:
388
507
  gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
389
508
  cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
390
509
 
510
+ num_copy_elems_X = tv_layout.shape[1][0]
511
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
391
512
  copy_atom_load_X = cute.make_copy_atom(
392
- cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
513
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
393
514
  )
394
- copy_atom_load_X_async = cute.make_copy_atom(
395
- cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
515
+ copy_atom_store_dX = cute.make_copy_atom(
516
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=num_copy_bits_X
396
517
  )
397
- copy_atom_store_O = cute.make_copy_atom(
398
- cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
399
- )
400
-
401
518
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
402
- thr_copy_X_async = cute.make_tiled_copy(
403
- copy_atom_load_X_async, tv_layout, tiler_mn
404
- ).get_slice(tidx)
405
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
406
-
407
- #### Thread View
408
- tXgX = thr_copy_X_async.partition_S(gX)
409
- tXsX = thr_copy_X_async.partition_S(sX)
519
+ thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
410
520
 
521
+ #### Partition to get thread view
522
+ tXgX = thr_copy_X.partition_S(gX)
523
+ tXsX = thr_copy_X.partition_S(sX)
411
524
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
412
- tXcFull = thr_copy_X.partition_S(cX) # improve
413
-
414
- tXgO = thr_copy_O.partition_D(gdX)
415
-
525
+ tXcFull = thr_copy_X.partition_S(cX)
526
+ tXgdX = thr_copy_dX.partition_D(gdX)
416
527
  # allocate fragments for gmem->rmem
417
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
528
+ tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
418
529
 
419
- is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
530
+ is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
420
531
  row = tXcX[0][0]
421
-
422
532
  tXpX = (
423
- utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
424
- if not is_even_N
425
- else None
533
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
426
534
  )
427
-
428
535
  if row < shape[0]:
429
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
536
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
430
537
  cute.arch.cp_async_commit_group()
431
538
  cute.arch.cp_async_wait_group(0)
432
- if cutlass.const_expr(not is_even_N):
539
+ if const_expr(not is_even_N):
433
540
  utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
434
-
435
541
  cute.autovec_copy(tXsX, tXrX)
436
- x = tXrX.load().to(cute.Float32)
542
+ x = tXrX.load().to(Float32)
437
543
 
438
- label = cute.Int32.zero
439
- dloss = cute.Float32.zero
440
- lse = cute.Float32.zero
544
+ target = Int32.zero
545
+ dloss = Float32.zero
546
+ lse = Float32.zero
441
547
  if row < shape[0]:
442
- label = cute.Int32(mTarget[row])
443
- dloss = cute.Float32(mDLoss[row])
444
- lse = cute.Float32(mLSE[row])
548
+ target = Int32(mTarget[row])
549
+ should_ignore = Boolean(target == ignore_index)
550
+ # Set dloss to 0 if this index should be ignored
551
+ dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
552
+ lse = Float32(mLSE[row])
445
553
 
446
554
  log2_e = math.log2(math.e)
447
- probs = utils.exp2f((x - lse) * log2_e)
555
+ probs = utils.exp2f(x * log2_e - lse * log2_e)
448
556
  prob_shifted = probs - 1.0
449
557
  mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
450
558
  for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
451
- mask[i] = tXcFull[i][1] == label
559
+ mask[i] = tXcFull[i][1] == target
452
560
  grad = cute.where(mask.load(), prob_shifted, probs)
453
561
  grad = grad * dloss
454
562
 
455
- tXrO.store(grad.to(tXrO.element_type))
456
- tOpO = (
457
- utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
563
+ tXrdX.store(grad.to(tXrdX.element_type))
564
+ tXpdX = (
565
+ utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
566
+ if not is_even_N
567
+ else None
458
568
  )
459
569
  if row < shape[0]:
460
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
570
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
461
571
 
462
572
 
463
573
  def _cross_entropy_backward(
@@ -465,8 +575,9 @@ def _cross_entropy_backward(
465
575
  target: torch.Tensor,
466
576
  dloss: torch.Tensor,
467
577
  lse: torch.Tensor,
468
- inplace_backward: bool = False,
469
- ) -> torch.Tensor:
578
+ dx: torch.Tensor,
579
+ ignore_index=-100,
580
+ ) -> None:
470
581
  """Cross entropy backward pass.
471
582
  Args:
472
583
  x: Input logits tensor of shape (M, N)
@@ -486,15 +597,10 @@ def _cross_entropy_backward(
486
597
  assert (
487
598
  x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
488
599
  ), "Tensors must be on CUDA device"
489
- assert x.dtype in [
490
- torch.float16,
491
- torch.bfloat16,
492
- torch.float32,
493
- ], "Unsupported input dtype"
600
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
494
601
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
495
602
 
496
- M, N = x.shape
497
- dx = torch.empty_like(x) if not inplace_backward else x
603
+ N = x.size(1)
498
604
  dtype = torch2cute_dtype_map[x.dtype]
499
605
 
500
606
  convert_from_dlpack = lambda tensor: (
@@ -504,14 +610,12 @@ def _cross_entropy_backward(
504
610
  )
505
611
  x_tensor = convert_from_dlpack(x)
506
612
  dx_tensor = convert_from_dlpack(dx)
507
- dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
508
- lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
509
- target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
510
- mode=0
511
- )
613
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=4).mark_layout_dynamic()
614
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
615
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
512
616
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
513
617
 
514
- compile_key = (dtype, N)
618
+ compile_key = (dtype, N, target.dtype, dloss.stride(), lse.stride(), target.stride())
515
619
  if compile_key not in _cross_entropy_backward.compile_cache:
516
620
  cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
517
621
  _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
@@ -521,48 +625,95 @@ def _cross_entropy_backward(
521
625
  dloss_tensor,
522
626
  dx_tensor,
523
627
  lse_tensor,
628
+ Int32(ignore_index),
524
629
  stream,
525
630
  )
526
631
  _cross_entropy_backward.compile_cache[compile_key](
527
- x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
632
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, Int32(ignore_index), stream
528
633
  )
529
- return dx
530
634
 
531
635
 
532
636
  _cross_entropy_backward.compile_cache = {}
533
637
 
534
638
 
639
+ @torch.library.custom_op("quack::cross_entropy_bwd_out", mutates_args={"dx"})
640
+ def cross_entropy_bwd_out(
641
+ x: torch.Tensor,
642
+ target: torch.Tensor,
643
+ dloss: torch.Tensor,
644
+ lse: torch.Tensor,
645
+ dx: torch.Tensor,
646
+ ignore_index: int = -100,
647
+ ) -> None:
648
+ _cross_entropy_backward(x, target, dloss, lse, dx, ignore_index)
649
+
650
+
651
+ def cross_entropy_bwd(
652
+ x: torch.Tensor,
653
+ target: torch.Tensor,
654
+ dloss: torch.Tensor,
655
+ lse: torch.Tensor,
656
+ ignore_index: int = -100,
657
+ inplace_backward: bool = False,
658
+ ) -> None:
659
+ if inplace_backward and not torch.compiler.is_compiling():
660
+ dx = x
661
+ _cross_entropy_backward(
662
+ x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index
663
+ )
664
+ else:
665
+ dx = torch.empty_like(x)
666
+ cross_entropy_bwd_out(
667
+ x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index
668
+ )
669
+ return dx
670
+
671
+
535
672
  class CrossEntropyFunction(torch.autograd.Function):
536
673
  @staticmethod
537
- def forward(ctx, x, target, inplace_backward=False):
538
- loss, lse = _cross_entropy(x, target, return_lse=True)
674
+ def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False):
675
+ if lse_partial is None:
676
+ loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True)
677
+ else:
678
+ # if we already compute partial lse, then to compute the final lse we treat
679
+ # @lse_partial as @x and @x as @target_logit
680
+ loss, lse = cross_entropy_fwd(
681
+ lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True
682
+ )
539
683
  ctx.save_for_backward(x, target, lse)
684
+ ctx.ignore_index = ignore_index
540
685
  ctx.inplace_backward = inplace_backward
541
686
  return loss
542
687
 
543
688
  @staticmethod
544
689
  def backward(ctx, dloss):
545
690
  x, target, lse = ctx.saved_tensors
546
- dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
547
- return dx, None, None
691
+ dx = cross_entropy_bwd(
692
+ x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward
693
+ )
694
+ return dx, None, None, None, None
548
695
 
549
696
 
550
697
  def cross_entropy(
551
698
  x: torch.Tensor,
552
699
  target: torch.Tensor,
553
- inplace_backward: bool = True,
554
- reduction: str = "none",
700
+ lse_partial: Optional[torch.Tensor] = None,
701
+ ignore_index: int = -100,
702
+ reduction: Literal["none", "mean", "sum"] = "mean",
703
+ inplace_backward: bool = False,
555
704
  ) -> torch.Tensor:
556
705
  """Cross entropy loss with automatic differentiation support.
557
706
 
558
707
  Args:
559
708
  x: Input logits tensor of shape (M, N)
560
709
  target: Target class indices tensor of shape (M,)
561
- inplace_backward: Whether to perform backward pass in-place
710
+ lse_partial: Optional precomputed log-sum-exp partial results
562
711
  reduction: Specifies the reduction to apply to the output:
563
712
  'none': no reduction will be applied (default)
564
713
  'mean': the sum of the output will be divided by the number of elements
565
714
  'sum': the output will be summed
715
+ inplace_backward: Whether to perform backward pass in-place
716
+ ignore_index: Index to ignore in loss computation (loss will be 0 for these indices)
566
717
 
567
718
  Returns:
568
719
  Cross entropy loss tensor:
@@ -570,10 +721,9 @@ def cross_entropy(
570
721
  - If reduction='mean': scalar tensor with mean loss
571
722
  - If reduction='sum': scalar tensor with sum of losses
572
723
  """
573
- loss = CrossEntropyFunction.apply(x, target, inplace_backward)
574
-
724
+ loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward)
575
725
  if reduction == "mean":
576
- return loss.mean()
726
+ return loss.sum() / (target != ignore_index).sum().float()
577
727
  elif reduction == "sum":
578
728
  return loss.sum()
579
729
  elif reduction == "none":