quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/cross_entropy.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
3
  import math
4
+ from functools import partial
4
5
  from typing import Optional, Type, Literal
5
6
 
6
7
  import torch
@@ -10,10 +11,12 @@ import cuda.bindings.driver as cuda
10
11
 
11
12
  import cutlass
12
13
  import cutlass.cute as cute
13
- from cutlass import Int32, Float32, Boolean, const_expr
14
- from cutlass.cute.runtime import from_dlpack
14
+ from cutlass import Int32, Int64, Float32, Boolean, const_expr
15
15
 
16
16
  import quack.utils as utils
17
+ import quack.copy_utils as copy_utils
18
+ import quack.layout_utils as layout_utils
19
+ from quack.compile_utils import make_fake_tensor as fake_tensor
17
20
  from quack.reduce import row_reduce, online_softmax_reduce
18
21
  from quack.reduction_base import ReductionBase
19
22
  from quack.cute_dsl_utils import torch2cute_dtype_map
@@ -26,46 +29,29 @@ class CrossEntropy(ReductionBase):
26
29
  dtype,
27
30
  N,
28
31
  stage=2 if not online_softmax else 1,
29
- reduction_dtype=Float32 if not online_softmax else cutlass.Int64,
32
+ reduction_dtype=Float32 if not online_softmax else Int64,
30
33
  )
31
34
  self.online_softmax = online_softmax
32
35
  self.reload_from = None if N <= 16384 or online_softmax else "smem"
33
36
 
34
- def _calculate_threads_per_row(self):
37
+ def _threads_per_row(self):
35
38
  N = self.N
36
- return (
37
- 8
38
- if N <= 64
39
- else (
40
- 16
41
- if N <= 128
42
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
43
- )
44
- )
39
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
40
+ if N <= limit:
41
+ return threads
42
+ return 256
45
43
 
46
44
  def _set_cluster_n(self):
47
45
  N = self.N
48
46
  if const_expr(self.dtype.width == 16):
49
- cluster_n = (
50
- 1
51
- if N <= 16 * 1024
52
- else (
53
- 2
54
- if N <= 32 * 1024
55
- else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
56
- )
57
- )
58
- else: # fp32
59
- cluster_n = (
60
- 1
61
- if N <= 16 * 1024
62
- else (
63
- 2
64
- if N <= 64 * 1024
65
- else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
66
- )
67
- )
68
- self.cluster_n = cluster_n
47
+ thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
48
+ else:
49
+ thresholds = [(16 * 1024, 1), (64 * 1024, 2), (128 * 1024, 4), (256 * 1024, 8)]
50
+ for limit, cluster in thresholds:
51
+ if N <= limit:
52
+ self.cluster_n = cluster
53
+ return
54
+ self.cluster_n = 16
69
55
 
70
56
  @cute.jit
