quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/cross_entropy.py CHANGED
@@ -1,17 +1,22 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
- from typing import Optional, Type
4
+ from typing import Optional, Type, Literal
5
+
6
+ import torch
7
+ from torch import Tensor
5
8
 
6
9
  import cuda.bindings.driver as cuda
7
10
 
8
11
  import cutlass
9
12
  import cutlass.cute as cute
13
+ from cutlass import Int32, Float32, Boolean, const_expr
14
+ from cutlass.cute.runtime import from_dlpack
10
15
 
11
16
  import quack.utils as utils
12
- import torch
13
- from cutlass.cute.runtime import from_dlpack
14
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
17
+ from quack.reduce import row_reduce, online_softmax_reduce
18
+ from quack.reduction_base import ReductionBase
19
+ from quack.cute_dsl_utils import torch2cute_dtype_map
15
20
 
16
21
 
17
22
  class CrossEntropy(ReductionBase):
@@ -21,7 +26,7 @@ class CrossEntropy(ReductionBase):
21
26
  dtype,
22
27
  N,
23
28
  stage=2 if not online_softmax else 1,
24
- reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
29
+ reduction_dtype=Float32 if not online_softmax else cutlass.Int64,
25
30
  )
26
31
  self.online_softmax = online_softmax
27
32
  self.reload_from = None if N <= 16384 or online_softmax else "smem"
@@ -40,7 +45,7 @@ class CrossEntropy(ReductionBase):
40
45
 
41
46
  def _set_cluster_n(self):
42
47
  N = self.N
