quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/reduce.py ADDED
@@ -0,0 +1,240 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ import operator
5
+ from typing import Callable, Optional
6
+
7
+ import cutlass
8
+ import cutlass.cute as cute
9
+ from cutlass import Float32
10
+
11
+ import quack.utils as utils
12
+
13
+
14
+ @cute.jit
15
+ def warp_reduce(
16
+ val: cute.TensorSSA | cute.Numeric,
17
+ op: Callable,
18
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
19
+ ) -> cute.TensorSSA | cute.Numeric:
20
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
21
+ res = cute.make_fragment(val.shape, val.dtype)
22
+ res.store(val)
23
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
24
+ res[i] = warp_reduce(res[i], op, width)
25
+ return res.load()
26
+ else:
27
+ for i in cutlass.range_constexpr(int(math.log2(width))):
28
+ val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
29
+ return val
30
+
31
+
32
+ @cute.jit
33
+ def block_reduce(
34
+ val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
35
+ ) -> cute.Numeric:
36
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
37
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
38
+ warps_per_row = cute.size(reduction_buffer.shape[1])
39
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
40
+ if lane_idx == 0:
41
+ reduction_buffer[row_idx, col_idx] = val
42
+ cute.arch.barrier()
43
+ block_reduce_val = init_val
44
+ if lane_idx < warps_per_row:
45
+ block_reduce_val = reduction_buffer[row_idx, lane_idx]
46
+ return warp_reduce(block_reduce_val, op)
47
+
48
+
49
+ @cute.jit
50
+ def cluster_reduce(
51
+ val: cute.Numeric,
52
+ op: Callable,
53
+ reduction_buffer: cute.Tensor,
54
+ mbar_ptr: cute.Pointer,
55
+ init_val: cute.Numeric = 0.0,
56
+ phase: Optional[cutlass.Int32] = None,
57
+ ) -> cute.Numeric:
58
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
59
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
60
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
61
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
62
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
63
+ if warp_idx == 0:
64
+ with cute.arch.elect_one():
65
+ num_warps = rows_per_block * warps_per_row
66
+ cute.arch.mbarrier_arrive_and_expect_tx(
67
+ mbar_ptr,
68
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
69
+ )
70
+ if lane_idx < cluster_n:
71
+ utils.store_shared_remote(
72
+ val,
73
+ utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
74
+ mbar_ptr,
75
+ peer_cta_rank_in_cluster=lane_idx,
76
+ )
77
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
78
+ block_reduce_val = init_val
79
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
80
+ for i in cutlass.range_constexpr(num_iter):
81
+ idx = lane_idx + i * cute.arch.WARP_SIZE
82
+ if idx < cute.size(reduction_buffer, mode=[1]):
83
+ block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
84
+ return warp_reduce(block_reduce_val, op)
85
+
86
+
87
+ @cute.jit
88
+ def block_or_cluster_reduce(
89
+ val: cute.Numeric,
90
+ op: Callable,
91
+ reduction_buffer: cute.Tensor,
92
+ mbar_ptr: Optional[cute.Pointer],
93
+ phase: Optional[cutlass.Int32] = None,
94
+ init_val: cute.Numeric = 0.0,
95
+ ) -> cute.Numeric:
96
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
97
+ if cutlass.const_expr(mbar_ptr is None):
98
+ return block_reduce(val, op, reduction_buffer, init_val=init_val)
99
+ else:
100
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
101
+
102
+
103
+ @cute.jit
104
+ def row_reduce(
105
+ x: cute.TensorSSA | cute.Numeric,
106
+ op: cute.ReductionOp,
107
+ threads_per_row: cutlass.Constexpr[int],
108
+ reduction_buffer: Optional[cute.Tensor] = None,
109
+ mbar_ptr: Optional[cute.Pointer] = None,
110
+ phase: Optional[cutlass.Int32] = None,
111
+ init_val: cute.Numeric = 0.0,
112
+ hook_fn: Optional[Callable] = None,
113
+ ) -> cute.Numeric:
114
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
115
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
116
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
117
+ else:
118
+ val = x
119
+ warp_op = {
120
+ cute.ReductionOp.ADD: operator.add,
121
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
122
+ cute.ReductionOp.MIN: min,
123
+ cute.ReductionOp.MUL: operator.mul,
124
+ }[op]
125
+ val = warp_reduce(
126
+ val,
127
+ warp_op,
128
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
129
+ )
130
+ if cutlass.const_expr(hook_fn is not None):
131
+ hook_fn()
132
+ if cutlass.const_expr(reduction_buffer is not None):
133
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
134
+ assert cluster_n == 1 or mbar_ptr is not None, (
135
+ "mbar_ptr must be provided for cluster reduction"
136
+ )
137
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
138
+ val = block_or_cluster_reduce(
139
+ val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
140
+ )
141
+ return val
142
+
143
+
144
+ @cute.jit
145
+ def online_softmax_reduce(
146
+ x: cute.TensorSSA,
147
+ threads_per_row: cutlass.Constexpr[int],
148
+ reduction_buffer: Optional[cute.Tensor] = None,
149
+ mbar_ptr: Optional[cute.Pointer] = None,
150
+ hook_fn: Optional[Callable] = None,
151
+ phase: Optional[cutlass.Int32] = None,
152
+ return_exp_x: bool = False,
153
+ ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
154
+ assert x.dtype == Float32, "x must be of type Float32"
155
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
156
+ max_x = warp_reduce(
157
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
158
+ cute.arch.fmax,
159
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
160
+ )
161
+ log2_e = math.log2(math.e)
162
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
163
+ sum_exp_x = warp_reduce(
164
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
165
+ operator.add,
166
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
167
+ )
168
+ if cutlass.const_expr(hook_fn is not None):
169
+ hook_fn()
170
+ if cutlass.const_expr(reduction_buffer is not None):
171
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
172
+ assert cluster_n == 1 or mbar_ptr is not None, (
173
+ "mbar_ptr must be provided for cluster reduction"
174
+ )
175
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
176
+ assert reduction_buffer.element_type == cutlass.Int64, (
177
+ "reduction_buffer must be of type cute.Int64"
178
+ )
179
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
180
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
181
+ if cutlass.const_expr(mbar_ptr is None):
182
+ if lane_idx == 0:
183
+ reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
184
+ cute.arch.barrier()
185
+ max_x_single_warp = -Float32.inf
186
+ sum_exp_x = 0.0
187
+ if lane_idx < warps_per_row:
188
+ max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
189
+ reduction_buffer[row_idx, lane_idx]
190
+ )
191
+ max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
192
+ sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
193
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
194
+ if cutlass.const_expr(return_exp_x):
195
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
196
+ max_x = max_x_final
197
+ else:
198
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
199
+ if warp_idx == 0:
200
+ with cute.arch.elect_one():
201
+ num_warps = rows_per_block * warps_per_row
202
+ cute.arch.mbarrier_arrive_and_expect_tx(
203
+ mbar_ptr,
204
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
205
+ )
206
+ if lane_idx < cluster_n:
207
+ utils.store_shared_remote(
208
+ utils.f32x2_to_i64(max_x, sum_exp_x),
209
+ utils.elem_pointer(
210
+ reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
211
+ ),
212
+ mbar_ptr,
213
+ peer_cta_rank_in_cluster=lane_idx,
214
+ )
215
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
216
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
217
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
218
+ max_x_single_warp.fill(-Float32.inf)
219
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
220
+ sum_exp_x_single_warp.fill(0.0)
221
+ for i in cutlass.range_constexpr(num_iter):
222
+ idx = lane_idx + i * cute.arch.WARP_SIZE
223
+ if idx < cute.size(reduction_buffer, mode=[1]):
224
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2(
225
+ reduction_buffer[row_idx, idx]
226
+ )
227
+ max_x_final = max_x_single_warp.load().reduce(
228
+ cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
229
+ )
230
+ max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
231
+ sum_exp_x = 0.0
232
+ for i in cutlass.range_constexpr(num_iter):
233
+ sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
234
+ max_x_single_warp[i] - max_x_final, fastmath=True
235
+ )
236
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
237
+ if cutlass.const_expr(return_exp_x):
238
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
239
+ max_x = max_x_final
240
+ return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
quack/reduction_base.py CHANGED
@@ -1,19 +1,11 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import torch
4
3
  from typing import Type, Tuple, Optional
5
4
 
6
5
  import cutlass
7
6
  import cutlass.cute as cute
8
7
 
9
8
 
10
- torch2cute_dtype_map = {
11
- torch.float16: cutlass.Float16,
12
- torch.bfloat16: cutlass.BFloat16,
13
- torch.float32: cutlass.Float32,
14
- }
15
-
16
-
17
9
  class ReductionBase:
18
10
  def __init__(
19
11
  self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
@@ -32,9 +24,8 @@ class ReductionBase:
32
24
  def _get_num_threads(self):
33
25
  return 128 if self.N <= 16384 else 256
34
26
 
35
- def _get_tv_layout(self):
36
- copy_bits = 128
37
- vecsize = copy_bits // self.dtype.width
27
+ def _get_tv_layout(self, num_copy_bits=128):
28
+ vecsize = num_copy_bits // self.dtype.width
38
29
  assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
39
30
  num_threads = self._get_num_threads()
40
31
  assert num_threads % cute.arch.WARP_SIZE == 0