quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/utils.py CHANGED
@@ -1,15 +1,14 @@
1
1
  # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
2
 
3
- import operator
4
3
  import math
5
- from typing import Callable, Optional, Tuple
4
+ from typing import Optional, Tuple, Type, Union
6
5
 
7
6
  import cutlass
8
7
  import cutlass.cute as cute
9
8
 
10
- from cutlass import Float32
9
+ from cutlass import Float32, Int32
11
10
  from cutlass.cutlass_dsl import T, dsl_user_op
12
- from cutlass._mlir.dialects import llvm, vector
11
+ from cutlass._mlir.dialects import llvm, nvvm, vector
13
12
  from cutlass.cute.runtime import from_dlpack
14
13
 
15
14
 
@@ -23,46 +22,20 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
23
22
  )
24
23
 
25
24
 
26
- @cute.jit
27
- def warp_reduce(
28
- val: cute.TensorSSA | cute.Numeric,
29
- op: Callable,
30
- width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
31
- ) -> cute.TensorSSA | cute.Numeric:
32
- if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
33
- res = cute.make_fragment(val.shape, val.dtype)
34
- res.store(val)
35
- for i in cutlass.range_constexpr(cute.size(val.shape)):
36
- res[i] = warp_reduce(res[i], op, width)
37
- return res.load()
38
- else:
39
- for i in cutlass.range_constexpr(int(math.log2(width))):
40
- val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
41
- return val
42
-
43
-
44
- @cute.jit
45
- def block_reduce(
46
- val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0
47
- ) -> cute.Numeric:
48
- """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)"""
49
- lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
50
- warps_per_row = cute.size(reduction_buffer.shape[1])
51
- row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
52
- if lane_idx == 0:
53
- reduction_buffer[row_idx, col_idx] = val
54
- cute.arch.barrier()
55
- block_reduce_val = init_val
56
- if lane_idx < warps_per_row:
57
- block_reduce_val = reduction_buffer[row_idx, lane_idx]
58
- return warp_reduce(block_reduce_val, op)
59
-
60
-
61
25
  @dsl_user_op
62
26
  def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
63
27
  return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
64
28
 
65
29
 
30
+ @cute.jit
31
+ def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32:
32
+ if cutlass.const_expr(isinstance(x, cute.Pointer)):
33
+ return Float32(cute.make_tensor(x, cute.make_layout(1))[0])
34
+ else:
35
+ assert isinstance(x, Float32)
36
+ return x
37
+
38
+
66
39
  @dsl_user_op
