quack-kernels 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/utils.py CHANGED
@@ -1,13 +1,15 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
+ import operator
3
4
  import math
4
- from typing import Type, Callable, Optional
5
+ from typing import Callable, Optional, Tuple
5
6
 
6
7
  import cutlass
7
8
  import cutlass.cute as cute
8
9
 
10
+ from cutlass import Float32
9
11
  from cutlass.cutlass_dsl import T, dsl_user_op
10
- from cutlass._mlir.dialects import nvvm, llvm
12
+ from cutlass._mlir.dialects import llvm, vector
11
13
  from cutlass.cute.runtime import from_dlpack
12
14
 
13
15
 
@@ -38,7 +40,7 @@ def min_constexpr(
38
40
  def warp_reduce(
39
41
  val: cute.TensorSSA | cute.Numeric,
40
42
  op: Callable,
41
- width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE
43
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
42
44
  ) -> cute.TensorSSA | cute.Numeric:
43
45
  if isinstance(val, cute.TensorSSA):
44
46
  res = cute.make_fragment(val.shape, val.dtype)
@@ -53,11 +55,12 @@ def warp_reduce(
53
55
 
54
56
 
55
57
  @cute.jit
56
- def block_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0) -> cute.Numeric:
57
- """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)
58
- """
58
+ def block_reduce(
59
+ val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
60
+ ) -> cute.Numeric:
61
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
59
62
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
60
- warps_per_row = reduction_buffer.shape[1]
63
+ warps_per_row = cute.size(reduction_buffer.shape[1])
61
64
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
62
65
  if lane_idx == 0:
63
66
  reduction_buffer[row_idx, col_idx] = val
@@ -74,9 +77,10 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
74
77
 
75
78
 
76
79
  @dsl_user_op
77
- def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None) -> cutlass.Int32:
78
- """Map the given smem pointer to the address at another CTA rank in the cluster.
79
- """
80
+ def set_block_rank(
81
+ smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
82
+ ) -> cutlass.Int32:
83
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
80
84
  smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