43
- if cutlass.const_expr(self.dtype.width == 16):
48
+ if const_expr(self.dtype.width == 16):
44
49
  cluster_n = (
45
50
  1
46
51
  if N <= 16 * 1024
@@ -65,21 +70,30 @@ class CrossEntropy(ReductionBase):
65
70
  @cute.jit
66
71
  def __call__(
67
72
  self,
68
- mX: cute.Tensor,
69
- mTarget: cute.Tensor,
70
- mLoss: cute.Tensor,
71
- mLSE: Optional[cute.Tensor],
73
+ mX: cute.Tensor, # (M, N)
74
+ mTarget: cute.Tensor, # (M,)
75
+ mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX
76
+ mLoss: cute.Tensor, # (M,)
77
+ mLSE: Optional[cute.Tensor], # (M,)
78
+ mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
79
+ ignore_index: Int32, # Index to ignore in loss computation
72
80
  stream: cuda.CUstream,
73
81
  ):
74
82
  assert mX.element_type == self.dtype
83
+ if const_expr(mTargetLogit is None):
84
+ mTargetLogit = mX
75
85
  self._set_cluster_n()
76
- tiler_mn, tv_layout = self._get_tv_layout()
86
+ # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
87
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
88
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
77
89
  num_threads = cute.size(tv_layout, mode=[0])
78
90
  num_warps = num_threads // cute.arch.WARP_SIZE
79
- self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
91
+ self.kernel(
92
+ mX, mTarget, mTargetLogit, mLoss, mLSE, mdX, ignore_index, tv_layout, tiler_mn
93
+ ).launch(
80
94
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
81
95
  block=[num_threads, 1, 1],
82
- cluster=([1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None),
96
+ cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
83
97
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
84
98
  stream=stream,
85
99
  )
@@ -89,17 +103,20 @@ class CrossEntropy(ReductionBase):
89
103
  self,
90
104
  mX: cute.Tensor, # (M, N)
91
105
  mTarget: cute.Tensor, # (M,)
106
+ mTargetLogit: cute.Tensor, # (M, K) or (M,)
92
107
  mLoss: cute.Tensor, # (M,)
93
108
  mLSE: Optional[cute.Tensor], # (M,)
109
+ mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
110
+ ignore_index: Int32, # Index to ignore in loss computation
94
111
  tv_layout: cute.Layout,
95
112
  tiler_mn: cute.Shape,
96
113
  ):
97
114
  tidx, _, _ = cute.arch.thread_idx()
98
115
  bidx, _, _ = cute.arch.block_idx()
99
- if cutlass.const_expr(self.cluster_n > 1):
116
+ if const_expr(self.cluster_n > 1):
100
117
  cluster_y = cute.arch.block_idx()[1]
101
118
  else:
102
- cluster_y = cutlass.const_expr(0)
119
+ cluster_y = const_expr(0)
103
120
 
104
121
  shape: cute.Shape = mX.shape
105
122
  idX = cute.make_identity_tensor(shape)
@@ -118,12 +135,14 @@ class CrossEntropy(ReductionBase):
118
135
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
119
136
 
120
137
  # declare the atoms which will be used later for memory copy
138
+ num_copy_elems_X = tv_layout.shape[1][0]
139
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
121
140
  copy_atom_load_X = cute.make_copy_atom(
122
- cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
141
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
123
142
  )
124
143
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
125
144
 
126
- #### Thread View
145
+ #### Partition to get thread view
127
146
  tXgX = thr_copy_X.partition_S(gX)
128
147
  tXsX = thr_copy_X.partition_D(sX)
129
148
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
@@ -133,14 +152,14 @@ class CrossEntropy(ReductionBase):
133
152
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
134
153
 
135
154
  row = tXcX[0][0]
136
- target = cute.Int32.zero
137
- if row < shape[0] and tXcX[0][1] == 0:
138
- target = cute.Int32(mTarget[row])
155
+ target = Int32.zero
156
+ if row < shape[0]:
157
+ target = Int32(mTarget[row])
139
158
 
140
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
159
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
141
160
  tXpX = (
142
161
  utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
143
- if cutlass.const_expr(not is_even_N)
162
+ if const_expr(not is_even_N)
144
163
  else None
145
164
  )
146
165
  if row < shape[0]:
@@ -148,99 +167,148 @@ class CrossEntropy(ReductionBase):
148
167
  cute.arch.cp_async_commit_group()
149
168
  cute.arch.cp_async_wait_group(0)
150
169
  # Fill OOB values with -inf
151
- if cutlass.const_expr(not is_even_N):
170
+ if const_expr(not is_even_N):
152
171
  utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
153
172
  cute.autovec_copy(tXsX, tXrX)
154
- x = tXrX.load().to(cute.Float32)
155
-
156
- target_logit = cute.Float32.zero
157
- if row < shape[0] and tXcX[0][1] == 0:
158
- # Use Int64 for indexing to deal with large tensors
159
- mX_off = utils.domain_offset_i64((row, 0), mX)
160
- target_logit = cute.Float32(mX_off[0, target])
173
+ x = tXrX.load().to(Float32)
174
+
175
+ target_logit = Float32.zero
176
+ should_ignore = Boolean(target == ignore_index)
177
+ if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
178
+ # Only load target logit if not ignoring this index
179
+ if const_expr(cute.rank(mTargetLogit.shape) == 2):
180
+ # Use Int64 for indexing to deal with large tensors
181
+ mTargetLogit_off = utils.domain_offset_i64((row, 0), mTargetLogit)
182
+ target_logit = Float32(mTargetLogit_off[0, target])
183
+ else:
184
+ assert cute.rank(mTargetLogit.shape) == 1
185
+ target_logit = Float32(mTargetLogit[row])
161
186
 
162
187
  threads_per_row = tv_layout.shape[0][0]
163
- if cutlass.const_expr(not self.online_softmax):
164
- max_x = utils.row_reduce(
188
+ if const_expr(not self.online_softmax):
189
+ max_x = row_reduce(
165
190
  x,
166
191
  cute.ReductionOp.MAX,
167
192
  threads_per_row,
168
193
  reduction_buffer[None, None, 0],
169
- mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
170
- init_val=-cutlass.Float32.inf,
171
- hook_fn=(
172
- cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
173
- ),
194
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
195
+ init_val=-Float32.inf,
196
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
174
197
  )
175
- if cutlass.const_expr(self.reload_from == "smem"):
198
+ if const_expr(self.reload_from == "smem"):
176
199
  cute.autovec_copy(tXsX, tXrX)
177
- x = tXrX.load().to(cute.Float32)
200
+ x = tXrX.load().to(Float32)
178
201
  log2_e = math.log2(math.e)
179
- # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
180
- # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
181
- # exp_x = utils.exp2f((x - max_x) * log2_e)
182
202
  # This would use ffma instead of fadd then fmul
183
- exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
184
- denom = utils.row_reduce(
203
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
204
+ denom = row_reduce(
185
205
  exp_x,
186
206
  cute.ReductionOp.ADD,
187
207
  threads_per_row,
188
208
  reduction_buffer[None, None, 1],
189
- mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
209
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
190
210
  init_val=0.0,
191
211
  )
192
212
  else:
193
- max_x, denom, _ = utils.online_softmax_reduce(
213
+ max_x, denom, exp_x = online_softmax_reduce(
194
214
  x,
195
215
  threads_per_row,
196
216
  reduction_buffer[None, None, 0],
197
217
  mbar_ptr,
198
- hook_fn=(
199
- cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
200
- ),
218
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
219
+ return_exp_x=const_expr(mdX is not None),
201
220
  )
202
221
 
222
+ # Write loss and lse to gmem
203
223
  if (
204
224
  tXcX[0][1] == 0
205
225
  and row < shape[0]
206
226
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
207
227
  ):
208
- ln_2 = math.log(2.0)
209
- lse = max_x + utils.log2f(denom) * ln_2
210
- loss_val = lse - target_logit
211
- mLoss[row] = loss_val.to(mLoss.element_type)
212
- if cutlass.const_expr(mLSE is not None):
228
+ lse = max_x + cute.math.log(denom, fastmath=True)
229
+ # Set loss to 0 if this index should be ignored, otherwise compute normally
230
+ loss_val = (lse - target_logit) if not should_ignore else Float32.zero
231
+ mLoss[row] = mLoss.element_type(loss_val)
232
+ if const_expr(mLSE is not None):
213
233
  mLSE[row] = lse
214
234
 
215
-
216
- def _cross_entropy(
217
- x: torch.Tensor,
218
- target: torch.Tensor,
219
- return_lse: bool = False,
220
- ) -> torch.Tensor:
235
+ # Compute gradient if mdX is provided
236
+ if const_expr(mdX is not None):
237
+ # Compute probabilities: exp(x) / sum(exp(x))
238
+ # If ignored, gradient should be zero
239
+ denom_inv = (
240
+ 1.0 / denom
241
+ if not (denom == 0.0 or denom != denom or should_ignore)
242
+ else Float32.zero
243
+ )
244
+ probs = exp_x * denom_inv
245
+ mdX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mdX)
246
+ gdX = cute.local_tile(mdX_off, tiler_mn, (0, cluster_y))
247
+ # Setup copy atom for storing gradient
248
+ copy_atom_store = cute.make_copy_atom(
249
+ cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_X
250
+ )
251
+ thr_copy_dX = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
252
+ tXgdX = thr_copy_dX.partition_D(gdX)
253
+ tXrdX = cute.make_fragment_like(tXgdX)
254
+ tXcFull = thr_copy_X.partition_S(cX)
255
+ # Compute gradient: probs for all classes, (probs - 1) for target class
256
+ # If ignored, gradient is already zero
257
+ tXrdX_f32 = cute.make_fragment_like(tXrX, Float32)
258
+ tXrdX_f32.store(probs)
259
+ if not should_ignore:
260
+ for i in cutlass.range(cute.size(tXrX), unroll_full=True):
261
+ tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
262
+ tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
263
+ tXpdX = (
264
+ utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
265
+ if not is_even_N
266
+ else None
267
+ )
268
+ if row < shape[0]:
269
+ cute.copy(copy_atom_store, tXrdX, tXgdX, pred=tXpdX)
270
+
271
+
272
+ @torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
273
+ def cross_entropy_fwd_out(
274
+ x: Tensor,
275
+ target: Tensor,
276
+ target_logit: Optional[Tensor],
277
+ loss: Tensor,
278
+ lse: Optional[Tensor],
279
+ dx: Optional[Tensor],
280
+ ignore_index: int = -100,
281
+ ) -> None:
221
282
  """Cross entropy forward pass.
222
283
 
223
284
  Args:
224
285
  x: Input logits tensor of shape (M, N)
225
286
  target: Target class indices tensor of shape (M,)
287
+ target_logit: (M, K) or (M,).
288
+ If provided, the target logit will be read from this tensor instead of x.
289
+ loss: Output loss tensor of shape (M,)
290
+ lse: Optional output log-sum-exp tensor of shape (M,)
291
+ dx: Optional output gradient tensor of shape (M, N)
292
+ ignore_index: Index to ignore in loss computation
226
293
 
227
294
  Returns:
228
- Cross entropy loss tensor of shape (M,)
295
+ None (mutates loss, lse, and optionally dx in-place)
229
296
  """
230
297
  assert x.dim() == 2, "Input must be 2D"
231
298
  assert target.dim() == 1, "Target must be 1D"
232
299
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
233
300
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
234
- assert x.dtype in [
235
- torch.float16,
236
- torch.bfloat16,
237
- torch.float32,
238
- ], "Unsupported input dtype"
301
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
239
302
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
240
- M, N = x.shape
241
- device = x.device
242
- loss = torch.empty(M, device=device, dtype=torch.float32)
243
- lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
303
+ if target_logit is not None:
304
+ assert target_logit.shape[0] == x.shape[0]
305
+ assert target_logit.is_cuda, "Target logits must be on CUDA device"
306
+ assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
307
+ if dx is not None:
308
+ assert dx.shape == x.shape, "dx must have same shape as x"
309
+ assert dx.is_cuda, "dx must be on CUDA device"
310
+ assert dx.dtype == x.dtype, "dx must have same dtype as x"
311
+ N = x.size(1)
244
312
  dtype = torch2cute_dtype_map[x.dtype]
245
313
  convert_from_dlpack = lambda tensor: (
246
314
  from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
@@ -248,28 +316,86 @@ def _cross_entropy(
248
316
  )
249
317
  )
250
318
  x_tensor = convert_from_dlpack(x)
251
- loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
319
+ loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_layout_dynamic()
252
320
  lse_tensor = (
253
- from_dlpack(lse.detach(), assumed_align=4).mark_compact_shape_dynamic(mode=0)
321
+ from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
254
322
  if lse is not None
255
323
  else None
256
324
  )
257
- target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0)
325
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
326
+ target_logit_tensor = (
327
+ from_dlpack(target_logit.detach(), assumed_align=4).mark_layout_dynamic(
328
+ leading_dim=target_logit.ndim - 1
329
+ )
330
+ if target_logit is not None
331
+ else None
332
+ )
333
+ dx_tensor = convert_from_dlpack(dx) if dx is not None else None
258
334
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
259
335
 
260
- compile_key = (dtype, N, lse is not None)
261
- if compile_key not in _cross_entropy.compile_cache:
262
- cross_entropy_op = CrossEntropy(dtype, N)
263
- _cross_entropy.compile_cache[compile_key] = cute.compile(
264
- cross_entropy_op, x_tensor, target_tensor, loss_tensor, lse_tensor, stream
336
+ compile_key = (
337
+ dtype,
338
+ N,
339
+ target_logit.dtype if target_logit is not None else None,
340
+ lse.dtype if lse is not None else None,
341
+ dx is not None,
342
+ loss.stride(),
343
+ lse.stride() if lse is not None else None,
344
+ target.stride(),
345
+ target_logit.stride(-1) if target_logit is not None else None,
346
+ )
347
+ if compile_key not in cross_entropy_fwd_out.compile_cache:
348
+ # If there's dx, it's faster to not use online softmax since we want the exp(x - max)
349
+ cross_entropy_op = CrossEntropy(dtype, N, online_softmax=dx is None)
350
+ cross_entropy_fwd_out.compile_cache[compile_key] = cute.compile(
351
+ cross_entropy_op,
352
+ x_tensor,
353
+ target_tensor,
354
+ target_logit_tensor,
355
+ loss_tensor,
356
+ lse_tensor,
357
+ dx_tensor,
358
+ Int32(ignore_index),
359
+ stream,
265
360
  )
266
- _cross_entropy.compile_cache[compile_key](
267
- x_tensor, target_tensor, loss_tensor, lse_tensor, stream
361
+ cross_entropy_fwd_out.compile_cache[compile_key](
362
+ x_tensor,
363
+ target_tensor,
364
+ target_logit_tensor,
365
+ loss_tensor,
366
+ lse_tensor,
367
+ dx_tensor,
368
+ Int32(ignore_index),
369
+ stream,
268
370
  )
269
- return loss if not return_lse else (loss, lse)
270
371
 
271
372
 
272
- _cross_entropy.compile_cache = {}
373
+ cross_entropy_fwd_out.compile_cache = {}
374
+
375
+
376
+ def cross_entropy_fwd(
377
+ x: torch.Tensor,
378
+ target: torch.Tensor,
379
+ target_logit: Optional[torch.Tensor] = None,
380
+ ignore_index: int = -100,
381
+ return_lse: bool = False,
382
+ return_dx: bool = False,
383
+ inplace_backward: bool = False,
384
+ ) -> torch.Tensor | tuple[torch.Tensor]:
385
+ M = x.size(0)
386
+ device = x.device
387
+ loss = torch.empty(M, device=device, dtype=torch.float32)
388
+ lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
389
+ dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None
390
+ cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index)
391
+ if return_lse and return_dx:
392
+ return loss, lse, dx
393
+ elif return_lse:
394
+ return loss, lse
395
+ elif return_dx:
396
+ return loss, dx
397
+ else:
398
+ return loss
273
399
 
274
400
 
275
401
  class CrossEntropyBackward:
@@ -279,7 +405,7 @@ class CrossEntropyBackward:
279
405
  self.vecsize = 128 // dtype.width
280
406
 
281
407
  def _calculate_threads_per_row(self):
282
- N = self.N
408
+ N = min(self.N, 16384) # We split by blocks of 16k
283
409
  return (
284
410
  8
285
411
  if N <= 64
@@ -290,13 +416,14 @@ class CrossEntropyBackward:
290
416
  )
291
417
  )
292
418
 
293
- def _get_tv_layout(self):
294
- N = self.N
295
- vecsize = self.vecsize
419
+ def _get_tv_layout(self, num_copy_bits=128):
420
+ vecsize = num_copy_bits // self.dtype.width
421
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
422
+ N = min(self.N, 16384)
296
423
  num_threads = 128 if N <= 16384 else 256
297
424
  threads_per_row = self._calculate_threads_per_row()
298
425
  cols_per_block = num_threads // threads_per_row
299
- num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row)
426
+ num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
300
427
  tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
301
428
  tv_layout = cute.make_layout(
302
429
  ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
@@ -315,40 +442,27 @@ class CrossEntropyBackward:
315
442
  mDLoss: cute.Tensor,
316
443
  mdX: cute.Tensor,
317
444
  mLSE: cute.Tensor,
445
+ ignore_index: Int32, # Index to ignore in gradient computation
318
446
  stream: cuda.CUstream,
319
447
  ):
320
448
  assert mX.element_type == self.dtype
321
449
  assert mdX.element_type == self.dtype
322
-
323
- tiler_mn, tv_layout = self._get_tv_layout()
450
+ # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
451
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
452
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
324
453
  num_threads = cute.size(tv_layout, mode=[0])
325
-
326
- mDLoss = cute.make_tensor(
327
- mDLoss.iterator,
328
- cute.append(mDLoss.layout, cute.make_layout((self.N,), stride=(0,))),
329
- )
330
- mTarget = cute.make_tensor(
331
- mTarget.iterator,
332
- cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
333
- )
334
- mLSE = cute.make_tensor(
335
- mLSE.iterator,
336
- cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
337
- )
338
-
454
+ # (M,) -> (M, N) with stride 0 in the N dimension
455
+ mDLoss, mTarget, mLSE = [
456
+ cute.make_tensor(
457
+ X.iterator, cute.append(X.layout, cute.make_layout((self.N,), stride=(0,)))
458
+ )
459
+ for X in (mDLoss, mTarget, mLSE)
460
+ ]
339
461
  smem_size = cute.size_in_bytes(
340
462
  mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
341
463
  )
342
-
343
464
  self.kernel(
344
- mX,
345
- mTarget,
346
- mDLoss,
347
- mdX,
348
- mLSE,
349
- mX.shape,
350
- tv_layout,
351
- tiler_mn,
465
+ mX, mTarget, mDLoss, mdX, mLSE, ignore_index, mX.shape, tv_layout, tiler_mn
352
466
  ).launch(
353
467
  grid=[
354
468
  cute.ceil_div(mX.shape[0], tiler_mn[0]),
@@ -368,6 +482,7 @@ class CrossEntropyBackward:
368
482
  mDLoss: cute.Tensor, # (M,)
369
483
  mdX: cute.Tensor, # (M, N)
370
484
  mLSE: cute.Tensor, # (M,)
485
+ ignore_index: Int32, # Index to ignore in gradient computation
371
486
  shape: cute.Shape,
372
487
  tv_layout: cute.Layout,
373
488
  tiler_mn: cute.Shape,
@@ -388,76 +503,67 @@ class CrossEntropyBackward:
388
503
  gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
389
504
  cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
390
505
 
506
+ num_copy_elems_X = tv_layout.shape[1][0]
507
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
391
508
  copy_atom_load_X = cute.make_copy_atom(
392
- cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128
393
- )
394
- copy_atom_load_X_async = cute.make_copy_atom(
395
- cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
509
+ cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
396
510
  )
397
- copy_atom_store_O = cute.make_copy_atom(
398
- cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
511
+ copy_atom_store_dX = cute.make_copy_atom(
512
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=num_copy_bits_X
399
513
  )
400
-
401
514
  thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
402
- thr_copy_X_async = cute.make_tiled_copy(
403
- copy_atom_load_X_async, tv_layout, tiler_mn
404
- ).get_slice(tidx)
405
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
406
-
407
- #### Thread View
408
- tXgX = thr_copy_X_async.partition_S(gX)
409
- tXsX = thr_copy_X_async.partition_S(sX)
515
+ thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
410
516
 
517
+ #### Partition to get thread view
518
+ tXgX = thr_copy_X.partition_S(gX)
519
+ tXsX = thr_copy_X.partition_S(sX)
411
520
  tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
412
- tXcFull = thr_copy_X.partition_S(cX) # improve
413
-
414
- tXgO = thr_copy_O.partition_D(gdX)
415
-
521
+ tXcFull = thr_copy_X.partition_S(cX)
522
+ tXgdX = thr_copy_dX.partition_D(gdX)
416
523
  # allocate fragments for gmem->rmem
417
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
524
+ tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
418
525
 
419
- is_even_N = cutlass.const_expr(shape[1] % tiler_mn[1] == 0)
526
+ is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
420
527
  row = tXcX[0][0]
421
-
422
528
  tXpX = (
423
- utils.predicate_k(thr_copy_X_async.partition_S(cX), limit=shape[1])
424
- if not is_even_N
425
- else None
529
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
426
530
  )
427
-
428
531
  if row < shape[0]:
429
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
532
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
430
533
  cute.arch.cp_async_commit_group()
431
534
  cute.arch.cp_async_wait_group(0)
432
- if cutlass.const_expr(not is_even_N):
535
+ if const_expr(not is_even_N):
433
536
  utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
434
-
435
537
  cute.autovec_copy(tXsX, tXrX)
436
- x = tXrX.load().to(cute.Float32)
538
+ x = tXrX.load().to(Float32)
437
539
 
438
- label = cute.Int32.zero
439
- dloss = cute.Float32.zero
440
- lse = cute.Float32.zero
540
+ target = Int32.zero
541
+ dloss = Float32.zero
542
+ lse = Float32.zero
441
543
  if row < shape[0]:
442
- label = cute.Int32(mTarget[row])
443
- dloss = cute.Float32(mDLoss[row])
444
- lse = cute.Float32(mLSE[row])
544
+ target = Int32(mTarget[row])
545
+ should_ignore = Boolean(target == ignore_index)
546
+ # Set dloss to 0 if this index should be ignored
547
+ dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
548
+ lse = Float32(mLSE[row])
445
549
 
446
550
  log2_e = math.log2(math.e)
447
- probs = utils.exp2f((x - lse) * log2_e)
551
+ probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
448
552
  prob_shifted = probs - 1.0
449
553
  mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
450
554
  for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
451
- mask[i] = tXcFull[i][1] == label
555
+ mask[i] = tXcFull[i][1] == target
452
556
  grad = cute.where(mask.load(), prob_shifted, probs)
453
557
  grad = grad * dloss
454
558
 
455
- tXrO.store(grad.to(tXrO.element_type))
456
- tOpO = (
457
- utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
559
+ tXrdX.store(grad.to(tXrdX.element_type))
560
+ tXpdX = (
561
+ utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
562
+ if not is_even_N
563
+ else None
458
564
  )
459
565
  if row < shape[0]:
460
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
566
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
461
567
 
462
568
 
463
569
  def _cross_entropy_backward(
@@ -465,8 +571,9 @@ def _cross_entropy_backward(
465
571
  target: torch.Tensor,
466
572
  dloss: torch.Tensor,
467
573
  lse: torch.Tensor,
468
- inplace_backward: bool = False,
469
- ) -> torch.Tensor:
574
+ dx: torch.Tensor,
575
+ ignore_index=-100,
576
+ ) -> None:
470
577
  """Cross entropy backward pass.
471
578
  Args:
472
579
  x: Input logits tensor of shape (M, N)
@@ -483,18 +590,13 @@ def _cross_entropy_backward(
483
590
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
484
591
  assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
485
592
  assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
486
- assert (
487
- x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
488
- ), "Tensors must be on CUDA device"
489
- assert x.dtype in [
490
- torch.float16,
491
- torch.bfloat16,
492
- torch.float32,
493
- ], "Unsupported input dtype"
593
+ assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
594
+ "Tensors must be on CUDA device"
595
+ )
596
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
494
597
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
495
598
 
496
- M, N = x.shape
497
- dx = torch.empty_like(x) if not inplace_backward else x
599
+ N = x.size(1)
498
600
  dtype = torch2cute_dtype_map[x.dtype]
499
601
 
500
602
  convert_from_dlpack = lambda tensor: (
@@ -504,14 +606,12 @@ def _cross_entropy_backward(
504
606
  )
505
607
  x_tensor = convert_from_dlpack(x)
506
608
  dx_tensor = convert_from_dlpack(dx)
507
- dloss_tensor = from_dlpack(dloss.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
508
- lse_tensor = from_dlpack(lse.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0)
509
- target_tensor = from_dlpack(target.detach(), assumed_align=32).mark_compact_shape_dynamic(
510
- mode=0
511
- )
609
+ dloss_tensor = from_dlpack(dloss.detach(), assumed_align=4).mark_layout_dynamic()
610
+ lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
611
+ target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
512
612
  stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
513
613
 
514
- compile_key = (dtype, N)
614
+ compile_key = (dtype, N, target.dtype, dloss.stride(), lse.stride(), target.stride())
515
615
  if compile_key not in _cross_entropy_backward.compile_cache:
516
616
  cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
517
617
  _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
@@ -521,48 +621,95 @@ def _cross_entropy_backward(
521
621
  dloss_tensor,
522
622
  dx_tensor,
523
623
  lse_tensor,
624
+ Int32(ignore_index),
524
625
  stream,
525
626
  )
526
627
  _cross_entropy_backward.compile_cache[compile_key](
527
- x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
628
+ x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, Int32(ignore_index), stream
528
629
  )
529
- return dx
530
630
 
531
631
 
532
632
  _cross_entropy_backward.compile_cache = {}
533
633
 
534
634
 
635
+ @torch.library.custom_op("quack::cross_entropy_bwd_out", mutates_args={"dx"})
636
+ def cross_entropy_bwd_out(
637
+ x: torch.Tensor,
638
+ target: torch.Tensor,
639
+ dloss: torch.Tensor,
640
+ lse: torch.Tensor,
641
+ dx: torch.Tensor,
642
+ ignore_index: int = -100,
643
+ ) -> None:
644
+ _cross_entropy_backward(x, target, dloss, lse, dx, ignore_index)
645
+
646
+
647
+ def cross_entropy_bwd(
648
+ x: torch.Tensor,
649
+ target: torch.Tensor,
650
+ dloss: torch.Tensor,
651
+ lse: torch.Tensor,
652
+ ignore_index: int = -100,
653
+ inplace_backward: bool = False,
654
+ ) -> None:
655
+ if inplace_backward and not torch.compiler.is_compiling():
656
+ dx = x
657
+ _cross_entropy_backward(
658
+ x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index
659
+ )
660
+ else:
661
+ dx = torch.empty_like(x)
662
+ cross_entropy_bwd_out(
663
+ x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index
664
+ )
665
+ return dx
666
+
667
+
535
668
  class CrossEntropyFunction(torch.autograd.Function):
536
669
  @staticmethod
537
- def forward(ctx, x, target, inplace_backward=False):
538
- loss, lse = _cross_entropy(x, target, return_lse=True)
670
+ def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False):
671
+ if lse_partial is None:
672
+ loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True)
673
+ else:
674
+ # if we already compute partial lse, then to compute the final lse we treat
675
+ # @lse_partial as @x and @x as @target_logit
676
+ loss, lse = cross_entropy_fwd(
677
+ lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True
678
+ )
539
679
  ctx.save_for_backward(x, target, lse)
680
+ ctx.ignore_index = ignore_index
540
681
  ctx.inplace_backward = inplace_backward
541
682
  return loss
542
683
 
543
684
  @staticmethod
544
685
  def backward(ctx, dloss):
545
686
  x, target, lse = ctx.saved_tensors
546
- dx = _cross_entropy_backward(x, target, dloss, lse, inplace_backward=ctx.inplace_backward)
547
- return dx, None, None
687
+ dx = cross_entropy_bwd(
688
+ x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward
689
+ )
690
+ return dx, None, None, None, None
548
691
 
549
692
 
550
693
  def cross_entropy(
551
694
  x: torch.Tensor,
552
695
  target: torch.Tensor,
553
- inplace_backward: bool = True,
554
- reduction: str = "none",
696
+ lse_partial: Optional[torch.Tensor] = None,
697
+ ignore_index: int = -100,
698
+ reduction: Literal["none", "mean", "sum"] = "mean",
699
+ inplace_backward: bool = False,
555
700
  ) -> torch.Tensor:
556
701
  """Cross entropy loss with automatic differentiation support.
557
702
 
558
703
  Args:
559
704
  x: Input logits tensor of shape (M, N)
560
705
  target: Target class indices tensor of shape (M,)
561
- inplace_backward: Whether to perform backward pass in-place
706
+ lse_partial: Optional precomputed log-sum-exp partial results
562
707
  reduction: Specifies the reduction to apply to the output:
563
708
  'none': no reduction will be applied (default)
564
709
  'mean': the sum of the output will be divided by the number of elements
565
710
  'sum': the output will be summed
711
+ inplace_backward: Whether to perform backward pass in-place
712
+ ignore_index: Index to ignore in loss computation (loss will be 0 for these indices)
566
713
 
567
714
  Returns:
568
715
  Cross entropy loss tensor:
@@ -570,10 +717,9 @@ def cross_entropy(
570
717
  - If reduction='mean': scalar tensor with mean loss
571
718
  - If reduction='sum': scalar tensor with sum of losses
572
719
  """
573
- loss = CrossEntropyFunction.apply(x, target, inplace_backward)
574
-
720
+ loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward)
575
721
  if reduction == "mean":
576
- return loss.mean()
722
+ return loss.sum() / (target != ignore_index).sum().float()
577
723
  elif reduction == "sum":
578
724
  return loss.sum()
579
725
  elif reduction == "none":