quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/reduce.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import operator
|
|
5
|
+
from typing import Callable, Optional
|
|
6
|
+
|
|
7
|
+
import cutlass
|
|
8
|
+
import cutlass.cute as cute
|
|
9
|
+
from cutlass import Float32
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@cute.jit
|
|
15
|
+
def warp_reduce(
|
|
16
|
+
val: cute.TensorSSA | cute.Numeric,
|
|
17
|
+
op: Callable,
|
|
18
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
19
|
+
) -> cute.TensorSSA | cute.Numeric:
|
|
20
|
+
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
21
|
+
res = cute.make_fragment(val.shape, val.dtype)
|
|
22
|
+
res.store(val)
|
|
23
|
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
24
|
+
res[i] = warp_reduce(res[i], op, width)
|
|
25
|
+
return res.load()
|
|
26
|
+
else:
|
|
27
|
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
28
|
+
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
29
|
+
return val
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@cute.jit
|
|
33
|
+
def block_reduce(
|
|
34
|
+
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
35
|
+
) -> cute.Numeric:
|
|
36
|
+
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
37
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
38
|
+
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
39
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
40
|
+
if lane_idx == 0:
|
|
41
|
+
reduction_buffer[row_idx, col_idx] = val
|
|
42
|
+
cute.arch.barrier()
|
|
43
|
+
block_reduce_val = init_val
|
|
44
|
+
if lane_idx < warps_per_row:
|
|
45
|
+
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
46
|
+
return warp_reduce(block_reduce_val, op)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@cute.jit
|
|
50
|
+
def cluster_reduce(
|
|
51
|
+
val: cute.Numeric,
|
|
52
|
+
op: Callable,
|
|
53
|
+
reduction_buffer: cute.Tensor,
|
|
54
|
+
mbar_ptr: cute.Pointer,
|
|
55
|
+
init_val: cute.Numeric = 0.0,
|
|
56
|
+
phase: Optional[cutlass.Int32] = None,
|
|
57
|
+
) -> cute.Numeric:
|
|
58
|
+
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
59
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
60
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
61
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
62
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
63
|
+
if warp_idx == 0:
|
|
64
|
+
with cute.arch.elect_one():
|
|
65
|
+
num_warps = rows_per_block * warps_per_row
|
|
66
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
67
|
+
mbar_ptr,
|
|
68
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
69
|
+
)
|
|
70
|
+
if lane_idx < cluster_n:
|
|
71
|
+
utils.store_shared_remote(
|
|
72
|
+
val,
|
|
73
|
+
utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
74
|
+
mbar_ptr,
|
|
75
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
76
|
+
)
|
|
77
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
78
|
+
block_reduce_val = init_val
|
|
79
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
80
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
81
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
82
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
83
|
+
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
84
|
+
return warp_reduce(block_reduce_val, op)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@cute.jit
|
|
88
|
+
def block_or_cluster_reduce(
|
|
89
|
+
val: cute.Numeric,
|
|
90
|
+
op: Callable,
|
|
91
|
+
reduction_buffer: cute.Tensor,
|
|
92
|
+
mbar_ptr: Optional[cute.Pointer],
|
|
93
|
+
phase: Optional[cutlass.Int32] = None,
|
|
94
|
+
init_val: cute.Numeric = 0.0,
|
|
95
|
+
) -> cute.Numeric:
|
|
96
|
+
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
97
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
98
|
+
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
99
|
+
else:
|
|
100
|
+
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@cute.jit
|
|
104
|
+
def row_reduce(
|
|
105
|
+
x: cute.TensorSSA | cute.Numeric,
|
|
106
|
+
op: cute.ReductionOp,
|
|
107
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
108
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
109
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
110
|
+
phase: Optional[cutlass.Int32] = None,
|
|
111
|
+
init_val: cute.Numeric = 0.0,
|
|
112
|
+
hook_fn: Optional[Callable] = None,
|
|
113
|
+
) -> cute.Numeric:
|
|
114
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
115
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
116
|
+
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
117
|
+
else:
|
|
118
|
+
val = x
|
|
119
|
+
warp_op = {
|
|
120
|
+
cute.ReductionOp.ADD: operator.add,
|
|
121
|
+
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
122
|
+
cute.ReductionOp.MIN: min,
|
|
123
|
+
cute.ReductionOp.MUL: operator.mul,
|
|
124
|
+
}[op]
|
|
125
|
+
val = warp_reduce(
|
|
126
|
+
val,
|
|
127
|
+
warp_op,
|
|
128
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
129
|
+
)
|
|
130
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
131
|
+
hook_fn()
|
|
132
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
133
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
134
|
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
135
|
+
"mbar_ptr must be provided for cluster reduction"
|
|
136
|
+
)
|
|
137
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
138
|
+
val = block_or_cluster_reduce(
|
|
139
|
+
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
140
|
+
)
|
|
141
|
+
return val
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@cute.jit
|
|
145
|
+
def online_softmax_reduce(
|
|
146
|
+
x: cute.TensorSSA,
|
|
147
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
148
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
149
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
150
|
+
hook_fn: Optional[Callable] = None,
|
|
151
|
+
phase: Optional[cutlass.Int32] = None,
|
|
152
|
+
return_exp_x: bool = False,
|
|
153
|
+
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
154
|
+
assert x.dtype == Float32, "x must be of type Float32"
|
|
155
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
156
|
+
max_x = warp_reduce(
|
|
157
|
+
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
158
|
+
cute.arch.fmax,
|
|
159
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
160
|
+
)
|
|
161
|
+
log2_e = math.log2(math.e)
|
|
162
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
|
|
163
|
+
sum_exp_x = warp_reduce(
|
|
164
|
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
165
|
+
operator.add,
|
|
166
|
+
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
167
|
+
)
|
|
168
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
169
|
+
hook_fn()
|
|
170
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
171
|
+
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
172
|
+
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
173
|
+
"mbar_ptr must be provided for cluster reduction"
|
|
174
|
+
)
|
|
175
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
176
|
+
assert reduction_buffer.element_type == cutlass.Int64, (
|
|
177
|
+
"reduction_buffer must be of type cute.Int64"
|
|
178
|
+
)
|
|
179
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
180
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
181
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
182
|
+
if lane_idx == 0:
|
|
183
|
+
reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x)
|
|
184
|
+
cute.arch.barrier()
|
|
185
|
+
max_x_single_warp = -Float32.inf
|
|
186
|
+
sum_exp_x = 0.0
|
|
187
|
+
if lane_idx < warps_per_row:
|
|
188
|
+
max_x_single_warp, sum_exp_x = utils.i64_to_f32x2(
|
|
189
|
+
reduction_buffer[row_idx, lane_idx]
|
|
190
|
+
)
|
|
191
|
+
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
192
|
+
sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
|
|
193
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
194
|
+
if cutlass.const_expr(return_exp_x):
|
|
195
|
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
196
|
+
max_x = max_x_final
|
|
197
|
+
else:
|
|
198
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
199
|
+
if warp_idx == 0:
|
|
200
|
+
with cute.arch.elect_one():
|
|
201
|
+
num_warps = rows_per_block * warps_per_row
|
|
202
|
+
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
203
|
+
mbar_ptr,
|
|
204
|
+
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
205
|
+
)
|
|
206
|
+
if lane_idx < cluster_n:
|
|
207
|
+
utils.store_shared_remote(
|
|
208
|
+
utils.f32x2_to_i64(max_x, sum_exp_x),
|
|
209
|
+
utils.elem_pointer(
|
|
210
|
+
reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
|
|
211
|
+
),
|
|
212
|
+
mbar_ptr,
|
|
213
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
214
|
+
)
|
|
215
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
216
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
217
|
+
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
218
|
+
max_x_single_warp.fill(-Float32.inf)
|
|
219
|
+
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
220
|
+
sum_exp_x_single_warp.fill(0.0)
|
|
221
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
222
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
223
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
224
|
+
max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2(
|
|
225
|
+
reduction_buffer[row_idx, idx]
|
|
226
|
+
)
|
|
227
|
+
max_x_final = max_x_single_warp.load().reduce(
|
|
228
|
+
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
229
|
+
)
|
|
230
|
+
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
231
|
+
sum_exp_x = 0.0
|
|
232
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
233
|
+
sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
|
|
234
|
+
max_x_single_warp[i] - max_x_final, fastmath=True
|
|
235
|
+
)
|
|
236
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
237
|
+
if cutlass.const_expr(return_exp_x):
|
|
238
|
+
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
|
|
239
|
+
max_x = max_x_final
|
|
240
|
+
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
quack/reduction_base.py
CHANGED
|
@@ -1,19 +1,11 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
3
|
from typing import Type, Tuple, Optional
|
|
5
4
|
|
|
6
5
|
import cutlass
|
|
7
6
|
import cutlass.cute as cute
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
torch2cute_dtype_map = {
|
|
11
|
-
torch.float16: cutlass.Float16,
|
|
12
|
-
torch.bfloat16: cutlass.BFloat16,
|
|
13
|
-
torch.float32: cutlass.Float32,
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
|
|
17
9
|
class ReductionBase:
|
|
18
10
|
def __init__(
|
|
19
11
|
self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32
|
|
@@ -32,9 +24,8 @@ class ReductionBase:
|
|
|
32
24
|
def _get_num_threads(self):
|
|
33
25
|
return 128 if self.N <= 16384 else 256
|
|
34
26
|
|
|
35
|
-
def _get_tv_layout(self):
|
|
36
|
-
|
|
37
|
-
vecsize = copy_bits // self.dtype.width
|
|
27
|
+
def _get_tv_layout(self, num_copy_bits=128):
|
|
28
|
+
vecsize = num_copy_bits // self.dtype.width
|
|
38
29
|
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
39
30
|
num_threads = self._get_num_threads()
|
|
40
31
|
assert num_threads % cute.arch.WARP_SIZE == 0
|