quack-kernels 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/utils.py CHANGED
@@ -2,13 +2,14 @@
2
2
 
3
3
  import operator
4
4
  import math
5
- from typing import Type, Callable, Optional
5
+ from typing import Callable, Optional, Tuple
6
6
 
7
7
  import cutlass
8
8
  import cutlass.cute as cute
9
9
 
10
+ from cutlass import Float32
10
11
  from cutlass.cutlass_dsl import T, dsl_user_op
11
- from cutlass._mlir.dialects import nvvm, llvm
12
+ from cutlass._mlir.dialects import llvm, vector
12
13
  from cutlass.cute.runtime import from_dlpack
13
14
 
14
15
 
@@ -36,27 +37,29 @@ def min_constexpr(
36
37
  return a if a < b else b
37
38
 
38
39
 
40
+ @cute.jit
39
41
  def warp_reduce(
40
42
  val: cute.TensorSSA | cute.Numeric,
41
43
  op: Callable,
42
- width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE
44
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
43
45
  ) -> cute.TensorSSA | cute.Numeric:
44
- if isinstance(val, cute.TensorSSA):
46
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
45
47
  res = cute.make_fragment(val.shape, val.dtype)
46
48
  res.store(val)
47
- for i in range(cute.size(val.shape)):
49
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
48
50
  res[i] = warp_reduce(res[i], op, width)
49
51
  return res.load()
50
52
  else:
51
- for i in range(int(math.log2(width))):
53
+ for i in cutlass.range_constexpr(int(math.log2(width))):
52
54
  val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
53
55
  return val
54
56
 
55
57
 
56
58
  @cute.jit
57
- def block_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0) -> cute.Numeric:
58
- """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)
59
- """
59
+ def block_reduce(
60
+ val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
61
+ ) -> cute.Numeric:
62
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
60
63
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
61
64
  warps_per_row = cute.size(reduction_buffer.shape[1])
62
65
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
@@ -75,9 +78,10 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
75
78
 
76
79
 
77
80
  @dsl_user_op
78
- def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None) -> cutlass.Int32:
79
- """Map the given smem pointer to the address at another CTA rank in the cluster.
80
- """
81
+ def set_block_rank(
82
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
83
+ ) -> cutlass.Int32:
84
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
81
85
  smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
