quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/utils.py CHANGED
@@ -1,8 +1,7 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import operator
4
3
  import math
5
- from typing import Callable, Optional, Tuple, Type, Union
4
+ from typing import Optional, Tuple, Type, Union
6
5
 
7
6
  import cutlass
8
7
  import cutlass.cute as cute
@@ -23,46 +22,20 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
23
22
  )
24
23
 
25
24
 
26
- @cute.jit
27
- def warp_reduce(
28
- val: cute.TensorSSA | cute.Numeric,
29
- op: Callable,
30
- width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
31
- ) -> cute.TensorSSA | cute.Numeric:
32
- if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
33
- res = cute.make_fragment(val.shape, val.dtype)
34
- res.store(val)
35
- for i in cutlass.range_constexpr(cute.size(val.shape)):
36
- res[i] = warp_reduce(res[i], op, width)
37
- return res.load()
38
- else:
39
- for i in cutlass.range_constexpr(int(math.log2(width))):
40
- val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
41
- return val
42
-
43
-
44
- @cute.jit
45
- def block_reduce(
46
- val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
47
- ) -> cute.Numeric:
48
- """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
49
- lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
50
- warps_per_row = cute.size(reduction_buffer.shape[1])
51
- row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
52
- if lane_idx == 0:
53
- reduction_buffer[row_idx, col_idx] = val
54
- cute.arch.barrier()
55
- block_reduce_val = init_val
56
- if lane_idx < warps_per_row:
57
- block_reduce_val = reduction_buffer[row_idx, lane_idx]
58
- return warp_reduce(block_reduce_val, op)
59
-
60
-
61
25
  @dsl_user_op
62
26
  def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
63
27
  return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
64
28
 
65
29
 
30
+ @cute.jit
31
+ def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
32
+ if cutlass.const_expr(isinstance(x, cute.Pointer)):
33
+ return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
34
+ else:
35
+ assert isinstance(x, Float32)
36
+ return x
37
+
38
+
66
39
  @dsl_user_op
67
40
  def set_block_rank(
68
41
  smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
@@ -114,197 +87,6 @@ def store_shared_remote(
114
87
  )
115
88
 
116
89
 
117
- @cute.jit
118
- def cluster_reduce(
119
- val: cute.Numeric,
120
- op: Callable,
121
- reduction_buffer: cute.Tensor,
122
- mbar_ptr: cute.Pointer,
123
- init_val: cute.Numeric = 0.0,
124
- phase: Optional[cutlass.Int32] = None,
125
- ) -> cute.Numeric:
126
- """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
127
- cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
128
- lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
129
- rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
130
- row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
131
- if warp_idx == 0:
132
- with cute.arch.elect_one():
133
- num_warps = rows_per_block * warps_per_row
134
- cute.arch.mbarrier_arrive_and_expect_tx(
135
- mbar_ptr,
136
- num_warps * cluster_n * reduction_buffer.element_type.width // 8,
137
- )
138
- if lane_idx < cluster_n:
139
- store_shared_remote(
140
- val,
141
- elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
142
- mbar_ptr,
143
- peer_cta_rank_in_cluster=lane_idx,
144
- )
145
- cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
146
- block_reduce_val = init_val
147
- num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
148
- for i in cutlass.range_constexpr(num_iter):
149
- idx = lane_idx + i * cute.arch.WARP_SIZE
150
- if idx < cute.size(reduction_buffer, mode=[1]):
151
- block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
152
- return warp_reduce(block_reduce_val, op)
153
-
154
-
155
- @cute.jit
156
- def block_or_cluster_reduce(
157
- val: cute.Numeric,
158
- op: Callable,
159
- reduction_buffer: cute.Tensor,
160
- mbar_ptr: Optional[cute.Pointer],
161
- phase: Optional[cutlass.Int32] = None,
162
- init_val: cute.Numeric = 0.0,
163
- ) -> cute.Numeric:
164
- """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
165
- if cutlass.const_expr(mbar_ptr is None):
166
- return block_reduce(val, op, reduction_buffer, init_val=init_val)
167
- else:
168
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
169
-
170
-
171
- @cute.jit
172
- def row_reduce(
173
- x: cute.TensorSSA | cute.Numeric,
174
- op: cute.ReductionOp,
175
- threads_per_row: cutlass.Constexpr[int],
176
- reduction_buffer: Optional[cute.Tensor] = None,
177
- mbar_ptr: Optional[cute.Pointer] = None,
178
- phase: Optional[cutlass.Int32] = None,
179
- init_val: cute.Numeric = 0.0,
180
- hook_fn: Optional[Callable] = None,
181
- ) -> cute.Numeric:
182
- """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
183
- if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
184
- val = x.reduce(op, init_val=init_val, reduction_profile=0)
185
- else:
186
- val = x
187
- warp_op = {
188
- cute.ReductionOp.ADD: operator.add,
189
- cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
190
- cute.ReductionOp.MIN: min,
191
- cute.ReductionOp.MUL: operator.mul,
192
- }[op]
193
- val = warp_reduce(
194
- val,
195
- warp_op,
196
- width=min(threads_per_row, cute.arch.WARP_SIZE),
197
- )
198
- if cutlass.const_expr(hook_fn is not None):
199
- hook_fn()
200
- if cutlass.const_expr(reduction_buffer is not None):
201
- warps_per_row, cluster_n = reduction_buffer.shape[1]
202
- assert cluster_n == 1 or mbar_ptr is not None, (
203
- "mbar_ptr must be provided for cluster reduction"
204
- )
205
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
206
- val = block_or_cluster_reduce(
207
- val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
208
- )
209
- return val
210
-
211
-
212
- @cute.jit
213
- def online_softmax_reduce(
214
- x: cute.TensorSSA,
215
- threads_per_row: cutlass.Constexpr[int],
216
- reduction_buffer: Optional[cute.Tensor] = None,
217
- mbar_ptr: Optional[cute.Pointer] = None,
218
- hook_fn: Optional[Callable] = None,
219
- phase: Optional[cutlass.Int32] = None,
220
- return_exp_x: bool = False,
221
- ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
222
- assert x.dtype == Float32, "x must be of type Float32"
223
- """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
224
- max_x = warp_reduce(
225
- x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
226
- cute.arch.fmax,
227
- width=min(threads_per_row, cute.arch.WARP_SIZE),
228
- )
229
- log2_e = math.log2(math.e)
230
- exp_x = exp2f(x * log2_e - (max_x * log2_e))
231
- # exp_x = exp2f((x - max_x) * log2_e)
232
- sum_exp_x = warp_reduce(
233
- exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
234
- operator.add,
235
- width=min(threads_per_row, cute.arch.WARP_SIZE),
236
- )
237
- if cutlass.const_expr(hook_fn is not None):
238
- hook_fn()
239
- if cutlass.const_expr(reduction_buffer is not None):
240
- rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
241
- assert cluster_n == 1 or mbar_ptr is not None, (
242
- "mbar_ptr must be provided for cluster reduction"
243
- )
244
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
245
- assert reduction_buffer.element_type == cutlass.Int64, (
246
- "reduction_buffer must be of type cute.Int64"
247
- )
248
- lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
249
- row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
250
- if cutlass.const_expr(mbar_ptr is None):
251
- if lane_idx == 0:
252
- reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
253
- cute.arch.barrier()
254
- max_x_single_warp = -Float32.inf
255
- sum_exp_x = 0.0
256
- if lane_idx < warps_per_row:
257
- max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
258
- max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
259
- sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
260
- sum_exp_x = warp_reduce(sum_exp_x, operator.add)
261
- if cutlass.const_expr(return_exp_x):
262
- exp_x *= exp2f((max_x - max_x_final) * log2_e)
263
- max_x = max_x_final
264
- else:
265
- cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
266
- if warp_idx == 0:
267
- with cute.arch.elect_one():
268
- num_warps = rows_per_block * warps_per_row
269
- cute.arch.mbarrier_arrive_and_expect_tx(
270
- mbar_ptr,
271
- num_warps * cluster_n * reduction_buffer.element_type.width // 8,
272
- )
273
- if lane_idx < cluster_n:
274
- store_shared_remote(
275
- f32x2_to_i64(max_x, sum_exp_x),
276
- elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
277
- mbar_ptr,
278
- peer_cta_rank_in_cluster=lane_idx,
279
- )
280
- cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
281
- num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
282
- max_x_single_warp = cute.make_fragment(num_iter, Float32)
283
- max_x_single_warp.fill(-Float32.inf)
284
- sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
285
- sum_exp_x_single_warp.fill(0.0)
286
- for i in cutlass.range_constexpr(num_iter):
287
- idx = lane_idx + i * cute.arch.WARP_SIZE
288
- if idx < cute.size(reduction_buffer, mode=[1]):
289
- max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
290
- reduction_buffer[row_idx, idx]
291
- )
292
- max_x_final = max_x_single_warp.load().reduce(
293
- cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
294
- )
295
- max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
296
- sum_exp_x = 0.0
297
- for i in cutlass.range_constexpr(num_iter):
298
- sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
299
- (max_x_single_warp[i] - max_x_final) * log2_e
300
- )
301
- sum_exp_x = warp_reduce(sum_exp_x, operator.add)
302
- if cutlass.const_expr(return_exp_x):
303
- exp_x *= exp2f((max_x - max_x_final) * log2_e)
304
- max_x = max_x_final
305
- return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
306
-
307
-
308
90
  @dsl_user_op
309
91
  def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
310
92
  return Float32(
@@ -318,84 +100,6 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
318
100
  )
319
101
 
320
102
 
321
- @cute.jit
322
- def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
323
- """exp2f calculation for both vector and scalar.
324
- :param x: input value
325
- :type x: cute.TensorSSA or Float32
326
- :return: exp2 value
327
- :rtype: cute.TensorSSA or Float32
328
- """
329
- if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
330
- res = cute.make_fragment(x.shape, Float32)
331
- res.store(x)
332
- for i in cutlass.range(cute.size(x.shape), unroll_full=True):
333
- res[i] = cute.arch.exp2(res[i])
334
- return res.load()
335
- else:
336
- return cute.arch.exp2(x)
337
-
338
-
339
- @dsl_user_op
340
- def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
341
- return Float32(
342
- llvm.inline_asm(
343
- T.f32(),
344
- [Float32(a).ir_value(loc=loc, ip=ip)],
345
- "lg2.approx.ftz.f32 $0, $1;",
346
- "=f,f",
347
- has_side_effects=False,
348
- is_align_stack=False,
349
- asm_dialect=llvm.AsmDialect.AD_ATT,
350
- )
351
- )
352
-
353
-
354
- @dsl_user_op
355
- def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
356
- return Float32(
357
- llvm.inline_asm(
358
- T.f32(),
359
- [Float32(a).ir_value(loc=loc, ip=ip)],
360
- "sqrt.approx.ftz.f32 $0, $1;",
361
- "=f,f",
362
- has_side_effects=False,
363
- is_align_stack=False,
364
- asm_dialect=llvm.AsmDialect.AD_ATT,
365
- )
366
- )
367
-
368
-
369
- @dsl_user_op
370
- def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
371
- return Float32(
372
- llvm.inline_asm(
373
- T.f32(),
374
- [Float32(a).ir_value(loc=loc, ip=ip)],
375
- "rsqrt.approx.ftz.f32 $0, $1;",
376
- "=f,f",
377
- has_side_effects=False,
378
- is_align_stack=False,
379
- asm_dialect=llvm.AsmDialect.AD_ATT,
380
- )
381
- )
382
-
383
-
384
- @dsl_user_op
385
- def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
386
- return Float32(
387
- llvm.inline_asm(
388
- T.f32(),
389
- [Float32(a).ir_value(loc=loc, ip=ip)],
390
- "tanh.approx.f32 $0, $1;",
391
- "=f,f",
392
- has_side_effects=False,
393
- is_align_stack=False,
394
- asm_dialect=llvm.AsmDialect.AD_ATT,
395
- )
396
- )
397
-
398
-
399
103
  @dsl_user_op
400
104
  def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
401
105
  return Int32(
@@ -411,16 +115,6 @@ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
411
115
  )
412
116
 
413
117
 
414
- @dsl_user_op
415
- def silu(a: float | Float32, *, loc=None, ip=None) -> Float32:
416
- """
417
- silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a)
418
- This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA.
419
- """
420
- a_half = 0.5 * a
421
- return a_half * tanh(a_half) + a_half
422
-
423
-
424
118
  @dsl_user_op
425
119
  def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
426
120
  return Int32(
@@ -498,7 +192,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
498
192
  tXpX: Predicate tensor indicating valid elements
499
193
  fill_value: Value to fill OOB locations with
500
194
  """
