quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/softmax.py CHANGED
@@ -1,169 +1,191 @@
1
1
  import math
2
2
  import torch
3
- import operator
4
- from typing import Callable
3
+ from typing import Type
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
8
7
  import cutlass
9
8
  import cutlass.cute as cute
10
9
  from cutlass.cute.runtime import from_dlpack
11
- import cutlass.torch as cutlass_torch
12
10
 
13
11
  import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
14
13
 
15
14
 
16
- @cute.kernel
17
- def softmax_kernel(
18
- mX: cute.Tensor,
19
- mO: cute.Tensor,
20
- tv_layout: cute.Layout,
21
- tiler_mn: cute.Shape,
22
- cluster_n: cutlass.Constexpr = 1,
23
- ):
24
- tidx, _, _ = cute.arch.thread_idx()
25
- bidx, cluster_y, _ = cute.arch.block_idx()
26
-
27
- shape = mX.shape
28
- idX = cute.make_identity_tensor(shape)
29
- # slice for CTAs
30
- gX, gO, cX = [
31
- cute.local_tile(mT, tiler_mn, (bidx, 0 if cluster_n == 1 else cluster_y))
32
- for mT in (mX, mO, idX)
33
- ]
34
-
35
- smem = cutlass.utils.SmemAllocator()
36
- sX = smem.allocate_tensor(mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16)
37
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
38
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
39
- reduction_buffer_layout = cute.make_ordered_layout(
15
+ class Softmax(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
40
17
  # 2 stages: 1 for max, 1 for sum
41
- (num_warps // warps_per_row, (warps_per_row, cluster_n), 2),
42
- order=(1, 0, 2)
43
- )
44
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
45
- if cutlass.const_expr(cluster_n > 1):
46
- # 1 mbar for max reduction, 1 mbar for sum reduction
47
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=2)
48
- else:
49
- mbar_ptr = None
50
-
51
- # declare the atoms which will be used later for memory copy
52
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128)
53
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
54
-
55
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
56
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
57
-
58
- tXgX = thr_copy_X.partition_S(gX)
59
- tXsX = thr_copy_X.partition_D(sX)
60
- tXgO = thr_copy_O.partition_D(gO)
61
- tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
62
-
63
- # allocate fragments for gmem->rmem
64
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
65
-
66
- if cluster_n > 1:
67
- if tidx < 2:
68
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
69
- cute.arch.mbarrier_init_fence()
70
- if tidx < 2:
71
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
72
- # Cluster arrive after barrier init
73
- cute.arch.cluster_arrive_relaxed()
74
-
75
- is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * cluster_n)
76
- tXpX = utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
77
- if tXcX[0][0] < shape[0]:
78
- cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
79
- cute.arch.cp_async_commit_group()
80
- cute.arch.cp_async_wait_group(0)
81
-
82
- cute.autovec_copy(tXsX, tXrX)
83
- x = tXrX.load().to(cute.Float32)
84
- # Fill OOB values with -inf
85
- if cutlass.const_expr(not is_even_N):
86
- tXrX_fp32 = cute.make_fragment_like(tXrX, cutlass.Float32)
87
- tXrX_fp32.store(x)
88
- for rest_v in range(tXpX.shape[0]):
89
- for rest_k in range(tXpX.shape[2]):
90
- if not tXpX[rest_v, 0, rest_k]:
91
- tXrX_fp32[(None, rest_v), None, rest_k].fill(-cutlass.Float32.inf)
92
- x = tXrX_fp32.load()
93
- threads_per_row = tv_layout.shape[0][0]
94
- max_x = utils.row_reduce(
95
- x,
96
- cute.ReductionOp.MAX,
97
- threads_per_row,
98
- reduction_buffer[None, None, 0],
99
- mbar_ptr + 0 if cluster_n > 1 else None,
100
- init_val=-cutlass.Float32.inf,
101
- hook_fn=cute.arch.cluster_wait if cutlass.const_expr(cluster_n > 1) else None
102
- )
103
- log2_e = math.log2(math.e)
104
- exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
105
- denom = utils.row_reduce(
106
- exp_x,
107
- cute.ReductionOp.ADD,
108
- threads_per_row,
109
- reduction_buffer[None, None, 1],
110
- mbar_ptr + 1 if cluster_n > 1 else None,
111
- init_val=0.0,
112
- )
113
- inv = 1.0 / denom
114
- y = exp_x * inv
115
- tXrO.store(y.to(tXrO.element_type))
116
- tOpO = utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
117
- if tXcX[0][0] < shape[0]:
118
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
119
-
120
-
121
- @cute.jit
122
- def softmax_interface(
123
- mX: cute.Tensor,
124
- mO: cute.Tensor,
125
- stream: cuda.CUstream,
126
- N: cutlass.Constexpr,
127
- copy_bits: cutlass.Constexpr = 128
128
- ):
129
- vecsize = copy_bits // mX.element_type.width
130
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
131
- num_threads = 128 if N <= 16384 else 256
132
- num_warps = num_threads // cute.arch.WARP_SIZE
133
- assert num_threads % cute.arch.WARP_SIZE == 0
134
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
135
- if cutlass.const_expr(mX.element_type.width == 16):
136
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
137
- else: # fp32
138
- cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
139
-
140
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
141
- cols_per_block = num_threads // threads_per_row
142
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
143
- tv_layout = cute.make_layout(
144
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
145
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
146
- )
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
23
+ )
24
+ self.online_softmax = online_softmax
147
25
 