71
57
  def __call__(
@@ -82,19 +68,30 @@ class CrossEntropy(ReductionBase):
82
68
  assert mX.element_type == self.dtype
83
69
  if const_expr(mTargetLogit is None):
84
70
  mTargetLogit = mX
71
+ if const_expr(mdX is not None):
72
+ assert mdX.element_type == self.dtype
85
73
  self._set_cluster_n()
86
- # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
87
- num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
88
- tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
89
- num_threads = cute.size(tv_layout, mode=[0])
90
- num_warps = num_threads // cute.arch.WARP_SIZE
74
+ largest_dtype_width = const_expr(mX.element_type.width)
75
+ if const_expr(mdX is not None):
76
+ largest_dtype_width = const_expr(max(largest_dtype_width, mdX.element_type.width))
77
+ vecsize = math.gcd(self.N, 128 // largest_dtype_width)
78
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
79
+ num_threads = tiled_copy.size
91
80
  self.kernel(
92
- mX, mTarget, mTargetLogit, mLoss, mLSE, mdX, ignore_index, tv_layout, tiler_mn
81
+ mX,
82
+ mTarget,
83
+ mTargetLogit,
84
+ mLoss,
85
+ mLSE,
86
+ mdX,
87
+ ignore_index,
88
+ tiler_mn,
89
+ tiled_copy,
90
+ threads_per_row,
93
91
  ).launch(
94
92
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
95
93
  block=[num_threads, 1, 1],
96
- cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
97
- smem=self._smem_size_in_bytes(tiler_mn, num_warps),
94
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
98
95
  stream=stream,
99
96
  )
100
97
 
@@ -108,47 +105,40 @@ class CrossEntropy(ReductionBase):
108
105
  mLSE: Optional[cute.Tensor], # (M,)
109
106
  mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
110
107
  ignore_index: Int32, # Index to ignore in loss computation
111
- tv_layout: cute.Layout,
112
108
  tiler_mn: cute.Shape,
109
+ tiled_copy: cute.TiledCopy,
110
+ threads_per_row: cutlass.Constexpr[int],
113
111
  ):
114
112
  tidx, _, _ = cute.arch.thread_idx()
115
113
  bidx, _, _ = cute.arch.block_idx()
116
- if const_expr(self.cluster_n > 1):
117
- cluster_y = cute.arch.block_idx()[1]
118
- else:
119
- cluster_y = const_expr(0)
114
+ cluster_y = const_expr(0) if const_expr(self.cluster_n == 1) else cute.arch.block_idx()[1]
115
+ tv_layout = tiled_copy.layout_tv_tiled
120
116
 
121
- shape: cute.Shape = mX.shape
117
+ shape = mX.shape
122
118
  idX = cute.make_identity_tensor(shape)
123
119
  # slice for CTAs
124
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
125
- mX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mX)
126
- gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
127
- cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
120
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
128
121
 
129
122
  smem = cutlass.utils.SmemAllocator()
130
123
  sX = smem.allocate_tensor(
131
- mX.element_type,
132
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
133
- byte_alignment=16,
124
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
134
125
  )
135
126
  reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
136
127
 
137
- # declare the atoms which will be used later for memory copy
138
- num_copy_elems_X = tv_layout.shape[1][0]
139
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
140
- copy_atom_load_X = cute.make_copy_atom(
141
- cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
142
- )
143
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
128
+ thr_copy = tiled_copy.get_slice(tidx)
144
129
 
145
- #### Partition to get thread view
146
- tXgX = thr_copy_X.partition_S(gX)
147
- tXsX = thr_copy_X.partition_D(sX)
148
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
130
+ tXgX = thr_copy.partition_S(gX)
131
+ tXsX = thr_copy.partition_D(sX)
132
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
149
133
  tXrX = cute.make_fragment_like(tXgX)
150
134
 
151
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
135
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
136
+ tXpX = (
137
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
138
+ )
139
+ copy = partial(copy_utils.copy, pred=tXpX)
140
+
141
+ num_warps = cute.size(tiled_copy) // cute.arch.WARP_SIZE
152
142
  self._initialize_cluster(tidx, mbar_ptr, num_warps)
153
143
 
154
144
  row = tXcX[0][0]
@@ -156,14 +146,8 @@ class CrossEntropy(ReductionBase):
156
146
  if row < shape[0]:
157
147
  target = Int32(mTarget[row])
158
148
 
159
- is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
160
- tXpX = (
161
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
162
- if const_expr(not is_even_N)
163
- else None
164
- )
165
149
  if row < shape[0]:
166
- cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
150
+ copy(tXgX, tXsX, is_async=True)
167
151
  cute.arch.cp_async_commit_group()
168
152
  cute.arch.cp_async_wait_group(0)
169
153
  # Fill OOB values with -inf
@@ -177,14 +161,11 @@ class CrossEntropy(ReductionBase):
177
161
  if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
178
162
  # Only load target logit if not ignoring this index
179
163
  if const_expr(cute.rank(mTargetLogit.shape) == 2):
180
- # Use Int64 for indexing to deal with large tensors
181
- mTargetLogit_off = utils.domain_offset_i64((row, 0), mTargetLogit)
182
- target_logit = Float32(mTargetLogit_off[0, target])
164
+ target_logit = Float32(mTargetLogit[row, target])
183
165
  else:
184
166
  assert cute.rank(mTargetLogit.shape) == 1
185
167
  target_logit = Float32(mTargetLogit[row])
186
168
 
