quack-kernels 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -1
- quack/cross_entropy.py +197 -166
- quack/reduction_base.py +98 -0
- quack/rmsnorm.py +211 -181
- quack/softmax.py +409 -156
- quack/utils.py +197 -39
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/METADATA +4 -1
- quack_kernels-0.1.3.dist-info/RECORD +11 -0
- quack_kernels-0.1.2.dist-info/RECORD +0 -10
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.2.dist-info → quack_kernels-0.1.3.dist-info}/top_level.txt +0 -0
quack/utils.py
CHANGED
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import operator
|
|
4
4
|
import math
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Callable, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import cutlass
|
|
8
8
|
import cutlass.cute as cute
|
|
9
9
|
|
|
10
|
+
from cutlass import Float32
|
|
10
11
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
11
|
-
from cutlass._mlir.dialects import
|
|
12
|
+
from cutlass._mlir.dialects import llvm, vector
|
|
12
13
|
from cutlass.cute.runtime import from_dlpack
|
|
13
14
|
|
|
14
15
|
|
|
@@ -39,7 +40,7 @@ def min_constexpr(
|
|
|
39
40
|
def warp_reduce(
|
|
40
41
|
val: cute.TensorSSA | cute.Numeric,
|
|
41
42
|
op: Callable,
|
|
42
|
-
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE
|
|
43
|
+
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
|
|
43
44
|
) -> cute.TensorSSA | cute.Numeric:
|
|
44
45
|
if isinstance(val, cute.TensorSSA):
|
|
45
46
|
res = cute.make_fragment(val.shape, val.dtype)
|
|
@@ -54,9 +55,10 @@ def warp_reduce(
|
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
@cute.jit
|
|
57
|
-
def block_reduce(
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
def block_reduce(
|
|
59
|
+
val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
|
|
60
|
+
) -> cute.Numeric:
|
|
61
|
+
"""reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
|
|
60
62
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
61
63
|
warps_per_row = cute.size(reduction_buffer.shape[1])
|
|
62
64
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
@@ -75,9 +77,10 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
|
|
75
77
|
|
|
76
78
|
|
|
77
79
|
@dsl_user_op
|
|
78
|
-
def set_block_rank(
|
|
79
|
-
|
|
80
|
-
|
|
80
|
+
def set_block_rank(
|
|
81
|
+
smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
|
|
82
|
+
) -> cutlass.Int32:
|
|
83
|
+
"""Map the given smem pointer to the address at another CTA rank in the cluster."""
|
|
81
84
|
smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
82
85
|
return cutlass.Int32(
|
|
83
86
|
llvm.inline_asm(
|
|
@@ -94,16 +97,29 @@ def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32,
|
|
|
94
97
|
|
|
95
98
|
@dsl_user_op
|
|
96
99
|
def store_shared_remote(
|
|
97
|
-
val: float |
|
|
98
|
-
|
|
100
|
+
val: float | Float32 | cutlass.Int64,
|
|
101
|
+
smem_ptr: cute.Pointer,
|
|
102
|
+
mbar_ptr: cute.Pointer,
|
|
103
|
+
peer_cta_rank_in_cluster: cute.typing.Int,
|
|
104
|
+
*,
|
|
105
|
+
loc=None,
|
|
106
|
+
ip=None,
|
|
99
107
|
) -> None:
|
|
100
|
-
remote_smem_ptr_i32 = set_block_rank(
|
|
101
|
-
|
|
108
|
+
remote_smem_ptr_i32 = set_block_rank(
|
|
109
|
+
smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
110
|
+
).ir_value()
|
|
111
|
+
remote_mbar_ptr_i32 = set_block_rank(
|
|
112
|
+
mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
|
|
113
|
+
).ir_value()
|
|
114
|
+
if isinstance(val, float):
|
|
115
|
+
val = Float32(val)
|
|
116
|
+
assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
|
|
117
|
+
suffix = "f32" if isinstance(val, Float32) else "s64"
|
|
102
118
|
llvm.inline_asm(
|
|
103
119
|
None,
|
|
104
|
-
[remote_smem_ptr_i32,
|
|
105
|
-
"st.async.shared::cluster.mbarrier::complete_tx::bytes.
|
|
106
|
-
"r,f,r",
|
|
120
|
+
[remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
|
|
121
|
+
f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
|
|
122
|
+
f"r,{'f' if isinstance(val, Float32) else 'l'},r",
|
|
107
123
|
has_side_effects=True,
|
|
108
124
|
is_align_stack=False,
|
|
109
125
|
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
@@ -111,17 +127,24 @@ def store_shared_remote(
|
|
|
111
127
|
|
|
112
128
|
|
|
113
129
|
@cute.jit
|
|
114
|
-
def cluster_reduce(
|
|
115
|
-
|
|
116
|
-
|
|
130
|
+
def cluster_reduce(
|
|
131
|
+
val: cute.Numeric,
|
|
132
|
+
op: Callable,
|
|
133
|
+
reduction_buffer: cute.Tensor,
|
|
134
|
+
mbar_ptr: cute.Pointer,
|
|
135
|
+
init_val: cute.Numeric = 0.0,
|
|
136
|
+
) -> cute.Numeric:
|
|
137
|
+
"""reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
117
138
|
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
118
139
|
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
119
140
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
120
141
|
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
121
142
|
if lane_idx < cluster_n:
|
|
122
143
|
store_shared_remote(
|
|
123
|
-
val,
|
|
124
|
-
|
|
144
|
+
val,
|
|
145
|
+
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
146
|
+
mbar_ptr,
|
|
147
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
125
148
|
)
|
|
126
149
|
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
127
150
|
block_reduce_val = init_val
|
|
@@ -134,9 +157,14 @@ def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tenso
|
|
|
134
157
|
|
|
135
158
|
|
|
136
159
|
@cute.jit
|
|
137
|
-
def block_or_cluster_reduce(
|
|
138
|
-
|
|
139
|
-
|
|
160
|
+
def block_or_cluster_reduce(
|
|
161
|
+
val: cute.Numeric,
|
|
162
|
+
op: Callable,
|
|
163
|
+
reduction_buffer: cute.Tensor,
|
|
164
|
+
mbar_ptr: Optional[cute.Pointer],
|
|
165
|
+
init_val: cute.Numeric = 0.0,
|
|
166
|
+
) -> cute.Numeric:
|
|
167
|
+
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
|
|
140
168
|
if cutlass.const_expr(mbar_ptr is None):
|
|
141
169
|
return block_reduce(val, op, reduction_buffer, init_val=init_val)
|
|
142
170
|
else:
|
|
@@ -153,15 +181,14 @@ def row_reduce(
|
|
|
153
181
|
init_val: cute.Numeric = 0.0,
|
|
154
182
|
hook_fn: Optional[Callable] = None,
|
|
155
183
|
) -> cute.Numeric:
|
|
156
|
-
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
|
|
157
|
-
"""
|
|
184
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
|
|
158
185
|
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
|
|
159
186
|
val = x.reduce(op, init_val=init_val, reduction_profile=0)
|
|
160
187
|
else:
|
|
161
188
|
val = x
|
|
162
189
|
warp_op = {
|
|
163
190
|
cute.ReductionOp.ADD: operator.add,
|
|
164
|
-
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype ==
|
|
191
|
+
cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
|
|
165
192
|
cute.ReductionOp.MIN: min,
|
|
166
193
|
cute.ReductionOp.MUL: operator.mul,
|
|
167
194
|
}[op]
|
|
@@ -174,7 +201,9 @@ def row_reduce(
|
|
|
174
201
|
hook_fn()
|
|
175
202
|
if cutlass.const_expr(reduction_buffer is not None):
|
|
176
203
|
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
177
|
-
assert
|
|
204
|
+
assert (
|
|
205
|
+
cluster_n == 1 or mbar_ptr is not None
|
|
206
|
+
), "mbar_ptr must be provided for cluster reduction"
|
|
178
207
|
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
179
208
|
val = block_or_cluster_reduce(
|
|
180
209
|
val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
|
|
@@ -182,17 +211,104 @@ def row_reduce(
|
|
|
182
211
|
return val
|
|
183
212
|
|
|
184
213
|
|
|
185
|
-
|
|
186
|
-
def
|
|
214
|
+
@cute.jit
|
|
215
|
+
def online_softmax_reduce(
|
|
216
|
+
x: cute.TensorSSA,
|
|
217
|
+
threads_per_row: cutlass.Constexpr[int],
|
|
218
|
+
reduction_buffer: Optional[cute.Tensor] = None,
|
|
219
|
+
mbar_ptr: Optional[cute.Pointer] = None,
|
|
220
|
+
hook_fn: Optional[Callable] = None,
|
|
221
|
+
return_exp_x: bool = False,
|
|
222
|
+
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
|
|
223
|
+
assert x.dtype == Float32, "x must be of type Float32"
|
|
224
|
+
"""reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
|
|
225
|
+
max_x = warp_reduce(
|
|
226
|
+
x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
|
227
|
+
cute.arch.fmax,
|
|
228
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
229
|
+
)
|
|
230
|
+
log2_e = math.log2(math.e)
|
|
231
|
+
exp_x = exp2f(x * log2_e - (max_x * log2_e))
|
|
232
|
+
# exp_x = exp2f((x - max_x) * log2_e)
|
|
233
|
+
sum_exp_x = warp_reduce(
|
|
234
|
+
exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
|
|
235
|
+
operator.add,
|
|
236
|
+
width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
|
|
237
|
+
)
|
|
238
|
+
if cutlass.const_expr(hook_fn is not None):
|
|
239
|
+
hook_fn()
|
|
240
|
+
if cutlass.const_expr(reduction_buffer is not None):
|
|
241
|
+
warps_per_row, cluster_n = reduction_buffer.shape[1]
|
|
242
|
+
assert (
|
|
243
|
+
cluster_n == 1 or mbar_ptr is not None
|
|
244
|
+
), "mbar_ptr must be provided for cluster reduction"
|
|
245
|
+
if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
|
|
246
|
+
assert (
|
|
247
|
+
reduction_buffer.element_type == cutlass.Int64
|
|
248
|
+
), "reduction_buffer must be of type cute.Int64"
|
|
249
|
+
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
|
|
250
|
+
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
|
|
251
|
+
if cutlass.const_expr(mbar_ptr is None):
|
|
252
|
+
if lane_idx == 0:
|
|
253
|
+
reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
|
|
254
|
+
cute.arch.barrier()
|
|
255
|
+
max_x_single_warp = -Float32.inf
|
|
256
|
+
sum_exp_x = 0.0
|
|
257
|
+
if lane_idx < warps_per_row:
|
|
258
|
+
max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
|
|
259
|
+
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
|
|
260
|
+
sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
|
|
261
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
262
|
+
if cutlass.const_expr(return_exp_x):
|
|
263
|
+
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
264
|
+
max_x = max_x_final
|
|
265
|
+
else:
|
|
266
|
+
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
|
267
|
+
if lane_idx < cluster_n:
|
|
268
|
+
store_shared_remote(
|
|
269
|
+
f32x2_to_i64(max_x, sum_exp_x),
|
|
270
|
+
elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
|
|
271
|
+
mbar_ptr,
|
|
272
|
+
peer_cta_rank_in_cluster=lane_idx,
|
|
273
|
+
)
|
|
274
|
+
cute.arch.mbarrier_wait(mbar_ptr, phase=0)
|
|
275
|
+
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
|
|
276
|
+
max_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
277
|
+
max_x_single_warp.fill(-Float32.inf)
|
|
278
|
+
sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
|
|
279
|
+
sum_exp_x_single_warp.fill(0.0)
|
|
280
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
281
|
+
idx = lane_idx + i * cute.arch.WARP_SIZE
|
|
282
|
+
if idx < cute.size(reduction_buffer, mode=[1]):
|
|
283
|
+
max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
|
|
284
|
+
reduction_buffer[row_idx, idx]
|
|
285
|
+
)
|
|
286
|
+
max_x_final = max_x_single_warp.load().reduce(
|
|
287
|
+
cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
|
|
288
|
+
)
|
|
289
|
+
max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
|
|
290
|
+
sum_exp_x = 0.0
|
|
291
|
+
for i in cutlass.range_constexpr(num_iter):
|
|
292
|
+
sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
|
|
293
|
+
(max_x_single_warp[i] - max_x_final) * log2_e
|
|
294
|
+
)
|
|
295
|
+
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
|
|
296
|
+
if cutlass.const_expr(return_exp_x):
|
|
297
|
+
exp_x *= exp2f((max_x - max_x_final) * log2_e)
|
|
298
|
+
max_x = max_x_final
|
|
299
|
+
return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
|
187
303
|
"""exp2f calculation for both vector and scalar.
|
|
188
304
|
|
|
189
305
|
:param x: input value
|
|
190
|
-
:type x: cute.TensorSSA or
|
|
306
|
+
:type x: cute.TensorSSA or Float32
|
|
191
307
|
:return: exp2 value
|
|
192
|
-
:rtype: cute.TensorSSA or
|
|
308
|
+
:rtype: cute.TensorSSA or Float32
|
|
193
309
|
"""
|
|
194
310
|
if isinstance(x, cute.TensorSSA):
|
|
195
|
-
res = cute.make_fragment(x.shape,
|
|
311
|
+
res = cute.make_fragment(x.shape, Float32)
|
|
196
312
|
res.store(x)
|
|
197
313
|
for i in range(cute.size(x.shape)):
|
|
198
314
|
res[i] = cute.arch.exp2(res[i])
|
|
@@ -202,11 +318,11 @@ def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float
|
|
|
202
318
|
|
|
203
319
|
|
|
204
320
|
@dsl_user_op
|
|
205
|
-
def log2f(a: float |
|
|
206
|
-
return
|
|
321
|
+
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
322
|
+
return Float32(
|
|
207
323
|
llvm.inline_asm(
|
|
208
324
|
T.f32(),
|
|
209
|
-
[
|
|
325
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
210
326
|
"lg2.approx.ftz.f32 $0, $1;",
|
|
211
327
|
"=f,f",
|
|
212
328
|
has_side_effects=False,
|
|
@@ -217,11 +333,11 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
|
|
|
217
333
|
|
|
218
334
|
|
|
219
335
|
@dsl_user_op
|
|
220
|
-
def rsqrt(a: float |
|
|
221
|
-
return
|
|
336
|
+
def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
337
|
+
return Float32(
|
|
222
338
|
llvm.inline_asm(
|
|
223
339
|
T.f32(),
|
|
224
|
-
[
|
|
340
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
225
341
|
"rsqrt.approx.ftz.f32 $0, $1;",
|
|
226
342
|
"=f,f",
|
|
227
343
|
has_side_effects=False,
|
|
@@ -244,3 +360,45 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
|
|
|
244
360
|
for rest_k in range(tApA.shape[2]):
|
|
245
361
|
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
|
|
246
362
|
return tApA
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
@cute.jit
|
|
366
|
+
def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
|
|
367
|
+
"""Fill out-of-bounds values in shared memory tensor.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
tXsX: Shared memory tensor to fill
|
|
371
|
+
tXpX: Predicate tensor indicating valid elements
|
|
372
|
+
fill_value: Value to fill OOB locations with
|
|
373
|
+
"""
|
|
374
|
+
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
|
|
375
|
+
tXrX_fill.fill(fill_value)
|
|
376
|
+
for rest_v in range(tXpX.shape[0]):
|
|
377
|
+
for rest_k in range(tXpX.shape[2]):
|
|
378
|
+
if not tXpX[rest_v, 0, rest_k]:
|
|
379
|
+
cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@dsl_user_op
|
|
383
|
+
def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
|
|
384
|
+
vec_f32x2 = vector.from_elements(
|
|
385
|
+
T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
|
|
386
|
+
)
|
|
387
|
+
vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
|
|
388
|
+
res = cutlass.Int64(
|
|
389
|
+
vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
|
390
|
+
)
|
|
391
|
+
return res
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@dsl_user_op
|
|
395
|
+
def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
396
|
+
vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
|
|
397
|
+
vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
|
|
398
|
+
res0 = Float32(
|
|
399
|
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
|
|
400
|
+
)
|
|
401
|
+
res1 = Float32(
|
|
402
|
+
vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
|
|
403
|
+
)
|
|
404
|
+
return res0, res1
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Requires-Python: >=3.9
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Requires-Dist: nvidia-cutlass-dsl==4.0.0
|
|
7
7
|
Requires-Dist: torch
|
|
8
|
+
Provides-Extra: dev
|
|
9
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
10
|
+
Requires-Dist: ruff; extra == "dev"
|
|
8
11
|
Dynamic: license-file
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
quack/__init__.py,sha256=aUR7drzgaqmbzw9H_eoFselMUVQVF3BHc9VOzZg5d-Q,203
|
|
2
|
+
quack/cross_entropy.py,sha256=_Xlyifd_YS8LaYxYlZEsuBfsi8zTH4At3i9DDggGCf8,9319
|
|
3
|
+
quack/reduction_base.py,sha256=nrRsXwTpLVQkPp2Gr_FgHRPnifqkMHRodve5ciHzx58,3667
|
|
4
|
+
quack/rmsnorm.py,sha256=YqGTTKHHXYzw3xnnjBRfaN9TDlhG8D_fSI9CHKAU40A,10548
|
|
5
|
+
quack/softmax.py,sha256=mWaUfaY6PBtO1ioYxXxS-yodQmcBNGasWVMUg9G066Y,15938
|
|
6
|
+
quack/utils.py,sha256=1-HMcFTEvGdAtqC3ucQGZ3DLa_PoJQsqwYlKd9bcXO8,15347
|
|
7
|
+
quack_kernels-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
8
|
+
quack_kernels-0.1.3.dist-info/METADATA,sha256=DDuEKHLjFx9dFTQV5YtXsnKVFZVoueO7NwhcwOtpw6g,284
|
|
9
|
+
quack_kernels-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
10
|
+
quack_kernels-0.1.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
11
|
+
quack_kernels-0.1.3.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
quack/__init__.py,sha256=Nf01m1CGrOjSkqGJom6P65hSLkckljRMhlkSoqqlO9k,137
|
|
2
|
-
quack/cross_entropy.py,sha256=gdo8sR9KT5TsrShbgAmy-bwRZLu0gTs_ykXBF2RMbFI,8900
|
|
3
|
-
quack/rmsnorm.py,sha256=JhwJSAPDDpB_hV90xU9ymiLU-zu4WScrSHc5JX2JarY,10470
|
|
4
|
-
quack/softmax.py,sha256=C8e8ZNaF5ePJ1NlrWZN1goCcvsx1C60FWlRyuFCcYoM,7737
|
|
5
|
-
quack/utils.py,sha256=PRdu-P7azA_PeHUNdtoy1zyxZwg_QyVrSiVwE1iXaWo,8961
|
|
6
|
-
quack_kernels-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
7
|
-
quack_kernels-0.1.2.dist-info/METADATA,sha256=3WjugLu1IhLlgsg2qUcLBZq1HI4-BIyyJIuQc5Hk-rU,186
|
|
8
|
-
quack_kernels-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
quack_kernels-0.1.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
|
|
10
|
-
quack_kernels-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|