quack-kernels 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {quack_kernels-0.2.4/quack_kernels.egg-info → quack_kernels-0.2.6}/PKG-INFO +2 -2
  2. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/pyproject.toml +3 -2
  3. quack_kernels-0.2.6/quack/__init__.py +21 -0
  4. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/activation.py +72 -64
  5. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/broadcast_utils.py +1 -1
  6. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/copy_utils.py +143 -20
  7. quack_kernels-0.2.6/quack/cute_dsl_ptxas.py +151 -0
  8. quack_kernels-0.2.6/quack/fast_math.py +33 -0
  9. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_act.py +296 -8
  10. quack_kernels-0.2.6/quack/gemm_dact.py +731 -0
  11. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_default_epi.py +4 -4
  12. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_interface.py +363 -0
  13. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_sm100.py +62 -88
  14. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_sm90.py +68 -114
  15. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_symmetric.py +2 -6
  16. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/layout_utils.py +10 -4
  17. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/linear.py +37 -0
  18. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/pipeline.py +87 -99
  19. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/reduce.py +2 -2
  20. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/rmsnorm.py +1 -3
  21. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sm90_utils.py +34 -2
  22. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/bitonic_sort.py +4 -4
  23. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/tile_scheduler.py +310 -256
  24. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/topk.py +4 -4
  25. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/utils.py +76 -40
  26. {quack_kernels-0.2.4 → quack_kernels-0.2.6/quack_kernels.egg-info}/PKG-INFO +2 -2
  27. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack_kernels.egg-info/SOURCES.txt +1 -0
  28. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack_kernels.egg-info/requires.txt +1 -1
  29. quack_kernels-0.2.6/quack_kernels.egg-info/top_level.txt +1 -0
  30. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear.py +93 -0
  31. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear_varlen_m.py +163 -0
  32. quack_kernels-0.2.4/quack/__init__.py +0 -11
  33. quack_kernels-0.2.4/quack/fast_math.py +0 -80
  34. quack_kernels-0.2.4/quack/gemm_dact.py +0 -215
  35. quack_kernels-0.2.4/quack_kernels.egg-info/top_level.txt +0 -5
  36. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/LICENSE +0 -0
  37. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/README.md +0 -0
  38. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/autotuner.py +0 -0
  39. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/compile_utils.py +0 -0
  40. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/cross_entropy.py +0 -0
  41. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/cute_dsl_utils.py +0 -0
  42. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm.py +0 -0
  43. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_config.py +0 -0
  44. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_wrapper_utils.py +0 -0
  45. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/linear_cross_entropy.py +0 -0
  46. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/mlp.py +0 -0
  47. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/reduction_base.py +0 -0
  48. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sm100_utils.py +0 -0
  49. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/softmax.py +0 -0
  50. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/generate_sorting_networks.py +0 -0
  51. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/sorting_networks.py +0 -0
  52. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/utils.py +0 -0
  53. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/tensormap_manager.py +0 -0
  54. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/varlen_utils.py +0 -0
  55. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack_kernels.egg-info/dependency_links.txt +0 -0
  56. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/setup.cfg +0 -0
  57. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_cross_entropy.py +0 -0
  58. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_layernorm.py +0 -0
  59. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear_cross_entropy.py +0 -0
  60. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear_varlen_k.py +0 -0
  61. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_rmsnorm.py +0 -0
  62. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_softmax.py +0 -0
  63. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_symmetric_gemm.py +0 -0
  64. {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_topk.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl<4.4.0,>=4.3.4
6
+ Requires-Dist: nvidia-cutlass-dsl>=4.4.0.dev1
7
7
  Requires-Dist: torch
8
8
  Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
9
9
  Requires-Dist: torch-c-dlpack-ext
@@ -7,7 +7,7 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl>=4.3.4,<4.4.0",
10
+ "nvidia-cutlass-dsl>=4.4.0.dev1",
11
11
  "torch",
12
12
  "apache-tvm-ffi>=0.1.6,<0.2",
13
13
  "torch-c-dlpack-ext",
@@ -20,7 +20,8 @@ dev = [
20
20
  ]
21
21
 
22
22
  [tool.setuptools.packages.find]
23
- exclude = ["tests", "benchmarks"]
23
+ where = ["."]
24
+ include = ["quack*"]
24
25
 
25
26
  [tool.setuptools.dynamic]
26
27
  version = {attr = "quack.__version__"}
@@ -0,0 +1,21 @@
1
+ __version__ = "0.2.6"
2
+
3
+ import os
4
+
5
+ from quack.rmsnorm import rmsnorm
6
+ from quack.softmax import softmax
7
+ from quack.cross_entropy import cross_entropy
8
+
9
+
10
+ if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
11
+ import quack.cute_dsl_ptxas # noqa: F401
12
+
13
+ # Patch to dump ptx and then use system ptxas to compile to cubin
14
+ quack.cute_dsl_ptxas.patch()
15
+
16
+
17
+ __all__ = [
18
+ "rmsnorm",
19
+ "softmax",
20
+ "cross_entropy",
21
+ ]
@@ -2,18 +2,24 @@
2
2
 
3
3
  import math
4
4
  from typing import Tuple
5
+ from functools import partial
5
6
 
6
7
  import cutlass.cute as cute
7
8
  from cutlass import Float32, Boolean, const_expr
8
9
  from cutlass.cutlass_dsl import T, dsl_user_op
9
- from cutlass._mlir.dialects import llvm
10
-
11
- import quack.utils as utils
10
+ from cutlass._mlir.dialects import llvm, nvvm
12
11
 
13
12
 
14
13
  F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
15
14
 
16
15
 
16
+ sub_packed_f32x2 = partial(
17
+ cute.arch.calc_packed_f32x2_op,
18
+ src_c=None,
19
+ calc_func=nvvm.sub_packed_f32x2,
20
+ )
21
+
22
+
17
23
  @dsl_user_op
18
24
  def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
25
  return Float32(
@@ -35,9 +41,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
35
41
  # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
42
  return 0.5 + 0.5 * tanh(0.5 * x)
37
43
  else:
38
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
44
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
39
45
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
- return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
46
+ return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
41
47
 
42
48
 
43
49
  @dsl_user_op
@@ -75,7 +81,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
75
81
  return cute.arch.fmax(x, Float32(0.0)) * x
76
82
  else:
77
83
  relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
- return utils.mul_packed_f32x2(relu_x, x)
84
+ return cute.arch.mul_packed_f32x2(relu_x, x)
79
85
 
80
86
 
81
87
  @dsl_user_op
@@ -98,8 +104,8 @@ def drelu_sq(
98
104
  return dx, relu_sq_out
99
105
  else:
100
106
  relu_x = relu(x)
101
- relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
- dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
107
+ relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
108
+ dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
103
109
  return dx, relu_sq_out
104
110
 
105
111
 
@@ -119,14 +125,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
119
125
  * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
126
  )
121
127
  else:
122
- x_sq = utils.mul_packed_f32x2(x, x)
123
- x_sq_scaled = utils.fma_packed_f32x2(
128
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
129
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
124
130
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
131
  )
126
- z = utils.mul_packed_f32x2(x, x_sq_scaled)
132
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
127
133
  tanh_z = (tanh(z[0]), tanh(z[1]))
128
- x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
- return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
134
+ x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
135
+ return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
130
136
 
131
137
 
132
138
  @dsl_user_op
@@ -167,28 +173,28 @@ def dgelu_tanh_approx(
167
173
  return dx, gelu_out
168
174
  else:
169
175
  # Compute z = x * (c1 + c2 * x^2)
170
- x_sq = utils.mul_packed_f32x2(x, x)
171
- x_sq_scaled = utils.fma_packed_f32x2(
176
+ x_sq = cute.arch.mul_packed_f32x2(x, x)
177
+ x_sq_scaled = cute.arch.fma_packed_f32x2(
172
178
  x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
179
  )
174
- z = utils.mul_packed_f32x2(x, x_sq_scaled)
180
+ z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
175
181
  tanh_z = (tanh(z[0]), tanh(z[1]))
176
- half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
- gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
182
+ half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
183
+ gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
184
 
179
185
  # Compute gradient
180
186
  # sech^2(z) = 1 - tanh^2(z)
181
- sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
187
+ sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
188
  # dz/dx = c1 + 3 * c2 * x^2
183
- dz_dx = utils.fma_packed_f32x2(
189
+ dz_dx = cute.arch.fma_packed_f32x2(
184
190
  x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
191
  )
186
192
  # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
- sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
- x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
- dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
193
+ sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
194
+ x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
195
+ dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
196
 
191
- dx = utils.mul_packed_f32x2(dout, dgelu)
197
+ dx = cute.arch.mul_packed_f32x2(dout, dgelu)
192
198
  return dx, gelu_out
193
199
 
194
200
 
@@ -204,15 +210,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
204
210
  )
205
211
  else:
206
212
  log2_e = math.log2(math.e)
207
- x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
213
+ x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
208
214
  x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
- x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
215
+ x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
210
216
  log_x_exp_p1 = (
211
217
  cute.math.log2(x_exp_p1[0], fastmath=True),
212
218
  cute.math.log2(x_exp_p1[1], fastmath=True),
213
219
  )
214
220
  ln2 = math.log(2.0)
215
- softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
221
+ softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
222
  use_linear_0 = Boolean(x[0] > 20.0)
217
223
  use_linear_1 = Boolean(x[1] > 20.0)
218
224
  return (
@@ -241,9 +247,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) ->
241
247
  # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
248
  return x_half * tanh(x_half) + x_half
243
249
  else:
244
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
250
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
251
  tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
- return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
252
+ return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
247
253
 
248
254
 
249
255
  @dsl_user_op
@@ -251,7 +257,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32
251
257
  if const_expr(not isinstance(x, tuple)):
252
258
  return silu(x) * y
253
259
  else:
254
- return utils.mul_packed_f32x2(silu(x), y)
260
+ return cute.arch.mul_packed_f32x2(silu(x), y)
255
261
 
256
262
 
257
263
  @dsl_user_op
@@ -301,20 +307,22 @@ def dswiglu(
301
307
  # Compute sigmoid(x) and silu(x)
302
308
  if const_expr(not already_halved):
303
309
  sigmoid_x = sigmoid(x)
304
- silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
310
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
305
311
  else:
306
312
  tanh_x = (tanh(x[0]), tanh(x[1]))
307
- sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
- silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
- silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
313
+ sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
314
+ silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
315
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
310
316
  # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
- sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
317
+ sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
312
318
  sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
319
  )
314
- d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
- dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
320
+ d_silu_x_dout = cute.arch.fma_packed_f32x2(
321
+ sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
322
+ )
323
+ dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
316
324
  dy = silu_x_dout
317
- swiglu_out = utils.mul_packed_f32x2(silu_x, y)
325
+ swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
318
326
  return dx, dy, swiglu_out
319
327
 
320
328
 
@@ -334,11 +342,11 @@ def swiglu_oai(
334
342
  silu_x = x_half * tanh(alpha * x_half) + x_half
335
343
  return silu_x * y + silu_x
336
344
  else:
337
- x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
- alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
345
+ x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
346
+ alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
339
347
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
- silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
- return utils.fma_packed_f32x2(silu_x, y, silu_x)
348
+ silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
349
+ return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
342
350
 
343
351
 
344
352
  @dsl_user_op
@@ -370,22 +378,22 @@ def dswiglu_oai(
370
378
  return dx, dy, swiglu_out
371
379
  else:
372
380
  # Compute sigmoid(alpha * x)
373
- alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
381
+ alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
382
  tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
- sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
- silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
- silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
383
+ sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
384
+ silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
385
+ silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
378
386
  # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
- silu_x_minus_product = utils.fma_packed_f32x2(
387
+ silu_x_minus_product = cute.arch.fma_packed_f32x2(
380
388
  silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
389
  )
382
- sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
390
+ sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
383
391
  (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
392
  )
385
- d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
- dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
393
+ d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
394
+ dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
395
  dy = silu_x_dout
388
- swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
396
+ swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
389
397
  return dx, dy, swiglu_out
390
398
 
391
399
 
@@ -400,7 +408,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
400
408
  return sigmoid_x * y # FMUL
401
409
  else:
402
410
  sigmoid_x = sigmoid(x)
403
- return utils.mul_packed_f32x2(sigmoid_x, y)
411
+ return cute.arch.mul_packed_f32x2(sigmoid_x, y)
404
412
 
405
413
 
406
414
  @dsl_user_op
@@ -430,11 +438,11 @@ def dglu(
430
438
  return dx, dy, glu_out
431
439
  else:
432
440
  sigmoid_x = sigmoid(x)
433
- sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
- glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
441
+ sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
442
+ glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
435
443
  # dx = (y - glu_out) * sigmoid_x_dout
436
- y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
- dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
444
+ y_minus_glu_out = sub_packed_f32x2(y, glu_out)
445
+ dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
446
  dy = sigmoid_x_dout
439
447
  return dx, dy, glu_out
440
448
 
@@ -448,7 +456,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
448
456
  return cute.arch.fmax(x, Float32(0.0)) * y
449
457
  else:
450
458
  relu_x = relu(x)
451
- return utils.mul_packed_f32x2(relu_x, y)
459
+ return cute.arch.mul_packed_f32x2(relu_x, y)
452
460
 
453
461
 
454
462
  @dsl_user_op
@@ -475,10 +483,10 @@ def dreglu(
475
483
  x0_pos = Boolean(x[0] > 0)
476
484
  x1_pos = Boolean(x[1] > 0)
477
485
  relu_x = relu(x)
478
- dout_y = utils.mul_packed_f32x2(dout, y)
486
+ dout_y = cute.arch.mul_packed_f32x2(dout, y)
479
487
  dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
- dy = utils.mul_packed_f32x2(dout, relu_x)
481
- reglu_out = utils.mul_packed_f32x2(relu_x, y)
488
+ dy = cute.arch.mul_packed_f32x2(dout, relu_x)
489
+ reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
482
490
  return dx, dy, reglu_out
483
491
 
484
492
 
@@ -491,7 +499,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
491
499
  if const_expr(not isinstance(x, tuple)):
492
500
  return gelu_tanh_approx(x) * y
493
501
  else:
494
- return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
502
+ return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
495
503
 
496
504
 
497
505
  @dsl_user_op
@@ -518,7 +526,7 @@ def dgeglu(
518
526
  # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
527
  dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
528
  # Compute gradients for geglu
521
- dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
- dy = utils.mul_packed_f32x2(gelu_x, dout)
523
- geglu_out = utils.mul_packed_f32x2(gelu_x, y)
529
+ dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
530
+ dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
531
+ geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
524
532
  return dx, dy, geglu_out
@@ -11,7 +11,7 @@ from quack.layout_utils import make_acc_tensor_mn_view
11
11
  @cute.jit
12
12
  def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
13
  if const_expr(tCrC.element_type != Float32): # Convert to f32
14
- tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
14
+ tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32)
15
15
  tCrC_f32.store(tCrC.load().to(Float32))
16
16
  else:
17
17
  tCrC_f32 = tCrC
@@ -7,18 +7,19 @@ import cutlass
7
7
  import cutlass.cute as cute
8
8
 
9
9
  from cutlass import Int32, Boolean, const_expr
10
- from cutlass.cute.nvgpu import cpasync
10
+ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
11
11
  from cutlass.cutlass_dsl import dsl_user_op
12
12
  import cutlass.pipeline
13
13
 
14
14
 
15
15
  @dsl_user_op
16
16
  def cvt_copy(
17
- atom: cute.CopyAtom,
17
+ tiled_copy: cute.TiledCopy,
18
18
  src: cute.Tensor,
19
19
  dst: cute.Tensor,
20
20
  *,
21
21
  pred: Optional[cute.Tensor] = None,
22
+ retile: bool = False,
22
23
  loc=None,
23
24
  ip=None,
24
25
  **kwargs,
@@ -28,7 +29,9 @@ def cvt_copy(
28
29
  src_cvt = cute.make_fragment_like(src, dst.element_type)
29
30
  src_cvt.store(src.load().to(dst.element_type))
30
31
  src = src_cvt
31
- cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
32
+ if const_expr(retile):
33
+ src = tiled_copy.retile(src)
34
+ cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
32
35
 
33
36
 
34
37
  @dsl_user_op
@@ -49,7 +52,7 @@ def load_s2r_retile(
49
52
  ) -> cute.Tensor:
50
53
  # Will also accept dst_shape being a tensor, in which case we write into that tensor
51
54
  if const_expr(not isinstance(dst_shape, cute.Tensor)):
52
- dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip)
55
+ dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
53
56
  else:
54
57
  dst = dst_shape
55
58
  cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
@@ -114,7 +117,7 @@ def tiled_copy_2d(
114
117
  @cute.jit
115
118
  def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
116
119
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
117
- tApA = cute.make_fragment(
120
+ tApA = cute.make_rmem_tensor(
118
121
  cute.make_layout(
119
122
  (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
120
123
  stride=(cute.size(tAcA, mode=[2]), 0, 1),
@@ -239,9 +242,7 @@ def sm90_get_smem_load_op(
239
242
  raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
240
243
  is_m_major = layout_c.is_m_major_c()
241
244
  if elem_ty_c.width == 16:
242
- return cute.make_copy_atom(
243
- cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
244
- )
245
+ return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
245
246
  else:
246
247
  return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
247
248
 
@@ -257,11 +258,127 @@ def get_smem_store_atom(
257
258
  )
258
259
  else:
259
260
  return cute.make_copy_atom(
260
- cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
261
+ warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
261
262
  element_type,
262
263
  )
263
264
 
264
265
 
266
+ def get_smem_load_atom(
267
+ arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
268
+ ) -> cute.CopyAtom:
269
+ if const_expr(arch < 90 or element_type.width != 16):
270
+ return cute.make_copy_atom(
271
+ cute.nvgpu.CopyUniversalOp(),
272
+ element_type,
273
+ num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
274
+ )
275
+ else:
276
+ return cute.make_copy_atom(
277
+ warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
278
+ element_type,
279
+ )
280
+
281
+
282
+ def get_smem_store_C(
283
+ tiled_mma: cute.TiledMma,
284
+ sC: cute.Tensor,
285
+ tidx: Int32,
286
+ arch: int,
287
+ transpose: bool = False,
288
+ position_independent=False,
289
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
290
+ dtype = sC.element_type
291
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
292
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
293
+ thr_copy = tiled_copy.get_slice(tidx)
294
+ if const_expr(not position_independent):
295
+ tRS_sC = thr_copy.partition_D(sC)
296
+ else:
297
+ tRS_sC = partition_D_position_independent(thr_copy, sC)
298
+
299
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
300
+ cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
301
+
302
+ return copy_fn, thr_copy, tRS_sC
303
+
304
+
305
+ def get_smem_load_C(
306
+ tiled_mma: cute.TiledMma,
307
+ sC: cute.Tensor,
308
+ tidx: Int32,
309
+ arch: int,
310
+ transpose: bool = False,
311
+ position_independent=False,
312
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
313
+ dtype = sC.element_type
314
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
315
+ tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
316
+ thr_copy = tiled_copy.get_slice(tidx)
317
+ if const_expr(not position_independent):
318
+ tSR_sC = thr_copy.partition_S(sC)
319
+ else:
320
+ tSR_sC = partition_S_position_independent(thr_copy, sC)
321
+ copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
322
+ thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
323
+ tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
324
+
325
+ def copy_fn(src_idx: Int32, **new_kwargs):
326
+ return load_s2r_retile(
327
+ tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
328
+ )
329
+
330
+ return copy_fn, thr_copy, tSR_sC
331
+
332
+
333
+ def get_smem_store_A(
334
+ tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
335
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
336
+ dtype = sA.element_type
337
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
338
+ copy_atom = get_smem_store_atom(arch, dtype, transpose)
339
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
340
+ thr_copy = tiled_copy.get_slice(tidx)
341
+ if const_expr(not position_independent):
342
+ tRS_sA = thr_copy.partition_D(sA)
343
+ else:
344
+ tRS_sA = partition_D_position_independent(thr_copy, sA)
345
+
346
+ def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
347
+ cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
348
+
349
+ return copy_fn, thr_copy, tRS_sA
350
+
351
+
352
+ def get_smem_load_A(
353
+ tiled_mma: cute.TiledMma,
354
+ sA: cute.Tensor,
355
+ tidx: Int32,
356
+ arch: int,
357
+ with_dst_tensor: bool = False,
358
+ position_independent=False,
359
+ ) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
360
+ dtype = sA.element_type
361
+ transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
362
+ copy_atom = get_smem_load_atom(arch, dtype, transpose)
363
+ tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
364
+ thr_copy = tiled_copy.get_slice(tidx)
365
+ if const_expr(not position_independent):
366
+ tSR_sA = thr_copy.partition_S(sA)
367
+ else:
368
+ tSR_sA = partition_S_position_independent(thr_copy, sA)
369
+ tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
370
+
371
+ def copy_fn(src_idx: Int32, **new_kwargs):
372
+ return load_s2r_retile(
373
+ tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
374
+ )
375
+
376
+ def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
377
+ return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
378
+
379
+ return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
380
+
381
+
265
382
  def tma_get_copy_fn(
266
383
  atom: cute.CopyAtom,
267
384
  cta_coord: cute.Coord,
@@ -269,6 +386,7 @@ def tma_get_copy_fn(
269
386
  src_tensor: cute.Tensor,
270
387
  dst_tensor: cute.Tensor,
271
388
  filter_zeros: bool = False,
389
+ single_stage: bool = False,
272
390
  **kwargs,
273
391
  ) -> Callable:
274
392
  src_is_smem = const_expr(
@@ -276,13 +394,15 @@ def tma_get_copy_fn(
276
394
  and src_tensor.memspace == cute.AddressSpace.smem
277
395
  )
278
396
  smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
397
+ group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
398
+ group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
279
399
  # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
280
400
  s, g = cpasync.tma_partition(
281
401
  atom,
282
402
  cta_coord,
283
403
  cta_layout,
284
- cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1),
285
- cute.group_modes(gmem_tensor, 0, cute.rank(gmem_tensor) - 1),
404
+ cute.group_modes(smem_tensor, 0, group_rank_smem),
405
+ cute.group_modes(gmem_tensor, 0, group_rank_gmem),
286
406
  )
287
407
  if const_expr(filter_zeros):
288
408
  s = cute.filter_zeros(s)
@@ -292,7 +412,10 @@ def tma_get_copy_fn(
292
412
  def copy_tma(src_idx, dst_idx, **new_kwargs):
293
413
  cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
294
414
 
295
- return copy_tma, s, g
415
+ def copy_tma_single_stage(**new_kwargs):
416
+ cute.copy(atom, src, dst, **new_kwargs, **kwargs)
417
+
418
+ return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
296
419
 
297
420
 
298
421
  def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
@@ -337,10 +460,10 @@ def gather_m_get_copy_fn(
337
460
  # Read and cache indices for A
338
461
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
339
462
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
340
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
463
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
341
464
  for m in cutlass.range(rows_per_thread, unroll_full=True):
342
465
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
343
- m_idx = cute.make_fragment(rows_per_thread, Int32)
466
+ m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
344
467
  for m in cutlass.range(rows_per_thread, unroll_full=True):
345
468
  row_idx = tAcA[0, m, 0][0]
346
469
  if tApA_m[m]:
@@ -353,7 +476,7 @@ def gather_m_get_copy_fn(
353
476
  def copy_fn(src_idx, dst_idx, pred: bool = False):
354
477
  tApA_k = None
355
478
  if const_expr(pred):
356
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
479
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
357
480
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
358
481
  for k in cutlass.range(cols_per_thread, unroll_full=True):
359
482
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
@@ -411,7 +534,7 @@ def gather_k_get_copy_fn(
411
534
  # Read and cache indices for A
412
535
  rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
413
536
  cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
414
- tApA_m = cute.make_fragment(rows_per_thread, Boolean)
537
+ tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
415
538
  for m in cutlass.range(rows_per_thread, unroll_full=True):
416
539
  tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
417
540
  threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
@@ -427,12 +550,12 @@ def gather_k_get_copy_fn(
427
550
  # Prefetch mAIdx early, even before smem is free
428
551
  tApA_k = None
429
552
  if const_expr(pred):
430
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
553
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
431
554
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
432
555
  for k in cutlass.range(cols_per_thread, unroll_full=True):
433
556
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
434
557
  gAIdx_cur = gAIdx[None, src_idx]
435
- k_idx = cute.make_fragment(cols_per_thread, Int32)
558
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
436
559
  for k in cutlass.range(cols_per_thread):
437
560
  col_idx = tAcA[0, 0, k][1]
438
561
  if const_expr(not pred):
@@ -449,13 +572,13 @@ def gather_k_get_copy_fn(
449
572
  ) -> Tuple[cute.Tensor, cute.Tensor]:
450
573
  tApA_k = None
451
574
  if const_expr(pred):
452
- tApA_k = cute.make_fragment(cols_per_thread, Boolean)
575
+ tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
453
576
  limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
454
577
  for k in cutlass.range(cols_per_thread, unroll_full=True):
455
578
  tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
456
579
  a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
457
580
  sAIdx_cur = sAIdx[None, dst_idx]
458
- k_idx = cute.make_fragment(cols_per_thread, Int32)
581
+ k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
459
582
  for k in cutlass.range(cols_per_thread):
460
583
  col_idx = tAcA[0, 0, k][1]
461
584
  k_idx[k] = sAIdx_cur[col_idx]