148
- smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(tiler_mn)) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
149
- softmax_kernel(mX, mO, tv_layout, tiler_mn, cluster_n).launch(
150
- grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
151
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
152
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
153
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
154
- smem=smem_allocated,
155
- stream=stream,
156
- )
26
+ def _calculate_threads_per_row(self):
27
+ N = self.N
28
+ return (
29
+ 8
30
+ if N <= 64
31
+ else (
32
+ 16
33
+ if N <= 128
34
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
35
+ )
36
+ )
157
37
 
38
+ def _set_cluster_n(self):
39
+ N = self.N
40
+ if cutlass.const_expr(self.dtype.width == 16):
41
+ cluster_n = (
42
+ 1
43
+ if N <= 16 * 1024
44
+ else (
45
+ 2
46
+ if N <= 32 * 1024
47
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
48
+ )
49
+ )
50
+ else: # fp32
51
+ cluster_n = (
52
+ 1
53
+ if N <= 32 * 1024
54
+ else (
55
+ 2
56
+ if N <= 64 * 1024
57
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
58
+ )
59
+ )
60
+ self.cluster_n = cluster_n
158
61
 
159
- torch2cute_dtype_map = {
160
- torch.float16: cutlass.Float16,
161
- torch.bfloat16: cutlass.BFloat16,
162
- torch.float32: cutlass.Float32,
163
- }
62
+ @cute.jit
63
+ def __call__(
64
+ self,
65
+ mX: cute.Tensor,
66
+ mO: cute.Tensor,
67
+ stream: cuda.CUstream,
68
+ ):
69
+ assert mX.element_type == self.dtype
70
+ assert mO.element_type == self.dtype
71
+ self._set_cluster_n()
72
+ tiler_mn, tv_layout = self._get_tv_layout()
73
+ num_threads = cute.size(tv_layout, mode=[0])
74
+ num_warps = num_threads // cute.arch.WARP_SIZE
75
+ self.kernel(mX, mO, tv_layout, tiler_mn).launch(
76
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
77
+ block=[num_threads, 1, 1],
78
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
79
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
80
+ stream=stream,
81
+ )
164
82
 
83
+ @cute.kernel
84
+ def kernel(
85
+ self,
86
+ mX: cute.Tensor,
87
+ mO: cute.Tensor,
88
+ tv_layout: cute.Layout,
89
+ tiler_mn: cute.Shape,
90
+ ):
91
+ tidx, _, _ = cute.arch.thread_idx()
92
+ bidx, _, _ = cute.arch.block_idx()
93
+ if cutlass.const_expr(self.cluster_n > 1):
94
+ cluster_y = cute.arch.block_idx()[1]
95
+ else:
96
+ cluster_y = cutlass.const_expr(0)
165
97
 
