quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +201 -167
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +212 -181
- quack/softmax.py +417 -156
- quack/utils.py +206 -45
- quack_kernels-0.1.4.dist-info/METADATA +11 -0
- quack_kernels-0.1.4.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/METADATA +0 -8
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.4.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import operator
|
|
4
4
|
import math
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Callable, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
8
8
|
import cutlass.cute as cute
|
|
9
9
|
|
|
10
|
+
from cutlass import Float32
|
|
10
11
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
11
|
-
from cutlass._mlir.dialects import
|
|
12
|
+
from cutlass._mlir.dialects import llvm, vector
|
|
12
13
|
from cutlass.cute.runtime import from_dlpack
|
|
13
14
|
|
|
14
15
|
|
|
@@ -36,27 +37,29 @@ def min_constexpr(
|
|
|
36
37
|
return a if a < b else b
|
|
37
38
|
|
|
38
39
|
|
|
40
|
+
@cute.jit
|
|
39
41
|
def warp_reduce(
|
|
40
42
|
val: cute.TensorSSA | cute.Numeric,
|
|
41
43
|
op: Callable,
|
|
42
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE
|
|
44
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
43
45
|
) -> cute.TensorSSA | cute.Numeric:
|
|
44
|
-
if isinstance(val, cute.TensorSSA):
|
|
46
|
+
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
|
|
45
47
|
res = cute.make_fragment(val.shape, val.dtype)
|
|
46
48
|
res.store(val)
|
|
47
|
-
for i in
|
|
49
|
+
for i in cutlass.range_constexpr(cute.size(val.shape)):
|
|
48
50
|
res[i] = warp_reduce(res[i], op, width)
|
|
49
51
|
return res.load()
|
|
50
52
|
else:
|
|
51
|
-
for i in
|
|
53
|
+
for i in cutlass.range_constexpr(int(math.log2(width))):
|
|
52
54
|
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
|
|
53
55
|
return val
|
|
54
56
|
|
|
55
57
|
|
|
56
58
|
@cute.jit
|
|
57
|
-
def block_reduce(
|
|
58
|
-
|
|
59
|
-
|
|
59
|
+
def block_reduce(
|
|
60
|
+
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
61
|
+
) -> cute.Numeric:
|
|
62
|
+
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
60
63
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
61
64
|
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
62
65
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
@@ -75,9 +78,10 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
|
|
75
78
|
|
|
76
79
|
|
|
77
80
|
@dsl_user_op
|
|
78
|
-
def set_block_rank(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
+
def set_block_rank(
|
|
82
|
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
|
|
83
|
+
) -> cutlass.Int32:
|
|
84
|
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
|
81
85
|
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
82
86
|
return cutlass.Int32(
|
|
83
87
|
llvm.inline_asm(
|
|
@@ -94,16 +98,29 @@ def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32,
|
|
|
94
98
|
|
|
95
99
|
@dsl_user_op
|
|
96
100
|
def store_shared_remote(
|
|
97
|
-
val: float |
|
|
98
|
-
|
|
101
|
+
val: float | Float32 | cutlass.Int64,
|
|
102
|
+
smem_ptr: cute.Pointer,
|
|
103
|
+
mbar_ptr: cute.Pointer,
|
|
104
|
+
peer_cta_rank_in_cluster: cute.typing.Int,
|
|
105
|
+
*,
|
|
106
|
+
loc=None,
|
|
107
|
+
ip=None,
|
|
99
108
|
) -> None:
|
|
100
|
-
remote_smem_ptr_i32 = set_block_rank(
|
|
101
|
-
|
|
109
|
+
remote_smem_ptr_i32 = set_block_rank(
|
|
110
|
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
111
|
+
).ir_value()
|
|
112
|
+
remote_mbar_ptr_i32 = set_block_rank(
|
|
113
|
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
114
|
+
).ir_value()
|
|
115
|
+
if cutlass.const_expr(isinstance(val, float)):
|
|
116
|
+
val = Float32(val)
|
|
117
|
+
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
118
|
+
suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
|
|
102
119
|
llvm.inline_asm(
|
|
103
120
|
None,
|
|
104
|
-
[remote_smem_ptr_i32,
|
|
105
|
-
"st.async.shared::cluster.mbarrier::complete_tx::bytes.
|
|
106
|
-
"r,f,r",
|
|
121
|
+
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
122
|
+
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
123
|
+
f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
|
|
107
124
|
has_side_effects=True,
|
|
108
125
|
is_align_stack=False,
|
|
109
126
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
@@ -111,17 +128,24 @@ def store_shared_remote(
|
|
|
111
128
|
|
|
112
129
|
|
|
113
130
|
@cute.jit
|
|
114
|
-
def cluster_reduce(
|
|
115
|
-
|
|
116
|
-
|
|
131
|
+
def cluster_reduce(
|
|
132
|
+
val: cute.Numeric,
|
|
133
|
+
op: Callable,
|
|
134
|
+
reduction_buffer: cute.Tensor,
|
|
135
|
+
mbar_ptr: cute.Pointer,
|
|
136
|
+
init_val: cute.Numeric = 0.0,
|
|
137
|
+
) -> cute.Numeric:
|
|
138
|
+
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
117
139
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
118
140
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
119
141
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
120
142
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
121
143
|
if lane_idx < cluster_n:
|
|
122
144
|
store_shared_remote(
|
|
123
|
-
val,
|
|
124
|
-
|
|
145
|
+
val,
|
|
146
|
+
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
147
|
+
mbar_ptr,
|
|
148
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
125
149
|
)
|
|
126
150
|
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
127
151
|
block_reduce_val = init_val
|
|
@@ -134,9 +158,14 @@ def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tenso
|
|
|
134
158
|
|
|
135
159
|
|
|
136
160
|
@cute.jit
|
|
137
|
-
def block_or_cluster_reduce(
|
|
138
|
-
|
|
139
|
-
|
|
161
|
+
def block_or_cluster_reduce(
|
|
162
|
+
val: cute.Numeric,
|
|
163
|
+
op: Callable,
|
|
164
|
+
reduction_buffer: cute.Tensor,
|
|
165
|
+
mbar_ptr: Optional[cute.Pointer],
|
|
166
|
+
init_val: cute.Numeric = 0.0,
|
|
167
|
+
) -> cute.Numeric:
|
|
168
|
+
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
140
169
|
if cutlass.const_expr(mbar_ptr is None):
|
|
141
170
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
142
171
|
else:
|
|
@@ -153,15 +182,14 @@ def row_reduce(
|
|
|
153
182
|
init_val: cute.Numeric = 0.0,
|
|
154
183
|
hook_fn: Optional[Callable] = None,
|
|
155
184
|
) -> cute.Numeric:
|
|
156
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
|
|
157
|
-
"""
|
|
185
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
158
186
|
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
159
187
|
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
160
188
|
else:
|
|
161
189
|
val = x
|
|
162
190
|
warp_op = {
|
|
163
191
|
cute.ReductionOp.ADD: operator.add,
|
|
164
|
-
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype ==
|
|
192
|
+
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
165
193
|
cute.ReductionOp.MIN: min,
|
|
166
194
|
cute.ReductionOp.MUL: operator.mul,
|
|
167
195
|
}[op]
|
|
@@ -174,7 +202,9 @@ def row_reduce(
|
|
|
174
202
|
hook_fn()
|
|
175
203
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
176
204
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
177
|
-
assert
|
|
205
|
+
assert (
|
|
206
|
+
cluster_n == 1 or mbar_ptr is not None
|
|
207
|
+
), "mbar_ptr must be provided for cluster reduction"
|
|
178
208
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
179
209
|
val = block_or_cluster_reduce(
|
|
180
210
|
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
@@ -182,19 +212,107 @@ def row_reduce(
|
|
|
182
212
|
return val
|
|
183
213
|
|
|
184
214
|
|
|
215
|
+
@cute.jit
|
|
216
|
+
def online_softmax_reduce(
|
|
217
|
+
x: cute.TensorSSA,
|
|
218
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
219
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
220
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
221
|
+
hook_fn: Optional[Callable] = None,
|
|
222
|
+
return_exp_x: bool = False,
|
|
223
|
+
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
224
|
+
assert x.dtype == Float32, "x must be of type Float32"
|
|
225
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
226
|
+
max_x = warp_reduce(
|
|
227
|
+
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
228
|
+
cute.arch.fmax,
|
|
229
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
230
|
+
)
|
|
231
|
+
log2_e = math.log2(math.e)
|
|
232
|
+
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
233
|
+
# exp_x = exp2f((x - max_x) * log2_e)
|
|
234
|
+
sum_exp_x = warp_reduce(
|
|
235
|
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
236
|
+
operator.add,
|
|
237
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
238
|
+
)
|
|
239
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
240
|
+
hook_fn()
|
|
241
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
242
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
243
|
+
assert (
|
|
244
|
+
cluster_n == 1 or mbar_ptr is not None
|
|
245
|
+
), "mbar_ptr must be provided for cluster reduction"
|
|
246
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
247
|
+
assert (
|
|
248
|
+
reduction_buffer.element_type == cutlass.Int64
|
|
249
|
+
), "reduction_buffer must be of type cute.Int64"
|
|
250
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
251
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
252
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
253
|
+
if lane_idx == 0:
|
|
254
|
+
reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
|
|
255
|
+
cute.arch.barrier()
|
|
256
|
+
max_x_single_warp = -Float32.inf
|
|
257
|
+
sum_exp_x = 0.0
|
|
258
|
+
if lane_idx < warps_per_row:
|
|
259
|
+
max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
|
|
260
|
+
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
261
|
+
sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
262
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
263
|
+
if cutlass.const_expr(return_exp_x):
|
|
264
|
+
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
265
|
+
max_x = max_x_final
|
|
266
|
+
else:
|
|
267
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
268
|
+
if lane_idx < cluster_n:
|
|
269
|
+
store_shared_remote(
|
|
270
|
+
f32x2_to_i64(max_x, sum_exp_x),
|
|
271
|
+
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
272
|
+
mbar_ptr,
|
|
273
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
274
|
+
)
|
|
275
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
276
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
277
|
+
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
278
|
+
max_x_single_warp.fill(-Float32.inf)
|
|
279
|
+
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
280
|
+
sum_exp_x_single_warp.fill(0.0)
|
|
281
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
282
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
283
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
284
|
+
max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
|
|
285
|
+
reduction_buffer[row_idx, idx]
|
|
286
|
+
)
|
|
287
|
+
max_x_final = max_x_single_warp.load().reduce(
|
|
288
|
+
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
289
|
+
)
|
|
290
|
+
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
291
|
+
sum_exp_x = 0.0
|
|
292
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
293
|
+
sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
|
|
294
|
+
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
295
|
+
)
|
|
296
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
297
|
+
if cutlass.const_expr(return_exp_x):
|
|
298
|
+
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
299
|
+
max_x = max_x_final
|
|
300
|
+
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
301
|
+
|
|
185
302
|
|
|
186
|
-
|
|
303
|
+
@cute.jit
|
|
304
|
+
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
187
305
|
"""exp2f calculation for both vector and scalar.
|
|
188
306
|
|
|
189
307
|
:param x: input value
|
|
190
|
-
:type x: cute.TensorSSA or
|
|
308
|
+
:type x: cute.TensorSSA or Float32
|
|
191
309
|
:return: exp2 value
|
|
192
|
-
:rtype: cute.TensorSSA or
|
|
310
|
+
:rtype: cute.TensorSSA or Float32
|
|
193
311
|
"""
|
|
194
|
-
if isinstance(x, cute.TensorSSA):
|
|
195
|
-
res = cute.make_fragment(x.shape,
|
|
312
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
313
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
196
314
|
res.store(x)
|
|
197
|
-
for i in
|
|
315
|
+
for i in cutlass.range_constexpr(cute.size(x.shape)):
|
|
198
316
|
res[i] = cute.arch.exp2(res[i])
|
|
199
317
|
return res.load()
|
|
200
318
|
else:
|
|
@@ -202,11 +320,11 @@ def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float
|
|
|
202
320
|
|
|
203
321
|
|
|
204
322
|
@dsl_user_op
|
|
205
|
-
def log2f(a: float |
|
|
206
|
-
return
|
|
323
|
+
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
324
|
+
return Float32(
|
|
207
325
|
llvm.inline_asm(
|
|
208
326
|
T.f32(),
|
|
209
|
-
[
|
|
327
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
210
328
|
"lg2.approx.ftz.f32 $0, $1;",
|
|
211
329
|
"=f,f",
|
|
212
330
|
has_side_effects=False,
|
|
@@ -217,11 +335,11 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
|
|
|
217
335
|
|
|
218
336
|
|
|
219
337
|
@dsl_user_op
|
|
220
|
-
def rsqrt(a: float |
|
|
221
|
-
return
|
|
338
|
+
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
339
|
+
return Float32(
|
|
222
340
|
llvm.inline_asm(
|
|
223
341
|
T.f32(),
|
|
224
|
-
[
|
|
342
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
225
343
|
"rsqrt.approx.ftz.f32 $0, $1;",
|
|
226
344
|
"=f,f",
|
|
227
345
|
has_side_effects=False,
|
|
@@ -231,6 +349,7 @@ def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
|
|
|
231
349
|
)
|
|
232
350
|
|
|
233
351
|
|
|
352
|
+
@cute.jit
|
|
234
353
|
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
235
354
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
236
355
|
tApA = cute.make_fragment(
|
|
@@ -240,7 +359,49 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
240
359
|
),
|
|
241
360
|
cutlass.Boolean,
|
|
242
361
|
)
|
|
243
|
-
for rest_v in
|
|
244
|
-
for rest_k in
|
|
362
|
+
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
|
|
363
|
+
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
|
|
245
364
|
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
246
365
|
return tApA
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@cute.jit
|
|
369
|
+
def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
|
|
370
|
+
"""Fill out-of-bounds values in shared memory tensor.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
tXsX: Shared memory tensor to fill
|
|
374
|
+
tXpX: Predicate tensor indicating valid elements
|
|
375
|
+
fill_value: Value to fill OOB locations with
|
|
376
|
+
"""
|
|
377
|
+
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
378
|
+
tXrX_fill.fill(fill_value)
|
|
379
|
+
for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
|
|
380
|
+
for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
|
|
381
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
382
|
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@dsl_user_op
|
|
386
|
+
def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
|
|
387
|
+
vec_f32x2 = vector.from_elements(
|
|
388
|
+
T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
|
|
389
|
+
)
|
|
390
|
+
vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
|
|
391
|
+
res = cutlass.Int64(
|
|
392
|
+
vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
|
393
|
+
)
|
|
394
|
+
return res
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
@dsl_user_op
|
|
398
|
+
def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
399
|
+
vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
|
|
400
|
+
vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
|
|
401
|
+
res0 = Float32(
|
|
402
|
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
|
403
|
+
)
|
|
404
|
+
res1 = Float32(
|
|
405
|
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
|
406
|
+
)
|
|
407
|
+
return res0, res1
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: quack-kernels
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Requires-Python: >=3.9
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
|
|
7
|
+
Requires-Dist: torch
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
10
|
+
Requires-Dist: ruff; extra == "dev"
|
|
11
|
+
Dynamic: license-file
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
quack/__init__.py,sha256=cFLxO6nA_faFqHf4N-Fy7G0j8ykuYPB1uOt9uoJ2dkQ,203
|
|
2
|
+
quack/cross_entropy.py,sha256=HnF2OErEzb10SWxY6HoYE42lnvlw2DsWCks7mylPwnI,9511
|
|
3
|
+
quack/reduction_base.py,sha256=Rsj9ZeSHcKAXGn1p7mY1vrrBqxevi4feLjY0JJhKnmY,3663
|
|
4
|
+
quack/rmsnorm.py,sha256=TkOZsXJwcsoZMLnmEWQ-pEF0r-iiZhGrCNLSFCXfv6s,10676
|
|
5
|
+
quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
|
|
6
|
+
quack/utils.py,sha256=zVc9U-5No19trE585KqDdXx9chAruXPRIPMZdO7mkRg,15603
|
|
7
|
+
quack_kernels-0.1.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
quack_kernels-0.1.4.dist-info/METADATA,sha256=xl62C5WFgiUbnOICAzjldsljJ9j1Fb_JxZVksHLCI8I,289
|
|
9
|
+
quack_kernels-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
+
quack_kernels-0.1.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
+
quack_kernels-0.1.4.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=Nf01m1CGrOjSkqGJom6P65hSLkckljRMhlkSoqqlO9k,137
|
|
2
|
-
quack/cross_entropy.py,sha256=gdo8sR9KT5TsrShbgAmy-bwRZLu0gTs_ykXBF2RMbFI,8900
|
|
3
|
-
quack/rmsnorm.py,sha256=JhwJSAPDDpB_hV90xU9ymiLU-zu4WScrSHc5JX2JarY,10470
|
|
4
|
-
quack/softmax.py,sha256=C8e8ZNaF5ePJ1NlrWZN1goCcvsx1C60FWlRyuFCcYoM,7737
|
|
5
|
-
quack/utils.py,sha256=PRdu-P7azA_PeHUNdtoy1zyxZwg_QyVrSiVwE1iXaWo,8961
|
|
6
|
-
quack_kernels-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
7
|
-
quack_kernels-0.1.2.dist-info/METADATA,sha256=3WjugLu1IhLlgsg2qUcLBZq1HI4-BIyyJIuQc5Hk-rU,186
|
|
8
|
-
quack_kernels-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
quack_kernels-0.1.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
10
|
-
quack_kernels-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|