501
- tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
195
+ tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
502
196
  tXrX_fill.fill(fill_value)
503
197
  for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
504
198
  for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
@@ -538,9 +232,9 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
538
232
  def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
539
233
  flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
540
234
  flat_stride = cute.flatten_to_tuple(tensor.stride)
541
- assert len(flat_coord_i64) == len(flat_stride), (
542
- "Coordinate and stride must have the same length"
543
- )
235
+ assert len(flat_coord_i64) == len(
236
+ flat_stride
237
+ ), "Coordinate and stride must have the same length"
544
238
  offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
545
239
  assert isinstance(tensor.iterator, cute.Pointer)
546
240
  # HACK: we assume that applying the offset does not change the pointer alignment
@@ -662,5 +356,3 @@ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None)
662
356
  return nvvm.atomicrmw(
663
357
  res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
664
358
  )
665
-
666
-
quack/varlen_utils.py ADDED
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass.cute as cute
7
+
8
+ from quack.cute_dsl_utils import ArgumentsBase
9
+
10
+
11
+ # Grouping arguments together that should be passed to __call__
12
+ @dataclass
13
+ class VarlenArguments(ArgumentsBase):
14
+ mCuSeqlensM: Optional[cute.Tensor] = None
15
+ mCuSeqlensK: Optional[cute.Tensor] = None
16
+ mTensormaps: Optional[cute.Tensor] = None
17
+
18
+ def __post_init__(self):
19
+ if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
20
+ assert (
21
+ self.mTensormaps is not None
22
+ ), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.11
4
- Requires-Python: >=3.12
3
+ Version: 0.2.1
4
+ Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -0,0 +1,37 @@
1
+ quack/__init__.py,sha256=H1m0CnfPidSSmprZeTGJc8LVh7stdBPmPLEuZwgN_7M,364
2
+ quack/activation.py,sha256=SzQDUCB-kccqsy1aYUrHYJ2cGxKMXxxqpjJaJoqBYaE,10017
3
+ quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
4
+ quack/cross_entropy.py,sha256=TE8j21c-7E4cInKtFjcKsgKXNhKCRFkNfhCJpgpasj8,28409
5
+ quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
6
+ quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
7
+ quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
8
+ quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
9
+ quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
10
+ quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
11
+ quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
12
+ quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
13
+ quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
14
+ quack/layernorm.py,sha256=AOe95-YqhFPw96x8pJq7FfBe26ROX9ZTvH025lM1ILs,13579
15
+ quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
16
+ quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
17
+ quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
18
+ quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
19
+ quack/reduce.py,sha256=0hRFMFfn6xC5QLk32Qmgc17XVkQ1yKC-3TfksccSBaU,10341
20
+ quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
21
+ quack/rmsnorm.py,sha256=PrW2zuaQs_Gr6g8B6DMsGSJFZdEsWf32if_EwUR_IDQ,49386
22
+ quack/softmax.py,sha256=WFWtgc40iLPFBpdStBBTC9803Npnv9rZjOzb_nK-RDs,17110
23
+ quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
24
+ quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
25
+ quack/tile_scheduler.py,sha256=BQ-SeW5wxulKuwmpq0CAIjkuirv4KWdUdoIGQB88aGE,42319
26
+ quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
27
+ quack/utils.py,sha256=wOgNw9VL40FCsLwN52juPfk48zVpX-rta3MQhAQe8Wc,12767
28
+ quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
29
+ quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
30
+ quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
31
+ quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
32
+ quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
33
+ quack_kernels-0.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ quack_kernels-0.2.1.dist-info/METADATA,sha256=_AFigx6aFt-25GzUP6YWalDBwHvwzgK9EU85WjZXvsI,285
35
+ quack_kernels-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ quack_kernels-0.2.1.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
37
+ quack_kernels-0.2.1.dist-info/RECORD,,
quack/lse.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (c) 2025, Tri Dao.
2
- # TODO: we probably dont' need this kernel, just use torch.logsumexp
3
- import torch
4
-
5
- import triton
6
- import triton.language as tl
7
-
8
-
9
- @triton.jit
10
- def _lse_kernel(
11
- lse_ptr,
12
- logits_ptr,
13
- n_rows,
14
- n_cols,
15
- logits_row_stride,
16
- logits_col_stride,
17
- BLOCK_SIZE_M: tl.constexpr,
18
- BLOCK_SIZE_N: tl.constexpr,
19
- ):
20
- row_start = tl.program_id(0) * BLOCK_SIZE_M
21
- rows = row_start + tl.arange(0, BLOCK_SIZE_M)
22
- cols = tl.arange(0, BLOCK_SIZE_N)
23
- logits = tl.load(
24
- logits_ptr + rows[:, None] * logits_row_stride + cols[None, :] * logits_col_stride,
25
- mask=(rows[:, None] < n_rows) & (cols[None, :] < n_cols),
26
- other=-float("inf"),
27
- ).to(tl.float32)
28
- m = tl.max(logits, 1)
29
- lse = tl.log(tl.sum(tl.exp(logits - m[:, None]), 1)) + m
30
- tl.store(lse_ptr + rows, lse, mask=rows < n_rows)
31
-
32
-
33
- def logsumexp(logits):
34
- n_rows, n_cols = logits.shape
35
- BLOCK_SIZE_M = 32 if logits.stride(1) != 1 else 1
36
- MAX_BLOCK_SIZE = 64 * 1024
37
- # BLOCK_SIZE_N = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE // BLOCK_SIZE_M)
38
- BLOCK_SIZE_N = triton.next_power_of_2(n_cols)
39
- assert (
40
- BLOCK_SIZE_M * BLOCK_SIZE_N <= MAX_BLOCK_SIZE
41
- ), f"Only support max dimension {MAX_BLOCK_SIZE // BLOCK_SIZE_M}"
42
- num_warps = (
43
- 4
44
- if BLOCK_SIZE_N < 2048
45
- else (8 if BLOCK_SIZE_N < 8192 else (16 if BLOCK_SIZE_N < 128 * 1024 else 32))
46
- )
47
- lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
48
- # Need this, otherwise Triton tries to launch from cuda:0 and we get
49
- # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
50
- with torch.cuda.device(logits.device.index):
51
- _lse_kernel[(triton.cdiv(n_rows, BLOCK_SIZE_M),)](
52
- lse,
53
- logits,
54
- n_rows,
55
- n_cols, # shapes
56
- logits.stride(0), # strides
57
- logits.stride(1),
58
- BLOCK_SIZE_M=BLOCK_SIZE_M, # constants
59
- BLOCK_SIZE_N=BLOCK_SIZE_N, # constants
60
- num_warps=num_warps,
61
- )
62
- return lse
@@ -1,31 +0,0 @@
1
- quack/__init__.py,sha256=AD0T-rBhSfKXpwZ6E4JIPiugvlFaAePjl-3pUhWOlPE,292
2
- quack/autotuner.py,sha256=aF9-Cw47gaX7_LZvyVbLsj6Z2AWi4UZ-0Qwjy06Xd5I,10733
3
- quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
4
- quack/cute_dsl_utils.py,sha256=LkNyFEKwYrgp-tLt_775EZWuBR3v7G80El3UAObHY2U,1292
5
- quack/dense_gemm_sm100.py,sha256=W_j8BO-ilb1YUYFuclo7_itfPIRTkjPV_ittWgQy8t4,109937
6
- quack/dense_gemm_sm90.py,sha256=Dff0GbIv92uTjrtsUE1GjVKCtwSf6_5KZbrqYZm-ZMY,110418
7
- quack/fast_math.py,sha256=XqXVvKLSxXC3c9tIGLvKVRWdPsmjAa_O4C0plmsfZ0w,3106
8
- quack/gemm_config.py,sha256=Gz4dkHH1Uwg9IdW-x5W_5tjdaFHBfxq4bn7hJx_xu5s,1789
9
- quack/gemm_interface.py,sha256=XHgxo08d8LIu6dTlQKBOBJtjCegUB5uLh4k9hC-5mvY,9525
10
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
11
- quack/linear.py,sha256=Wd0KeXWvWjbkKrgW4Av1ud2v_mbhzf1RvubF7BYhcw4,6425
12
- quack/lse.py,sha256=aANOleIYREyrkUQM9cfJ9Gt63eawMb2KVd7YAGWNoZU,2092
13
- quack/mlp.py,sha256=D9V7aIfvoBMzhKwN8ZE6GlSOmwFJe_JGqgOvQprU0OQ,8224
14
- quack/pipeline.py,sha256=SwvRZAR4RqYH60wAFC3OTu5DisN1XDMv5umQF4czJW4,5867
15
- quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
16
- quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
17
- quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
18
- quack/symmetric_dense_gemm_sm90.py,sha256=t-6eLasZwyu1NW4HpnvVBBPOvfqUzOg8VHe9sJQYdmg,88637
19
- quack/tensormap_manager.py,sha256=pzBNwLCB8kV_yp8X8_BoDdtbwWeht2jrgRhyyfVIcMI,5261
20
- quack/tile_scheduler.py,sha256=mImjD2LuIVchM6USJoJY4-CSG54jGuwyLIvFG6LTP9Y,42205
21
- quack/topk.py,sha256=1pObblNJnxKLaE_T3qGvaMnUua0dqG2en9OU5PSp71s,9020
22
- quack/utils.py,sha256=4ViEFgHecaX5wcYpO6XzTCzdnuZv2rniUJAJH5Ta0bA,24981
23
- quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
24
- quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
25
- quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
26
- quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
27
- quack_kernels-0.1.11.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
28
- quack_kernels-0.1.11.dist-info/METADATA,sha256=WTYlk9lmhr4Jkin71stp3h-NrBdme-8OrBc7lAf4vSw,286
29
- quack_kernels-0.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- quack_kernels-0.1.11.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
31
- quack_kernels-0.1.11.dist-info/RECORD,,