166
- def softmax(x: torch.Tensor) -> torch.Tensor:
98
+ shape = mX.shape
99
+ idX = cute.make_identity_tensor(shape)
100
+ # slice for CTAs
101
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
102
+
103
+ smem = cutlass.utils.SmemAllocator()
104
+ sX = smem.allocate_tensor(
105
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
106
+ )
107
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
108
+
109
+ # declare the atoms which will be used later for memory copy
110
+ copy_atom_load_X = cute.make_copy_atom(
111
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
112
+ )
113
+ copy_atom_store_O = cute.make_copy_atom(
114
+ cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128
115
+ )
116
+
117
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
118
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
119
+
120
+ tXgX = thr_copy_X.partition_S(gX)
121
+ tXsX = thr_copy_X.partition_D(sX)
122
+ tXgO = thr_copy_O.partition_D(gO)
123
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
124
+
125
+ # allocate fragments for gmem->rmem
126
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
127
+
128
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
129
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
130
+
131
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
132
+ tXpX = (
133
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
134
+ if cutlass.const_expr(not is_even_N)
135
+ else None
136
+ )
137
+ if tXcX[0][0] < shape[0]:
138
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
139
+ cute.arch.cp_async_commit_group()
140
+ cute.arch.cp_async_wait_group(0)
141
+ # Fill OOB values with -inf
142
+ if cutlass.const_expr(not is_even_N):
143
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
144
+
145
+ cute.autovec_copy(tXsX, tXrX)
146
+ x = tXrX.load().to(cute.Float32)
147
+ threads_per_row = tv_layout.shape[0][0]
148
+ if cutlass.const_expr(not self.online_softmax):
149
+ max_x = utils.row_reduce(
150
+ x,
151
+ cute.ReductionOp.MAX,
152
+ threads_per_row,
153
+ reduction_buffer[None, None, 0],
154
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
155
+ init_val=-cutlass.Float32.inf,
156
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
157
+ )
158
+ log2_e = math.log2(math.e)
159
+ exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
160
+ denom = utils.row_reduce(
161
+ exp_x,
162
+ cute.ReductionOp.ADD,
163
+ threads_per_row,
164
+ reduction_buffer[None, None, 1],
165
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
166
+ init_val=0.0,
167
+ )
168
+ else:
169
+ max_x, denom, exp_x = utils.online_softmax_reduce(
170
+ x,
171
+ threads_per_row,
172
+ reduction_buffer[None, None, 0],
173
+ mbar_ptr,
174
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
175
+ return_exp_x=True,
176
+ )
177
+ y = exp_x * (1.0 / denom)
178
+ tXrO.store(y.to(tXrO.element_type))
179
+ tOpO = (
180
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
181
+ if cutlass.const_expr(not is_even_N)
182
+ else None
183
+ )
184
+ if tXcX[0][0] < shape[0]:
185
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
186
+
187
+
188
+ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
167
189
  """Softmax forward pass.
168
190
  Args:
169
191
  x: Input tensor of shape (M, N)
@@ -174,22 +196,261 @@ def softmax(x: torch.Tensor) -> torch.Tensor:
174
196
  assert x.is_cuda, "Tensor must be on CUDA device"
175
197
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
176
198
  M, N = x.shape
177
- device = x.device
178
199
  out = torch.empty_like(x)
179
200
  dtype = torch2cute_dtype_map[x.dtype]
180
201
  convert_from_dlpack = lambda tensor: (
181
- from_dlpack(tensor.detach(), assumed_align=16)
182
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
202
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
203
+ mode=0, stride_order=(0, 1)
204
+ )
183
205
  )
184
206
  x_tensor, out_tensor = [convert_from_dlpack(tensor) for tensor in (x, out)]
185
207
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
186
208
  compile_key = (dtype, N)
187
- if compile_key not in softmax.compile_cache:
188
- softmax.compile_cache[compile_key] = cute.compile(
189
- softmax_interface, x_tensor, out_tensor, current_stream, N
209
+ if compile_key not in _softmax_fwd.compile_cache:
210
+ softmax_op = Softmax(dtype, N)
211
+ _softmax_fwd.compile_cache[compile_key] = cute.compile(
212
+ softmax_op, x_tensor, out_tensor, current_stream
190
213
  )
191
- softmax.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
214
+ _softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
192
215
  return out
193
216
 
194
217
 
195
- softmax.compile_cache = {}
218
+ _softmax_fwd.compile_cache = {}
219
+
220
+
221
+ class SoftmaxBackward(ReductionBase):
222
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
223
+ # 1 stage for computing dot product
224
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
225
+
226
+ def _calculate_threads_per_row(self):
227
+ N = self.N
228
+ return (
229
+ 8
230
+ if N <= 64
231
+ else (
232
+ 16
233
+ if N <= 128
234
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 8192 else 256)))
235
+ )
236
+ )
237
+
238
+ def _set_cluster_n(self):
239
+ N = self.N
240
+ if cutlass.const_expr(self.dtype.width == 16):
241
+ cluster_n = (
242
+ 1
243
+ if N <= 16 * 1024
244
+ else (
245
+ 2
246
+ if N <= 32 * 1024
247
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
248
+ )
249
+ )
250
+ else: # fp32
251
+ cluster_n = (
252
+ 1
253
+ if N <= 16 * 1024
254
+ else (
255
+ 2
256
+ if N <= 32 * 1024
257
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
258
+ )
259
+ )
260
+ self.cluster_n = cluster_n
261
+
262
+ def _get_num_threads(self):
263
+ return 128 if self.N <= 8192 else 256
264
+
265
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
266
+ return (
267
+ # Multiply by 2 since we need space for Y and dY
268
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
269
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
270
+ + self.stage * (cutlass.Int64.width // 8)
271
+ )
272
+
273
+ @cute.jit
274
+ def __call__(
275
+ self,
276
+ mdY: cute.Tensor,
277
+ mY: cute.Tensor,
278
+ mdX: cute.Tensor,
279
+ stream: cuda.CUstream,
280
+ ):
281
+ assert mdY.element_type == self.dtype
282
+ assert mY.element_type == self.dtype
283
+ assert mdX.element_type == self.dtype
284
+ self._set_cluster_n()
285
+ tiler_mn, tv_layout = self._get_tv_layout()
286
+ num_threads = cute.size(tv_layout, mode=[0])
287
+ num_warps = num_threads // cute.arch.WARP_SIZE
288
+ self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
289
+ grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
290
+ block=[num_threads, 1, 1],
291
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
292
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
293
+ stream=stream,
294
+ )
295
+
296
+ @cute.kernel
297
+ def kernel(
298
+ self,
299
+ mdY: cute.Tensor,
300
+ mY: cute.Tensor,
301
+ mdX: cute.Tensor,
302
+ tv_layout: cute.Layout,
303
+ tiler_mn: cute.Shape,
304
+ ):
305
+ tidx, _, _ = cute.arch.thread_idx()
306
+ bidx, _, _ = cute.arch.block_idx()
307
+ if cutlass.const_expr(self.cluster_n > 1):
308
+ cluster_y = cute.arch.block_idx()[1]
309
+ else:
310
+ cluster_y = cutlass.const_expr(0)
311
+
312
+ shape = mdY.shape
313
+ idX = cute.make_identity_tensor(shape)
314
+ # slice for CTAs
315
+ gdY, gY, gdX, cX = [
316
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
317
+ ]
318
+
319
+ smem = cutlass.utils.SmemAllocator()
320
+ sdY = smem.allocate_tensor(
321
+ mdY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
322
+ )
323
+ sY = smem.allocate_tensor(
324
+ mY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
325
+ )
326
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
327
+
328
+ # declare the atoms which will be used later for memory copy
329
+ copy_atom_load = cute.make_copy_atom(
330
+ cute.nvgpu.cpasync.CopyG2SOp(), mdY.element_type, num_bits_per_copy=128
331
+ )
332
+ copy_atom_store = cute.make_copy_atom(
333
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
334
+ )
335
+
336
+ thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
337
+ thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
338
+
339
+ tdYgdY = thr_copy_load.partition_S(gdY)
340
+ tdYsdY = thr_copy_load.partition_D(sdY)
341
+ tYgY = thr_copy_load.partition_S(gY)
342
+ tYsY = thr_copy_load.partition_D(sY)
343
+ tdXgdX = thr_copy_store.partition_D(gdX)
344
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
345
+
346
+ # allocate fragments for gmem->rmem
347
+ tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
348
+
349
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
350
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
351
+
352
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
353
+ tdYpdY = (
354
+ utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
355
+ if cutlass.const_expr(not is_even_N)
356
+ else None
357
+ )
358
+
359
+ if tXcX[0][0] < shape[0]:
360
+ cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
361
+ cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
362
+ cute.arch.cp_async_commit_group()
363
+ cute.arch.cp_async_wait_group(0)
364
+
365
+ cute.autovec_copy(tdYsdY, tdYrdY)
366
+ cute.autovec_copy(tYsY, tYrY)
367
+ dy = tdYrdY.load().to(cute.Float32)
368
+ y = tYrY.load().to(cute.Float32)
369
+
370
+ # Compute dot product: dot = Σⱼ dy_j × y_j
371
+ threads_per_row = tv_layout.shape[0][0]
372
+ dot = utils.row_reduce(
373
+ dy * y,
374
+ cute.ReductionOp.ADD,
375
+ threads_per_row,
376
+ reduction_buffer[None, None, 0],
377
+ mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
378
+ init_val=0.0,
379
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
380
+ )
381
+
382
+ # Compute gradient: dx_i = y_i × (dy_i - dot)
383
+ dx = y * (dy - dot)
384
+ tdXrdX.store(dx.to(tdXrdX.element_type))
385
+ tdXpdX = (
386
+ utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
387
+ if cutlass.const_expr(not is_even_N)
388
+ else None
389
+ )
390
+ if tXcX[0][0] < shape[0]:
391
+ cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
392
+
393
+
394
+ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
395
+ """Softmax backward pass.
396
+ Args:
397
+ dy: Upstream gradients tensor of shape (M, N)
398
+ y: Softmax output tensor of shape (M, N)
399
+ Returns:
400
+ Input gradients tensor of same shape as dy and y
401
+ """
402
+ assert dy.dim() == 2, "dy must be 2D"
403
+ assert y.dim() == 2, "y must be 2D"
404
+ assert dy.shape == y.shape, "dy and y must have same shape"
405
+ assert dy.is_cuda and y.is_cuda, "Tensors must be on CUDA device"
406
+ assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
407
+ assert y.dtype == dy.dtype, "dy and y must have same dtype"
408
+
409
+ M, N = dy.shape
410
+ dx = torch.empty_like(dy)
411
+ dtype = torch2cute_dtype_map[dy.dtype]
412
+ convert_from_dlpack = lambda tensor: (
413
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
414
+ mode=0, stride_order=(0, 1)
415
+ )
416
+ )
417
+ dy_tensor, y_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (dy, y, dx)]
418
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
419
+
420
+ compile_key = (dtype, N)
421
+ if compile_key not in _softmax_backward.compile_cache:
422
+ softmax_backward_op = SoftmaxBackward(dtype, N)
423
+ _softmax_backward.compile_cache[compile_key] = cute.compile(
424
+ softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
425
+ )
426
+ _softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
427
+ return dx
428
+
429
+
430
+ _softmax_backward.compile_cache = {}
431
+
432
+
433
+ class SoftmaxFunction(torch.autograd.Function):
434
+ @staticmethod
435
+ def forward(ctx, x):
436
+ y = _softmax_fwd(x)
437
+ ctx.save_for_backward(y)
438
+ return y
439
+
440
+ @staticmethod
441
+ def backward(ctx, dy):
442
+ (y,) = ctx.saved_tensors
443
+ dx = _softmax_backward(dy, y)
444
+ return dx
445
+
446
+
447
+ def softmax(x: torch.Tensor) -> torch.Tensor:
448
+ """Softmax forward pass with automatic differentiation support.
449
+
450
+ Args:
451
+ x: Input tensor of shape (M, N)
452
+
453
+ Returns:
454
+ Softmax output tensor of same shape as x
455
+ """
456
+ return SoftmaxFunction.apply(x)