67
40
  def set_block_rank(
68
41
  smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None
@@ -100,208 +73,31 @@ def store_shared_remote(
100
73
  ).ir_value()
101
74
  if cutlass.const_expr(isinstance(val, float)):
102
75
  val = Float32(val)
103
- assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
104
- suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
76
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
77
+ suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
78
+ constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
105
79
  llvm.inline_asm(
106
80
  None,
107
81
  [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
108
82
  f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
109
- f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
83
+ f"r,{constraint},r",
110
84
  has_side_effects=True,
111
85
  is_align_stack=False,
112
86
  asm_dialect=llvm.AsmDialect.AD_ATT,
113
87
  )
114
88
 
115
89
 
116
- @cute.jit
117
- def cluster_reduce(
118
- val: cute.Numeric,
119
- op: Callable,
120
- reduction_buffer: cute.Tensor,
121
- mbar_ptr: cute.Pointer,
122
- init_val: cute.Numeric = 0.0,
123
- phase: Optional[cutlass.Int32] = None,
124
- ) -> cute.Numeric:
125
- """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
126
- cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
127
- lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
128
- rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
129
- row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
130
- if warp_idx == 0:
131
- with cute.arch.elect_one():
132
- num_warps = rows_per_block * warps_per_row
133
- cute.arch.mbarrier_arrive_and_expect_tx(
134
- mbar_ptr,
135
- num_warps * cluster_n * reduction_buffer.element_type.width // 8,
136
- )
137
- if lane_idx < cluster_n:
138
- store_shared_remote(
139
- val,
140
- elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
141
- mbar_ptr,
142
- peer_cta_rank_in_cluster=lane_idx,
90
+ @dsl_user_op
91
+ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32:
92
+ return Float32(
93
+ nvvm.fmin(
94
+ T.f32(),
95
+ Float32(a).ir_value(loc=loc, ip=ip),
96
+ Float32(b).ir_value(loc=loc, ip=ip),
97
+ loc=loc,
98
+ ip=ip,
143
99
  )
144
- cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
145
- block_reduce_val = init_val
146
- num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
147
- for i in cutlass.range_constexpr(num_iter):
148
- idx = lane_idx + i * cute.arch.WARP_SIZE
149
- if idx < cute.size(reduction_buffer, mode=[1]):
150
- block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
151
- return warp_reduce(block_reduce_val, op)
152
-
153
-
154
- @cute.jit
155
- def block_or_cluster_reduce(
156
- val: cute.Numeric,
157
- op: Callable,
158
- reduction_buffer: cute.Tensor,
159
- mbar_ptr: Optional[cute.Pointer],
160
- phase: Optional[cutlass.Int32] = None,
161
- init_val: cute.Numeric = 0.0,
162
- ) -> cute.Numeric:
163
- """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
164
- if cutlass.const_expr(mbar_ptr is None):
165
- return block_reduce(val, op, reduction_buffer, init_val=init_val)
166
- else:
167
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val)
168
-
169
-
170
- @cute.jit
171
- def row_reduce(
172
- x: cute.TensorSSA | cute.Numeric,
173
- op: cute.ReductionOp,
174
- threads_per_row: cutlass.Constexpr[int],
175
- reduction_buffer: Optional[cute.Tensor] = None,
176
- mbar_ptr: Optional[cute.Pointer] = None,
177
- phase: Optional[cutlass.Int32] = None,
178
- init_val: cute.Numeric = 0.0,
179
- hook_fn: Optional[Callable] = None,
180
- ) -> cute.Numeric:
181
- """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))"""
182
- if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
183
- val = x.reduce(op, init_val=init_val, reduction_profile=0)
184
- else:
185
- val = x
186
- warp_op = {
187
- cute.ReductionOp.ADD: operator.add,
188
- cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
189
- cute.ReductionOp.MIN: min,
190
- cute.ReductionOp.MUL: operator.mul,
191
- }[op]
192
- val = warp_reduce(
193
- val,
194
- warp_op,
195
- width=min(threads_per_row, cute.arch.WARP_SIZE),
196
100
  )
197
- if cutlass.const_expr(hook_fn is not None):
198
- hook_fn()
199
- if cutlass.const_expr(reduction_buffer is not None):
200
- warps_per_row, cluster_n = reduction_buffer.shape[1]
201
- assert (
202
- cluster_n == 1 or mbar_ptr is not None
203
- ), "mbar_ptr must be provided for cluster reduction"
204
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
205
- val = block_or_cluster_reduce(
206
- val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val
207
- )
208
- return val
209
-
210
-
211
- @cute.jit
212
- def online_softmax_reduce(
213
- x: cute.TensorSSA,
214
- threads_per_row: cutlass.Constexpr[int],
215
- reduction_buffer: Optional[cute.Tensor] = None,
216
- mbar_ptr: Optional[cute.Pointer] = None,
217
- hook_fn: Optional[Callable] = None,
218
- phase: Optional[cutlass.Int32] = None,
219
- return_exp_x: bool = False,
220
- ) -> [Float32, Float32, Optional[cute.TensorSSA]]:
221
- assert x.dtype == Float32, "x must be of type Float32"
222
- """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
223
- max_x = warp_reduce(
224
- x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
225
- cute.arch.fmax,
226
- width=min(threads_per_row, cute.arch.WARP_SIZE),
227
- )
228
- log2_e = math.log2(math.e)
229
- exp_x = exp2f(x * log2_e - (max_x * log2_e))
230
- # exp_x = exp2f((x - max_x) * log2_e)
231
- sum_exp_x = warp_reduce(
232
- exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
233
- operator.add,
234
- width=min(threads_per_row, cute.arch.WARP_SIZE),
235
- )
236
- if cutlass.const_expr(hook_fn is not None):
237
- hook_fn()
238
- if cutlass.const_expr(reduction_buffer is not None):
239
- rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
240
- assert (
241
- cluster_n == 1 or mbar_ptr is not None
242
- ), "mbar_ptr must be provided for cluster reduction"
243
- if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
244
- assert (
245
- reduction_buffer.element_type == cutlass.Int64
246
- ), "reduction_buffer must be of type cute.Int64"
247
- lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
248
- row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
249
- if cutlass.const_expr(mbar_ptr is None):
250
- if lane_idx == 0:
251
- reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
252
- cute.arch.barrier()
253
- max_x_single_warp = -Float32.inf
254
- sum_exp_x = 0.0
255
- if lane_idx < warps_per_row:
256
- max_x_single_warp, sum_exp_x = i64_to_f32x2(reduction_buffer[row_idx, lane_idx])
257
- max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
258
- sum_exp_x *= exp2f((max_x_single_warp - max_x_final) * log2_e)
259
- sum_exp_x = warp_reduce(sum_exp_x, operator.add)
260
- if cutlass.const_expr(return_exp_x):
261
- exp_x *= exp2f((max_x - max_x_final) * log2_e)
262
- max_x = max_x_final
263
- else:
264
- cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
265
- if warp_idx == 0:
266
- with cute.arch.elect_one():
267
- num_warps = rows_per_block * warps_per_row
268
- cute.arch.mbarrier_arrive_and_expect_tx(
269
- mbar_ptr,
270
- num_warps * cluster_n * reduction_buffer.element_type.width // 8,
271
- )
272
- if lane_idx < cluster_n:
273
- store_shared_remote(
274
- f32x2_to_i64(max_x, sum_exp_x),
275
- elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
276
- mbar_ptr,
277
- peer_cta_rank_in_cluster=lane_idx,
278
- )
279
- cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
280
- num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
281
- max_x_single_warp = cute.make_fragment(num_iter, Float32)
282
- max_x_single_warp.fill(-Float32.inf)
283
- sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
284
- sum_exp_x_single_warp.fill(0.0)
285
- for i in cutlass.range_constexpr(num_iter):
286
- idx = lane_idx + i * cute.arch.WARP_SIZE
287
- if idx < cute.size(reduction_buffer, mode=[1]):
288
- max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
289
- reduction_buffer[row_idx, idx]
290
- )
291
- max_x_final = max_x_single_warp.load().reduce(
292
- cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0
293
- )
294
- max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
295
- sum_exp_x = 0.0
296
- for i in cutlass.range_constexpr(num_iter):
297
- sum_exp_x += sum_exp_x_single_warp[i] * exp2f(
298
- (max_x_single_warp[i] - max_x_final) * log2_e
299
- )
300
- sum_exp_x = warp_reduce(sum_exp_x, operator.add)
301
- if cutlass.const_expr(return_exp_x):
302
- exp_x *= exp2f((max_x - max_x_final) * log2_e)
303
- max_x = max_x_final
304
- return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
305
101
 
306
102
 
307
103
  @cute.jit
@@ -337,6 +133,21 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
337
133
  )
338
134
 
339
135
 
136
+ @dsl_user_op
137
+ def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
138
+ return Float32(
139
+ llvm.inline_asm(
140
+ T.f32(),
141
+ [Float32(a).ir_value(loc=loc, ip=ip)],
142
+ "sqrt.approx.ftz.f32 $0, $1;",
143
+ "=f,f",
144
+ has_side_effects=False,
145
+ is_align_stack=False,
146
+ asm_dialect=llvm.AsmDialect.AD_ATT,
147
+ )
148
+ )
149
+
150
+
340
151
  @dsl_user_op
341
152
  def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
342
153
  return Float32(
@@ -352,6 +163,73 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
352
163
  )
353
164
 
354
165
 
166
+ @dsl_user_op
167
+ def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
168
+ return Int32(
169
+ llvm.inline_asm(
170
+ T.i32(),
171
+ [Float32(a).ir_value(loc=loc, ip=ip)],
172
+ "cvt.rpi.ftz.s32.f32 $0, $1;",
173
+ "=r,f",
174
+ has_side_effects=False,
175
+ is_align_stack=False,
176
+ asm_dialect=llvm.AsmDialect.AD_ATT,
177
+ )
178
+ )
179
+
180
+
181
+ @dsl_user_op
182
+ def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32:
183
+ return Int32(
184
+ llvm.inline_asm(
185
+ T.i32(),
186
+ [
187
+ Int32(a).ir_value(loc=loc, ip=ip),
188
+ Int32(b).ir_value(loc=loc, ip=ip),
189
+ Int32(c).ir_value(loc=loc, ip=ip),
190
+ ],
191
+ "prmt.b32 $0, $1, $2, $3;",
192
+ "=r,r,r,r",
193
+ has_side_effects=False,
194
+ is_align_stack=False,
195
+ asm_dialect=llvm.AsmDialect.AD_ATT,
196
+ )
197
+ )
198
+
199
+
200
+ @cute.jit
201
+ def permute_gated_Cregs_b16(t: cute.Tensor) -> None:
202
+ assert t.element_type.width == 16
203
+ assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation"
204
+ t_u32 = cute.recast_tensor(t, Int32)
205
+
206
+ quad_idx = cute.arch.lane_idx() % 4
207
+ lane_03 = quad_idx == 0 or quad_idx == 3
208
+ selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054)
209
+ selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276)
210
+ # upper_map = [0, 3, 1, 2]
211
+ # lower_map = [1, 2, 0, 3]
212
+ # upper_idx = upper_map[quad_idx]
213
+ # indexing isn't supported so we have to do arithmetic
214
+ upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2
215
+ lower_idx = upper_idx ^ 1
216
+
217
+ # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
218
+ width = 4
219
+ mask = cute.arch.WARP_SIZE - width
220
+ clamp = cute.arch.WARP_SIZE - 1
221
+ mask_and_clamp = mask << 8 | clamp
222
+
223
+ for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True):
224
+ upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1]
225
+ upper0 = upper if lane_03 else lower
226
+ lower0 = lower if lane_03 else upper
227
+ upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp)
228
+ lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp)
229
+ t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper)
230
+ t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower)
231
+
232
+
355
233
  @cute.jit
356
234
  def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
357
235
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
@@ -377,7 +255,7 @@ def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Nu
377
255
  tXpX: Predicate tensor indicating valid elements
378
256
  fill_value: Value to fill OOB locations with
379
257
  """
380
- tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
258
+ tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
381
259
  tXrX_fill.fill(fill_value)
382
260
  for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]):