82
86
  return cutlass.Int32(
83
87
  llvm.inline_asm(
@@ -94,16 +98,29 @@ def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32,
94
98
 
95
99
  @dsl_user_op
96
100
  def store_shared_remote(
97
- val: float | cute.Float32, smem_ptr: cute.Pointer, mbar_ptr: cute.Pointer,
98
- peer_cta_rank_in_cluster: cute.typing.Int, *, loc=None, ip=None
101
+ val: float | Float32 | cutlass.Int64,
102
+ smem_ptr: cute.Pointer,
103
+ mbar_ptr: cute.Pointer,
104
+ peer_cta_rank_in_cluster: cute.typing.Int,
105
+ *,
106
+ loc=None,
107
+ ip=None,
99
108
  ) -> None:
100
- remote_smem_ptr_i32 = set_block_rank(smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
101
- remote_mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
109
+ remote_smem_ptr_i32 = set_block_rank(
110
+ smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
111
+ ).ir_value()
112
+ remote_mbar_ptr_i32 = set_block_rank(
113
+ mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
114
+ ).ir_value()
115
+ if cutlass.const_expr(isinstance(val, float)):
116
+ val = Float32(val)
117
+ assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
118
+ suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
102
119
  llvm.inline_asm(
103
120
  None,
104
- [remote_smem_ptr_i32, cute.Float32(val).ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
105
- "st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];",
106
- "r,f,r",
121
+ [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
122
+ f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
123
+ f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
107
124
  has_side_effects=True,
108
125
  is_align_stack=False,
109
126
  asm_dialect=llvm.AsmDialect.AD_ATT,
@@ -111,17 +128,24 @@ def store_shared_remote(
111
128
 
112
129
 
113
130
  @cute.jit
114
- def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, mbar_ptr: cute.Pointer, init_val: cute.Numeric = 0.0) -> cute.Numeric:
115
- """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
116
- """
131
+ def cluster_reduce(
132
+ val: cute.Numeric,
133
+ op: Callable,
134
+ reduction_buffer: cute.Tensor,
135
+ mbar_ptr: cute.Pointer,
136
+ init_val: cute.Numeric = 0.0,
137
+ ) -> cute.Numeric:
138
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
117
139
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
118
140
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
119
141
  warps_per_row, cluster_n = reduction_buffer.shape[1]
120
142
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
121
143
  if lane_idx < cluster_n:
122
144
  store_shared_remote(
123
- val, elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
124
- mbar_ptr, peer_cta_rank_in_cluster=lane_idx
145
+ val,
146
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
147
+ mbar_ptr,
148
+ peer_cta_rank_in_cluster=lane_idx,
125
149
  )
126
150
  cute.arch.mbarrier_wait(mbar_ptr, phase=0)
127
151
  block_reduce_val = init_val
@@ -134,9 +158,14 @@ def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tenso
134
158
 
135
159
 
136
160
  @cute.jit
137
- def block_or_cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, mbar_ptr: Optional[cute.Pointer], init_val: cute.Numeric = 0.0) -> cute.Numeric:
138
- """Perform either block or cluster reduction based on whether mbar_ptr is provided.
139
- """
161
+ def block_or_cluster_reduce(
162
+ val: cute.Numeric,
163
+ op: Callable,
164
+ reduction_buffer: cute.Tensor,
165
+ mbar_ptr: Optional[cute.Pointer],
166
+ init_val: cute.Numeric = 0.0,
167
+ ) -> cute.Numeric:
168
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
140
169
  if cutlass.const_expr(mbar_ptr is None):
141
170
  return block_reduce(val, op, reduction_buffer, init_val=init_val)
142
171
  else:
@@ -153,15 +182,14 @@ def row_reduce(
153
182
  init_val: cute.Numeric = 0.0,
154
183
  hook_fn: Optional[Callable] = None,
155
184
  ) -> cute.Numeric:
156
- """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
157
- """
185
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
158
186
  if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
159
187
  val = x.reduce(op, init_val=init_val, reduction_profile=0)
160
188
  else:
161
189
  val = x
162
190
  warp_op = {
163
191
  cute.ReductionOp.ADD: operator.add,
164
- cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == cute.Float32) else max,
192
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
165
193
  cute.ReductionOp.MIN: min,
166
194
  cute.ReductionOp.MUL: operator.mul,
167
195
  }[op]
@@ -174,7 +202,9 @@ def row_reduce(
174
202
  hook_fn()
175
203
  if cutlass.const_expr(reduction_buffer is not None):
176
204
  warps_per_row, cluster_n = reduction_buffer.shape[1]
177
- assert cluster_n == 1 or mbar_ptr is not None, "mbar_ptr must be provided for cluster reduction"
205
+ assert (
206
+ cluster_n == 1 or mbar_ptr is not None
207
+ ), "mbar_ptr must be provided for cluster reduction"
178
208
  if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
179
209
  val = block_or_cluster_reduce(
180
210
  val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
@@ -182,19 +212,107 @@ def row_reduce(
182
212
  return val
183
213
 
184
214
 
215
+ @cute.jit
216
+ def online_softmax_reduce(
217
+ x: cute.TensorSSA,
218
+ threads_per_row: cutlass.Constexpr[int],
219
+ reduction_buffer: Optional[cute.Tensor] = None,
220
+ mbar_ptr: Optional[cute.Pointer] = None,
221
+ hook_fn: Optional[Callable] = None,
222
+ return_exp_x: bool = False,
223
+ ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
224
+ assert x.dtype == Float32, "x must be of type Float32"
225
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
226
+ max_x = warp_reduce(
227
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
228
+ cute.arch.fmax,
229
+ width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
230
+ )
231
+ log2_e = math.log2(math.e)
232
+ exp_x = exp2f(x * log2_e - (max_x * log2_e))
233
+ # exp_x = exp2f((x - max_x) * log2_e)
234
+ sum_exp_x = warp_reduce(
235
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
236
+ operator.add,
237
+ width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
238
+ )
239
+ if cutlass.const_expr(hook_fn is not None):
240
+ hook_fn()
241
+ if cutlass.const_expr(reduction_buffer is not None):
242
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
243
+ assert (
244
+ cluster_n == 1 or mbar_ptr is not None
245
+ ), "mbar_ptr must be provided for cluster reduction"
246
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
247
+ assert (
248
+ reduction_buffer.element_type == cutlass.Int64
249
+ ), "reduction_buffer must be of type cute.Int64"
250
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
251
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
252
+ if cutlass.const_expr(mbar_ptr is None):
253
+ if lane_idx == 0:
254
+ reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
255
+ cute.arch.barrier()
256
+ max_x_single_warp = -Float32.inf
257
+ sum_exp_x = 0.0
258
+ if lane_idx < warps_per_row:
259
+ max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
260
+ max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
261
+ sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
262
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
263
+ if cutlass.const_expr(return_exp_x):
264
+ exp_x *= exp2f((max_x - max_x_final) * log2_e)
265
+ max_x = max_x_final
266
+ else:
267
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
268
+ if lane_idx < cluster_n:
269
+ store_shared_remote(
270
+ f32x2_to_i64(max_x, sum_exp_x),
271
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
272
+ mbar_ptr,
273
+ peer_cta_rank_in_cluster=lane_idx,
274
+ )
275
+ cute.arch.mbarrier_wait(mbar_ptr, phase=0)
276
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
277
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
278
+ max_x_single_warp.fill(-Float32.inf)
279
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
280
+ sum_exp_x_single_warp.fill(0.0)
281
+ for i in cutlass.range_constexpr(num_iter):
282
+ idx = lane_idx + i * cute.arch.WARP_SIZE
283
+ if idx < cute.size(reduction_buffer, mode=[1]):
284
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
285
+ reduction_buffer[row_idx, idx]
286
+ )
287
+ max_x_final = max_x_single_warp.load().reduce(
288
+ cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
289
+ )
290
+ max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
291
+ sum_exp_x = 0.0
292
+ for i in cutlass.range_constexpr(num_iter):
293
+ sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
294
+ (max_x_single_warp[i] - max_x_final) * log2_e
295
+ )
296
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
297
+ if cutlass.const_expr(return_exp_x):
298
+ exp_x *= exp2f((max_x - max_x_final) * log2_e)
299
+ max_x = max_x_final
300
+ return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
301
+
185
302
 
186
- def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float32:
303
+ @cute.jit
304
+ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
187
305
  """exp2f calculation for both vector and scalar.
188
306
 
189
307
  :param x: input value
190
- :type x: cute.TensorSSA or cutlass.Float32
308
+ :type x: cute.TensorSSA or Float32
191
309
  :return: exp2 value
192
- :rtype: cute.TensorSSA or cutlass.Float32
310
+ :rtype: cute.TensorSSA or Float32
193
311
  """
194
- if isinstance(x, cute.TensorSSA):
195
- res = cute.make_fragment(x.shape, cutlass.Float32)
312
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
313
+ res = cute.make_fragment(x.shape, Float32)
196
314
  res.store(x)
197
- for i in range(cute.size(x.shape)):
315
+ for i in cutlass.range_constexpr(cute.size(x.shape)):
198
316
  res[i] = cute.arch.exp2(res[i])
199
317
  return res.load()
200
318
  else:
@@ -202,11 +320,11 @@ def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float
202
320
 
203
321
 
204
322
  @dsl_user_op
205
- def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
206
- return cutlass.Float32(
323
+ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
324
+ return Float32(
207
325
  llvm.inline_asm(
208
326
  T.f32(),
209
- [cutlass.Float32(a).ir_value(loc=loc, ip=ip)],
327
+ [Float32(a).ir_value(loc=loc, ip=ip)],
210
328
  "lg2.approx.ftz.f32 $0, $1;",
211
329
  "=f,f",
212
330
  has_side_effects=False,
@@ -217,11 +335,11 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
217
335
 
218
336
 
219
337
  @dsl_user_op
220
- def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
221
- return cute.Float32(
338
+ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
339
+ return Float32(
222
340
  llvm.inline_asm(
223
341
  T.f32(),
224
- [cute.Float32(a).ir_value(loc=loc, ip=ip)],
342
+ [Float32(a).ir_value(loc=loc, ip=ip)],
225
343
  "rsqrt.approx.ftz.f32 $0, $1;",
226
344
  "=f,f",
227
345
  has_side_effects=False,
@@ -231,6 +349,7 @@ def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
231
349
  )
232
350
 
233
351
 
352
+ @cute.jit
234
353
  def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
235
354
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
236
355
  tApA = cute.make_fragment(
@@ -240,7 +359,49 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
240
359
  ),
241
360
  cutlass.Boolean,
242
361
  )
243
- for rest_v in range(tApA.shape[0]):
244
- for rest_k in range(tApA.shape[2]):
362
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
363
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
245
364
  tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
246
365
  return tApA
366
+
367
+
368
+ @cute.jit
369
+ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
370
+ """Fill out-of-bounds values in shared memory tensor.
371
+
372
+ Args:
373
+ tXsX: Shared memory tensor to fill
374
+ tXpX: Predicate tensor indicating valid elements
375
+ fill_value: Value to fill OOB locations with
376
+ """
377
+ tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
378
+ tXrX_fill.fill(fill_value)
379
+ for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
380
+ for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
381
+ if not tXpX[rest_v, 0, rest_k]:
382
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
383
+
384
+
385
+ @dsl_user_op
386
+ def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
387
+ vec_f32x2 = vector.from_elements(
388
+ T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
389
+ )
390
+ vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
391
+ res = cutlass.Int64(
392
+ vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
393
+ )
394
+ return res
395
+
396
+
397
+ @dsl_user_op
398
+ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
399
+ vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
400
+ vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
401
+ res0 = Float32(
402
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
403
+ )
404
+ res1 = Float32(
405
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
406
+ )
407
+ return res0, res1
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.4
2
+ Name: quack-kernels
3
+ Version: 0.1.4
4
+ Requires-Python: >=3.9
5
+ License-File: LICENSE
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
7
+ Requires-Dist: torch
8
+ Provides-Extra: dev
9
+ Requires-Dist: pre-commit; extra == "dev"
10
+ Requires-Dist: ruff; extra == "dev"
11
+ Dynamic: license-file
@@ -0,0 +1,11 @@
1
+ quack/__init__.py,sha256=cFLxO6nA_faFqHf4N-Fy7G0j8ykuYPB1uOt9uoJ2dkQ,203
2
+ quack/cross_entropy.py,sha256=HnF2OErEzb10SWxY6HoYE42lnvlw2DsWCks7mylPwnI,9511
3
+ quack/reduction_base.py,sha256=Rsj9ZeSHcKAXGn1p7mY1vrrBqxevi4feLjY0JJhKnmY,3663
4
+ quack/rmsnorm.py,sha256=TkOZsXJwcsoZMLnmEWQ-pEF0r-iiZhGrCNLSFCXfv6s,10676
5
+ quack/softmax.py,sha256=VfhlC2huRuv7olFSVFgS8LF1yF8TFV64yjjjQxYX9yk,16364
6
+ quack/utils.py,sha256=zVc9U-5No19trE585KqDdXx9chAruXPRIPMZdO7mkRg,15603
7
+ quack_kernels-0.1.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ quack_kernels-0.1.4.dist-info/METADATA,sha256=xl62C5WFgiUbnOICAzjldsljJ9j1Fb_JxZVksHLCI8I,289
9
+ quack_kernels-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ quack_kernels-0.1.4.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
+ quack_kernels-0.1.4.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: quack-kernels
3
- Version: 0.1.2
4
- Requires-Python: >=3.9
5
- License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.0.0
7
- Requires-Dist: torch
8
- Dynamic: license-file
@@ -1,10 +0,0 @@
1
- quack/__init__.py,sha256=Nf01m1CGrOjSkqGJom6P65hSLkckljRMhlkSoqqlO9k,137
2
- quack/cross_entropy.py,sha256=gdo8sR9KT5TsrShbgAmy-bwRZLu0gTs_ykXBF2RMbFI,8900
3
- quack/rmsnorm.py,sha256=JhwJSAPDDpB_hV90xU9ymiLU-zu4WScrSHc5JX2JarY,10470
4
- quack/softmax.py,sha256=C8e8ZNaF5ePJ1NlrWZN1goCcvsx1C60FWlRyuFCcYoM,7737
5
- quack/utils.py,sha256=PRdu-P7azA_PeHUNdtoy1zyxZwg_QyVrSiVwE1iXaWo,8961
6
- quack_kernels-0.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
- quack_kernels-0.1.2.dist-info/METADATA,sha256=3WjugLu1IhLlgsg2qUcLBZq1HI4-BIyyJIuQc5Hk-rU,186
8
- quack_kernels-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- quack_kernels-0.1.2.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
10
- quack_kernels-0.1.2.dist-info/RECORD,,