quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +288 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +61 -59
- quack/topk.py +14 -8
- quack/utils.py +14 -259
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import operator
|
|
4
3
|
import math
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Optional, Tuple, Type, Union
|
|
6
5
|
|
|
7
6
|
import cutlass
|
|
8
7
|
import cutlass.cute as cute
|
|
@@ -23,46 +22,20 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
23
22
|
)
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
@cute.jit
|
|
27
|
-
def warp_reduce(
|
|
28
|
-
val: cute.TensorSSA | cute.Numeric,
|
|
29
|
-
op: Callable,
|
|
30
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
31
|
-
) -> cute.TensorSSA | cute.Numeric:
|
|
32
|
-
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
33
|
-
res = cute.make_fragment(val.shape, val.dtype)
|
|
34
|
-
res.store(val)
|
|
35
|
-
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
36
|
-
res[i] = warp_reduce(res[i], op, width)
|
|
37
|
-
return res.load()
|
|
38
|
-
else:
|
|
39
|
-
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
40
|
-
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
41
|
-
return val
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@cute.jit
|
|
45
|
-
def block_reduce(
|
|
46
|
-
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
47
|
-
) -> cute.Numeric:
|
|
48
|
-
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
49
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
50
|
-
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
51
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
52
|
-
if lane_idx == 0:
|
|
53
|
-
reduction_buffer[row_idx, col_idx] = val
|
|
54
|
-
cute.arch.barrier()
|
|
55
|
-
block_reduce_val = init_val
|
|
56
|
-
if lane_idx < warps_per_row:
|
|
57
|
-
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
58
|
-
return warp_reduce(block_reduce_val, op)
|
|
59
|
-
|
|
60
|
-
|
|
61
25
|
@dsl_user_op
|
|
62
26
|
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
63
27
|
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
64
28
|
|
|
65
29
|
|
|
30
|
+
@cute.jit
|
|
31
|
+
def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
32
|
+
if cutlass.const_expr(isinstance(x, cute.Pointer)):
|
|
33
|
+
return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
|
|
34
|
+
else:
|
|
35
|
+
assert isinstance(x, Float32)
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
66
39
|
@dsl_user_op
|
|
67
40
|
def set_block_rank(
|
|
68
41
|
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
|
|
@@ -114,197 +87,6 @@ def store_shared_remote(
|
|
|
114
87
|
)
|
|
115
88
|
|
|
116
89
|
|
|
117
|
-
@cute.jit
|
|
118
|
-
def cluster_reduce(
|
|
119
|
-
val: cute.Numeric,
|
|
120
|
-
op: Callable,
|
|
121
|
-
reduction_buffer: cute.Tensor,
|
|
122
|
-
mbar_ptr: cute.Pointer,
|
|
123
|
-
init_val: cute.Numeric = 0.0,
|
|
124
|
-
phase: Optional[cutlass.Int32] = None,
|
|
125
|
-
) -> cute.Numeric:
|
|
126
|
-
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
127
|
-
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
128
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
129
|
-
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
130
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
131
|
-
if warp_idx == 0:
|
|
132
|
-
with cute.arch.elect_one():
|
|
133
|
-
num_warps = rows_per_block * warps_per_row
|
|
134
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
135
|
-
mbar_ptr,
|
|
136
|
-
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
137
|
-
)
|
|
138
|
-
if lane_idx < cluster_n:
|
|
139
|
-
store_shared_remote(
|
|
140
|
-
val,
|
|
141
|
-
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
142
|
-
mbar_ptr,
|
|
143
|
-
peer_cta_rank_in_cluster=lane_idx,
|
|
144
|
-
)
|
|
145
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
146
|
-
block_reduce_val = init_val
|
|
147
|
-
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
148
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
149
|
-
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
150
|
-
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
151
|
-
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
152
|
-
return warp_reduce(block_reduce_val, op)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
@cute.jit
|
|
156
|
-
def block_or_cluster_reduce(
|
|
157
|
-
val: cute.Numeric,
|
|
158
|
-
op: Callable,
|
|
159
|
-
reduction_buffer: cute.Tensor,
|
|
160
|
-
mbar_ptr: Optional[cute.Pointer],
|
|
161
|
-
phase: Optional[cutlass.Int32] = None,
|
|
162
|
-
init_val: cute.Numeric = 0.0,
|
|
163
|
-
) -> cute.Numeric:
|
|
164
|
-
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
165
|
-
if cutlass.const_expr(mbar_ptr is None):
|
|
166
|
-
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
167
|
-
else:
|
|
168
|
-
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
@cute.jit
|
|
172
|
-
def row_reduce(
|
|
173
|
-
x: cute.TensorSSA | cute.Numeric,
|
|
174
|
-
op: cute.ReductionOp,
|
|
175
|
-
threads_per_row: cutlass.Constexpr[int],
|
|
176
|
-
reduction_buffer: Optional[cute.Tensor] = None,
|
|
177
|
-
mbar_ptr: Optional[cute.Pointer] = None,
|
|
178
|
-
phase: Optional[cutlass.Int32] = None,
|
|
179
|
-
init_val: cute.Numeric = 0.0,
|
|
180
|
-
hook_fn: Optional[Callable] = None,
|
|
181
|
-
) -> cute.Numeric:
|
|
182
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
183
|
-
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
184
|
-
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
185
|
-
else:
|
|
186
|
-
val = x
|
|
187
|
-
warp_op = {
|
|
188
|
-
cute.ReductionOp.ADD: operator.add,
|
|
189
|
-
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
190
|
-
cute.ReductionOp.MIN: min,
|
|
191
|
-
cute.ReductionOp.MUL: operator.mul,
|
|
192
|
-
}[op]
|
|
193
|
-
val = warp_reduce(
|
|
194
|
-
val,
|
|
195
|
-
warp_op,
|
|
196
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
197
|
-
)
|
|
198
|
-
if cutlass.const_expr(hook_fn is not None):
|
|
199
|
-
hook_fn()
|
|
200
|
-
if cutlass.const_expr(reduction_buffer is not None):
|
|
201
|
-
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
202
|
-
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
203
|
-
"mbar_ptr must be provided for cluster reduction"
|
|
204
|
-
)
|
|
205
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
206
|
-
val = block_or_cluster_reduce(
|
|
207
|
-
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
208
|
-
)
|
|
209
|
-
return val
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
@cute.jit
|
|
213
|
-
def online_softmax_reduce(
|
|
214
|
-
x: cute.TensorSSA,
|
|
215
|
-
threads_per_row: cutlass.Constexpr[int],
|
|
216
|
-
reduction_buffer: Optional[cute.Tensor] = None,
|
|
217
|
-
mbar_ptr: Optional[cute.Pointer] = None,
|
|
218
|
-
hook_fn: Optional[Callable] = None,
|
|
219
|
-
phase: Optional[cutlass.Int32] = None,
|
|
220
|
-
return_exp_x: bool = False,
|
|
221
|
-
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
222
|
-
assert x.dtype == Float32, "x must be of type Float32"
|
|
223
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
224
|
-
max_x = warp_reduce(
|
|
225
|
-
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
226
|
-
cute.arch.fmax,
|
|
227
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
228
|
-
)
|
|
229
|
-
log2_e = math.log2(math.e)
|
|
230
|
-
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
231
|
-
# exp_x = exp2f((x - max_x) * log2_e)
|
|
232
|
-
sum_exp_x = warp_reduce(
|
|
233
|
-
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
234
|
-
operator.add,
|
|
235
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
236
|
-
)
|
|
237
|
-
if cutlass.const_expr(hook_fn is not None):
|
|
238
|
-
hook_fn()
|
|
239
|
-
if cutlass.const_expr(reduction_buffer is not None):
|
|
240
|
-
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
241
|
-
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
242
|
-
"mbar_ptr must be provided for cluster reduction"
|
|
243
|
-
)
|
|
244
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
245
|
-
assert reduction_buffer.element_type == cutlass.Int64, (
|
|
246
|
-
"reduction_buffer must be of type cute.Int64"
|
|
247
|
-
)
|
|
248
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
249
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
250
|
-
if cutlass.const_expr(mbar_ptr is None):
|
|
251
|
-
if lane_idx == 0:
|
|
252
|
-
reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
|
|
253
|
-
cute.arch.barrier()
|
|
254
|
-
max_x_single_warp = -Float32.inf
|
|
255
|
-
sum_exp_x = 0.0
|
|
256
|
-
if lane_idx < warps_per_row:
|
|
257
|
-
max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
|
|
258
|
-
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
259
|
-
sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
260
|
-
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
261
|
-
if cutlass.const_expr(return_exp_x):
|
|
262
|
-
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
263
|
-
max_x = max_x_final
|
|
264
|
-
else:
|
|
265
|
-
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
266
|
-
if warp_idx == 0:
|
|
267
|
-
with cute.arch.elect_one():
|
|
268
|
-
num_warps = rows_per_block * warps_per_row
|
|
269
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
270
|
-
mbar_ptr,
|
|
271
|
-
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
272
|
-
)
|
|
273
|
-
if lane_idx < cluster_n:
|
|
274
|
-
store_shared_remote(
|
|
275
|
-
f32x2_to_i64(max_x, sum_exp_x),
|
|
276
|
-
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
277
|
-
mbar_ptr,
|
|
278
|
-
peer_cta_rank_in_cluster=lane_idx,
|
|
279
|
-
)
|
|
280
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
281
|
-
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
282
|
-
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
283
|
-
max_x_single_warp.fill(-Float32.inf)
|
|
284
|
-
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
285
|
-
sum_exp_x_single_warp.fill(0.0)
|
|
286
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
287
|
-
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
288
|
-
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
289
|
-
max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
|
|
290
|
-
reduction_buffer[row_idx, idx]
|
|
291
|
-
)
|
|
292
|
-
max_x_final = max_x_single_warp.load().reduce(
|
|
293
|
-
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
294
|
-
)
|
|
295
|
-
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
296
|
-
sum_exp_x = 0.0
|
|
297
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
298
|
-
sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
|
|
299
|
-
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
300
|
-
)
|
|
301
|
-
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
302
|
-
if cutlass.const_expr(return_exp_x):
|
|
303
|
-
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
304
|
-
max_x = max_x_final
|
|
305
|
-
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
306
|
-
|
|
307
|
-
|
|
308
90
|
@dsl_user_op
|
|
309
91
|
def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
|
310
92
|
return Float32(
|
|
@@ -381,21 +163,6 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
381
163
|
)
|
|
382
164
|
|
|
383
165
|
|
|
384
|
-
@dsl_user_op
|
|
385
|
-
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
386
|
-
return Float32(
|
|
387
|
-
llvm.inline_asm(
|
|
388
|
-
T.f32(),
|
|
389
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
390
|
-
"tanh.approx.f32 $0, $1;",
|
|
391
|
-
"=f,f",
|
|
392
|
-
has_side_effects=False,
|
|
393
|
-
is_align_stack=False,
|
|
394
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
395
|
-
)
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
|
|
399
166
|
@dsl_user_op
|
|
400
167
|
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
401
168
|
return Int32(
|
|
@@ -411,16 +178,6 @@ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
|
411
178
|
)
|
|
412
179
|
|
|
413
180
|
|
|
414
|
-
@dsl_user_op
|
|
415
|
-
def silu(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
416
|
-
"""
|
|
417
|
-
silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a)
|
|
418
|
-
This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA.
|
|
419
|
-
"""
|
|
420
|
-
a_half = 0.5 * a
|
|
421
|
-
return a_half * tanh(a_half) + a_half
|
|
422
|
-
|
|
423
|
-
|
|
424
181
|
@dsl_user_op
|
|
425
182
|
def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
|
|
426
183
|
return Int32(
|
|
@@ -498,7 +255,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
|
|
|
498
255
|
tXpX: Predicate tensor indicating valid elements
|
|
499
256
|
fill_value: Value to fill OOB locations with
|
|
500
257
|
"""
|
|
501
|
-
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0),
|
|
258
|
+
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
|
|
502
259
|
tXrX_fill.fill(fill_value)
|
|
503
260
|
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
504
261
|
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
@@ -538,9 +295,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
538
295
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
539
296
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
540
297
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
541
|
-
assert len(flat_coord_i64) == len(
|
|
542
|
-
|
|
543
|
-
)
|
|
298
|
+
assert len(flat_coord_i64) == len(
|
|
299
|
+
flat_stride
|
|
300
|
+
), "Coordinate and stride must have the same length"
|
|
544
301
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
545
302
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
546
303
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -662,5 +419,3 @@ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None)
|
|
|
662
419
|
return nvvm.atomicrmw(
|
|
663
420
|
res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
|
664
421
|
)
|
|
665
|
-
|
|
666
|
-
|
quack/varlen_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
|
|
8
|
+
from quack.cute_dsl_utils import ArgumentsBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Grouping arguments together that should be passed to __call__
|
|
12
|
+
@dataclass
|
|
13
|
+
class VarlenArguments(ArgumentsBase):
|
|
14
|
+
mCuSeqlensM: Optional[cute.Tensor] = None
|
|
15
|
+
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
|
+
mTensormaps: Optional[cute.Tensor] = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
|
|
20
|
+
assert (
|
|
21
|
+
self.mTensormaps is not None
|
|
22
|
+
), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Requires-Python: >=3.12
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.2.0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
quack/__init__.py,sha256=fGBYbb9JlaNT7HdtUTbUnuAkL5G2Dg8XZAA5Ir1R-ow,364
|
|
2
|
+
quack/activation.py,sha256=ysXaVUXX2yGQC5o4ZVeRXw_fDIHOrqnzpHJaIsc0kHc,10271
|
|
3
|
+
quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
|
|
4
|
+
quack/cross_entropy.py,sha256=Kc3P83Vsu1nGaCu7llsO3vct3J_t3frRYPxij7JfHMA,28619
|
|
5
|
+
quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
|
|
6
|
+
quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
|
|
7
|
+
quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
|
|
8
|
+
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
9
|
+
quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
|
|
10
|
+
quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
11
|
+
quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
|
|
12
|
+
quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
|
|
13
|
+
quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
|
|
14
|
+
quack/layernorm.py,sha256=JkK0sVdUfZ-SmoBmNqLF3wCiszDbdorvcBH2julv0Vg,13560
|
|
15
|
+
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
|
+
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
17
|
+
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
18
|
+
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
|
+
quack/reduce.py,sha256=hsYByu6haCZjLTLB-qpYmKDjqS2UqlwPgfWTup38GNA,10341
|
|
20
|
+
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
+
quack/rmsnorm.py,sha256=93qlTPjY9JBm3R5M-HeHse1PbAfD9931G3OFs71yo_g,48998
|
|
22
|
+
quack/softmax.py,sha256=Mq3_2Ul8H64zeGUI9wOKEpIISJnrCcHQpZvk2sb10Tg,17101
|
|
23
|
+
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
|
+
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
+
quack/tile_scheduler.py,sha256=8qqYmx6GpQzt8XiidcrdLIaWf0TGbJVdwKFfeb1X_us,42265
|
|
26
|
+
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
+
quack/utils.py,sha256=tiqeJZiPPFl5irQWCUd7dTPA_OAv4SjHUW5S-u9wO8Y,14526
|
|
28
|
+
quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
|
|
29
|
+
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
|
+
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
|
+
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
|
+
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
+
quack_kernels-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
quack_kernels-0.2.0.dist-info/METADATA,sha256=DAeQymRUqp7lSfSTNyS7TZF3oWcFzCKriGJ2p8JLu6A,285
|
|
35
|
+
quack_kernels-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
quack_kernels-0.2.0.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
+
quack_kernels-0.2.0.dist-info/RECORD,,
|
quack/lse.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025, Tri Dao.
|
|
2
|
-
# TODO: we probably dont' need this kernel, just use torch.logsumexp
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
import triton
|
|
6
|
-
import triton.language as tl
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@triton.jit
|
|
10
|
-
def _lse_kernel(
|
|
11
|
-
lse_ptr,
|
|
12
|
-
logits_ptr,
|
|
13
|
-
n_rows,
|
|
14
|
-
n_cols,
|
|
15
|
-
logits_row_stride,
|
|
16
|
-
logits_col_stride,
|
|
17
|
-
BLOCK_SIZE_M: tl.constexpr,
|
|
18
|
-
BLOCK_SIZE_N: tl.constexpr,
|
|
19
|
-
):
|
|
20
|
-
row_start = tl.program_id(0) * BLOCK_SIZE_M
|
|
21
|
-
rows = row_start + tl.arange(0, BLOCK_SIZE_M)
|
|
22
|
-
cols = tl.arange(0, BLOCK_SIZE_N)
|
|
23
|
-
logits = tl.load(
|
|
24
|
-
logits_ptr + rows[:, None] * logits_row_stride + cols[None, :] * logits_col_stride,
|
|
25
|
-
mask=(rows[:, None] < n_rows) & (cols[None, :] < n_cols),
|
|
26
|
-
other=-float("inf"),
|
|
27
|
-
).to(tl.float32)
|
|
28
|
-
m = tl.max(logits, 1)
|
|
29
|
-
lse = tl.log(tl.sum(tl.exp(logits - m[:, None]), 1)) + m
|
|
30
|
-
tl.store(lse_ptr + rows, lse, mask=rows < n_rows)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def logsumexp(logits):
|
|
34
|
-
n_rows, n_cols = logits.shape
|
|
35
|
-
BLOCK_SIZE_M = 32 if logits.stride(1) != 1 else 1
|
|
36
|
-
MAX_BLOCK_SIZE = 64 * 1024
|
|
37
|
-
# BLOCK_SIZE_N = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE // BLOCK_SIZE_M)
|
|
38
|
-
BLOCK_SIZE_N = triton.next_power_of_2(n_cols)
|
|
39
|
-
assert (
|
|
40
|
-
BLOCK_SIZE_M * BLOCK_SIZE_N <= MAX_BLOCK_SIZE
|
|
41
|
-
), f"Only support max dimension {MAX_BLOCK_SIZE // BLOCK_SIZE_M}"
|
|
42
|
-
num_warps = (
|
|
43
|
-
4
|
|
44
|
-
if BLOCK_SIZE_N < 2048
|
|
45
|
-
else (8 if BLOCK_SIZE_N < 8192 else (16 if BLOCK_SIZE_N < 128 * 1024 else 32))
|
|
46
|
-
)
|
|
47
|
-
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
|
48
|
-
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
|
49
|
-
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
|
50
|
-
with torch.cuda.device(logits.device.index):
|
|
51
|
-
_lse_kernel[(triton.cdiv(n_rows, BLOCK_SIZE_M),)](
|
|
52
|
-
lse,
|
|
53
|
-
logits,
|
|
54
|
-
n_rows,
|
|
55
|
-
n_cols, # shapes
|
|
56
|
-
logits.stride(0), # strides
|
|
57
|
-
logits.stride(1),
|
|
58
|
-
BLOCK_SIZE_M=BLOCK_SIZE_M, # constants
|
|
59
|
-
BLOCK_SIZE_N=BLOCK_SIZE_N, # constants
|
|
60
|
-
num_warps=num_warps,
|
|
61
|
-
)
|
|
62
|
-
return lse
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=AD0T-rBhSfKXpwZ6E4JIPiugvlFaAePjl-3pUhWOlPE,292
|
|
2
|
-
quack/autotuner.py,sha256=aF9-Cw47gaX7_LZvyVbLsj6Z2AWi4UZ-0Qwjy06Xd5I,10733
|
|
3
|
-
quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
|
|
4
|
-
quack/cute_dsl_utils.py,sha256=LkNyFEKwYrgp-tLt_775EZWuBR3v7G80El3UAObHY2U,1292
|
|
5
|
-
quack/dense_gemm_sm100.py,sha256=W_j8BO-ilb1YUYFuclo7_itfPIRTkjPV_ittWgQy8t4,109937
|
|
6
|
-
quack/dense_gemm_sm90.py,sha256=Dff0GbIv92uTjrtsUE1GjVKCtwSf6_5KZbrqYZm-ZMY,110418
|
|
7
|
-
quack/fast_math.py,sha256=XqXVvKLSxXC3c9tIGLvKVRWdPsmjAa_O4C0plmsfZ0w,3106
|
|
8
|
-
quack/gemm_config.py,sha256=Gz4dkHH1Uwg9IdW-x5W_5tjdaFHBfxq4bn7hJx_xu5s,1789
|
|
9
|
-
quack/gemm_interface.py,sha256=XHgxo08d8LIu6dTlQKBOBJtjCegUB5uLh4k9hC-5mvY,9525
|
|
10
|
-
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
11
|
-
quack/linear.py,sha256=Wd0KeXWvWjbkKrgW4Av1ud2v_mbhzf1RvubF7BYhcw4,6425
|
|
12
|
-
quack/lse.py,sha256=aANOleIYREyrkUQM9cfJ9Gt63eawMb2KVd7YAGWNoZU,2092
|
|
13
|
-
quack/mlp.py,sha256=D9V7aIfvoBMzhKwN8ZE6GlSOmwFJe_JGqgOvQprU0OQ,8224
|
|
14
|
-
quack/pipeline.py,sha256=SwvRZAR4RqYH60wAFC3OTu5DisN1XDMv5umQF4czJW4,5867
|
|
15
|
-
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
16
|
-
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
17
|
-
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
18
|
-
quack/symmetric_dense_gemm_sm90.py,sha256=t-6eLasZwyu1NW4HpnvVBBPOvfqUzOg8VHe9sJQYdmg,88637
|
|
19
|
-
quack/tensormap_manager.py,sha256=pzBNwLCB8kV_yp8X8_BoDdtbwWeht2jrgRhyyfVIcMI,5261
|
|
20
|
-
quack/tile_scheduler.py,sha256=mImjD2LuIVchM6USJoJY4-CSG54jGuwyLIvFG6LTP9Y,42205
|
|
21
|
-
quack/topk.py,sha256=1pObblNJnxKLaE_T3qGvaMnUua0dqG2en9OU5PSp71s,9020
|
|
22
|
-
quack/utils.py,sha256=4ViEFgHecaX5wcYpO6XzTCzdnuZv2rniUJAJH5Ta0bA,24981
|
|
23
|
-
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
24
|
-
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
25
|
-
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
26
|
-
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
27
|
-
quack_kernels-0.1.11.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
28
|
-
quack_kernels-0.1.11.dist-info/METADATA,sha256=WTYlk9lmhr4Jkin71stp3h-NrBdme-8OrBc7lAf4vSw,286
|
|
29
|
-
quack_kernels-0.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
-
quack_kernels-0.1.11.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
31
|
-
quack_kernels-0.1.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|