quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/softmax.py CHANGED
@@ -1,176 +1,186 @@
1
1
  import math
2
2
  import torch
3
- import operator
4
- from typing import Callable
3
+ from typing import Type
5
4
 
6
5
  import cuda.bindings.driver as cuda
7
6
 
8
7
  import cutlass
9
8
  import cutlass.cute as cute
10
9
  from cutlass.cute.runtime import from_dlpack
11
- import cutlass.torch as cutlass_torch
12
10
 
13
11
  import quack.utils as utils
12
+ from quack.reduction_base import ReductionBase, torch2cute_dtype_map
14
13
 
15
14
 
16
- @cute.kernel
17
- def softmax_kernel(
18
- gX: cute.Tensor,
19
- gO: cute.Tensor,
20
- cX: cute.Tensor, # coordinate tensor
21
- shape: cute.Shape,
22
- tv_layout: cute.Layout,
23
- tiler_mn: cute.Shape,
24
- cluster_n: cutlass.Constexpr = 1,
25
- ):
26
- tidx, _, _ = cute.arch.thread_idx()
27
- bidx, cluster_y, _ = cute.arch.block_idx()
28
- gdim, _, _ = cute.arch.grid_dim()
29
-
30
- # slice for CTAs
31
- # logical id -> address
32
- blkX, blkOut, blkCrd = [gT[(None, None), bidx if cluster_n == 1 else (bidx, cluster_y)] for gT in (gX, gO, cX)]
33
-
34
- # declare the atoms which will be used later for memory copy
35
- copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128)
36
- copy_atom_load_X_async = cute.make_copy_atom(cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128)
37
- copy_atom_store_O = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128)
38
-
39
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
40
- thr_copy_X_async = cute.make_tiled_copy(copy_atom_load_X_async, tv_layout, tiler_mn).get_slice(tidx)
41
- thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
42
-
43
- smem = cutlass.utils.SmemAllocator()
44
- # Don't use blkX.layout here, because the stride is N, not N_rounded
45
- sX = smem.allocate_tensor(gX.element_type, cute.make_ordered_layout(blkX.shape, order=(1, 0)), byte_alignment=16)
46
- num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
47
- warps_per_row = utils.max_constexpr(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
48
-
49
- reduction_buffer_layout = cute.make_ordered_layout(
15
+ class Softmax(ReductionBase):
16
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int, online_softmax: bool = True):
50
17
  # 2 stages: 1 for max, 1 for sum
51
- (num_warps // warps_per_row, warps_per_row if cluster_n == 1 else (warps_per_row, cluster_n), 2),
52
- order=(1, 0, 2)
53
- )
54
- reduction_buffer = smem.allocate_tensor(cutlass.Float32, reduction_buffer_layout, byte_alignment=4)
55
- if cutlass.const_expr(cluster_n > 1):
56
- # 1 mbar for max reduction, 1 mbar for sum reduction
57
- mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=2)
58
- else:
59
- mbar_ptr = None
60
-
61
- tXgX = thr_copy_X_async.partition_S(blkX)
62
- tXsX = thr_copy_X_async.partition_S(sX)
63
- tXgO = thr_copy_O.partition_D(blkOut)
64
- tXcX = thr_copy_X.partition_S(blkCrd)[(0, None), None, None]
65
-
66
- # allocate fragments for gmem->rmem
67
- tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
68
-
69
- if cluster_n > 1:
70
- if tidx < 2:
71
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
72
- cute.arch.mbarrier_init_fence()
73
- if tidx < 2:
74
- cute.arch.mbarrier_init_tx_bytes(mbar_ptr + tidx, num_warps * cluster_n * cutlass.Float32.width // 8)
75
- # Cluster arrive after barrier init
76
- cute.arch.cluster_arrive_relaxed()
77
-
78
- tXpX = cute.make_fragment_like(tXgX[(0, None), None, None], cutlass.Boolean)
79
- for i in range(cute.size(tXpX)):
80
- tXpX[i] = cute.elem_less(tXcX[i][1], shape[1])
81
-
82
- if tXcX[0][0] < shape[0]:
83
- cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
84
- cute.arch.cp_async_commit_group()
85
- cute.arch.cp_async_wait_group(0)
86
-
87
- cute.autovec_copy(tXsX, tXrX)
88
- x = tXrX.load().to(cute.Float32)
89
- max_x = utils.warp_reduce(
90
- x.reduce(cute.ReductionOp.MAX, init_val=float('-inf'), reduction_profile=0),
91
- cute.arch.fmax,
92
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
93
- )
94
- if cutlass.const_expr(cluster_n > 1):
95
- cute.arch.cluster_wait()
96
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
97
- max_mbar_ptr = mbar_ptr + 0 if cluster_n > 1 else None
98
- max_x = utils.block_or_cluster_reduce(
99
- max_x, cute.arch.fmax, reduction_buffer[None, None, 0], max_mbar_ptr, init_val=-cutlass.Float32.inf
100
- )
101
- log2_e = math.log2(math.e)
102
- exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
103
- denom = utils.warp_reduce(
104
- exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
105
- operator.add,
106
- width=utils.min_constexpr(tv_layout.shape[0][0], cute.arch.WARP_SIZE),
107
- )
108
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
109
- sum_mbar_ptr = mbar_ptr + 1 if cluster_n > 1 else None
110
- denom = utils.block_or_cluster_reduce(
111
- denom, operator.add, reduction_buffer[None, None, 1], sum_mbar_ptr, init_val=0.0
112
- )
113
- inv = 1.0 / denom
114
- y = exp_x * inv
115
-
116
- tXrO.store(y.to(tXrO.element_type))
117
- tOcX = thr_copy_O.partition_S(blkCrd)[(0, None), None, None]
118
- tOpO = cute.make_fragment_like(tXgO[(0, None), None, None], cutlass.Boolean)
119
- for i in range(cute.size(tOpO)):
120
- tOpO[i] = cute.elem_less(tOcX[i][1], shape[1])
121
- if tXcX[0][0] < shape[0]:
122
- cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
123
-
124
-
125
- @cute.jit
126
- def softmax_interface(
127
- mX: cute.Tensor,
128
- mOut: cute.Tensor,
129
- stream: cuda.CUstream,
130
- N: cutlass.Constexpr,
131
- copy_bits: cutlass.Constexpr = 128
132
- ):
133
- vecsize = copy_bits // mX.element_type.width
134
- assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}"
135
- num_threads = 128 if N <= 16384 else 256
136
- num_warps = num_threads // cute.arch.WARP_SIZE
137
- assert num_threads % cute.arch.WARP_SIZE == 0
138
- threads_per_row = 8 if N <= 64 else (16 if N <= 128 else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))))
139
- if cutlass.const_expr(mX.element_type.width == 16):
140
- cluster_n = 1 if N <= 16 * 1024 else (2 if N <= 32 * 1024 else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)))
141
- else: # fp32
142
- cluster_n = 1 if N <= 32 * 1024 else (2 if N <= 64 * 1024 else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)))
143
-
144
- num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row * cluster_n)
145
- cols_per_block = num_threads // threads_per_row
146
- tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) # This rounds up N
147
- tv_layout = cute.make_layout(
148
- ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
149
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * threads_per_row))
150
- )
18
+ super().__init__(
19
+ dtype,
20
+ N,
21
+ stage=2 if not online_softmax else 1,
22
+ reduction_dtype=cutlass.Float32 if not online_softmax else cutlass.Int64,
23
+ )
24
+ self.online_softmax = online_softmax
151
25
 
