quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +199 -168
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +208 -195
- quack/softmax.py +409 -163
- quack/utils.py +249 -35
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.1.dist-info/RECORD +0 -10
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.1.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
|
+
import operator
|
|
3
4
|
import math
|
|
4
|
-
from typing import
|
|
5
|
+
from typing import Callable, Optional, Tuple
|
|
5
6
|
|
|
6
7
|
import cutlass
|
|
7
8
|
import cutlass.cute as cute
|
|
8
9
|
|
|
10
|
+
from cutlass import Float32
|
|
9
11
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
10
|
-
from cutlass._mlir.dialects import
|
|
12
|
+
from cutlass._mlir.dialects import llvm, vector
|
|
11
13
|
from cutlass.cute.runtime import from_dlpack
|
|
12
14
|
|
|
13
15
|
|
|
@@ -38,7 +40,7 @@ def min_constexpr(
|
|
|
38
40
|
def warp_reduce(
|
|
39
41
|
val: cute.TensorSSA | cute.Numeric,
|
|
40
42
|
op: Callable,
|
|
41
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE
|
|
43
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
42
44
|
) -> cute.TensorSSA | cute.Numeric:
|
|
43
45
|
if isinstance(val, cute.TensorSSA):
|
|
44
46
|
res = cute.make_fragment(val.shape, val.dtype)
|
|
@@ -53,11 +55,12 @@ def warp_reduce(
|
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
@cute.jit
|
|
56
|
-
def block_reduce(
|
|
57
|
-
|
|
58
|
-
|
|
58
|
+
def block_reduce(
|
|
59
|
+
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
60
|
+
) -> cute.Numeric:
|
|
61
|
+
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
59
62
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
60
|
-
warps_per_row = reduction_buffer.shape[1]
|
|
63
|
+
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
61
64
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
62
65
|
if lane_idx == 0:
|
|
63
66
|
reduction_buffer[row_idx, col_idx] = val
|
|
@@ -74,9 +77,10 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
|
|
74
77
|
|
|
75
78
|
|
|
76
79
|
@dsl_user_op
|
|
77
|
-
def set_block_rank(
|
|
78
|
-
|
|
79
|
-
|
|
80
|
+
def set_block_rank(
|
|
81
|
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
|
|
82
|
+
) -> cutlass.Int32:
|
|
83
|
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
|
80
84
|
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
81
85
|
return cutlass.Int32(
|
|
82
86
|
llvm.inline_asm(
|
|
@@ -93,16 +97,29 @@ def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32,
|
|
|
93
97
|
|
|
94
98
|
@dsl_user_op
|
|
95
99
|
def store_shared_remote(
|
|
96
|
-
val: float |
|
|
97
|
-
|
|
100
|
+
val: float | Float32 | cutlass.Int64,
|
|
101
|
+
smem_ptr: cute.Pointer,
|
|
102
|
+
mbar_ptr: cute.Pointer,
|
|
103
|
+
peer_cta_rank_in_cluster: cute.typing.Int,
|
|
104
|
+
*,
|
|
105
|
+
loc=None,
|
|
106
|
+
ip=None,
|
|
98
107
|
) -> None:
|
|
99
|
-
remote_smem_ptr_i32 = set_block_rank(
|
|
100
|
-
|
|
108
|
+
remote_smem_ptr_i32 = set_block_rank(
|
|
109
|
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
110
|
+
).ir_value()
|
|
111
|
+
remote_mbar_ptr_i32 = set_block_rank(
|
|
112
|
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
113
|
+
).ir_value()
|
|
114
|
+
if isinstance(val, float):
|
|
115
|
+
val = Float32(val)
|
|
116
|
+
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
117
|
+
suffix = "f32" if isinstance(val, Float32) else "s64"
|
|
101
118
|
llvm.inline_asm(
|
|
102
119
|
None,
|
|
103
|
-
[remote_smem_ptr_i32,
|
|
104
|
-
"st.async.shared::cluster.mbarrier::complete_tx::bytes.
|
|
105
|
-
"r,f,r",
|
|
120
|
+
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
121
|
+
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
122
|
+
f"r,{'f' if isinstance(val, Float32) else 'l'},r",
|
|
106
123
|
has_side_effects=True,
|
|
107
124
|
is_align_stack=False,
|
|
108
125
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
@@ -110,17 +127,24 @@ def store_shared_remote(
|
|
|
110
127
|
|
|
111
128
|
|
|
112
129
|
@cute.jit
|
|
113
|
-
def cluster_reduce(
|
|
114
|
-
|
|
115
|
-
|
|
130
|
+
def cluster_reduce(
|
|
131
|
+
val: cute.Numeric,
|
|
132
|
+
op: Callable,
|
|
133
|
+
reduction_buffer: cute.Tensor,
|
|
134
|
+
mbar_ptr: cute.Pointer,
|
|
135
|
+
init_val: cute.Numeric = 0.0,
|
|
136
|
+
) -> cute.Numeric:
|
|
137
|
+
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
116
138
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
117
139
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
118
140
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
119
141
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
120
142
|
if lane_idx < cluster_n:
|
|
121
143
|
store_shared_remote(
|
|
122
|
-
val,
|
|
123
|
-
|
|
144
|
+
val,
|
|
145
|
+
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
146
|
+
mbar_ptr,
|
|
147
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
124
148
|
)
|
|
125
149
|
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
126
150
|
block_reduce_val = init_val
|
|
@@ -133,25 +157,158 @@ def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tenso
|
|
|
133
157
|
|
|
134
158
|
|
|
135
159
|
@cute.jit
|
|
136
|
-
def block_or_cluster_reduce(
|
|
137
|
-
|
|
138
|
-
|
|
160
|
+
def block_or_cluster_reduce(
|
|
161
|
+
val: cute.Numeric,
|
|
162
|
+
op: Callable,
|
|
163
|
+
reduction_buffer: cute.Tensor,
|
|
164
|
+
mbar_ptr: Optional[cute.Pointer],
|
|
165
|
+
init_val: cute.Numeric = 0.0,
|
|
166
|
+
) -> cute.Numeric:
|
|
167
|
+
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
139
168
|
if cutlass.const_expr(mbar_ptr is None):
|
|
140
169
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
141
170
|
else:
|
|
142
171
|
return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
|
|
143
172
|
|
|
144
173
|
|
|
145
|
-
|
|
174
|
+
@cute.jit
|
|
175
|
+
def row_reduce(
|
|
176
|
+
x: cute.TensorSSA | cute.Numeric,
|
|
177
|
+
op: cute.ReductionOp,
|
|
178
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
179
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
180
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
181
|
+
init_val: cute.Numeric = 0.0,
|
|
182
|
+
hook_fn: Optional[Callable] = None,
|
|
183
|
+
) -> cute.Numeric:
|
|
184
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
185
|
+
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
186
|
+
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
187
|
+
else:
|
|
188
|
+
val = x
|
|
189
|
+
warp_op = {
|
|
190
|
+
cute.ReductionOp.ADD: operator.add,
|
|
191
|
+
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
192
|
+
cute.ReductionOp.MIN: min,
|
|
193
|
+
cute.ReductionOp.MUL: operator.mul,
|
|
194
|
+
}[op]
|
|
195
|
+
val = warp_reduce(
|
|
196
|
+
val,
|
|
197
|
+
warp_op,
|
|
198
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
199
|
+
)
|
|
200
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
201
|
+
hook_fn()
|
|
202
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
203
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
204
|
+
assert (
|
|
205
|
+
cluster_n == 1 or mbar_ptr is not None
|
|
206
|
+
), "mbar_ptr must be provided for cluster reduction"
|
|
207
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
208
|
+
val = block_or_cluster_reduce(
|
|
209
|
+
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
210
|
+
)
|
|
211
|
+
return val
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@cute.jit
|
|
215
|
+
def online_softmax_reduce(
|
|
216
|
+
x: cute.TensorSSA,
|
|
217
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
218
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
219
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
220
|
+
hook_fn: Optional[Callable] = None,
|
|
221
|
+
return_exp_x: bool = False,
|
|
222
|
+
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
223
|
+
assert x.dtype == Float32, "x must be of type Float32"
|
|
224
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
225
|
+
max_x = warp_reduce(
|
|
226
|
+
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
227
|
+
cute.arch.fmax,
|
|
228
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
229
|
+
)
|
|
230
|
+
log2_e = math.log2(math.e)
|
|
231
|
+
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
232
|
+
# exp_x = exp2f((x - max_x) * log2_e)
|
|
233
|
+
sum_exp_x = warp_reduce(
|
|
234
|
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
235
|
+
operator.add,
|
|
236
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
237
|
+
)
|
|
238
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
239
|
+
hook_fn()
|
|
240
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
241
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
242
|
+
assert (
|
|
243
|
+
cluster_n == 1 or mbar_ptr is not None
|
|
244
|
+
), "mbar_ptr must be provided for cluster reduction"
|
|
245
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
246
|
+
assert (
|
|
247
|
+
reduction_buffer.element_type == cutlass.Int64
|
|
248
|
+
), "reduction_buffer must be of type cute.Int64"
|
|
249
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
250
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
251
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
252
|
+
if lane_idx == 0:
|
|
253
|
+
reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
|
|
254
|
+
cute.arch.barrier()
|
|
255
|
+
max_x_single_warp = -Float32.inf
|
|
256
|
+
sum_exp_x = 0.0
|
|
257
|
+
if lane_idx < warps_per_row:
|
|
258
|
+
max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
|
|
259
|
+
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
260
|
+
sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
261
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
262
|
+
if cutlass.const_expr(return_exp_x):
|
|
263
|
+
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
264
|
+
max_x = max_x_final
|
|
265
|
+
else:
|
|
266
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
267
|
+
if lane_idx < cluster_n:
|
|
268
|
+
store_shared_remote(
|
|
269
|
+
f32x2_to_i64(max_x, sum_exp_x),
|
|
270
|
+
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
271
|
+
mbar_ptr,
|
|
272
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
273
|
+
)
|
|
274
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
275
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
276
|
+
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
277
|
+
max_x_single_warp.fill(-Float32.inf)
|
|
278
|
+
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
279
|
+
sum_exp_x_single_warp.fill(0.0)
|
|
280
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
281
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
282
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
283
|
+
max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
|
|
284
|
+
reduction_buffer[row_idx, idx]
|
|
285
|
+
)
|
|
286
|
+
max_x_final = max_x_single_warp.load().reduce(
|
|
287
|
+
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
288
|
+
)
|
|
289
|
+
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
290
|
+
sum_exp_x = 0.0
|
|
291
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
292
|
+
sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
|
|
293
|
+
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
294
|
+
)
|
|
295
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
296
|
+
if cutlass.const_expr(return_exp_x):
|
|
297
|
+
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
298
|
+
max_x = max_x_final
|
|
299
|
+
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
146
303
|
"""exp2f calculation for both vector and scalar.
|
|
147
304
|
|
|
148
305
|
:param x: input value
|
|
149
|
-
:type x: cute.TensorSSA or
|
|
306
|
+
:type x: cute.TensorSSA or Float32
|
|
150
307
|
:return: exp2 value
|
|
151
|
-
:rtype: cute.TensorSSA or
|
|
308
|
+
:rtype: cute.TensorSSA or Float32
|
|
152
309
|
"""
|
|
153
310
|
if isinstance(x, cute.TensorSSA):
|
|
154
|
-
res = cute.make_fragment(x.shape,
|
|
311
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
155
312
|
res.store(x)
|
|
156
313
|
for i in range(cute.size(x.shape)):
|
|
157
314
|
res[i] = cute.arch.exp2(res[i])
|
|
@@ -161,11 +318,11 @@ def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float
|
|
|
161
318
|
|
|
162
319
|
|
|
163
320
|
@dsl_user_op
|
|
164
|
-
def log2f(a: float |
|
|
165
|
-
return
|
|
321
|
+
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
322
|
+
return Float32(
|
|
166
323
|
llvm.inline_asm(
|
|
167
324
|
T.f32(),
|
|
168
|
-
[
|
|
325
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
169
326
|
"lg2.approx.ftz.f32 $0, $1;",
|
|
170
327
|
"=f,f",
|
|
171
328
|
has_side_effects=False,
|
|
@@ -176,11 +333,11 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
|
|
|
176
333
|
|
|
177
334
|
|
|
178
335
|
@dsl_user_op
|
|
179
|
-
def rsqrt(a: float |
|
|
180
|
-
return
|
|
336
|
+
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
337
|
+
return Float32(
|
|
181
338
|
llvm.inline_asm(
|
|
182
339
|
T.f32(),
|
|
183
|
-
[
|
|
340
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
184
341
|
"rsqrt.approx.ftz.f32 $0, $1;",
|
|
185
342
|
"=f,f",
|
|
186
343
|
has_side_effects=False,
|
|
@@ -188,3 +345,60 @@ def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
|
|
|
188
345
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
189
346
|
)
|
|
190
347
|
)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
351
|
+
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
352
|
+
tApA = cute.make_fragment(
|
|
353
|
+
cute.make_layout(
|
|
354
|
+
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
355
|
+
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
356
|
+
),
|
|
357
|
+
cutlass.Boolean,
|
|
358
|
+
)
|
|
359
|
+
for rest_v in range(tApA.shape[0]):
|
|
360
|
+
for rest_k in range(tApA.shape[2]):
|
|
361
|
+
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
362
|
+
return tApA
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
@cute.jit
|
|
366
|
+
def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
|
|
367
|
+
"""Fill out-of-bounds values in shared memory tensor.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
tXsX: Shared memory tensor to fill
|
|
371
|
+
tXpX: Predicate tensor indicating valid elements
|
|
372
|
+
fill_value: Value to fill OOB locations with
|
|
373
|
+
"""
|
|
374
|
+
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
375
|
+
tXrX_fill.fill(fill_value)
|
|
376
|
+
for rest_v in range(tXpX.shape[0]):
|
|
377
|
+
for rest_k in range(tXpX.shape[2]):
|
|
378
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
379
|
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@dsl_user_op
|
|
383
|
+
def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
|
|
384
|
+
vec_f32x2 = vector.from_elements(
|
|
385
|
+
T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
|
|
386
|
+
)
|
|
387
|
+
vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
|
|
388
|
+
res = cutlass.Int64(
|
|
389
|
+
vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
|
390
|
+
)
|
|
391
|
+
return res
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@dsl_user_op
|
|
395
|
+
def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
396
|
+
vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
|
|
397
|
+
vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
|
|
398
|
+
res0 = Float32(
|
|
399
|
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
|
400
|
+
)
|
|
401
|
+
res1 = Float32(
|
|
402
|
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
|
403
|
+
)
|
|
404
|
+
return res0, res1
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Requires-Python: >=3.9
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Dist: nvidia-cutlass-dsl==4.0.0
|
|
7
7
|
Requires-Dist: torch
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
10
|
+
Requires-Dist: ruff; extra == "dev"
|
|
8
11
|
Dynamic: license-file
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
quack/__init__.py,sha256=aUR7drzgaqmbzw9H_eoFselMUVQVF3BHc9VOzZg5d-Q,203
|
|
2
|
+
quack/cross_entropy.py,sha256=_Xlyifd_YS8LaYxYlZEsuBfsi8zTH4At3i9DDggGCf8,9319
|
|
3
|
+
quack/reduction_base.py,sha256=nrRsXwTpLVQkPp2Gr_FgHRPnifqkMHRodve5ciHzx58,3667
|
|
4
|
+
quack/rmsnorm.py,sha256=YqGTTKHHXYzw3xnnjBRfaN9TDlhG8D_fSI9CHKAU40A,10548
|
|
5
|
+
quack/softmax.py,sha256=mWaUfaY6PBtO1ioYxXxS-yodQmcBNGasWVMUg9G066Y,15938
|
|
6
|
+
quack/utils.py,sha256=1-HMcFTEvGdAtqC3ucQGZ3DLa_PoJQsqwYlKd9bcXO8,15347
|
|
7
|
+
quack_kernels-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
quack_kernels-0.1.3.dist-info/METADATA,sha256=DDuEKHLjFx9dFTQV5YtXsnKVFZVoueO7NwhcwOtpw6g,284
|
|
9
|
+
quack_kernels-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
+
quack_kernels-0.1.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
+
quack_kernels-0.1.3.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=y3Oa4OVPqaGU_P1miI435DzfpMgIwKVmU8-Eogv58jQ,137
|
|
2
|
-
quack/cross_entropy.py,sha256=V0kG8DCNh2735sPIDwe68NB50rAqDF3XQApnGyo-sKg,9220
|
|
3
|
-
quack/rmsnorm.py,sha256=RNqcT-q4uvMbF6ejpzuqQH8l8VVuTRlnueXf28V47sc,11954
|
|
4
|
-
quack/softmax.py,sha256=QABgOESH5JjDm3yuUkyZZKXXpzn7CTuMSs0NEBnFD80,8536
|
|
5
|
-
quack/utils.py,sha256=ofV7QLDuq80h3nEA3TwZW-ti8CnYwMgnz1dpxpvhHpk,6859
|
|
6
|
-
quack_kernels-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
7
|
-
quack_kernels-0.1.1.dist-info/METADATA,sha256=XG3zS0_q48TzkoR7CemzaJGVYHS731yVOrzH49_uRK8,186
|
|
8
|
-
quack_kernels-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
quack_kernels-0.1.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
10
|
-
quack_kernels-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|