quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
-
import operator
|
|
4
3
|
import math
|
|
5
|
-
from typing import
|
|
4
|
+
from typing import Optional, Tuple, Type, Union
|
|
6
5
|
|
|
7
6
|
import cutlass
|
|
8
7
|
import cutlass.cute as cute
|
|
9
8
|
|
|
10
|
-
from cutlass import Float32
|
|
9
|
+
from cutlass import Float32, Int32
|
|
11
10
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
12
|
-
from cutlass._mlir.dialects import llvm, vector
|
|
11
|
+
from cutlass._mlir.dialects import llvm, nvvm, vector
|
|
13
12
|
from cutlass.cute.runtime import from_dlpack
|
|
14
13
|
|
|
15
14
|
|
|
@@ -23,46 +22,20 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
|
|
23
22
|
)
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
@cute.jit
|
|
27
|
-
def warp_reduce(
|
|
28
|
-
val: cute.TensorSSA | cute.Numeric,
|
|
29
|
-
op: Callable,
|
|
30
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
31
|
-
) -> cute.TensorSSA | cute.Numeric:
|
|
32
|
-
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
33
|
-
res = cute.make_fragment(val.shape, val.dtype)
|
|
34
|
-
res.store(val)
|
|
35
|
-
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
36
|
-
res[i] = warp_reduce(res[i], op, width)
|
|
37
|
-
return res.load()
|
|
38
|
-
else:
|
|
39
|
-
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
40
|
-
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
41
|
-
return val
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@cute.jit
|
|
45
|
-
def block_reduce(
|
|
46
|
-
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
47
|
-
) -> cute.Numeric:
|
|
48
|
-
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
49
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
50
|
-
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
51
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
52
|
-
if lane_idx == 0:
|
|
53
|
-
reduction_buffer[row_idx, col_idx] = val
|
|
54
|
-
cute.arch.barrier()
|
|
55
|
-
block_reduce_val = init_val
|
|
56
|
-
if lane_idx < warps_per_row:
|
|
57
|
-
block_reduce_val = reduction_buffer[row_idx, lane_idx]
|
|
58
|
-
return warp_reduce(block_reduce_val, op)
|
|
59
|
-
|
|
60
|
-
|
|
61
25
|
@dsl_user_op
|
|
62
26
|
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
|
63
27
|
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
|
|
64
28
|
|
|
65
29
|
|
|
30
|
+
@cute.jit
|
|
31
|
+
def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
|
|
32
|
+
if cutlass.const_expr(isinstance(x, cute.Pointer)):
|
|
33
|
+
return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
|
|
34
|
+
else:
|
|
35
|
+
assert isinstance(x, Float32)
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
66
39
|
@dsl_user_op
|
|
67
40
|
def set_block_rank(
|
|
68
41
|
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
|
|
@@ -100,208 +73,31 @@ def store_shared_remote(
|
|
|
100
73
|
).ir_value()
|
|
101
74
|
if cutlass.const_expr(isinstance(val, float)):
|
|
102
75
|
val = Float32(val)
|
|
103
|
-
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
104
|
-
suffix = "f32"
|
|
76
|
+
assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
|
|
77
|
+
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
|
|
78
|
+
constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
|
|
105
79
|
llvm.inline_asm(
|
|
106
80
|
None,
|
|
107
81
|
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
108
82
|
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
109
|
-
f"r,{
|
|
83
|
+
f"r,{constraint},r",
|
|
110
84
|
has_side_effects=True,
|
|
111
85
|
is_align_stack=False,
|
|
112
86
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
113
87
|
)
|
|
114
88
|
|
|
115
89
|
|
|
116
|
-
@
|
|
117
|
-
def
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
126
|
-
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
127
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
128
|
-
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
129
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
130
|
-
if warp_idx == 0:
|
|
131
|
-
with cute.arch.elect_one():
|
|
132
|
-
num_warps = rows_per_block * warps_per_row
|
|
133
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
134
|
-
mbar_ptr,
|
|
135
|
-
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
136
|
-
)
|
|
137
|
-
if lane_idx < cluster_n:
|
|
138
|
-
store_shared_remote(
|
|
139
|
-
val,
|
|
140
|
-
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
141
|
-
mbar_ptr,
|
|
142
|
-
peer_cta_rank_in_cluster=lane_idx,
|
|
90
|
+
@dsl_user_op
|
|
91
|
+
def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
|
|
92
|
+
return Float32(
|
|
93
|
+
nvvm.fmin(
|
|
94
|
+
T.f32(),
|
|
95
|
+
Float32(a).ir_value(loc=loc, ip=ip),
|
|
96
|
+
Float32(b).ir_value(loc=loc, ip=ip),
|
|
97
|
+
loc=loc,
|
|
98
|
+
ip=ip,
|
|
143
99
|
)
|
|
144
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
145
|
-
block_reduce_val = init_val
|
|
146
|
-
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
147
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
148
|
-
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
149
|
-
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
150
|
-
block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
|
|
151
|
-
return warp_reduce(block_reduce_val, op)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@cute.jit
|
|
155
|
-
def block_or_cluster_reduce(
|
|
156
|
-
val: cute.Numeric,
|
|
157
|
-
op: Callable,
|
|
158
|
-
reduction_buffer: cute.Tensor,
|
|
159
|
-
mbar_ptr: Optional[cute.Pointer],
|
|
160
|
-
phase: Optional[cutlass.Int32] = None,
|
|
161
|
-
init_val: cute.Numeric = 0.0,
|
|
162
|
-
) -> cute.Numeric:
|
|
163
|
-
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
164
|
-
if cutlass.const_expr(mbar_ptr is None):
|
|
165
|
-
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
166
|
-
else:
|
|
167
|
-
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
@cute.jit
|
|
171
|
-
def row_reduce(
|
|
172
|
-
x: cute.TensorSSA | cute.Numeric,
|
|
173
|
-
op: cute.ReductionOp,
|
|
174
|
-
threads_per_row: cutlass.Constexpr[int],
|
|
175
|
-
reduction_buffer: Optional[cute.Tensor] = None,
|
|
176
|
-
mbar_ptr: Optional[cute.Pointer] = None,
|
|
177
|
-
phase: Optional[cutlass.Int32] = None,
|
|
178
|
-
init_val: cute.Numeric = 0.0,
|
|
179
|
-
hook_fn: Optional[Callable] = None,
|
|
180
|
-
) -> cute.Numeric:
|
|
181
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
182
|
-
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
183
|
-
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
184
|
-
else:
|
|
185
|
-
val = x
|
|
186
|
-
warp_op = {
|
|
187
|
-
cute.ReductionOp.ADD: operator.add,
|
|
188
|
-
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
189
|
-
cute.ReductionOp.MIN: min,
|
|
190
|
-
cute.ReductionOp.MUL: operator.mul,
|
|
191
|
-
}[op]
|
|
192
|
-
val = warp_reduce(
|
|
193
|
-
val,
|
|
194
|
-
warp_op,
|
|
195
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
196
100
|
)
|
|
197
|
-
if cutlass.const_expr(hook_fn is not None):
|
|
198
|
-
hook_fn()
|
|
199
|
-
if cutlass.const_expr(reduction_buffer is not None):
|
|
200
|
-
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
201
|
-
assert (
|
|
202
|
-
cluster_n == 1 or mbar_ptr is not None
|
|
203
|
-
), "mbar_ptr must be provided for cluster reduction"
|
|
204
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
205
|
-
val = block_or_cluster_reduce(
|
|
206
|
-
val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
|
|
207
|
-
)
|
|
208
|
-
return val
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
@cute.jit
|
|
212
|
-
def online_softmax_reduce(
|
|
213
|
-
x: cute.TensorSSA,
|
|
214
|
-
threads_per_row: cutlass.Constexpr[int],
|
|
215
|
-
reduction_buffer: Optional[cute.Tensor] = None,
|
|
216
|
-
mbar_ptr: Optional[cute.Pointer] = None,
|
|
217
|
-
hook_fn: Optional[Callable] = None,
|
|
218
|
-
phase: Optional[cutlass.Int32] = None,
|
|
219
|
-
return_exp_x: bool = False,
|
|
220
|
-
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
221
|
-
assert x.dtype == Float32, "x must be of type Float32"
|
|
222
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
223
|
-
max_x = warp_reduce(
|
|
224
|
-
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
225
|
-
cute.arch.fmax,
|
|
226
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
227
|
-
)
|
|
228
|
-
log2_e = math.log2(math.e)
|
|
229
|
-
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
230
|
-
# exp_x = exp2f((x - max_x) * log2_e)
|
|
231
|
-
sum_exp_x = warp_reduce(
|
|
232
|
-
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
233
|
-
operator.add,
|
|
234
|
-
width=min(threads_per_row, cute.arch.WARP_SIZE),
|
|
235
|
-
)
|
|
236
|
-
if cutlass.const_expr(hook_fn is not None):
|
|
237
|
-
hook_fn()
|
|
238
|
-
if cutlass.const_expr(reduction_buffer is not None):
|
|
239
|
-
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
|
|
240
|
-
assert (
|
|
241
|
-
cluster_n == 1 or mbar_ptr is not None
|
|
242
|
-
), "mbar_ptr must be provided for cluster reduction"
|
|
243
|
-
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
244
|
-
assert (
|
|
245
|
-
reduction_buffer.element_type == cutlass.Int64
|
|
246
|
-
), "reduction_buffer must be of type cute.Int64"
|
|
247
|
-
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
248
|
-
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
249
|
-
if cutlass.const_expr(mbar_ptr is None):
|
|
250
|
-
if lane_idx == 0:
|
|
251
|
-
reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
|
|
252
|
-
cute.arch.barrier()
|
|
253
|
-
max_x_single_warp = -Float32.inf
|
|
254
|
-
sum_exp_x = 0.0
|
|
255
|
-
if lane_idx < warps_per_row:
|
|
256
|
-
max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
|
|
257
|
-
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
258
|
-
sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
259
|
-
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
260
|
-
if cutlass.const_expr(return_exp_x):
|
|
261
|
-
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
262
|
-
max_x = max_x_final
|
|
263
|
-
else:
|
|
264
|
-
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
265
|
-
if warp_idx == 0:
|
|
266
|
-
with cute.arch.elect_one():
|
|
267
|
-
num_warps = rows_per_block * warps_per_row
|
|
268
|
-
cute.arch.mbarrier_arrive_and_expect_tx(
|
|
269
|
-
mbar_ptr,
|
|
270
|
-
num_warps * cluster_n * reduction_buffer.element_type.width // 8,
|
|
271
|
-
)
|
|
272
|
-
if lane_idx < cluster_n:
|
|
273
|
-
store_shared_remote(
|
|
274
|
-
f32x2_to_i64(max_x, sum_exp_x),
|
|
275
|
-
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
276
|
-
mbar_ptr,
|
|
277
|
-
peer_cta_rank_in_cluster=lane_idx,
|
|
278
|
-
)
|
|
279
|
-
cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
|
|
280
|
-
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
281
|
-
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
282
|
-
max_x_single_warp.fill(-Float32.inf)
|
|
283
|
-
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
284
|
-
sum_exp_x_single_warp.fill(0.0)
|
|
285
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
286
|
-
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
287
|
-
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
288
|
-
max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
|
|
289
|
-
reduction_buffer[row_idx, idx]
|
|
290
|
-
)
|
|
291
|
-
max_x_final = max_x_single_warp.load().reduce(
|
|
292
|
-
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
293
|
-
)
|
|
294
|
-
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
295
|
-
sum_exp_x = 0.0
|
|
296
|
-
for i in cutlass.range_constexpr(num_iter):
|
|
297
|
-
sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
|
|
298
|
-
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
299
|
-
)
|
|
300
|
-
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
301
|
-
if cutlass.const_expr(return_exp_x):
|
|
302
|
-
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
303
|
-
max_x = max_x_final
|
|
304
|
-
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
305
101
|
|
|
306
102
|
|
|
307
103
|
@cute.jit
|
|
@@ -337,6 +133,21 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
337
133
|
)
|
|
338
134
|
|
|
339
135
|
|
|
136
|
+
@dsl_user_op
|
|
137
|
+
def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
138
|
+
return Float32(
|
|
139
|
+
llvm.inline_asm(
|
|
140
|
+
T.f32(),
|
|
141
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
142
|
+
"sqrt.approx.ftz.f32 $0, $1;",
|
|
143
|
+
"=f,f",
|
|
144
|
+
has_side_effects=False,
|
|
145
|
+
is_align_stack=False,
|
|
146
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
340
151
|
@dsl_user_op
|
|
341
152
|
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
342
153
|
return Float32(
|
|
@@ -352,6 +163,73 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
|
352
163
|
)
|
|
353
164
|
|
|
354
165
|
|
|
166
|
+
@dsl_user_op
|
|
167
|
+
def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
|
|
168
|
+
return Int32(
|
|
169
|
+
llvm.inline_asm(
|
|
170
|
+
T.i32(),
|
|
171
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
172
|
+
"cvt.rpi.ftz.s32.f32 $0, $1;",
|
|
173
|
+
"=r,f",
|
|
174
|
+
has_side_effects=False,
|
|
175
|
+
is_align_stack=False,
|
|
176
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@dsl_user_op
|
|
182
|
+
def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
|
|
183
|
+
return Int32(
|
|
184
|
+
llvm.inline_asm(
|
|
185
|
+
T.i32(),
|
|
186
|
+
[
|
|
187
|
+
Int32(a).ir_value(loc=loc, ip=ip),
|
|
188
|
+
Int32(b).ir_value(loc=loc, ip=ip),
|
|
189
|
+
Int32(c).ir_value(loc=loc, ip=ip),
|
|
190
|
+
],
|
|
191
|
+
"prmt.b32 $0, $1, $2, $3;",
|
|
192
|
+
"=r,r,r,r",
|
|
193
|
+
has_side_effects=False,
|
|
194
|
+
is_align_stack=False,
|
|
195
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@cute.jit
|
|
201
|
+
def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
|
|
202
|
+
assert t.element_type.width == 16
|
|
203
|
+
assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
|
|
204
|
+
t_u32 = cute.recast_tensor(t, Int32)
|
|
205
|
+
|
|
206
|
+
quad_idx = cute.arch.lane_idx() % 4
|
|
207
|
+
lane_03 = quad_idx == 0 or quad_idx == 3
|
|
208
|
+
selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
|
|
209
|
+
selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
|
|
210
|
+
# upper_map = [0, 3, 1, 2]
|
|
211
|
+
# lower_map = [1, 2, 0, 3]
|
|
212
|
+
# upper_idx = upper_map[quad_idx]
|
|
213
|
+
# indexing isn't supported so we have to do arithmetic
|
|
214
|
+
upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
|
|
215
|
+
lower_idx = upper_idx ^ 1
|
|
216
|
+
|
|
217
|
+
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
|
|
218
|
+
width = 4
|
|
219
|
+
mask = cute.arch.WARP_SIZE - width
|
|
220
|
+
clamp = cute.arch.WARP_SIZE - 1
|
|
221
|
+
mask_and_clamp = mask << 8 | clamp
|
|
222
|
+
|
|
223
|
+
for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
|
|
224
|
+
upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
|
|
225
|
+
upper0 = upper if lane_03 else lower
|
|
226
|
+
lower0 = lower if lane_03 else upper
|
|
227
|
+
upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
|
|
228
|
+
lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
|
|
229
|
+
t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
|
|
230
|
+
t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
|
|
231
|
+
|
|
232
|
+
|
|
355
233
|
@cute.jit
|
|
356
234
|
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
357
235
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
@@ -377,7 +255,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
|
|
|
377
255
|
tXpX: Predicate tensor indicating valid elements
|
|
378
256
|
fill_value: Value to fill OOB locations with
|
|
379
257
|
"""
|
|
380
|
-
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0),
|
|
258
|
+
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
|
|
381
259
|
tXrX_fill.fill(fill_value)
|
|
382
260
|
for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
|
|
383
261
|
for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
|
|
@@ -446,3 +324,98 @@ def coord_offset_i64(
|
|
|
446
324
|
assumed_align=tensor.iterator.max_alignment,
|
|
447
325
|
)
|
|
448
326
|
return cute.make_tensor(new_ptr, tensor.layout)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
@cute.jit
|
|
330
|
+
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
|
|
331
|
+
if cutlass.const_expr(lane is None):
|
|
332
|
+
lane = cute.arch.lane_idx()
|
|
333
|
+
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
|
|
334
|
+
offset = 1 << i
|
|
335
|
+
# Very important that we set mask_and_clamp to 0
|
|
336
|
+
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
|
|
337
|
+
if lane >= offset:
|
|
338
|
+
val += partial_sum
|
|
339
|
+
return val
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
|
|
343
|
+
"""
|
|
344
|
+
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
|
345
|
+
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
|
346
|
+
"""
|
|
347
|
+
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
|
348
|
+
acc_layout_mn = cute.make_layout(
|
|
349
|
+
(
|
|
350
|
+
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
|
351
|
+
(
|
|
352
|
+
acc_layout_col_major.shape[0][0],
|
|
353
|
+
*acc_layout_col_major.shape[0][2:],
|
|
354
|
+
acc_layout_col_major.shape[2],
|
|
355
|
+
), # MMA_N
|
|
356
|
+
*acc_layout_col_major.shape[3:],
|
|
357
|
+
),
|
|
358
|
+
stride=(
|
|
359
|
+
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
|
360
|
+
(
|
|
361
|
+
acc_layout_col_major.stride[0][0],
|
|
362
|
+
*acc_layout_col_major.stride[0][2:],
|
|
363
|
+
acc_layout_col_major.stride[2],
|
|
364
|
+
), # MMA_N
|
|
365
|
+
*acc_layout_col_major.stride[3:],
|
|
366
|
+
),
|
|
367
|
+
)
|
|
368
|
+
return cute.composition(acc_layout, acc_layout_mn)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
|
|
372
|
+
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@dsl_user_op
|
|
376
|
+
def sm90_get_smem_load_op(
|
|
377
|
+
layout_c: cutlass.utils.LayoutEnum,
|
|
378
|
+
elem_ty_c: Type[cutlass.Numeric],
|
|
379
|
+
*,
|
|
380
|
+
loc=None,
|
|
381
|
+
ip=None,
|
|
382
|
+
) -> cute.CopyAtom:
|
|
383
|
+
"""
|
|
384
|
+
Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
|
|
385
|
+
|
|
386
|
+
Parameters:
|
|
387
|
+
-----------
|
|
388
|
+
layout_c : LayoutEnum
|
|
389
|
+
The layout enum of the output tensor D.
|
|
390
|
+
|
|
391
|
+
elem_ty_c : Type[Numeric]
|
|
392
|
+
The element type for output tensor D.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
--------
|
|
396
|
+
Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
|
|
400
|
+
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
401
|
+
is_m_major = layout_c.is_m_major_c()
|
|
402
|
+
if elem_ty_c.width == 16:
|
|
403
|
+
return cute.make_copy_atom(
|
|
404
|
+
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@dsl_user_op
|
|
411
|
+
def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
|
412
|
+
return nvvm.atomicrmw(
|
|
413
|
+
res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
@dsl_user_op
|
|
418
|
+
def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
|
|
419
|
+
return nvvm.atomicrmw(
|
|
420
|
+
res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
|
|
421
|
+
)
|
quack/varlen_utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
|
|
8
|
+
from quack.cute_dsl_utils import ArgumentsBase
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Grouping arguments together that should be passed to __call__
|
|
12
|
+
@dataclass
|
|
13
|
+
class VarlenArguments(ArgumentsBase):
|
|
14
|
+
mCuSeqlensM: Optional[cute.Tensor] = None
|
|
15
|
+
mCuSeqlensK: Optional[cute.Tensor] = None
|
|
16
|
+
mTensormaps: Optional[cute.Tensor] = None
|
|
17
|
+
|
|
18
|
+
def __post_init__(self):
|
|
19
|
+
if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
|
|
20
|
+
assert (
|
|
21
|
+
self.mTensormaps is not None
|
|
22
|
+
), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Requires-Python: >=3.12
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.2.0
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Provides-Extra: dev
|
|
9
9
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
quack/__init__.py,sha256=fGBYbb9JlaNT7HdtUTbUnuAkL5G2Dg8XZAA5Ir1R-ow,364
|
|
2
|
+
quack/activation.py,sha256=ysXaVUXX2yGQC5o4ZVeRXw_fDIHOrqnzpHJaIsc0kHc,10271
|
|
3
|
+
quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
|
|
4
|
+
quack/cross_entropy.py,sha256=Kc3P83Vsu1nGaCu7llsO3vct3J_t3frRYPxij7JfHMA,28619
|
|
5
|
+
quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
|
|
6
|
+
quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
|
|
7
|
+
quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
|
|
8
|
+
quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
|
|
9
|
+
quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
|
|
10
|
+
quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
|
|
11
|
+
quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
|
|
12
|
+
quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
|
|
13
|
+
quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
|
|
14
|
+
quack/layernorm.py,sha256=JkK0sVdUfZ-SmoBmNqLF3wCiszDbdorvcBH2julv0Vg,13560
|
|
15
|
+
quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
|
|
16
|
+
quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
|
|
17
|
+
quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
|
|
18
|
+
quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
|
|
19
|
+
quack/reduce.py,sha256=hsYByu6haCZjLTLB-qpYmKDjqS2UqlwPgfWTup38GNA,10341
|
|
20
|
+
quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
|
|
21
|
+
quack/rmsnorm.py,sha256=93qlTPjY9JBm3R5M-HeHse1PbAfD9931G3OFs71yo_g,48998
|
|
22
|
+
quack/softmax.py,sha256=Mq3_2Ul8H64zeGUI9wOKEpIISJnrCcHQpZvk2sb10Tg,17101
|
|
23
|
+
quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
|
|
24
|
+
quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
|
|
25
|
+
quack/tile_scheduler.py,sha256=8qqYmx6GpQzt8XiidcrdLIaWf0TGbJVdwKFfeb1X_us,42265
|
|
26
|
+
quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
|
|
27
|
+
quack/utils.py,sha256=tiqeJZiPPFl5irQWCUd7dTPA_OAv4SjHUW5S-u9wO8Y,14526
|
|
28
|
+
quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
|
|
29
|
+
quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
|
|
30
|
+
quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
|
|
31
|
+
quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
|
|
32
|
+
quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
|
|
33
|
+
quack_kernels-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
quack_kernels-0.2.0.dist-info/METADATA,sha256=DAeQymRUqp7lSfSTNyS7TZF3oWcFzCKriGJ2p8JLu6A,285
|
|
35
|
+
quack_kernels-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
quack_kernels-0.2.0.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
37
|
+
quack_kernels-0.2.0.dist-info/RECORD,,
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=4tLchTx7d0d1ZVg6psRjjoXAWKHqzIWRF5mUk8ZdgkQ,204
|
|
2
|
-
quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
|
|
3
|
-
quack/dense_gemm_sm90.py,sha256=jULXfAQkRh1SUAOpesx8wouY-GLDCm05Fb5LynozSl8,59932
|
|
4
|
-
quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
|
|
5
|
-
quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
|
|
6
|
-
quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
|
|
7
|
-
quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
|
|
8
|
-
quack/utils.py,sha256=RZq-7YA8UMUizHpVyZM1we4zGm9NaC178M2g2HXdjmE,17799
|
|
9
|
-
quack_kernels-0.1.10.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
10
|
-
quack_kernels-0.1.10.dist-info/METADATA,sha256=baMTwibt6u0IQb8YJFFhCY0RD3Aervf5sl6EpYF6IQ8,286
|
|
11
|
-
quack_kernels-0.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
12
|
-
quack_kernels-0.1.10.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
13
|
-
quack_kernels-0.1.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|