81
85
  return cutlass.Int32(
82
86
  llvm.inline_asm(
@@ -93,16 +97,29 @@ def set_block_rank(smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32,
93
97
 
94
98
  @dsl_user_op
95
99
  def store_shared_remote(
96
- val: float | cute.Float32, smem_ptr: cute.Pointer, mbar_ptr: cute.Pointer,
97
- peer_cta_rank_in_cluster: cute.typing.Int, *, loc=None, ip=None
100
+ val: float | Float32 | cutlass.Int64,
101
+ smem_ptr: cute.Pointer,
102
+ mbar_ptr: cute.Pointer,
103
+ peer_cta_rank_in_cluster: cute.typing.Int,
104
+ *,
105
+ loc=None,
106
+ ip=None,
98
107
  ) -> None:
99
- remote_smem_ptr_i32 = set_block_rank(smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
100
- remote_mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
108
+ remote_smem_ptr_i32 = set_block_rank(
109
+ smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
110
+ ).ir_value()
111
+ remote_mbar_ptr_i32 = set_block_rank(
112
+ mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
113
+ ).ir_value()
114
+ if isinstance(val, float):
115
+ val = Float32(val)
116
+ assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
117
+ suffix = "f32" if isinstance(val, Float32) else "s64"
101
118
  llvm.inline_asm(
102
119
  None,
103
- [remote_smem_ptr_i32, cute.Float32(val).ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
104
- "st.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];",
105
- "r,f,r",
120
+ [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
121
+ f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
122
+ f"r,{'f' if isinstance(val, Float32) else 'l'},r",
106
123
  has_side_effects=True,
107
124
  is_align_stack=False,
108
125
  asm_dialect=llvm.AsmDialect.AD_ATT,
@@ -110,17 +127,24 @@ def store_shared_remote(
110
127
 
111
128
 
112
129
  @cute.jit
113
- def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, mbar_ptr: cute.Pointer, init_val: cute.Numeric = 0.0) -> cute.Numeric:
114
- """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))
115
- """
130
+ def cluster_reduce(
131
+ val: cute.Numeric,
132
+ op: Callable,
133
+ reduction_buffer: cute.Tensor,
134
+ mbar_ptr: cute.Pointer,
135
+ init_val: cute.Numeric = 0.0,
136
+ ) -> cute.Numeric:
137
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
116
138
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
117
139
  lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
118
140
  warps_per_row, cluster_n = reduction_buffer.shape[1]
119
141
  row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
120
142
  if lane_idx < cluster_n:
121
143
  store_shared_remote(
122
- val, elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
123
- mbar_ptr, peer_cta_rank_in_cluster=lane_idx
144
+ val,
145
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
146
+ mbar_ptr,
147
+ peer_cta_rank_in_cluster=lane_idx,
124
148
  )
125
149
  cute.arch.mbarrier_wait(mbar_ptr, phase=0)
126
150
  block_reduce_val = init_val
@@ -133,25 +157,158 @@ def cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tenso
133
157
 
134
158
 
135
159
  @cute.jit
136
- def block_or_cluster_reduce(val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, mbar_ptr: Optional[cute.Pointer], init_val: cute.Numeric = 0.0) -> cute.Numeric:
137
- """Perform either block or cluster reduction based on whether mbar_ptr is provided.
138
- """
160
+ def block_or_cluster_reduce(
161
+ val: cute.Numeric,
162
+ op: Callable,
163
+ reduction_buffer: cute.Tensor,
164
+ mbar_ptr: Optional[cute.Pointer],
165
+ init_val: cute.Numeric = 0.0,
166
+ ) -> cute.Numeric:
167
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
139
168
  if cutlass.const_expr(mbar_ptr is None):
140
169
  return block_reduce(val, op, reduction_buffer, init_val=init_val)
141
170
  else:
142
171
  return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val)
143
172
 
144
173
 
145
- def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float32:
174
+ @cute.jit
175
+ def row_reduce(
176
+ x: cute.TensorSSA | cute.Numeric,
177
+ op: cute.ReductionOp,
178
+ threads_per_row: cutlass.Constexpr[int],
179
+ reduction_buffer: Optional[cute.Tensor] = None,
180
+ mbar_ptr: Optional[cute.Pointer] = None,
181
+ init_val: cute.Numeric = 0.0,
182
+ hook_fn: Optional[Callable] = None,
183
+ ) -> cute.Numeric:
184
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
185
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
186
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
187
+ else:
188
+ val = x
189
+ warp_op = {
190
+ cute.ReductionOp.ADD: operator.add,
191
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
192
+ cute.ReductionOp.MIN: min,
193
+ cute.ReductionOp.MUL: operator.mul,
194
+ }[op]
195
+ val = warp_reduce(
196
+ val,
197
+ warp_op,
198
+ width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
199
+ )
200
+ if cutlass.const_expr(hook_fn is not None):
201
+ hook_fn()
202
+ if cutlass.const_expr(reduction_buffer is not None):
203
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
204
+ assert (
205
+ cluster_n == 1 or mbar_ptr is not None
206
+ ), "mbar_ptr must be provided for cluster reduction"
207
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
208
+ val = block_or_cluster_reduce(
209
+ val, warp_op, reduction_buffer, mbar_ptr, init_val=init_val
210
+ )
211
+ return val
212
+
213
+
214
+ @cute.jit
215
+ def online_softmax_reduce(
216
+ x: cute.TensorSSA,
217
+ threads_per_row: cutlass.Constexpr[int],
218
+ reduction_buffer: Optional[cute.Tensor] = None,
219
+ mbar_ptr: Optional[cute.Pointer] = None,
220
+ hook_fn: Optional[Callable] = None,
221
+ return_exp_x: bool = False,
222
+ ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
223
+ assert x.dtype == Float32, "x must be of type Float32"
224
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
225
+ max_x = warp_reduce(
226
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
227
+ cute.arch.fmax,
228
+ width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
229
+ )
230
+ log2_e = math.log2(math.e)
231
+ exp_x = exp2f(x * log2_e - (max_x * log2_e))
232
+ # exp_x = exp2f((x - max_x) * log2_e)
233
+ sum_exp_x = warp_reduce(
234
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
235
+ operator.add,
236
+ width=min_constexpr(threads_per_row, cute.arch.WARP_SIZE),
237
+ )
238
+ if cutlass.const_expr(hook_fn is not None):
239
+ hook_fn()
240
+ if cutlass.const_expr(reduction_buffer is not None):
241
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
242
+ assert (
243
+ cluster_n == 1 or mbar_ptr is not None
244
+ ), "mbar_ptr must be provided for cluster reduction"
245
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
246
+ assert (
247
+ reduction_buffer.element_type == cutlass.Int64
248
+ ), "reduction_buffer must be of type cute.Int64"
249
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
250
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
251
+ if cutlass.const_expr(mbar_ptr is None):
252
+ if lane_idx == 0:
253
+ reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
254
+ cute.arch.barrier()
255
+ max_x_single_warp = -Float32.inf
256
+ sum_exp_x = 0.0
257
+ if lane_idx < warps_per_row:
258
+ max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
259
+ max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
260
+ sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
261
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
262
+ if cutlass.const_expr(return_exp_x):
263
+ exp_x *= exp2f((max_x - max_x_final) * log2_e)
264
+ max_x = max_x_final
265
+ else:
266
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
267
+ if lane_idx < cluster_n:
268
+ store_shared_remote(
269
+ f32x2_to_i64(max_x, sum_exp_x),
270
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
271
+ mbar_ptr,
272
+ peer_cta_rank_in_cluster=lane_idx,
273
+ )
274
+ cute.arch.mbarrier_wait(mbar_ptr, phase=0)
275
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
276
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
277
+ max_x_single_warp.fill(-Float32.inf)
278
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
279
+ sum_exp_x_single_warp.fill(0.0)
280
+ for i in cutlass.range_constexpr(num_iter):
281
+ idx = lane_idx + i * cute.arch.WARP_SIZE
282
+ if idx < cute.size(reduction_buffer, mode=[1]):
283
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
284
+ reduction_buffer[row_idx, idx]
285
+ )
286
+ max_x_final = max_x_single_warp.load().reduce(
287
+ cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
288
+ )
289
+ max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
290
+ sum_exp_x = 0.0
291
+ for i in cutlass.range_constexpr(num_iter):
292
+ sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
293
+ (max_x_single_warp[i] - max_x_final) * log2_e
294
+ )
295
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
296
+ if cutlass.const_expr(return_exp_x):
297
+ exp_x *= exp2f((max_x - max_x_final) * log2_e)
298
+ max_x = max_x_final
299
+ return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
300
+
301
+
302
+ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
146
303
  """exp2f calculation for both vector and scalar.
147
304
 
148
305
  :param x: input value
149
- :type x: cute.TensorSSA or cutlass.Float32
306
+ :type x: cute.TensorSSA or Float32
150
307
  :return: exp2 value
151
- :rtype: cute.TensorSSA or cutlass.Float32
308
+ :rtype: cute.TensorSSA or Float32
152
309
  """
153
310
  if isinstance(x, cute.TensorSSA):
154
- res = cute.make_fragment(x.shape, cutlass.Float32)
311
+ res = cute.make_fragment(x.shape, Float32)
155
312
  res.store(x)
156
313
  for i in range(cute.size(x.shape)):
157
314
  res[i] = cute.arch.exp2(res[i])
@@ -161,11 +318,11 @@ def exp2f(x: cute.TensorSSA | cutlass.Float32) -> cute.TensorSSA | cutlass.Float
161
318
 
162
319
 
163
320
  @dsl_user_op
164
- def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
165
- return cutlass.Float32(
321
+ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
322
+ return Float32(
166
323
  llvm.inline_asm(
167
324
  T.f32(),
168
- [cutlass.Float32(a).ir_value(loc=loc, ip=ip)],
325
+ [Float32(a).ir_value(loc=loc, ip=ip)],
169
326
  "lg2.approx.ftz.f32 $0, $1;",
170
327
  "=f,f",
171
328
  has_side_effects=False,
@@ -176,11 +333,11 @@ def log2f(a: float | cutlass.Float32, *, loc=None, ip=None) -> cutlass.Float32:
176
333
 
177
334
 
178
335
  @dsl_user_op
179
- def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
180
- return cute.Float32(
336
+ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
337
+ return Float32(
181
338
  llvm.inline_asm(
182
339
  T.f32(),
183
- [cute.Float32(a).ir_value(loc=loc, ip=ip)],
340
+ [Float32(a).ir_value(loc=loc, ip=ip)],
184
341
  "rsqrt.approx.ftz.f32 $0, $1;",
185
342
  "=f,f",
186
343
  has_side_effects=False,
@@ -188,3 +345,60 @@ def rsqrt(a: float | cute.Float32, *, loc=None, ip=None) -> cute.Float32:
188
345
  asm_dialect=llvm.AsmDialect.AD_ATT,
189
346
  )
190
347
  )
348
+
349
+
350
+ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
351
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
352
+ tApA = cute.make_fragment(
353
+ cute.make_layout(
354
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
355
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
356
+ ),
357
+ cutlass.Boolean,
358
+ )
359
+ for rest_v in range(tApA.shape[0]):
360
+ for rest_k in range(tApA.shape[2]):
361
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
362
+ return tApA
363
+
364
+
365
+ @cute.jit
366
+ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) -> None:
367
+ """Fill out-of-bounds values in shared memory tensor.
368
+
369
+ Args:
370
+ tXsX: Shared memory tensor to fill
371
+ tXpX: Predicate tensor indicating valid elements
372
+ fill_value: Value to fill OOB locations with
373
+ """
374
+ tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
375
+ tXrX_fill.fill(fill_value)
376
+ for rest_v in range(tXpX.shape[0]):
377
+ for rest_k in range(tXpX.shape[2]):
378
+ if not tXpX[rest_v, 0, rest_k]:
379
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
380
+
381
+
382
+ @dsl_user_op
383
+ def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
384
+ vec_f32x2 = vector.from_elements(
385
+ T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip
386
+ )
387
+ vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2)
388
+ res = cutlass.Int64(
389
+ vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
390
+ )
391
+ return res
392
+
393
+
394
+ @dsl_user_op
395
+ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
396
+ vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip)
397
+ vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1)
398
+ res0 = Float32(
399
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
400
+ )
401
+ res1 = Float32(
402
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
403
+ )
404
+ return res0, res1
@@ -1,8 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.0.0
7
7
  Requires-Dist: torch
8
+ Provides-Extra: dev
9
+ Requires-Dist: pre-commit; extra == "dev"
10
+ Requires-Dist: ruff; extra == "dev"
8
11
  Dynamic: license-file
@@ -0,0 +1,11 @@
1
+ quack/__init__.py,sha256=aUR7drzgaqmbzw9H_eoFselMUVQVF3BHc9VOzZg5d-Q,203
2
+ quack/cross_entropy.py,sha256=_Xlyifd_YS8LaYxYlZEsuBfsi8zTH4At3i9DDggGCf8,9319
3
+ quack/reduction_base.py,sha256=nrRsXwTpLVQkPp2Gr_FgHRPnifqkMHRodve5ciHzx58,3667
4
+ quack/rmsnorm.py,sha256=YqGTTKHHXYzw3xnnjBRfaN9TDlhG8D_fSI9CHKAU40A,10548
5
+ quack/softmax.py,sha256=mWaUfaY6PBtO1ioYxXxS-yodQmcBNGasWVMUg9G066Y,15938
6
+ quack/utils.py,sha256=1-HMcFTEvGdAtqC3ucQGZ3DLa_PoJQsqwYlKd9bcXO8,15347
7
+ quack_kernels-0.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
8
+ quack_kernels-0.1.3.dist-info/METADATA,sha256=DDuEKHLjFx9dFTQV5YtXsnKVFZVoueO7NwhcwOtpw6g,284
9
+ quack_kernels-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ quack_kernels-0.1.3.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
11
+ quack_kernels-0.1.3.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- quack/__init__.py,sha256=y3Oa4OVPqaGU_P1miI435DzfpMgIwKVmU8-Eogv58jQ,137
2
- quack/cross_entropy.py,sha256=V0kG8DCNh2735sPIDwe68NB50rAqDF3XQApnGyo-sKg,9220
3
- quack/rmsnorm.py,sha256=RNqcT-q4uvMbF6ejpzuqQH8l8VVuTRlnueXf28V47sc,11954
4
- quack/softmax.py,sha256=QABgOESH5JjDm3yuUkyZZKXXpzn7CTuMSs0NEBnFD80,8536
5
- quack/utils.py,sha256=ofV7QLDuq80h3nEA3TwZW-ti8CnYwMgnz1dpxpvhHpk,6859
6
- quack_kernels-0.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
- quack_kernels-0.1.1.dist-info/METADATA,sha256=XG3zS0_q48TzkoR7CemzaJGVYHS731yVOrzH49_uRK8,186
8
- quack_kernels-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- quack_kernels-0.1.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
10
- quack_kernels-0.1.1.dist-info/RECORD,,