187
- threads_per_row = tv_layout.shape[0][0]
188
169
  if const_expr(not self.online_softmax):
189
170
  max_x = row_reduce(
190
171
  x,
@@ -237,21 +218,16 @@ class CrossEntropy(ReductionBase):
237
218
  # Compute probabilities: exp(x) / sum(exp(x))
238
219
  # If ignored, gradient should be zero
239
220
  denom_inv = (
240
- 1.0 / denom
221
+ # 1.0 / denom
222
+ cute.arch.rcp_approx(denom)
241
223
  if not (denom == 0.0 or denom != denom or should_ignore)
242
224
  else Float32.zero
243
225
  )
244
226
  probs = exp_x * denom_inv
245
- mdX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mdX)
246
- gdX = cute.local_tile(mdX_off, tiler_mn, (0, cluster_y))
247
- # Setup copy atom for storing gradient
248
- copy_atom_store = cute.make_copy_atom(
249
- cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_X
250
- )
251
- thr_copy_dX = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
252
- tXgdX = thr_copy_dX.partition_D(gdX)
227
+ gdX = cute.local_tile(mdX, tiler_mn, (bidx, cluster_y))
228
+ tXgdX = thr_copy.partition_D(gdX)
253
229
  tXrdX = cute.make_fragment_like(tXgdX)
254
- tXcFull = thr_copy_X.partition_S(cX)
230
+ tXcFull = thr_copy.partition_S(cX)
255
231
  # Compute gradient: probs for all classes, (probs - 1) for target class
256
232
  # If ignored, gradient is already zero
257
233
  tXrdX_f32 = cute.make_fragment_like(tXrX, Float32)
@@ -260,13 +236,8 @@ class CrossEntropy(ReductionBase):
260
236
  for i in cutlass.range(cute.size(tXrX), unroll_full=True):
261
237
  tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
262
238
  tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
263
- tXpdX = (
264
- utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
265
- if not is_even_N
266
- else None
267
- )
268
239
  if row < shape[0]:
269
- cute.copy(copy_atom_store, tXrdX, tXgdX, pred=tXpdX)
240
+ copy(tXrdX, tXgdX)
270
241
 
271
242
 
272
243
  @torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
