quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import operator
|
|
4
3
|
import math
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Optional, Tuple, Type, Union
|
|
6
5
|
|
|
7
6
|
import cutlass
|
|
8
7
|
import cutlass.cute as cute
|
|
@@ -23,46 +22,20 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
23
22
|
)
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
@cute.jit
|
|
27
|
-
def warp_reduce(
|
|
28
|
-
val: cute.TensorSSA | cute.Numeric,
|
|
29
|
-
op: Callable,
|
|
30
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
31
|
-
) -> cute.TensorSSA | cute.Numeric:
|
|
32
|
-
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
33
|
-
res = cute.make_fragment(val.shape, val.dtype)
|
|
34
|
-
res.store(val)
|
|
35
|
-
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
36
|
-
res[i] = warp_reduce(res[i], op, width)
|
|
37
|
-
return res.load()
|
|
38
|
-
else:
|
|
39
|
-
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
40
|
-
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
41
|
-
return val
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@cute.jit
|
|
45
|
-
def block_reduce(
|
|
46
|
-
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
47
|
-
) -> cute.Numeric:
|
|
48
|
-
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
49
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
50
|
-
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
51
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
52
|
-
if lane_idx == 0:
|
|
53
|
-
reduction_buffer[row_idx, col_idx] = val
|
|
54
|
-
cute.arch.barrier()
|
|
55
|
-
block_reduce_val = init_val
|
|
56
|
-
if lane_idx < warps_per_row:
|
|
57
|
-
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
58
|
-
return warp_reduce(block_reduce_val, op)
|
|
59
|
-
|
|
60
|
-
|
|
61
25
|
@dsl_user_op
|
|
62
26
|
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
63
27
|
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
64
28
|
|
|
65
29
|
|
|
30
|
+
@cute.jit
|
|
31
|
+
def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
32
|
+
if cutlass.const_expr(isinstance(x, cute.Pointer)):
|
|
33
|
+
return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
|
|
34
|
+
else:
|
|
35
|
+
assert isinstance(x, Float32)
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
66
39
|
@dsl_user_op
|
|
67
40
|
def set_block_rank(
|
|
68
41
|
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
|
|
@@ -114,197 +87,6 @@ def store_shared_remote(
|
|
|
114
87
|
)
|
|
115
88
|
|
|
116
89
|
|
|
117
|
-
@cute.jit
|
|
118
|
-
def cluster_reduce(
|
|
119
|
-
val: cute.Numeric,
|
|
120
|
-
op: Callable,
|
|
121
|
-
reduction_buffer: cute.Tensor,
|
|
122
|
-
mbar_ptr: cute.Pointer,
|
|
123
|
-
init_val: cute.Numeric = 0.0,
|
|
124
|
-
phase: Optional[cutlass.Int32] = None,
|
|
125
|
-
) -> cute.Numeric:
|
|
126
|
-
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
127
|
-
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
128
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
129
|
-
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
130
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
131
|
-
if warp_idx == 0:
|
|
132
|
-
with cute.arch.elect_one():
|
|
133
|
-
num_warps = rows_per_block * warps_per_row
|
|
134
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
135
|
-
mbar_ptr,
|
|
136
|
-
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
137
|
-
)
|
|
138
|
-
if lane_idx < cluster_n:
|
|
139
|
-
store_shared_remote(
|
|
140
|
-
val,
|
|
141
|
-
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
142
|
-
mbar_ptr,
|
|
143
|
-
peer_cta_rank_in_cluster=lane_idx,
|
|
144
|
-
)
|
|
145
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
146
|
-
block_reduce_val = init_val
|
|
147
|
-
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
148
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
149
|
-
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
150
|
-
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
151
|
-
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
152
|
-
return warp_reduce(block_reduce_val, op)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
@cute.jit
|
|
156
|
-
def block_or_cluster_reduce(
|
|
157
|
-
val: cute.Numeric,
|
|
158
|
-
op: Callable,
|
|
159
|
-
reduction_buffer: cute.Tensor,
|
|
160
|
-
mbar_ptr: Optional[cute.Pointer],
|
|
161
|
-
phase: Optional[cutlass.Int32] = None,
|
|
162
|
-
init_val: cute.Numeric = 0.0,
|
|
163
|
-
) -> cute.Numeric:
|
|
164
|
-
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
165
|
-
if cutlass.const_expr(mbar_ptr is None):
|
|
166
|
-
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
167
|
-
else:
|
|
168
|
-
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
@cute.jit
|
|
172
|
-
def row_reduce(
|
|
173
|
-
x: cute.TensorSSA | cute.Numeric,
|
|
174
|
-
op: cute.ReductionOp,
|
|
175
|
-
threads_per_row: cutlass.Constexpr[int],
|
|
176
|
-
reduction_buffer: Optional[cute.Tensor] = None,
|
|
177
|
-
mbar_ptr: Optional[cute.Pointer] = None,
|
|
178
|
-
phase: Optional[cutlass.Int32] = None,
|
|
179
|
-
init_val: cute.Numeric = 0.0,
|
|
180
|
-
hook_fn: Optional[Callable] = None,
|
|
181
|
-
) -> cute.Numeric:
|
|
182
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
183
|
-
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
184
|
-
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
185
|
-
else:
|
|
186
|
-
val = x
|
|
187
|
-
warp_op = {
|
|
188
|
-
cute.ReductionOp.ADD: operator.add,
|
|
189
|
-
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
190
|
-
cute.ReductionOp.MIN: min,
|
|
191
|
-
cute.ReductionOp.MUL: operator.mul,
|
|
192
|
-
}[op]
|
|
193
|
-
val = warp_reduce(
|
|
194
|
-
val,
|
|
195
|
-
warp_op,
|
|
196
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
197
|
-
)
|
|
198
|
-
if cutlass.const_expr(hook_fn is not None):
|
|
199
|
-
hook_fn()
|
|
200
|
-
if cutlass.const_expr(reduction_buffer is not None):
|
|
201
|
-
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
202
|
-
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
203
|
-
"mbar_ptr must be provided for cluster reduction"
|
|
204
|
-
)
|
|
205
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
206
|
-
val = block_or_cluster_reduce(
|
|
207
|
-
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
208
|
-
)
|
|
209
|
-
return val
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
@cute.jit
|
|
213
|
-
def online_softmax_reduce(
|
|
214
|
-
x: cute.TensorSSA,
|
|
215
|
-
threads_per_row: cutlass.Constexpr[int],
|
|
216
|
-
reduction_buffer: Optional[cute.Tensor] = None,
|
|
217
|
-
mbar_ptr: Optional[cute.Pointer] = None,
|
|
218
|
-
hook_fn: Optional[Callable] = None,
|
|
219
|
-
phase: Optional[cutlass.Int32] = None,
|
|
220
|
-
return_exp_x: bool = False,
|
|
221
|
-
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
222
|
-
assert x.dtype == Float32, "x must be of type Float32"
|
|
223
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
224
|
-
max_x = warp_reduce(
|
|
225
|
-
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
226
|
-
cute.arch.fmax,
|
|
227
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
228
|
-
)
|
|
229
|
-
log2_e = math.log2(math.e)
|
|
230
|
-
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
231
|
-
# exp_x = exp2f((x - max_x) * log2_e)
|
|
232
|
-
sum_exp_x = warp_reduce(
|
|
233
|
-
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
234
|
-
operator.add,
|
|
235
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
236
|
-
)
|
|
237
|
-
if cutlass.const_expr(hook_fn is not None):
|
|
238
|
-
hook_fn()
|
|
239
|
-
if cutlass.const_expr(reduction_buffer is not None):
|
|
240
|
-
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
241
|
-
assert cluster_n == 1 or mbar_ptr is not None, (
|
|
242
|
-
"mbar_ptr must be provided for cluster reduction"
|
|
243
|
-
)
|
|
244
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
245
|
-
assert reduction_buffer.element_type == cutlass.Int64, (
|
|
246
|
-
"reduction_buffer must be of type cute.Int64"
|
|
247
|
-
)
|
|
248
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
249
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
250
|
-
if cutlass.const_expr(mbar_ptr is None):
|
|
251
|
-
if lane_idx == 0:
|
|
252
|
-
reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
|
|
253
|
-
cute.arch.barrier()
|
|
254
|
-
max_x_single_warp = -Float32.inf
|
|
255
|
-
sum_exp_x = 0.0
|
|
256
|
-
if lane_idx < warps_per_row:
|
|
257
|
-
max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
|
|
258
|
-
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
259
|
-
sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
260
|
-
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
261
|
-
if cutlass.const_expr(return_exp_x):
|
|
262
|
-
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
263
|
-
max_x = max_x_final
|
|
264
|
-
else:
|
|
265
|
-
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
266
|
-
if warp_idx == 0:
|
|
267
|
-
with cute.arch.elect_one():
|
|
268
|
-
num_warps = rows_per_block * warps_per_row
|
|
269
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
270
|
-
mbar_ptr,
|
|
271
|
-
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
272
|
-
)
|
|
273
|
-
if lane_idx < cluster_n:
|
|
274
|
-
store_shared_remote(
|
|
275
|
-
f32x2_to_i64(max_x, sum_exp_x),
|
|
276
|
-
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
277
|
-
mbar_ptr,
|
|
278
|
-
peer_cta_rank_in_cluster=lane_idx,
|
|
279
|
-
)
|
|
280
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
281
|
-
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
282
|
-
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
283
|
-
max_x_single_warp.fill(-Float32.inf)
|
|
284
|
-
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
285
|
-
sum_exp_x_single_warp.fill(0.0)
|
|
286
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
287
|
-
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
288
|
-
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
289
|
-
max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
|
|
290
|
-
reduction_buffer[row_idx, idx]
|
|
291
|
-
)
|
|
292
|
-
max_x_final = max_x_single_warp.load().reduce(
|
|
293
|
-
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
294
|
-
)
|
|
295
|
-
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
296
|
-
sum_exp_x = 0.0
|
|
297
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
298
|
-
sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
|
|
299
|
-
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
300
|
-
)
|
|
301
|
-
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
302
|
-
if cutlass.const_expr(return_exp_x):
|
|
303
|
-
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
304
|
-
max_x = max_x_final
|
|
305
|
-
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
306
|
-
|
|
307
|
-
|
|
308
90
|
@dsl_user_op
|
|
309
91
|
def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
|
310
92
|
return Float32(
|
|
@@ -318,84 +100,6 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
|
|
|
318
100
|
)
|
|
319
101
|
|
|
320
102
|
|
|
321
|
-
@cute.jit
|
|
322
|
-
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
323
|
-
"""exp2f calculation for both vector and scalar.
|
|
324
|
-
:param x: input value
|
|
325
|
-
:type x: cute.TensorSSA or Float32
|
|
326
|
-
:return: exp2 value
|
|
327
|
-
:rtype: cute.TensorSSA or Float32
|
|
328
|
-
"""
|
|
329
|
-
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
330
|
-
res = cute.make_fragment(x.shape, Float32)
|
|
331
|
-
res.store(x)
|
|
332
|
-
for i in cutlass.range(cute.size(x.shape), unroll_full=True):
|
|
333
|
-
res[i] = cute.arch.exp2(res[i])
|
|
334
|
-
return res.load()
|
|
335
|
-
else:
|
|
336
|
-
return cute.arch.exp2(x)
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
@dsl_user_op
|
|
340
|
-
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
341
|
-
return Float32(
|
|
342
|
-
llvm.inline_asm(
|
|
343
|
-
T.f32(),
|
|
344
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
345
|
-
"lg2.approx.ftz.f32 $0, $1;",
|
|
346
|
-
"=f,f",
|
|
347
|
-
has_side_effects=False,
|
|
348
|
-
is_align_stack=False,
|
|
349
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
350
|
-
)
|
|
351
|
-
)
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
@dsl_user_op
|
|
355
|
-
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
356
|
-
return Float32(
|
|
357
|
-
llvm.inline_asm(
|
|
358
|
-
T.f32(),
|
|
359
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
360
|
-
"sqrt.approx.ftz.f32 $0, $1;",
|
|
361
|
-
"=f,f",
|
|
362
|
-
has_side_effects=False,
|
|
363
|
-
is_align_stack=False,
|
|
364
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
365
|
-
)
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
@dsl_user_op
|
|
370
|
-
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
371
|
-
return Float32(
|
|
372
|
-
llvm.inline_asm(
|
|
373
|
-
T.f32(),
|
|
374
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
375
|
-
"rsqrt.approx.ftz.f32 $0, $1;",
|
|
376
|
-
"=f,f",
|
|
377
|
-
has_side_effects=False,
|
|
378
|
-
is_align_stack=False,
|
|
379
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
380
|
-
)
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
@dsl_user_op
|
|
385
|
-
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
386
|
-
return Float32(
|
|
387
|
-
llvm.inline_asm(
|
|
388
|
-
T.f32(),
|
|
389
|
-
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
390
|
-
"tanh.approx.f32 $0, $1;",
|
|
391
|
-
"=f,f",
|
|
392
|
-
has_side_effects=False,
|
|
393
|
-
is_align_stack=False,
|
|
394
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
395
|
-
)
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
|
|
399
103
|
@dsl_user_op
|
|
400
104
|
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
401
105
|
return Int32(
|
|
@@ -411,16 +115,6 @@ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
|
411
115
|
)
|
|
412
116
|
|
|
413
117
|
|
|
414
|
-
@dsl_user_op
|
|
415
|
-
def silu(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
416
|
-
"""
|
|
417
|
-
silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a)
|
|
418
|
-
This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA.
|
|
419
|
-
"""
|
|
420
|
-
a_half = 0.5 * a
|
|
421
|
-
return a_half * tanh(a_half) + a_half
|
|
422
|
-
|
|
423
|
-
|
|
424
118
|
@dsl_user_op
|
|
425
119
|
def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
|
|
426
120
|
return Int32(
|
|
@@ -498,7 +192,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
|
|
|
498
192
|
tXpX: Predicate tensor indicating valid elements
|
|
499
193
|
fill_value: Value to fill OOB locations with
|
|
500
194
|
"""
|
|
501
|
-
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0),
|
|
195
|
+
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
|
|
502
196
|
tXrX_fill.fill(fill_value)
|
|
503
197
|
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
504
198
|
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
@@ -538,9 +232,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
|
|
|
538
232
|
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
|
539
233
|
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
|
540
234
|
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
|
541
|
-
assert len(flat_coord_i64) == len(
|
|
542
|
-
|
|
543
|
-
)
|
|
235
|
+
assert len(flat_coord_i64) == len(
|
|
236
|
+
flat_stride
|
|
237
|
+
), "Coordinate and stride must have the same length"
|
|
544
238
|
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
|
545
239
|
assert isinstance(tensor.iterator, cute.Pointer)
|
|
546
240
|
# HACK: we assume that applying the offset does not change the pointer alignment
|
|
@@ -662,5 +356,3 @@ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None)
|
|
|
662
356
|
return nvvm.atomicrmw(
|
|
663
357
|
res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
|
664
358
|
)
|
|
665
|
-
|
|
666
|
-
|
quack/varlen_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
|
|
8
|
+
from quack.cute_dsl_utils import ArgumentsBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Grouping arguments together that should be passed to __call__
|
|
12
|
+
@dataclass
|
|
13
|
+
class VarlenArguments(ArgumentsBase):
|
|
14
|
+
mCuSeqlensM: Optional[cute.Tensor] = None
|
|
15
|
+
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
|
+
mTensormaps: Optional[cute.Tensor] = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
|
|
20
|
+
assert (
|
|
21
|
+
self.mTensormaps is not None
|
|
22
|
+
), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1
|
|
4
|
-
Requires-Python: >=3.
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.2.0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
quack/__init__.py,sha256=H1m0CnfPidSSmprZeTGJc8LVh7stdBPmPLEuZwgN_7M,364
|
|
2
|
+
quack/activation.py,sha256=SzQDUCB-kccqsy1aYUrHYJ2cGxKMXxxqpjJaJoqBYaE,10017
|
|
3
|
+
quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
|
|
4
|
+
quack/cross_entropy.py,sha256=TE8j21c-7E4cInKtFjcKsgKXNhKCRFkNfhCJpgpasj8,28409
|
|
5
|
+
quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
|
|
6
|
+
quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
|
|
7
|
+
quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
|
|
8
|
+
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
9
|
+
quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
|
|
10
|
+
quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
11
|
+
quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
|
|
12
|
+
quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
|
|
13
|
+
quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
|
|
14
|
+
quack/layernorm.py,sha256=AOe95-YqhFPw96x8pJq7FfBe26ROX9ZTvH025lM1ILs,13579
|
|
15
|
+
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
|
+
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
17
|
+
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
18
|
+
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
|
+
quack/reduce.py,sha256=0hRFMFfn6xC5QLk32Qmgc17XVkQ1yKC-3TfksccSBaU,10341
|
|
20
|
+
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
+
quack/rmsnorm.py,sha256=PrW2zuaQs_Gr6g8B6DMsGSJFZdEsWf32if_EwUR_IDQ,49386
|
|
22
|
+
quack/softmax.py,sha256=WFWtgc40iLPFBpdStBBTC9803Npnv9rZjOzb_nK-RDs,17110
|
|
23
|
+
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
|
+
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
+
quack/tile_scheduler.py,sha256=BQ-SeW5wxulKuwmpq0CAIjkuirv4KWdUdoIGQB88aGE,42319
|
|
26
|
+
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
+
quack/utils.py,sha256=wOgNw9VL40FCsLwN52juPfk48zVpX-rta3MQhAQe8Wc,12767
|
|
28
|
+
quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
|
|
29
|
+
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
|
+
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
|
+
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
|
+
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
+
quack_kernels-0.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
quack_kernels-0.2.1.dist-info/METADATA,sha256=_AFigx6aFt-25GzUP6YWalDBwHvwzgK9EU85WjZXvsI,285
|
|
35
|
+
quack_kernels-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
quack_kernels-0.2.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
+
quack_kernels-0.2.1.dist-info/RECORD,,
|
quack/lse.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2025, Tri Dao.
|
|
2
|
-
# TODO: we probably dont' need this kernel, just use torch.logsumexp
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
import triton
|
|
6
|
-
import triton.language as tl
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
@triton.jit
|
|
10
|
-
def _lse_kernel(
|
|
11
|
-
lse_ptr,
|
|
12
|
-
logits_ptr,
|
|
13
|
-
n_rows,
|
|
14
|
-
n_cols,
|
|
15
|
-
logits_row_stride,
|
|
16
|
-
logits_col_stride,
|
|
17
|
-
BLOCK_SIZE_M: tl.constexpr,
|
|
18
|
-
BLOCK_SIZE_N: tl.constexpr,
|
|
19
|
-
):
|
|
20
|
-
row_start = tl.program_id(0) * BLOCK_SIZE_M
|
|
21
|
-
rows = row_start + tl.arange(0, BLOCK_SIZE_M)
|
|
22
|
-
cols = tl.arange(0, BLOCK_SIZE_N)
|
|
23
|
-
logits = tl.load(
|
|
24
|
-
logits_ptr + rows[:, None] * logits_row_stride + cols[None, :] * logits_col_stride,
|
|
25
|
-
mask=(rows[:, None] < n_rows) & (cols[None, :] < n_cols),
|
|
26
|
-
other=-float("inf"),
|
|
27
|
-
).to(tl.float32)
|
|
28
|
-
m = tl.max(logits, 1)
|
|
29
|
-
lse = tl.log(tl.sum(tl.exp(logits - m[:, None]), 1)) + m
|
|
30
|
-
tl.store(lse_ptr + rows, lse, mask=rows < n_rows)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def logsumexp(logits):
|
|
34
|
-
n_rows, n_cols = logits.shape
|
|
35
|
-
BLOCK_SIZE_M = 32 if logits.stride(1) != 1 else 1
|
|
36
|
-
MAX_BLOCK_SIZE = 64 * 1024
|
|
37
|
-
# BLOCK_SIZE_N = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE // BLOCK_SIZE_M)
|
|
38
|
-
BLOCK_SIZE_N = triton.next_power_of_2(n_cols)
|
|
39
|
-
assert (
|
|
40
|
-
BLOCK_SIZE_M * BLOCK_SIZE_N <= MAX_BLOCK_SIZE
|
|
41
|
-
), f"Only support max dimension {MAX_BLOCK_SIZE // BLOCK_SIZE_M}"
|
|
42
|
-
num_warps = (
|
|
43
|
-
4
|
|
44
|
-
if BLOCK_SIZE_N < 2048
|
|
45
|
-
else (8 if BLOCK_SIZE_N < 8192 else (16 if BLOCK_SIZE_N < 128 * 1024 else 32))
|
|
46
|
-
)
|
|
47
|
-
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
|
|
48
|
-
# Need this, otherwise Triton tries to launch from cuda:0 and we get
|
|
49
|
-
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
|
50
|
-
with torch.cuda.device(logits.device.index):
|
|
51
|
-
_lse_kernel[(triton.cdiv(n_rows, BLOCK_SIZE_M),)](
|
|
52
|
-
lse,
|
|
53
|
-
logits,
|
|
54
|
-
n_rows,
|
|
55
|
-
n_cols, # shapes
|
|
56
|
-
logits.stride(0), # strides
|
|
57
|
-
logits.stride(1),
|
|
58
|
-
BLOCK_SIZE_M=BLOCK_SIZE_M, # constants
|
|
59
|
-
BLOCK_SIZE_N=BLOCK_SIZE_N, # constants
|
|
60
|
-
num_warps=num_warps,
|
|
61
|
-
)
|
|
62
|
-
return lse
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=AD0T-rBhSfKXpwZ6E4JIPiugvlFaAePjl-3pUhWOlPE,292
|
|
2
|
-
quack/autotuner.py,sha256=aF9-Cw47gaX7_LZvyVbLsj6Z2AWi4UZ-0Qwjy06Xd5I,10733
|
|
3
|
-
quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
|
|
4
|
-
quack/cute_dsl_utils.py,sha256=LkNyFEKwYrgp-tLt_775EZWuBR3v7G80El3UAObHY2U,1292
|
|
5
|
-
quack/dense_gemm_sm100.py,sha256=W_j8BO-ilb1YUYFuclo7_itfPIRTkjPV_ittWgQy8t4,109937
|
|
6
|
-
quack/dense_gemm_sm90.py,sha256=Dff0GbIv92uTjrtsUE1GjVKCtwSf6_5KZbrqYZm-ZMY,110418
|
|
7
|
-
quack/fast_math.py,sha256=XqXVvKLSxXC3c9tIGLvKVRWdPsmjAa_O4C0plmsfZ0w,3106
|
|
8
|
-
quack/gemm_config.py,sha256=Gz4dkHH1Uwg9IdW-x5W_5tjdaFHBfxq4bn7hJx_xu5s,1789
|
|
9
|
-
quack/gemm_interface.py,sha256=XHgxo08d8LIu6dTlQKBOBJtjCegUB5uLh4k9hC-5mvY,9525
|
|
10
|
-
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
11
|
-
quack/linear.py,sha256=Wd0KeXWvWjbkKrgW4Av1ud2v_mbhzf1RvubF7BYhcw4,6425
|
|
12
|
-
quack/lse.py,sha256=aANOleIYREyrkUQM9cfJ9Gt63eawMb2KVd7YAGWNoZU,2092
|
|
13
|
-
quack/mlp.py,sha256=D9V7aIfvoBMzhKwN8ZE6GlSOmwFJe_JGqgOvQprU0OQ,8224
|
|
14
|
-
quack/pipeline.py,sha256=SwvRZAR4RqYH60wAFC3OTu5DisN1XDMv5umQF4czJW4,5867
|
|
15
|
-
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
16
|
-
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
17
|
-
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
18
|
-
quack/symmetric_dense_gemm_sm90.py,sha256=t-6eLasZwyu1NW4HpnvVBBPOvfqUzOg8VHe9sJQYdmg,88637
|
|
19
|
-
quack/tensormap_manager.py,sha256=pzBNwLCB8kV_yp8X8_BoDdtbwWeht2jrgRhyyfVIcMI,5261
|
|
20
|
-
quack/tile_scheduler.py,sha256=mImjD2LuIVchM6USJoJY4-CSG54jGuwyLIvFG6LTP9Y,42205
|
|
21
|
-
quack/topk.py,sha256=1pObblNJnxKLaE_T3qGvaMnUua0dqG2en9OU5PSp71s,9020
|
|
22
|
-
quack/utils.py,sha256=4ViEFgHecaX5wcYpO6XzTCzdnuZv2rniUJAJH5Ta0bA,24981
|
|
23
|
-
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
24
|
-
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
25
|
-
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
26
|
-
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
27
|
-
quack_kernels-0.1.11.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
28
|
-
quack_kernels-0.1.11.dist-info/METADATA,sha256=WTYlk9lmhr4Jkin71stp3h-NrBdme-8OrBc7lAf4vSw,286
|
|
29
|
-
quack_kernels-0.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
-
quack_kernels-0.1.11.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
31
|
-
quack_kernels-0.1.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|