152
- idX = cute.make_identity_tensor(mX.shape)
153
- gX, gO, cX = [cute.zipped_divide(mT, tiler_mn) for mT in (mX, mOut, idX)] # ((TileM,TileN),(RestM,RestN))
154
-
155
- smem_allocated = cute.size_in_bytes(mX.element_type, cute.make_layout(gX.shape[0])) + 2 * num_warps * cluster_n * (cutlass.Float32.width // 8) + 2 * (cutlass.Int64.width // 8)
156
- softmax_kernel(gX, gO, cX, mX.shape, tv_layout, tiler_mn, cluster_n).launch(
157
- grid=[cute.size(gX, mode=[1, 0]), cluster_n, 1],
158
- block=[cute.size(tv_layout, mode=[0]), 1, 1],
159
- # Launching with cluster=[1, 1, 1] instead of None slows down the kernel by ~8us
160
- cluster=[1, cluster_n, 1] if cluster_n > 1 else None,
161
- smem=smem_allocated,
162
- stream=stream,
163
- )
26
+ def _calculate_threads_per_row(self):
27
+ N = self.N
28
+ return (
29
+ 8
30
+ if N <= 64
31
+ else (
32
+ 16
33
+ if N <= 128
34
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
35
+ )
36
+ )
164
37
 
38
+ def _set_cluster_n(self):
39
+ N = self.N
40
+ if cutlass.const_expr(self.dtype.width == 16):
41
+ cluster_n = (
42
+ 1
43
+ if N <= 16 * 1024
44
+ else (
45
+ 2
46
+ if N <= 32 * 1024
47
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
48
+ )
49
+ )
50
+ else: # fp32
51
+ cluster_n = (
52
+ 1
53
+ if N <= 32 * 1024
54
+ else (
55
+ 2
56
+ if N <= 64 * 1024
57
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
58
+ )
59
+ )
60
+ self.cluster_n = cluster_n
165
61
 
166
- torch2cute_dtype_map = {
167
- torch.float16: cutlass.Float16,
168
- torch.bfloat16: cutlass.BFloat16,
169
- torch.float32: cutlass.Float32,
170
- }
62
+ @cute.jit
63
+ def __call__(
64
+ self,
65
+ mX: cute.Tensor,
66
+ mO: cute.Tensor,
67
+ stream: cuda.CUstream,
68
+ ):
69
+ assert mX.element_type == self.dtype
70
+ assert mO.element_type == self.dtype
71
+ self._set_cluster_n()
72
+ tiler_mn, tv_layout = self._get_tv_layout()
73
+ num_threads = cute.size(tv_layout, mode=[0])
74
+ num_warps = num_threads // cute.arch.WARP_SIZE
75
+ self.kernel(mX, mO, tv_layout, tiler_mn).launch(
76
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
77
+ block=[num_threads, 1, 1],
78
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
79
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
80
+ stream=stream,
81
+ )
171
82
 
83
+ @cute.kernel
84
+ def kernel(
85
+ self,
86
+ mX: cute.Tensor,
87
+ mO: cute.Tensor,
88
+ tv_layout: cute.Layout,
89
+ tiler_mn: cute.Shape,
90
+ ):
91
+ tidx, _, _ = cute.arch.thread_idx()
92
+ bidx, cluster_y, _ = cute.arch.block_idx()
172
93
 
173
- def softmax(x: torch.Tensor) -> torch.Tensor:
94
+ shape = mX.shape
95
+ idX = cute.make_identity_tensor(shape)
96
+ # slice for CTAs
97
+ gX, gO, cX = [
98
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
99
+ for mT in (mX, mO, idX)
100
+ ]
101
+
102
+ smem = cutlass.utils.SmemAllocator()
103
+ sX = smem.allocate_tensor(
104
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
105
+ )
106
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
107
+
108
+ # declare the atoms which will be used later for memory copy
109
+ copy_atom_load_X = cute.make_copy_atom(
110
+ cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128
111
+ )
112
+ copy_atom_store_O = cute.make_copy_atom(
113
+ cute.nvgpu.CopyUniversalOp(), gO.element_type, num_bits_per_copy=128
114
+ )
115
+
116
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
117
+ thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
118
+
119
+ tXgX = thr_copy_X.partition_S(gX)
120
+ tXsX = thr_copy_X.partition_D(sX)
121
+ tXgO = thr_copy_O.partition_D(gO)
122
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
123
+
124
+ # allocate fragments for gmem->rmem
125
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
126
+
127
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
128
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
129
+
130
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
131
+ tXpX = (
132
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
133
+ )
134
+ if tXcX[0][0] < shape[0]:
135
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
136
+ cute.arch.cp_async_commit_group()
137
+ cute.arch.cp_async_wait_group(0)
138
+ # Fill OOB values with -inf
139
+ if cutlass.const_expr(not is_even_N):
140
+ utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
141
+
142
+ cute.autovec_copy(tXsX, tXrX)
143
+ x = tXrX.load().to(cute.Float32)
144
+ threads_per_row = tv_layout.shape[0][0]
145
+ if cutlass.const_expr(not self.online_softmax):
146
+ max_x = utils.row_reduce(
147
+ x,
148
+ cute.ReductionOp.MAX,
149
+ threads_per_row,
150
+ reduction_buffer[None, None, 0],
151
+ mbar_ptr + 0 if self.cluster_n > 1 else None,
152
+ init_val=-cutlass.Float32.inf,
153
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
154
+ )
155
+ log2_e = math.log2(math.e)
156
+ exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
157
+ denom = utils.row_reduce(
158
+ exp_x,
159
+ cute.ReductionOp.ADD,
160
+ threads_per_row,
161
+ reduction_buffer[None, None, 1],
162
+ mbar_ptr + 1 if self.cluster_n > 1 else None,
163
+ init_val=0.0,
164
+ )
165
+ else:
166
+ max_x, denom, exp_x = utils.online_softmax_reduce(
167
+ x,
168
+ threads_per_row,
169
+ reduction_buffer[None, None, 0],
170
+ mbar_ptr,
171
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
172
+ return_exp_x=True,
173
+ )
174
+ y = exp_x * (1.0 / denom)
175
+ tXrO.store(y.to(tXrO.element_type))
176
+ tOpO = (
177
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
178
+ )
179
+ if tXcX[0][0] < shape[0]:
180
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
181
+
182
+
183
+ def _softmax_fwd(x: torch.Tensor) -> torch.Tensor:
174
184
  """Softmax forward pass.
175
185
  Args:
176
186
  x: Input tensor of shape (M, N)
@@ -181,22 +191,258 @@ def softmax(x: torch.Tensor) -> torch.Tensor:
181
191
  assert x.is_cuda, "Tensor must be on CUDA device"
182
192
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
183
193
  M, N = x.shape
184
- device = x.device
185
194
  out = torch.empty_like(x)
186
195
  dtype = torch2cute_dtype_map[x.dtype]
187
196
  convert_from_dlpack = lambda tensor: (
188
- from_dlpack(tensor.detach(), assumed_align=16)
189
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
197
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
198
+ mode=0, stride_order=(0, 1)
199
+ )
190
200
  )
191
201
  x_tensor, out_tensor = [convert_from_dlpack(tensor) for tensor in (x, out)]
192
202
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
193
203
  compile_key = (dtype, N)
194
- if compile_key not in softmax.compile_cache:
195
- softmax.compile_cache[compile_key] = cute.compile(
196
- softmax_interface, x_tensor, out_tensor, current_stream, N
204
+ if compile_key not in _softmax_fwd.compile_cache:
205
+ softmax_op = Softmax(dtype, N)
206
+ _softmax_fwd.compile_cache[compile_key] = cute.compile(
207
+ softmax_op, x_tensor, out_tensor, current_stream
197
208
  )
198
- softmax.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
209
+ _softmax_fwd.compile_cache[compile_key](x_tensor, out_tensor, current_stream)
199
210
  return out
200
211
 
201
212
 
202
- softmax.compile_cache = {}
213
+ _softmax_fwd.compile_cache = {}
214
+
215
+
216
+ class SoftmaxBackward(ReductionBase):
217
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
218
+ # 1 stage for computing dot product
219
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
220
+
221
+ def _calculate_threads_per_row(self):
222
+ N = self.N
223
+ return (
224
+ 8
225
+ if N <= 64
226
+ else (
227
+ 16
228
+ if N <= 128
229
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 8192 else 256)))
230
+ )
231
+ )
232
+
233
+ def _set_cluster_n(self):
234
+ N = self.N
235
+ if cutlass.const_expr(self.dtype.width == 16):
236
+ cluster_n = (
237
+ 1
238
+ if N <= 16 * 1024
239
+ else (
240
+ 2
241
+ if N <= 32 * 1024
242
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
243
+ )
244
+ )
245
+ else: # fp32
246
+ cluster_n = (
247
+ 1
248
+ if N <= 16 * 1024
249
+ else (
250
+ 2
251
+ if N <= 32 * 1024
252
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
253
+ )
254
+ )
255
+ self.cluster_n = cluster_n
256
+
257
+ def _get_num_threads(self):
258
+ return 128 if self.N <= 8192 else 256
259
+
260
+ def _smem_size_in_bytes(self, tiler_mn, num_warps):
261
+ return (
262
+ # Multiply by 2 since we need space for Y and dY
263
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
264
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
265
+ + self.stage * (cutlass.Int64.width // 8)
266
+ )
267
+
268
+ @cute.jit
269
+ def __call__(
270
+ self,
271
+ mdY: cute.Tensor,
272
+ mY: cute.Tensor,
273
+ mdX: cute.Tensor,
274
+ stream: cuda.CUstream,
275
+ ):
276
+ assert mdY.element_type == self.dtype
277
+ assert mY.element_type == self.dtype
278
+ assert mdX.element_type == self.dtype
279
+ self._set_cluster_n()
280
+ tiler_mn, tv_layout = self._get_tv_layout()
281
+ num_threads = cute.size(tv_layout, mode=[0])
282
+ num_warps = num_threads // cute.arch.WARP_SIZE
283
+ self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
284
+ grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
285
+ block=[num_threads, 1, 1],
286
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
287
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
288
+ stream=stream,
289
+ )
290
+
291
+ @cute.kernel
292
+ def kernel(
293
+ self,
294
+ mdY: cute.Tensor,
295
+ mY: cute.Tensor,
296
+ mdX: cute.Tensor,
297
+ tv_layout: cute.Layout,
298
+ tiler_mn: cute.Shape,
299
+ ):
300
+ tidx, _, _ = cute.arch.thread_idx()
301
+ bidx, cluster_y, _ = cute.arch.block_idx()
302
+
303
+ shape = mdY.shape
304
+ idX = cute.make_identity_tensor(shape)
305
+ # slice for CTAs
306
+ gdY, gY, gdX, cX = [
307
+ cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
308
+ for mT in (mdY, mY, mdX, idX)
309
+ ]
310
+
311
+ smem = cutlass.utils.SmemAllocator()
312
+ sdY = smem.allocate_tensor(
313
+ mdY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
314
+ )
315
+ sY = smem.allocate_tensor(
316
+ mY.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16
317
+ )
318
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
319
+
320
+ # declare the atoms which will be used later for memory copy
321
+ copy_atom_load = cute.make_copy_atom(
322
+ cute.nvgpu.cpasync.CopyG2SOp(), mdY.element_type, num_bits_per_copy=128
323
+ )
324
+ copy_atom_store = cute.make_copy_atom(
325
+ cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
326
+ )
327
+
328
+ thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
329
+ thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
330
+
331
+ tdYgdY = thr_copy_load.partition_S(gdY)
332
+ tdYsdY = thr_copy_load.partition_D(sdY)
333
+ tYgY = thr_copy_load.partition_S(gY)
334
+ tYsY = thr_copy_load.partition_D(sY)
335
+ tdXgdX = thr_copy_store.partition_D(gdX)
336
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
337
+
338
+ # allocate fragments for gmem->rmem
339
+ tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
340
+
341
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
342
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
343
+
344
+ is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
345
+ tdYpdY = (
346
+ utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
347
+ if not is_even_N
348
+ else None
349
+ )
350
+
351
+ if tXcX[0][0] < shape[0]:
352
+ cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
353
+ cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
354
+ cute.arch.cp_async_commit_group()
355
+ cute.arch.cp_async_wait_group(0)
356
+
357
+ cute.autovec_copy(tdYsdY, tdYrdY)
358
+ cute.autovec_copy(tYsY, tYrY)
359
+ dy = tdYrdY.load().to(cute.Float32)
360
+ y = tYrY.load().to(cute.Float32)
361
+
362
+ # Compute dot product: dot = Σⱼ dy_j × y_j
363
+ threads_per_row = tv_layout.shape[0][0]
364
+ dot = utils.row_reduce(
365
+ dy * y,
366
+ cute.ReductionOp.ADD,
367
+ threads_per_row,
368
+ reduction_buffer[None, None, 0],
369
+ mbar_ptr if self.cluster_n > 1 else None,
370
+ init_val=0.0,
371
+ hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
372
+ )
373
+
374
+ # Compute gradient: dx_i = y_i × (dy_i - dot)
375
+ dx = y * (dy - dot)
376
+ tdXrdX.store(dx.to(tdXrdX.element_type))
377
+ tdXpdX = (
378
+ utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
379
+ if not is_even_N
380
+ else None
381
+ )
382
+ if tXcX[0][0] < shape[0]:
383
+ cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
384
+
385
+
386
+ def _softmax_backward(dy: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
387
+ """Softmax backward pass.
388
+ Args:
389
+ dy: Upstream gradients tensor of shape (M, N)
390
+ y: Softmax output tensor of shape (M, N)
391
+ Returns:
392
+ Input gradients tensor of same shape as dy and y
393
+ """
394
+ assert dy.dim() == 2, "dy must be 2D"
395
+ assert y.dim() == 2, "y must be 2D"
396
+ assert dy.shape == y.shape, "dy and y must have same shape"
397
+ assert dy.is_cuda and y.is_cuda, "Tensors must be on CUDA device"
398
+ assert dy.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
399
+ assert y.dtype == dy.dtype, "dy and y must have same dtype"
400
+
401
+ M, N = dy.shape
402
+ dx = torch.empty_like(dy)
403
+ dtype = torch2cute_dtype_map[dy.dtype]
404
+ convert_from_dlpack = lambda tensor: (
405
+ from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
406
+ mode=0, stride_order=(0, 1)
407
+ )
408
+ )
409
+ dy_tensor, y_tensor, dx_tensor = [convert_from_dlpack(tensor) for tensor in (dy, y, dx)]
410
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
411
+
412
+ compile_key = (dtype, N)
413
+ if compile_key not in _softmax_backward.compile_cache:
414
+ softmax_backward_op = SoftmaxBackward(dtype, N)
415
+ _softmax_backward.compile_cache[compile_key] = cute.compile(
416
+ softmax_backward_op, dy_tensor, y_tensor, dx_tensor, current_stream
417
+ )
418
+ _softmax_backward.compile_cache[compile_key](dy_tensor, y_tensor, dx_tensor, current_stream)
419
+ return dx
420
+
421
+
422
+ _softmax_backward.compile_cache = {}
423
+
424
+
425
+ class SoftmaxFunction(torch.autograd.Function):
426
+ @staticmethod
427
+ def forward(ctx, x):
428
+ y = _softmax_fwd(x)
429
+ ctx.save_for_backward(y)
430
+ return y
431
+
432
+ @staticmethod
433
+ def backward(ctx, dy):
434
+ (y,) = ctx.saved_tensors
435
+ dx = _softmax_backward(dy, y)
436
+ return dx
437
+
438
+
439
+ def softmax(x: torch.Tensor) -> torch.Tensor:
440
+ """Softmax forward pass with automatic differentiation support.
441
+
442
+ Args:
443
+ x: Input tensor of shape (M, N)
444
+
445
+ Returns:
446
+ Softmax output tensor of same shape as x
447
+ """
448
+ return SoftmaxFunction.apply(x)