@@ -296,77 +267,61 @@ def cross_entropy_fwd_out(
296
267
  """
297
268
  assert x.dim() == 2, "Input must be 2D"
298
269
  assert target.dim() == 1, "Target must be 1D"
299
- assert x.shape[0] == target.shape[0], "Batch dimensions must match"
300
270
  assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
301
271
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
302
272
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
303
273
  if target_logit is not None:
304
- assert target_logit.shape[0] == x.shape[0]
305
274
  assert target_logit.is_cuda, "Target logits must be on CUDA device"
306
275
  assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
307
276
  if dx is not None:
308
- assert dx.shape == x.shape, "dx must have same shape as x"
309
277
  assert dx.is_cuda, "dx must be on CUDA device"
310
- assert dx.dtype == x.dtype, "dx must have same dtype as x"
311
278
  N = x.size(1)
312
279
  dtype = torch2cute_dtype_map[x.dtype]
313
- convert_from_dlpack = lambda tensor: (
314
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
315
- mode=0, stride_order=(0, 1)
316
- )
280
+ target_dtype = torch2cute_dtype_map[target.dtype]
281
+ target_logit_dtype = (
282
+ torch2cute_dtype_map[target_logit.dtype] if target_logit is not None else None
317
283
  )
318
- x_tensor = convert_from_dlpack(x)
319
- loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_layout_dynamic()
320
- lse_tensor = (
321
- from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
322
- if lse is not None
323
- else None
324
- )
325
- target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
326
- target_logit_tensor = (
327
- from_dlpack(target_logit.detach(), assumed_align=4).mark_layout_dynamic(
328
- leading_dim=target_logit.ndim - 1
329
- )
330
- if target_logit is not None
331
- else None
332
- )
333
- dx_tensor = convert_from_dlpack(dx) if dx is not None else None
334
- stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
335
-
336
284
  compile_key = (
337
285
  dtype,
286
+ target_dtype,
287
+ target_logit_dtype,
338
288
  N,
339
- target_logit.dtype if target_logit is not None else None,
340
- lse.dtype if lse is not None else None,
289
+ lse is not None,
341
290
  dx is not None,
342
- loss.stride(),
343
- lse.stride() if lse is not None else None,
344
- target.stride(),
345
- target_logit.stride(-1) if target_logit is not None else None,
346
291
  )
347
292
  if compile_key not in cross_entropy_fwd_out.compile_cache:
293
+ batch_sym = cute.sym_int()
294
+ div = math.gcd(128 // dtype.width, N)
295
+ x_cute = fake_tensor(dtype, (batch_sym, N), div)
296
+ dx_cute = fake_tensor(dtype, (batch_sym, N), div) if dx is not None else None
297
+ target_cute = fake_tensor(target_dtype, (batch_sym,))
298
+ if target_logit is not None:
299
+ if target_logit.ndim == 2:
300
+ target_logit_cute = fake_tensor(
301
+ target_logit_dtype, (batch_sym, cute.sym_int()), div
302
+ )
303
+ else:
304
+ target_logit_cute = fake_tensor(target_logit_dtype, (batch_sym,))
305
+ else:
306
+ target_logit_cute = None
307
+ loss_cute = fake_tensor(Float32, (batch_sym,))
308
+ lse_cute = fake_tensor(Float32, (batch_sym,)) if lse is not None else None
348
309
  # If there's dx, it's faster to not use online softmax since we want the exp(x - max)
349
310
  cross_entropy_op = CrossEntropy(dtype, N, online_softmax=dx is None)
350
311
  cross_entropy_fwd_out.compile_cache[compile_key] = cute.compile(
351
312
  cross_entropy_op,
352
- x_tensor,
353
- target_tensor,
354
- target_logit_tensor,
355
- loss_tensor,
356
- lse_tensor,
357
- dx_tensor,
358
- Int32(ignore_index),
359
- stream,
313
+ x_cute,
314
+ target_cute,
315
+ target_logit_cute,
316
+ loss_cute,
317
+ lse_cute,
318
+ dx_cute,
319
+ Int32(0), # ignore_index, just for compilation
320
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
321
+ options="--enable-tvm-ffi",
360
322
  )
361
323
  cross_entropy_fwd_out.compile_cache[compile_key](
362
- x_tensor,
363
- target_tensor,
364
- target_logit_tensor,
365
- loss_tensor,
366
- lse_tensor,
367
- dx_tensor,
368
- Int32(ignore_index),
369
- stream,
324
+ x, target, target_logit, loss, lse, dx, Int32(ignore_index)
370
325
  )
371
326
 
372
327
 
@@ -404,35 +359,25 @@ class CrossEntropyBackward:
404
359
  self.N = N
405
360
  self.vecsize = 128 // dtype.width
406
361
 
407
- def _calculate_threads_per_row(self):
362
+ def _threads_per_row(self):
408
363
  N = min(self.N, 16384) # We split by blocks of 16k
409
- return (
410
- 8
411
- if N <= 64
412
- else (
413
- 16
414
- if N <= 128
415
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
416
- )
417
- )
364
+ for limit, threads in [(64, 8), (128, 16), (3072, 32), (6144, 64), (16384, 128)]:
365
+ if N <= limit:
366
+ return threads
367
+ return 256
418
368
 
419
- def _get_tv_layout(self, num_copy_bits=128):
420
- vecsize = num_copy_bits // self.dtype.width
369
+ def _get_tiled_copy(self, vecsize: int):
421
370
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
422
371
  N = min(self.N, 16384)
423
372
  num_threads = 128 if N <= 16384 else 256
424
- threads_per_row = self._calculate_threads_per_row()
373
+ threads_per_row = self._threads_per_row()
425
374
  cols_per_block = num_threads // threads_per_row
426
375
  num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
427
376
  tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
428
- tv_layout = cute.make_layout(
429
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
430
- stride=(
431
- (vecsize * cols_per_block, 1),
432
- (cols_per_block, cols_per_block * vecsize * threads_per_row),
433
- ),
377
+ tiled_copy = copy_utils.tiled_copy_2d(
378
+ self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize
434
379
  )
435
- return tiler_mn, tv_layout
380
+ return tiled_copy, tiler_mn, threads_per_row
436
381
 
437
382
  @cute.jit
438
383
  def __call__(
@@ -448,21 +393,24 @@ class CrossEntropyBackward:
448
393
  assert mX.element_type == self.dtype
449
394
  assert mdX.element_type == self.dtype
450
395
  # e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
451
- num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
452
- tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
453
- num_threads = cute.size(tv_layout, mode=[0])
396
+ vecsize = math.gcd(self.N, 128 // self.dtype.width)
397
+ tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(vecsize=vecsize)
398
+ num_threads = tiled_copy.size
454
399
  # (M,) -> (M, N) with stride 0 in the N dimension
455
400
  mDLoss, mTarget, mLSE = [
456
- cute.make_tensor(
457
- X.iterator, cute.append(X.layout, cute.make_layout((self.N,), stride=(0,)))
458
- )
459
- for X in (mDLoss, mTarget, mLSE)
401
+ layout_utils.expand(X, dim=1, size=self.N) for X in (mDLoss, mTarget, mLSE)
460
402
  ]
461
- smem_size = cute.size_in_bytes(
462
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
463
- )
464
403
  self.kernel(
465
- mX, mTarget, mDLoss, mdX, mLSE, ignore_index, mX.shape, tv_layout, tiler_mn
404
+ mX,
405
+ mTarget,
406
+ mDLoss,
407
+ mdX,
408
+ mLSE,
409
+ ignore_index,
410
+ mX.shape,
411
+ tiler_mn,
412
+ tiled_copy,
413
+ threads_per_row,
466
414
  ).launch(
467
415
  grid=[
468
416
  cute.ceil_div(mX.shape[0], tiler_mn[0]),
@@ -470,7 +418,6 @@ class CrossEntropyBackward:
470
418
  1,
471
419
  ],
472
420
  block=[num_threads, 1, 1],
473
- smem=smem_size,
474
421
  stream=stream,
475
422
  )
476
423
 
@@ -484,52 +431,39 @@ class CrossEntropyBackward:
484
431
  mLSE: cute.Tensor, # (M,)
485
432
  ignore_index: Int32, # Index to ignore in gradient computation
486
433
  shape: cute.Shape,
487
- tv_layout: cute.Layout,
488
434
  tiler_mn: cute.Shape,
435
+ tiled_copy: cute.TiledCopy,
436
+ threads_per_row: cutlass.Constexpr[int],
489
437
  ):
490
438
  tidx, _, _ = cute.arch.thread_idx()
491
439
  bidx, bidy, _ = cute.arch.block_idx()
492
440
 
493
441
  smem = cutlass.utils.SmemAllocator()
494
442
  sX = smem.allocate_tensor(
495
- mX.element_type,
496
- cute.make_ordered_layout(tiler_mn, order=(1, 0)),
497
- byte_alignment=16,
443
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
498
444
  )
499
445
 
500
446
  idX = cute.make_identity_tensor(shape)
501
- # We use domain_offset_i64 to deal with tensors larger than 2^31 elements
502
- mX, mdX = [utils.domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)]
503
- gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
504
- cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
505
-
506
- num_copy_elems_X = tv_layout.shape[1][0]
507
- num_copy_bits_X = mX.element_type.width * num_copy_elems_X
508
- copy_atom_load_X = cute.make_copy_atom(
509
- cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
510
- )
511
- copy_atom_store_dX = cute.make_copy_atom(
512
- cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=num_copy_bits_X
513
- )
514
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
515
- thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
516
-
517
- #### Partition to get thread view
518
- tXgX = thr_copy_X.partition_S(gX)
519
- tXsX = thr_copy_X.partition_S(sX)
520
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
521
- tXcFull = thr_copy_X.partition_S(cX)
522
- tXgdX = thr_copy_dX.partition_D(gdX)
523
- # allocate fragments for gmem->rmem
447
+ gX, gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX)]
448
+
449
+ thr_copy = tiled_copy.get_slice(tidx)
450
+
451
+ tXgX = thr_copy.partition_S(gX)
452
+ tXsX = thr_copy.partition_D(sX)
453
+ tXcX = thr_copy.partition_S(cX)[(0, None), None, None]
454
+ tXcFull = thr_copy.partition_S(cX)
455
+ tXgdX = thr_copy.partition_D(gdX)
524
456
  tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
525
457
 
526
458
  is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
527
- row = tXcX[0][0]
528
459
  tXpX = (
529
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
460
+ None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1])
530
461
  )
462
+ copy = partial(copy_utils.copy, pred=tXpX)
463
+
464
+ row = tXcX[0][0]
531
465
  if row < shape[0]:
532
- cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
466
+ copy(tXgX, tXsX, is_async=True)
533
467
  cute.arch.cp_async_commit_group()
534
468
  cute.arch.cp_async_wait_group(0)
535
469
  if const_expr(not is_even_N):
@@ -544,26 +478,22 @@ class CrossEntropyBackward:
544
478
  target = Int32(mTarget[row])
545
479
  should_ignore = Boolean(target == ignore_index)
546
480
  # Set dloss to 0 if this index should be ignored
547
- dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
481
+ if not should_ignore:
482
+ dloss = Float32(mDLoss[row])
548
483
  lse = Float32(mLSE[row])
549
484
 
550
485
  log2_e = math.log2(math.e)
551
486
  probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
552
487
  prob_shifted = probs - 1.0
553
- mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
488
+ mask = cute.make_fragment_like(tXrX, Boolean)
554
489
  for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
555
490
  mask[i] = tXcFull[i][1] == target
556
491
  grad = cute.where(mask.load(), prob_shifted, probs)
557
492
  grad = grad * dloss
558
493
 
559
494
  tXrdX.store(grad.to(tXrdX.element_type))
560
- tXpdX = (
561
- utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
562
- if not is_even_N
563
- else None
564
- )
565
495
  if row < shape[0]:
566
- cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
496
+ copy(tXrdX, tXgdX)
567
497
 
568
498
 
569
499
  def _cross_entropy_backward(
@@ -598,34 +528,28 @@ def _cross_entropy_backward(
598
528
 
599
529
  N = x.size(1)
600
530
  dtype = torch2cute_dtype_map[x.dtype]
601
-
602
- convert_from_dlpack = lambda tensor: (
603
- from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
604
- mode=0, stride_order=(0, 1)
605
- )
606
- )
607
- x_tensor = convert_from_dlpack(x)
608
- dx_tensor = convert_from_dlpack(dx)
609
- dloss_tensor = from_dlpack(dloss.detach(), assumed_align=4).mark_layout_dynamic()
610
- lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
611
- target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
612
- stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
613
-
614
- compile_key = (dtype, N, target.dtype, dloss.stride(), lse.stride(), target.stride())
531
+ target_dtype = torch2cute_dtype_map[target.dtype]
532
+ compile_key = (dtype, target_dtype, N)
615
533
  if compile_key not in _cross_entropy_backward.compile_cache:
534
+ batch_sym = cute.sym_int()
535
+ div = math.gcd(128 // dtype.width, N)
536
+ x_cute, dx_cute = [fake_tensor(dtype, (batch_sym, N), div)] * 2
537
+ target_cute = fake_tensor(target_dtype, (batch_sym,))
538
+ dloss_cute, lse_cute = [fake_tensor(Float32, (batch_sym,))] * 2
616
539
  cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
617
540
  _cross_entropy_backward.compile_cache[compile_key] = cute.compile(
618
541
  cross_entropy_backward_op,
619
- x_tensor,
620
- target_tensor,
621
- dloss_tensor,
622
- dx_tensor,
623
- lse_tensor,
624
- Int32(ignore_index),
625
- stream,
542
+ x_cute,
543
+ target_cute,
544
+ dloss_cute,
545
+ dx_cute,
546
+ lse_cute,
547
+ Int32(0), # ignore_index, just for compilation
548
+ cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),
549
+ options="--enable-tvm-ffi",
626
550
  )
627
551
  _cross_entropy_backward.compile_cache[compile_key](
628
- x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, Int32(ignore_index), stream
552
+ x, target, dloss, dx, lse, Int32(ignore_index)
629
553
  )
630
554
 
631
555