383
261
  for rest_k in cutlass.range_constexpr(tXsX.shape[2]):
@@ -446,3 +324,98 @@ def coord_offset_i64(
446
324
  assumed_align=tensor.iterator.max_alignment,
447
325
  )
448
326
  return cute.make_tensor(new_ptr, tensor.layout)
327
+
328
+
329
+ @cute.jit
330
+ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
331
+ if cutlass.const_expr(lane is None):
332
+ lane = cute.arch.lane_idx()
333
+ for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
334
+ offset = 1 << i
335
+ # Very important that we set mask_and_clamp to 0
336
+ partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
337
+ if lane >= offset:
338
+ val += partial_sum
339
+ return val
340
+
341
+
342
+ def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout:
343
+ """
344
+ For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
345
+ For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
346
+ """
347
+ acc_layout_col_major = cute.make_layout(acc_layout.shape)
348
+ acc_layout_mn = cute.make_layout(
349
+ (
350
+ (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
351
+ (
352
+ acc_layout_col_major.shape[0][0],
353
+ *acc_layout_col_major.shape[0][2:],
354
+ acc_layout_col_major.shape[2],
355
+ ), # MMA_N
356
+ *acc_layout_col_major.shape[3:],
357
+ ),
358
+ stride=(
359
+ (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
360
+ (
361
+ acc_layout_col_major.stride[0][0],
362
+ *acc_layout_col_major.stride[0][2:],
363
+ acc_layout_col_major.stride[2],
364
+ ), # MMA_N
365
+ *acc_layout_col_major.stride[3:],
366
+ ),
367
+ )
368
+ return cute.composition(acc_layout, acc_layout_mn)
369
+
370
+
371
+ def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor:
372
+ return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout))
373
+
374
+
375
+ @dsl_user_op
376
+ def sm90_get_smem_load_op(
377
+ layout_c: cutlass.utils.LayoutEnum,
378
+ elem_ty_c: Type[cutlass.Numeric],
379
+ *,
380
+ loc=None,
381
+ ip=None,
382
+ ) -> cute.CopyAtom:
383
+ """
384
+ Selects the largest vectorized smem load atom available subject to constraint of gmem layout.
385
+
386
+ Parameters:
387
+ -----------
388
+ layout_c : LayoutEnum
389
+ The layout enum of the output tensor D.
390
+
391
+ elem_ty_c : Type[Numeric]
392
+ The element type for output tensor D.
393
+
394
+ Returns:
395
+ --------
396
+ Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters.
397
+ """
398
+
399
+ if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta):
400
+ raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
401
+ is_m_major = layout_c.is_m_major_c()
402
+ if elem_ty_c.width == 16:
403
+ return cute.make_copy_atom(
404
+ cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
405
+ )
406
+ else:
407
+ return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
408
+
409
+
410
+ @dsl_user_op
411
+ def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
412
+ return nvvm.atomicrmw(
413
+ res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
414
+ )
415
+
416
+
417
+ @dsl_user_op
418
+ def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32:
419
+ return nvvm.atomicrmw(
420
+ res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value()
421
+ )
quack/varlen_utils.py ADDED
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
+
6
+ import cutlass.cute as cute
7
+
8
+ from quack.cute_dsl_utils import ArgumentsBase
9
+
10
+
11
+ # Grouping arguments together that should be passed to __call__
12
+ @dataclass
13
+ class VarlenArguments(ArgumentsBase):
14
+ mCuSeqlensM: Optional[cute.Tensor] = None
15
+ mCuSeqlensK: Optional[cute.Tensor] = None
16
+ mTensormaps: Optional[cute.Tensor] = None
17
+
18
+ def __post_init__(self):
19
+ if self.mCuSeqlensM is not None or self.mCuSeqlensK is not None:
20
+ assert (
21
+ self.mTensormaps is not None
22
+ ), "mTensormaps must be provided if mCuSeqlensM or mCuSeqlensK is provided"
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.10
3
+ Version: 0.2.0
4
4
  Requires-Python: >=3.12
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -0,0 +1,37 @@
1
+ quack/__init__.py,sha256=fGBYbb9JlaNT7HdtUTbUnuAkL5G2Dg8XZAA5Ir1R-ow,364
2
+ quack/activation.py,sha256=ysXaVUXX2yGQC5o4ZVeRXw_fDIHOrqnzpHJaIsc0kHc,10271
3
+ quack/autotuner.py,sha256=czO6JrYL0EJpOeJOYDSsVdrJaFuwfL3vTdG8QfL1F34,10792
4
+ quack/cross_entropy.py,sha256=Kc3P83Vsu1nGaCu7llsO3vct3J_t3frRYPxij7JfHMA,28619
5
+ quack/cute_dsl_utils.py,sha256=D2Pw7rzX9jY8u8wikIPvPvinmFLCDeZg95HPBLqGej4,4635
6
+ quack/dense_gemm_sm100.py,sha256=hKBNC34UxdctrTKVP68nvANZl4Dq2rnUjRcweESEq3g,109965
7
+ quack/dense_gemm_sm90.py,sha256=TjnjHnjhAwWH5YQWsFlADq07xSxtsprkw_p2Cy0yw7I,100407
8
+ quack/fast_math.py,sha256=E1XUqfUt0_n9BPZNggF-UDzZ6anso9bYUrwqafemWvQ,2297
9
+ quack/gemm_act_sm90.py,sha256=N5UAFWZvw1na22Vh5JSGgcdqZ2zI6kQMBVOLxYbCAUU,14332
10
+ quack/gemm_config.py,sha256=gbYjPFeyT5wAhVwFQroRHlHoMKEJqAWX9P8wWy04l8Q,2258
11
+ quack/gemm_dact_sm90.py,sha256=KCXgjOzdamSDexwrwf_pX2r-ippPRirbClrlU6BP7b8,4990
12
+ quack/gemm_interface.py,sha256=_JTpE7zQw6NUw-v65Wql_XUOZBfW0oSEgiMnharTJU4,20501
13
+ quack/gemm_wrapper_utils.py,sha256=aMMtu-Ojhtjay_5xJH4AjP-JRVks1AB8jmtNme_DIqU,5960
14
+ quack/layernorm.py,sha256=JkK0sVdUfZ-SmoBmNqLF3wCiszDbdorvcBH2julv0Vg,13560
15
+ quack/linear.py,sha256=SrhRiAFjC7ONIMVmiNu-kSPLHNUyaCXt59a1f_5nNXo,9383
16
+ quack/linear_cross_entropy.py,sha256=Zhy_gdMsKHOie-jntBaqIuiDJtkiq6qEBwnyuWwIRw4,10092
17
+ quack/mlp.py,sha256=YjdwQRwEePA9KyidFXp5H1-lxiJc8dZ41vl8Fv8pgss,2259
18
+ quack/pipeline.py,sha256=DyCwZX8WvoUBFcMBz7CeYm9VUM31haEGgBhAzmxu8cE,5519
19
+ quack/reduce.py,sha256=hsYByu6haCZjLTLB-qpYmKDjqS2UqlwPgfWTup38GNA,10341
20
+ quack/reduction_base.py,sha256=CT-t_j7z8H1ByD9FkQYDRik_-THMDFv9QoXHmr9Xx9E,3636
21
+ quack/rmsnorm.py,sha256=93qlTPjY9JBm3R5M-HeHse1PbAfD9931G3OFs71yo_g,48998
22
+ quack/softmax.py,sha256=Mq3_2Ul8H64zeGUI9wOKEpIISJnrCcHQpZvk2sb10Tg,17101
23
+ quack/symmetric_dense_gemm_sm90.py,sha256=2UXooIpClT2izdyGis1XaIgYYlLj-7MrcOMg2yR7YCk,88694
24
+ quack/tensormap_manager.py,sha256=Ts3Mxp0_es2RNA0ffvUjWMXN79lsfWEBZ0DQYhtbcnw,5338
25
+ quack/tile_scheduler.py,sha256=8qqYmx6GpQzt8XiidcrdLIaWf0TGbJVdwKFfeb1X_us,42265
26
+ quack/topk.py,sha256=RQl-23lIicQ9ry9Njur8i0JGem_WbO_Gchr6jy8EtVM,9185
27
+ quack/utils.py,sha256=tiqeJZiPPFl5irQWCUd7dTPA_OAv4SjHUW5S-u9wO8Y,14526
28
+ quack/varlen_utils.py,sha256=vkduMEpo5bJJvZRNnIcKPb6pp1wD34vaIpMIB0ZGIZA,681
29
+ quack/sort/bitonic_sort.py,sha256=8t0SG1a6iEpYIlY8YM_AWvm4aN-4AA4vEzdBuJMJm9g,4768
30
+ quack/sort/generate_sorting_networks.py,sha256=vkJBOjTVEinQkWT4OtFqOWxFVdTIPoNAQocneKc9-rM,14477
31
+ quack/sort/sorting_networks.py,sha256=l_26zi3gXD_z-tnm2eAczRrmE-mbaz00KmqH6ONivL8,9686
32
+ quack/sort/utils.py,sha256=Mkr-l97RMAV-ZoNrwuzA1U3KO0Wjr38CV9Jm7ScyZoI,1090
33
+ quack_kernels-0.2.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ quack_kernels-0.2.0.dist-info/METADATA,sha256=DAeQymRUqp7lSfSTNyS7TZF3oWcFzCKriGJ2p8JLu6A,285
35
+ quack_kernels-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ quack_kernels-0.2.0.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
37
+ quack_kernels-0.2.0.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- quack/__init__.py,sha256=4tLchTx7d0d1ZVg6psRjjoXAWKHqzIWRF5mUk8ZdgkQ,204
2
- quack/cross_entropy.py,sha256=xsg2bXZ4wNvusBARhN4PwAzm5PbejEcfwj71nR7bzuE,20852
3
- quack/dense_gemm_sm90.py,sha256=jULXfAQkRh1SUAOpesx8wouY-GLDCm05Fb5LynozSl8,59932
4
- quack/layernorm.py,sha256=1WUspbr6ktPZ25O00kKs-FK_lm_Fejat72BMV8tBSfw,13504
5
- quack/reduction_base.py,sha256=4nAzkZR1yoQVA4Lc-GpU0XMjS5ARAmvYdeE0Doy7UCU,3789
6
- quack/rmsnorm.py,sha256=bJEHqc8ila-LTGco-tNNCUyFBjJ2UdXeoMplYNJPXFI,32740
7
- quack/softmax.py,sha256=3-5P_ORBrfQ6JYTIzgDs9jwmV7Za73SogaX7q9M7GCM,16698
8
- quack/utils.py,sha256=RZq-7YA8UMUizHpVyZM1we4zGm9NaC178M2g2HXdjmE,17799
9
- quack_kernels-0.1.10.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- quack_kernels-0.1.10.dist-info/METADATA,sha256=baMTwibt6u0IQb8YJFFhCY0RD3Aervf5sl6EpYF6IQ8,286
11
- quack_kernels-0.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
- quack_kernels-0.1.10.dist-info/top_level.txt,sha256=6e4Jr_vNJbZTYwlO_Ahf_sDeHDE0zcqcf7Le11FKxxo,6
13
- quack_kernels-0.1.10.dist-info/RECORD,,