quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.2.5"
1
+ __version__ = "0.2.6"
2
2
 
3
3
  import os
4
4
 
quack/activation.py CHANGED
@@ -2,18 +2,24 @@
2
2
 
3
3
  import math
4
4
  from typing import Tuple
5
+ from functools import partial
5
6
 
6
7
  import cutlass.cute as cute
7
8
  from cutlass import Float32, Boolean, const_expr
8
9
  from cutlass.cutlass_dsl import T, dsl_user_op
9
- from cutlass._mlir.dialects import llvm
10
-
11
- import quack.utils as utils
10
+ from cutlass._mlir.dialects import llvm, nvvm
12
11
 
13
12
 
14
13
  F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
15
14
 
16
15
 
16
+ sub_packed_f32x2 = partial(
17
+ cute.arch.calc_packed_f32x2_op,
18
+ src_c=None,
19
+ calc_func=nvvm.sub_packed_f32x2,
20
+ )
21
+
22
+
17
23
  @dsl_user_op
18
24
  def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
25
  return Float32(
@@ -35,9 +41,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
35
41
  # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
42
  return 0.5 + 0.5 * tanh(0.5 * x)
37
43
  else:
38
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
44
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
39
45
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
- return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
46
+ return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
41
47
 
42
48
 
43
49
  @dsl_user_op
@@ -75,7 +81,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
75
81
  return cute.arch.fmax(x, Float32(0.0)) * x
76
82
  else:
77
83
  relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
- return utils.mul_packed_f32x2(relu_x, x)
84
+ return cute.arch.mul_packed_f32x2(relu_x, x)
79
85
 
80
86
 
81
87
  @dsl_user_op
@@ -98,8 +104,8 @@ def drelu_sq(
98
104
  return dx, relu_sq_out
99
105
  else:
100
106
  relu_x = relu(x)
101
- relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
- dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
107
+ relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
108
+ dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
103
109
  return dx, relu_sq_out
104
110
 
105
111
 
@@ -119,14 +125,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
119
125
  * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
126
  )
121
127
  else:
122
- x_sq = utils.mul_packed_f32x2(x, x)
123
- x_sq_scaled = utils.fma_packed_f32x2(
128
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
129
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
124
130
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
131
  )
126
- z = utils.mul_packed_f32x2(x, x_sq_scaled)
132
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
127
133
  tanh_z = (tanh(z[0]), tanh(z[1]))
128
- x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
- return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
134
+ x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
135
+ return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
130
136
 
131
137
 
132
138
  @dsl_user_op
@@ -167,28 +173,28 @@ def dgelu_tanh_approx(
167
173
  return dx, gelu_out
168
174
  else:
169
175
  # Compute z = x * (c1 + c2 * x^2)
170
- x_sq = utils.mul_packed_f32x2(x, x)
171
- x_sq_scaled = utils.fma_packed_f32x2(
176
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
177
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
172
178
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
179
  )
174
- z = utils.mul_packed_f32x2(x, x_sq_scaled)
180
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
175
181
  tanh_z = (tanh(z[0]), tanh(z[1]))
176
- half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
- gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
182
+ half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
183
+ gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
184
 
179
185
  # Compute gradient
180
186
  # sech^2(z) = 1 - tanh^2(z)
181
- sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
187
+ sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
188
  # dz/dx = c1 + 3 * c2 * x^2
183
- dz_dx = utils.fma_packed_f32x2(
189
+ dz_dx = cute.arch.fma_packed_f32x2(
184
190
  x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
191
  )
186
192
  # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
- sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
- x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
- dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
193
+ sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
194
+ x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
195
+ dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
196
 
191
- dx = utils.mul_packed_f32x2(dout, dgelu)
197
+ dx = cute.arch.mul_packed_f32x2(dout, dgelu)
192
198
  return dx, gelu_out
193
199
 
194
200
 
@@ -204,15 +210,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
204
210
  )
205
211
  else:
206
212
  log2_e = math.log2(math.e)
207
- x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
213
+ x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
208
214
  x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
- x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
215
+ x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
210
216
  log_x_exp_p1 = (
211
217
  cute.math.log2(x_exp_p1[0], fastmath=True),
212
218
  cute.math.log2(x_exp_p1[1], fastmath=True),
213
219
  )
214
220
  ln2 = math.log(2.0)
215
- softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
221
+ softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
222
  use_linear_0 = Boolean(x[0] > 20.0)
217
223
  use_linear_1 = Boolean(x[1] > 20.0)
218
224
  return (
@@ -241,9 +247,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) ->
241
247
  # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
248
  return x_half * tanh(x_half) + x_half
243
249
  else:
244
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
250
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
251
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
- return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
252
+ return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
247
253
 
248
254
 
249
255
  @dsl_user_op
@@ -251,7 +257,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32
251
257
  if const_expr(not isinstance(x, tuple)):
252
258
  return silu(x) * y
253
259
  else:
254
- return utils.mul_packed_f32x2(silu(x), y)
260
+ return cute.arch.mul_packed_f32x2(silu(x), y)
255
261
 
256
262
 
257
263
  @dsl_user_op
@@ -301,20 +307,22 @@ def dswiglu(
301
307
  # Compute sigmoid(x) and silu(x)
302
308
  if const_expr(not already_halved):
303
309
  sigmoid_x = sigmoid(x)
304
- silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
310
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
305
311
  else:
306
312
  tanh_x = (tanh(x[0]), tanh(x[1]))
307
- sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
- silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
- silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
313
+ sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
314
+ silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
315
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
310
316
  # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
- sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
317
+ sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
312
318
  sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
319
  )
314
- d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
- dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
320
+ d_silu_x_dout = cute.arch.fma_packed_f32x2(
321
+ sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
322
+ )
323
+ dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
316
324
  dy = silu_x_dout
317
- swiglu_out = utils.mul_packed_f32x2(silu_x, y)
325
+ swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
318
326
  return dx, dy, swiglu_out
319
327
 
320
328
 
@@ -334,11 +342,11 @@ def swiglu_oai(
334
342
  silu_x = x_half * tanh(alpha * x_half) + x_half
335
343
  return silu_x * y + silu_x
336
344
  else:
337
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
- alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
345
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
346
+ alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
339
347
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
- silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
- return utils.fma_packed_f32x2(silu_x, y, silu_x)
348
+ silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
349
+ return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
342
350
 
343
351
 
344
352
  @dsl_user_op
@@ -370,22 +378,22 @@ def dswiglu_oai(
370
378
  return dx, dy, swiglu_out
371
379
  else:
372
380
  # Compute sigmoid(alpha * x)
373
- alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
381
+ alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
382
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
- sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
- silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
- silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
383
+ sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
384
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
385
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
378
386
  # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
- silu_x_minus_product = utils.fma_packed_f32x2(
387
+ silu_x_minus_product = cute.arch.fma_packed_f32x2(
380
388
  silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
389
  )
382
- sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
390
+ sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
383
391
  (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
392
  )
385
- d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
- dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
393
+ d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
394
+ dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
395
  dy = silu_x_dout
388
- swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
396
+ swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
389
397
  return dx, dy, swiglu_out
390
398
 
391
399
 
@@ -400,7 +408,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
400
408
  return sigmoid_x * y # FMUL
401
409
  else:
402
410
  sigmoid_x = sigmoid(x)
403
- return utils.mul_packed_f32x2(sigmoid_x, y)
411
+ return cute.arch.mul_packed_f32x2(sigmoid_x, y)
404
412
 
405
413
 
406
414
  @dsl_user_op
@@ -430,11 +438,11 @@ def dglu(
430
438
  return dx, dy, glu_out
431
439
  else:
432
440
  sigmoid_x = sigmoid(x)
433
- sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
- glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
441
+ sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
442
+ glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
435
443
  # dx = (y - glu_out) * sigmoid_x_dout
436
- y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
- dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
444
+ y_minus_glu_out = sub_packed_f32x2(y, glu_out)
445
+ dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
446
  dy = sigmoid_x_dout
439
447
  return dx, dy, glu_out
440
448
 
@@ -448,7 +456,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
448
456
  return cute.arch.fmax(x, Float32(0.0)) * y
449
457
  else:
450
458
  relu_x = relu(x)
451
- return utils.mul_packed_f32x2(relu_x, y)
459
+ return cute.arch.mul_packed_f32x2(relu_x, y)
452
460
 
453
461
 
454
462
  @dsl_user_op
@@ -475,10 +483,10 @@ def dreglu(
475
483
  x0_pos = Boolean(x[0] > 0)
476
484
  x1_pos = Boolean(x[1] > 0)
477
485
  relu_x = relu(x)
478
- dout_y = utils.mul_packed_f32x2(dout, y)
486
+ dout_y = cute.arch.mul_packed_f32x2(dout, y)
479
487
  dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
- dy = utils.mul_packed_f32x2(dout, relu_x)
481
- reglu_out = utils.mul_packed_f32x2(relu_x, y)
488
+ dy = cute.arch.mul_packed_f32x2(dout, relu_x)
489
+ reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
482
490
  return dx, dy, reglu_out
483
491
 
484
492
 
@@ -491,7 +499,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
491
499
  if const_expr(not isinstance(x, tuple)):
492
500
  return gelu_tanh_approx(x) * y
493
501
  else:
494
- return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
502
+ return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
495
503
 
496
504
 
497
505
  @dsl_user_op
@@ -518,7 +526,7 @@ def dgeglu(
518
526
  # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
527
  dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
528
  # Compute gradients for geglu
521
- dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
- dy = utils.mul_packed_f32x2(gelu_x, dout)
523
- geglu_out = utils.mul_packed_f32x2(gelu_x, y)
529
+ dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
530
+ dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
531
+ geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
524
532
  return dx, dy, geglu_out
quack/broadcast_utils.py CHANGED
@@ -11,7 +11,7 @@ from quack.layout_utils import make_acc_tensor_mn_view
11
11
  @cute.jit
12
12
  def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
13
  if const_expr(tCrC.element_type != Float32): # Convert to f32
14
- tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
14
+ tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32)
15
15
  tCrC_f32.store(tCrC.load().to(Float32))
16
16
  else:
17
17
  tCrC_f32 = tCrC
quack/copy_utils.py CHANGED
@@ -7,7 +7,7 @@ import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
9
  from cutlass import Int32, Boolean, const_expr
10
- from cutlass.cute.nvgpu import cpasync, warpgroup
10
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
11
11
  from cutlass.cutlass_dsl import dsl_user_op
12
12
  import cutlass.pipeline
13
13
 
@@ -52,7 +52,7 @@ def load_s2r_retile(
52
52
  ) -> cute.Tensor:
53
53
  # Will also accept dst_shape being a tensor, in which case we write into that tensor
54
54
  if const_expr(not isinstance(dst_shape, cute.Tensor)):
55
- dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
55
+ dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
56
56
  else:
57
57
  dst = dst_shape
58
58
  cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
@@ -117,7 +117,7 @@ def tiled_copy_2d(
117
117
  @cute.jit
118
118
  def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
119
119
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
120
- tApA = cute.make_fragment(
120
+ tApA = cute.make_rmem_tensor(
121
121
  cute.make_layout(
122
122
  (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
123
123
  stride=(cute.size(tAcA, mode=[2]), 0, 1),
@@ -242,9 +242,7 @@ def sm90_get_smem_load_op(
242
242
  raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
243
243
  is_m_major = layout_c.is_m_major_c()
244
244
  if elem_ty_c.width == 16:
245
- return cute.make_copy_atom(
246
- cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
247
- )
245
+ return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
248
246
  else:
249
247
  return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
250
248
 
@@ -260,7 +258,7 @@ def get_smem_store_atom(
260
258
  )
261
259
  else:
262
260
  return cute.make_copy_atom(
263
- cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
261
+ warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
264
262
  element_type,
265
263
  )
266
264
 
@@ -276,7 +274,7 @@ def get_smem_load_atom(
276
274
  )
277
275
  else:
278
276
  return cute.make_copy_atom(
279
- cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
277
+ warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
280
278
  element_type,
281
279
  )
282
280
 
@@ -368,8 +366,6 @@ def get_smem_load_A(
368
366
  tSR_sA = thr_copy.partition_S(sA)
369
367
  else:
370
368
  tSR_sA = partition_S_position_independent(thr_copy, sA)
371
- copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
372
- thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
373
369
  tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
374
370
 
375
371
  def copy_fn(src_idx: Int32, **new_kwargs):
@@ -464,10 +460,10 @@ def gather_m_get_copy_fn(
464
460
  # Read and cache indices for A
465
461
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
466
462
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
467
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
463
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
468
464
  for m in cutlass.range(rows_per_thread, unroll_full=True):
469
465
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
470
- m_idx = cute.make_fragment(rows_per_thread, Int32)
466
+ m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
471
467
  for m in cutlass.range(rows_per_thread, unroll_full=True):
472
468
  row_idx = tAcA[0, m, 0][0]
473
469
  if tApA_m[m]:
@@ -480,7 +476,7 @@ def gather_m_get_copy_fn(
480
476
  def copy_fn(src_idx, dst_idx, pred: bool = False):
481
477
  tApA_k = None
482
478
  if const_expr(pred):
483
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
479
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
484
480
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
485
481
  for k in cutlass.range(cols_per_thread, unroll_full=True):
486
482
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
@@ -538,7 +534,7 @@ def gather_k_get_copy_fn(
538
534
  # Read and cache indices for A
539
535
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
540
536
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
541
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
537
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
542
538
  for m in cutlass.range(rows_per_thread, unroll_full=True):
543
539
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
544
540
  threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
@@ -554,12 +550,12 @@ def gather_k_get_copy_fn(
554
550
  # Prefetch mAIdx early, even before smem is free
555
551
  tApA_k = None
556
552
  if const_expr(pred):
557
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
553
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
558
554
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
559
555
  for k in cutlass.range(cols_per_thread, unroll_full=True):
560
556
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
561
557
  gAIdx_cur = gAIdx[None, src_idx]
562
- k_idx = cute.make_fragment(cols_per_thread, Int32)
558
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
563
559
  for k in cutlass.range(cols_per_thread):
564
560
  col_idx = tAcA[0, 0, k][1]
565
561
  if const_expr(not pred):
@@ -576,13 +572,13 @@ def gather_k_get_copy_fn(
576
572
  ) -> Tuple[cute.Tensor, cute.Tensor]:
577
573
  tApA_k = None
578
574
  if const_expr(pred):
579
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
575
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
580
576
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
581
577
  for k in cutlass.range(cols_per_thread, unroll_full=True):
582
578
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
583
579
  a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
584
580
  sAIdx_cur = sAIdx[None, dst_idx]
585
- k_idx = cute.make_fragment(cols_per_thread, Int32)
581
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
586
582
  for k in cutlass.range(cols_per_thread):
587
583
  col_idx = tAcA[0, 0, k][1]
588
584
  k_idx[k] = sAIdx_cur[col_idx]
quack/fast_math.py CHANGED
@@ -1,80 +1,33 @@
1
1
  # Copyright (c) 2025, Tri Dao.
2
2
 
3
- from typing import Tuple
4
- from dataclasses import dataclass
5
-
6
3
  import cutlass
7
4
  import cutlass.cute as cute
8
- from cutlass import Int32, Uint32
9
- from cutlass.cutlass_dsl import T, dsl_user_op
10
- from cutlass._mlir.dialects import llvm
11
-
12
- from quack.cute_dsl_utils import ParamsBase
13
-
14
-
15
- @cute.jit
16
- def clz(x: Int32) -> Int32:
17
- # for i in cutlass.range_constexpr(32):
18
- # if (1 << (31 - i)) & x:
19
- # return Int32(i)
20
- # return Int32(32)
21
- # Early exit is not supported yet
22
- res = Int32(32)
23
- done = False
24
- for i in cutlass.range(32):
25
- if ((1 << (31 - i)) & x) and not done:
26
- res = Int32(i)
27
- done = True
28
- return res
29
-
30
-
31
- def find_log2(x: Int32) -> Int32:
32
- a: Int32 = Int32(31 - clz(x))
33
- return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
34
-
35
-
36
- @dsl_user_op
37
- def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
38
- return Uint32(
39
- llvm.inline_asm(
40
- T.i32(),
41
- [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
42
- "mul.hi.u32 $0, $1, $2;",
43
- "=r,r,r",
44
- has_side_effects=False,
45
- is_align_stack=False,
46
- asm_dialect=llvm.AsmDialect.AD_ATT,
47
- )
48
- )
49
-
50
-
51
- @dataclass
52
- class FastDivmod(ParamsBase):
53
- divisor: Int32
54
- multiplier: Uint32
55
- shift_right: Uint32
56
-
57
- # called by host
58
- @staticmethod
59
- def create(divisor: Int32) -> "FastDivmod":
60
- """Construct the FastDivmod object, in host code.
61
- This precomputes some values based on the divisor and is computationally expensive.
62
- """
63
- p = Uint32(31 + find_log2(divisor))
64
- divisor_u32 = Uint32(divisor)
65
- multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66
- shift_right = Uint32(p - 32)
67
- return FastDivmod(divisor, multiplier, shift_right)
68
-
69
- @cute.jit
70
- def div(self, dividend: Int32) -> Int32:
71
- return (
72
- Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
73
- if self.divisor != 1
74
- else dividend
75
- )
76
-
77
- def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
78
- quotient = self.div(dividend)
79
- remainder = dividend - quotient * self.divisor
80
- return quotient, remainder
5
+ from cutlass.base_dsl.typing import Integer
6
+ from cutlass.cutlass_dsl import dsl_user_op
7
+
8
+
9
+ class FastDivmod(cute.FastDivmodDivisor):
10
+ """We store the divisor along with the FastDivmodDivisor."""
11
+
12
+ @dsl_user_op
13
+ def __init__(
14
+ self,
15
+ divisor: Integer,
16
+ is_power_of_2: bool = None,
17
+ *,
18
+ loc=None,
19
+ ip=None,
20
+ ):
21
+ super().__init__(divisor, is_power_of_2=is_power_of_2, loc=loc, ip=ip)
22
+ self.divisor = divisor
23
+
24
+ def __extract_mlir_values__(self):
25
+ """Extract MLIR values for Host->Device transfer."""
26
+ return [self._divisor] + cutlass.extract_mlir_values(self.divisor)
27
+
28
+ def __new_from_mlir_values__(self, values):
29
+ """Reconstruct FastDivmodDivisor from MLIR values."""
30
+ new_obj = object.__new__(FastDivmod)
31
+ new_obj._divisor = values[0]
32
+ new_obj.divisor = cutlass.new_from_mlir_values(self.divisor, values[1:])
33
+ return new_obj