quack-kernels 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {quack_kernels-0.2.0/quack_kernels.egg-info → quack_kernels-0.2.1}/PKG-INFO +2 -2
  2. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/pyproject.toml +1 -1
  3. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/__init__.py +1 -1
  4. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/activation.py +16 -25
  5. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/cross_entropy.py +6 -10
  6. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/layernorm.py +1 -1
  7. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/reduce.py +6 -7
  8. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/rmsnorm.py +57 -23
  9. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/softmax.py +1 -1
  10. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/tile_scheduler.py +3 -2
  11. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/utils.py +0 -63
  12. {quack_kernels-0.2.0 → quack_kernels-0.2.1/quack_kernels.egg-info}/PKG-INFO +2 -2
  13. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/LICENSE +0 -0
  14. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/README.md +0 -0
  15. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/autotuner.py +0 -0
  16. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/cute_dsl_utils.py +0 -0
  17. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/dense_gemm_sm100.py +0 -0
  18. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/dense_gemm_sm90.py +0 -0
  19. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/fast_math.py +0 -0
  20. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/gemm_act_sm90.py +0 -0
  21. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/gemm_config.py +0 -0
  22. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/gemm_dact_sm90.py +0 -0
  23. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/gemm_interface.py +0 -0
  24. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/gemm_wrapper_utils.py +0 -0
  25. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/linear.py +0 -0
  26. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/linear_cross_entropy.py +0 -0
  27. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/mlp.py +0 -0
  28. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/pipeline.py +0 -0
  29. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/reduction_base.py +0 -0
  30. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/sort/bitonic_sort.py +0 -0
  31. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/sort/generate_sorting_networks.py +0 -0
  32. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/sort/sorting_networks.py +0 -0
  33. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/sort/utils.py +0 -0
  34. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/symmetric_dense_gemm_sm90.py +0 -0
  35. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/tensormap_manager.py +0 -0
  36. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/topk.py +0 -0
  37. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack/varlen_utils.py +0 -0
  38. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack_kernels.egg-info/SOURCES.txt +0 -0
  39. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack_kernels.egg-info/dependency_links.txt +0 -0
  40. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack_kernels.egg-info/requires.txt +0 -0
  41. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/quack_kernels.egg-info/top_level.txt +0 -0
  42. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/setup.cfg +0 -0
  43. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_cross_entropy.py +0 -0
  44. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_layernorm.py +0 -0
  45. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_linear.py +0 -0
  46. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_linear_cross_entropy.py +0 -0
  47. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_rmsnorm.py +0 -0
  48. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_softmax.py +0 -0
  49. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_symmetric_dense_gemm_sm90.py +0 -0
  50. {quack_kernels-0.2.0 → quack_kernels-0.2.1}/tests/test_topk.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.0
4
- Requires-Python: >=3.12
3
+ Version: 0.2.1
4
+ Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.2.0
7
7
  Requires-Dist: torch
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "quack-kernels"
7
7
  dynamic = ["version"]
8
- requires-python = ">=3.12"
8
+ requires-python = ">=3.10"
9
9
  dependencies = [
10
10
  "nvidia-cutlass-dsl==4.2.0",
11
11
  "torch",
@@ -1,4 +1,4 @@
1
- __version__ = "0.2.0"
1
+ __version__ = "0.2.1"
2
2
 
3
3
  import cutlass.cute as cute
4
4
 
@@ -6,23 +6,12 @@ from typing import Tuple
6
6
  import cutlass
7
7
  import cutlass.cute as cute
8
8
  from cutlass import Float32
9
- from cutlass.cutlass_dsl import T, dsl_user_op
10
- from cutlass._mlir.dialects import llvm
9
+ from cutlass.cutlass_dsl import dsl_user_op
11
10
 
12
11
 
13
12
  @dsl_user_op
14
- def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
15
- return Float32(
16
- llvm.inline_asm(
17
- T.f32(),
18
- [Float32(a).ir_value(loc=loc, ip=ip)],
19
- "tanh.approx.f32 $0, $1;",
20
- "=f,f",
21
- has_side_effects=False,
22
- is_align_stack=False,
23
- asm_dialect=llvm.AsmDialect.AD_ATT,
24
- )
25
- )
13
+ def sigmoid(x: Float32, *, loc=None, ip=None) -> Float32:
14
+ return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
26
15
 
27
16
 
28
17
  @dsl_user_op
@@ -67,7 +56,10 @@ def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
67
56
  """
68
57
  sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
69
58
  sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
70
- return 0.5 * (x * (1 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))))
59
+ return 0.5 * (
60
+ x
61
+ * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
62
+ )
71
63
 
72
64
 
73
65
  @dsl_user_op
@@ -88,7 +80,7 @@ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[
88
80
 
89
81
  # Compute z = x * (c1 + c2 * x^2)
90
82
  x_sq = x * x
91
- tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
83
+ tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
92
84
  half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
93
85
  gelu_out = x * half_tanh_z_plus_one
94
86
 
@@ -111,7 +103,7 @@ def silu(x: Float32, *, loc=None, ip=None) -> Float32:
111
103
  This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
112
104
  """
113
105
  x_half = 0.5 * x
114
- return x_half * tanh(x_half) + x_half
106
+ return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
115
107
 
116
108
 
117
109
  @dsl_user_op
@@ -134,8 +126,8 @@ def dswiglu(
134
126
  to use FFMA instead of FADD and FMUL).
135
127
  """
136
128
  # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
137
- x_half = 0.5 * x # FMUL
138
- sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
129
+ # FMUL, MUFU.TANH, then FFMA
130
+ sigmoid_x = sigmoid(x)
139
131
  silu_x = x * sigmoid_x # FMUL
140
132
  silu_x_dout = silu_x * dout # FMUL
141
133
  # d_silu(x) * dout
@@ -161,7 +153,7 @@ def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=Non
161
153
  """
162
154
  # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
163
155
  x_half = 0.5 * x
164
- silu_x = x_half * tanh(alpha * x_half) + x_half
156
+ silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
165
157
  return silu_x * y + silu_x
166
158
 
167
159
 
@@ -179,7 +171,8 @@ def dswiglu_oai(
179
171
  """
180
172
  # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
181
173
  alpha_x_half = (0.5 * alpha) * x # FMUL
182
- sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) # MUFU.TANH, then FFMA
174
+ # MUFU.TANH, then FFMA
175
+ sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
183
176
  silu_x = x * sigmoid_alpha_x # FMUL
184
177
  silu_x_dout = silu_x * dout # FMUL
185
178
  # FFMA, FFMA, FMUL
@@ -197,8 +190,7 @@ def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
197
190
  glu(x, y) = sigmoid(x) * y
198
191
  Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
199
192
  """
200
- x_half = 0.5 * x # FMUL
201
- sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
193
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
202
194
  return sigmoid_x * y # FMUL
203
195
 
204
196
 
@@ -215,8 +207,7 @@ def dglu(
215
207
  - glu_out = sigmoid(x) * y
216
208
  """
217
209
  # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
218
- x_half = 0.5 * x # FMUL
219
- sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
210
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
220
211
  sigmoid_x_dout = sigmoid_x * dout # FMUL
221
212
  glu_out = sigmoid_x * y # FMUL
222
213
  # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
@@ -199,11 +199,8 @@ class CrossEntropy(ReductionBase):
199
199
  cute.autovec_copy(tXsX, tXrX)
200
200
  x = tXrX.load().to(Float32)
201
201
  log2_e = math.log2(math.e)
202
- # exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
203
- # a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
204
- # exp_x = utils.exp2f((x - max_x) * log2_e)
205
202
  # This would use ffma instead of fadd then fmul
206
- exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
203
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
207
204
  denom = row_reduce(
208
205
  exp_x,
209
206
  cute.ReductionOp.ADD,
@@ -228,8 +225,7 @@ class CrossEntropy(ReductionBase):
228
225
  and row < shape[0]
229
226
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
230
227
  ):
231
- ln_2 = math.log(2.0)
232
- lse = max_x + utils.log2f(denom) * ln_2
228
+ lse = max_x + cute.math.log(denom, fastmath=True)
233
229
  # Set loss to 0 if this index should be ignored, otherwise compute normally
234
230
  loss_val = (lse - target_logit) if not should_ignore else Float32.zero
235
231
  mLoss[row] = mLoss.element_type(loss_val)
@@ -552,7 +548,7 @@ class CrossEntropyBackward:
552
548
  lse = Float32(mLSE[row])
553
549
 
554
550
  log2_e = math.log2(math.e)
555
- probs = utils.exp2f(x * log2_e - lse * log2_e)
551
+ probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
556
552
  prob_shifted = probs - 1.0
557
553
  mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
558
554
  for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
@@ -594,9 +590,9 @@ def _cross_entropy_backward(
594
590
  assert x.shape[0] == target.shape[0], "Batch dimensions must match"
595
591
  assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
596
592
  assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
597
- assert (
598
- x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
599
- ), "Tensors must be on CUDA device"
593
+ assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
594
+ "Tensors must be on CUDA device"
595
+ )
600
596
  assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
601
597
  assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
602
598
 
@@ -217,7 +217,7 @@ class LayerNorm(ReductionBase):
217
217
  mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
218
218
  init_val=0.0,
219
219
  )
220
- rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
220
+ rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
221
221
  if cutlass.const_expr(mRstd is not None):
222
222
  # Only the thread corresponding to column 0 writes out the rstd to gmem
223
223
  if (
@@ -159,8 +159,7 @@ def online_softmax_reduce(
159
159
  width=min(threads_per_row, cute.arch.WARP_SIZE),
160
160
  )
161
161
  log2_e = math.log2(math.e)
162
- exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
163
- # exp_x = exp2f((x - max_x) * log2_e)
162
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
164
163
  sum_exp_x = warp_reduce(
165
164
  exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
166
165
  operator.add,
@@ -190,10 +189,10 @@ def online_softmax_reduce(
190
189
  reduction_buffer[row_idx, lane_idx]
191
190
  )
192
191
  max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
193
- sum_exp_x *= utils.exp2f((max_x_single_warp - max_x_final) * log2_e)
192
+ sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
194
193
  sum_exp_x = warp_reduce(sum_exp_x, operator.add)
195
194
  if cutlass.const_expr(return_exp_x):
196
- exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
195
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
197
196
  max_x = max_x_final
198
197
  else:
199
198
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
@@ -231,11 +230,11 @@ def online_softmax_reduce(
231
230
  max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
232
231
  sum_exp_x = 0.0
233
232
  for i in cutlass.range_constexpr(num_iter):
234
- sum_exp_x += sum_exp_x_single_warp[i] * utils.exp2f(
235
- (max_x_single_warp[i] - max_x_final) * log2_e
233
+ sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
234
+ max_x_single_warp[i] - max_x_final, fastmath=True
236
235
  )
237
236
  sum_exp_x = warp_reduce(sum_exp_x, operator.add)
238
237
  if cutlass.const_expr(return_exp_x):
239
- exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
238
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
240
239
  max_x = max_x_final
241
240
  return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
@@ -19,6 +19,7 @@ from quack.reduce import row_reduce
19
19
  from quack.reduction_base import ReductionBase
20
20
  from quack.cute_dsl_utils import torch2cute_dtype_map
21
21
 
22
+
22
23
  class RMSNorm(ReductionBase):
23
24
  def __init__(self, dtype: cutlass.Numeric, N: int):
24
25
  super().__init__(dtype, N, stage=1)
@@ -132,7 +133,9 @@ class RMSNorm(ReductionBase):
132
133
  mW_expanded_layout = cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
133
134
  mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
134
135
  if const_expr(mB is not None):
135
- mB_expanded_layout = cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
136
+ mB_expanded_layout = cute.prepend(
137
+ mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
138
+ )
136
139
  mB = cute.make_tensor(mB.iterator, mB_expanded_layout)
137
140
  if const_expr(mRstd is not None):
138
141
  mRstd_expanded_layout = cute.append(
@@ -202,11 +205,7 @@ class RMSNorm(ReductionBase):
202
205
  ]
203
206
  cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
204
207
  gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
205
- gB = (
206
- cute.local_tile(mB, tiler_mn, (0, cluster_y))
207
- if const_expr(mB is not None)
208
- else None
209
- )
208
+ gB = cute.local_tile(mB, tiler_mn, (0, cluster_y)) if const_expr(mB is not None) else None
210
209
  gRstd = (
211
210
  cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
212
211
  if const_expr(mRstd is not None)
@@ -226,12 +225,18 @@ class RMSNorm(ReductionBase):
226
225
  copy_atom_load_W = cute.make_copy_atom(
227
226
  cute.nvgpu.CopyUniversalOp(), mW.element_type, num_bits_per_copy=num_copy_bits_W
228
227
  )
229
- num_bits_per_copy_B = cutlass.const_expr(
230
- min(128, num_copy_elems_X * mB.element_type.width)
231
- ) if const_expr(mB is not None) else 0
232
- copy_atom_load_B = cute.make_copy_atom(
233
- cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
234
- ) if const_expr(mB is not None) else None
228
+ num_bits_per_copy_B = (
229
+ cutlass.const_expr(min(128, num_copy_elems_X * mB.element_type.width))
230
+ if const_expr(mB is not None)
231
+ else 0
232
+ )
233
+ copy_atom_load_B = (
234
+ cute.make_copy_atom(
235
+ cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=num_bits_per_copy_B
236
+ )
237
+ if const_expr(mB is not None)
238
+ else None
239
+ )
235
240
  if const_expr(mRes is not None):
236
241
  num_copy_bits_Res = const_expr(min(128, num_copy_elems_X * mRes.element_type.width))
237
242
  copy_atom_load_Res_async = cute.make_copy_atom(
@@ -317,7 +322,7 @@ class RMSNorm(ReductionBase):
317
322
  init_val=0.0,
318
323
  hook_fn=(cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None),
319
324
  )
320
- rstd = utils.rsqrt(sum_sq_x / shape[1] + eps)
325
+ rstd = cute.math.rsqrt(sum_sq_x / shape[1] + eps, fastmath=True)
321
326
  if const_expr(mRstd is not None):
322
327
  # Only the thread corresponding to column 0 writes out the rstd to gmem
323
328
  if (
@@ -355,7 +360,7 @@ class RMSNorm(ReductionBase):
355
360
  mutates_args=("out", "rstd", "residual_out"),
356
361
  device_types="cuda",
357
362
  # We need to specify the schema manually since we're mutating an optional tensor
358
- schema="(Tensor x, Tensor weight, Tensor(a!) out, Tensor? bias, Tensor(a!)? rstd, Tensor? residual, Tensor(a!)? residual_out, float eps=1e-6) -> ()",
363
+ schema="(Tensor x, Tensor weight, Tensor(a2!) out, Tensor? bias, Tensor(a4!)? rstd, Tensor? residual, Tensor(a6!)? residual_out, float eps=1e-6) -> ()",
359
364
  )
360
365
  def _rmsnorm_fwd(
361
366
  x: Tensor,
@@ -509,6 +514,7 @@ def rmsnorm_ref(x, w, bias=None, residual=None, eps=1e-6):
509
514
  else:
510
515
  return out.to(x.dtype), x_f32.to(residual.dtype)
511
516
 
517
+
512
518
  def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
513
519
  """Reference implementation for RMSNorm backward pass."""
514
520
  x_f32 = x.float()
@@ -521,6 +527,7 @@ def rmsnorm_bwd_ref(x, w, dout, rstd, eps=1e-6):
521
527
  dw = (dout * x_hat).sum(dim=0)
522
528
  return dx.to(x.dtype), dw.to(w.dtype)
523
529
 
530
+
524
531
  class RMSNormBackward(ReductionBase):
525
532
  def __init__(self, dtype: cutlass.Numeric, N: int):
526
533
  # 2 stages for double buffering when computing mean of x_hat * wdy
@@ -744,7 +751,11 @@ class RMSNormBackward(ReductionBase):
744
751
  # Always compute partial weight gradients in fp32
745
752
  tXrdW = cute.make_fragment_like(tXgdW, Float32)
746
753
 
747
- gdB = cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y)) if const_expr(mdB is not None) else None
754
+ gdB = (
755
+ cute.local_tile(mdB, (1, tiler_mn[1]), (bidx_start, cluster_y))
756
+ if const_expr(mdB is not None)
757
+ else None
758
+ )
748
759
  tXgdB = thr_copy_X.partition_S(gdB) if const_expr(mdB is not None) else None
749
760
  tXrdB = cute.make_fragment_like(tXgdB, Float32) if const_expr(mdB is not None) else None
750
761
 
@@ -772,8 +783,10 @@ class RMSNormBackward(ReductionBase):
772
783
  tXrX, tXrdO, tXrdX = [
773
784
  cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
774
785
  ]
786
+ tXrdResO = None
775
787
  if const_expr(mdResO is not None):
776
788
  tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
789
+ tXrdRes = None
777
790
  if const_expr(mdRes is not None):
778
791
  tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
779
792
 
@@ -930,7 +943,9 @@ class RMSNormBackward(ReductionBase):
930
943
  if row == 0:
931
944
  for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
932
945
  tXrdB_other = cute.make_fragment_like(tXrdB)
933
- tXsdB_other = cute.make_tensor(tXsdB.iterator + i * sdB.stride[0], tXsdB.layout)
946
+ tXsdB_other = cute.make_tensor(
947
+ tXsdB.iterator + i * sdB.stride[0], tXsdB.layout
948
+ )
934
949
  cute.autovec_copy(tXsdB_other, tXrdB_other)
935
950
  tXrdB.store(tXrdB.load() + tXrdB_other.load())
936
951
  cute.copy(copy_atom_store_dB, tXrdB, tXgdB, pred=tXpdB)
@@ -963,7 +978,7 @@ def _get_sm_count(N: int, device: torch.device) -> int:
963
978
  mutates_args={"dx", "dw_partial", "db_partial", "dresidual"},
964
979
  device_types="cuda",
965
980
  # We need to specify the schema manually since we're mutating an optional tensor
966
- schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a!) dx, Tensor(a!) dw_partial, Tensor(a!)? db_partial, Tensor? dresidual_out, Tensor(a!)? dresidual) -> ()",
981
+ schema="(Tensor x, Tensor weight, Tensor dout, Tensor rstd, Tensor(a4!) dx, Tensor(a5!) dw_partial, Tensor(a6!)? db_partial, Tensor? dresidual_out, Tensor(a8!)? dresidual) -> ()",
967
982
  )
968
983
  def _rmsnorm_bwd(
969
984
  x: Tensor,
@@ -1031,14 +1046,23 @@ def _rmsnorm_bwd(
1031
1046
  )
1032
1047
 
1033
1048
  dw_partial_tensor = from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1034
- db_partial_tensor = from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) if db_partial is not None else None
1049
+ db_partial_tensor = (
1050
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
1051
+ if db_partial is not None
1052
+ else None
1053
+ )
1035
1054
  rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
1036
1055
 
1037
1056
  current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
1038
1057
 
1039
- compile_key = (N, x_tensor.element_type, weight_tensor.element_type, db_partial.dtype if db_partial is not None else None,
1058
+ compile_key = (
1059
+ N,
1060
+ x_tensor.element_type,
1061
+ weight_tensor.element_type,
1062
+ db_partial.dtype if db_partial is not None else None,
1040
1063
  dresidual.dtype if dresidual is not None else None,
1041
- dresidual_out.dtype if dresidual_out is not None else None)
1064
+ dresidual_out.dtype if dresidual_out is not None else None,
1065
+ )
1042
1066
  if compile_key not in _rmsnorm_bwd.compile_cache:
1043
1067
  rmsnorm_backward_op = RMSNormBackward(x_tensor.element_type, N)
1044
1068
  _rmsnorm_bwd.compile_cache[compile_key] = cute.compile(
@@ -1106,7 +1130,17 @@ def rmsnorm_bwd(
1106
1130
 
1107
1131
  class RMSNormFunction(torch.autograd.Function):
1108
1132
  @staticmethod
1109
- def forward(ctx, x, weight, bias=None, residual=None, out_dtype=None, residual_dtype=None, eps=1e-6, prenorm=False):
1133
+ def forward(
1134
+ ctx,
1135
+ x,
1136
+ weight,
1137
+ bias=None,
1138
+ residual=None,
1139
+ out_dtype=None,
1140
+ residual_dtype=None,
1141
+ eps=1e-6,
1142
+ prenorm=False,
1143
+ ):
1110
1144
  x_shape_og = x.shape
1111
1145
  # Flatten input
1112
1146
  x = x.reshape(-1, x.shape[-1])
@@ -1129,7 +1163,7 @@ class RMSNormFunction(torch.autograd.Function):
1129
1163
  ctx.x_shape_og = x_shape_og
1130
1164
  ctx.residual_dtype = residual.dtype if residual is not None else None
1131
1165
  ctx.prenorm = prenorm
1132
- if residual_out is None or prenorm == False:
1166
+ if residual_out is None or not prenorm:
1133
1167
  return out.reshape(x_shape_og)
1134
1168
  else:
1135
1169
  return out.reshape(x_shape_og), residual_out.reshape(x_shape_og)
@@ -1213,4 +1247,4 @@ class QuackRMSNorm(torch.nn.Module):
1213
1247
 
1214
1248
  def reset_parameters(self):
1215
1249
  """Reset the weight parameter to ones."""
1216
- torch.nn.init.ones_(self.weight)
1250
+ torch.nn.init.ones_(self.weight)
@@ -159,7 +159,7 @@ class Softmax(ReductionBase):
159
159
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
160
  )
161
161
  log2_e = math.log2(math.e)
162
- exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
162
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
163
163
  denom = row_reduce(
164
164
  exp_x,
165
165
  cute.ReductionOp.ADD,
@@ -390,7 +390,7 @@ def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]:
390
390
  Convert a triangular index to 2D coordinates.
391
391
  This is used to convert the linear index to 2D coordinates for triangular matrices.
392
392
  """
393
- row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1
393
+ row = utils.ceil((cute.math.sqrt(2 * idx + 2.25, fastmath=True) - 0.5)) - 1
394
394
  col = idx - (row * (row + 1)) // 2
395
395
  return row, col
396
396
 
@@ -524,7 +524,8 @@ class TriangularTileScheduler(TileScheduler):
524
524
  group_size = params.group_size_divmod.divisor
525
525
  group_id = (
526
526
  utils.ceil(
527
- (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32
527
+ (cute.math.sqrt(2 * cluster_id_in_problem + 2.25, fastmath=True) - 0.5)
528
+ * params.group_size_inv_f32
528
529
  )
529
530
  - 1
530
531
  )
@@ -100,69 +100,6 @@ def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=Non
100
100
  )
101
101
 
102
102
 
103
- @cute.jit
104
- def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
105
- """exp2f calculation for both vector and scalar.
106
- :param x: input value
107
- :type x: cute.TensorSSA or Float32
108
- :return: exp2 value
109
- :rtype: cute.TensorSSA or Float32
110
- """
111
- if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
112
- res = cute.make_fragment(x.shape, Float32)
113
- res.store(x)
114
- for i in cutlass.range(cute.size(x.shape), unroll_full=True):
115
- res[i] = cute.arch.exp2(res[i])
116
- return res.load()
117
- else:
118
- return cute.arch.exp2(x)
119
-
120
-
121
- @dsl_user_op
122
- def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
123
- return Float32(
124
- llvm.inline_asm(
125
- T.f32(),
126
- [Float32(a).ir_value(loc=loc, ip=ip)],
127
- "lg2.approx.ftz.f32 $0, $1;",
128
- "=f,f",
129
- has_side_effects=False,
130
- is_align_stack=False,
131
- asm_dialect=llvm.AsmDialect.AD_ATT,
132
- )
133
- )
134
-
135
-
136
- @dsl_user_op
137
- def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
138
- return Float32(
139
- llvm.inline_asm(
140
- T.f32(),
141
- [Float32(a).ir_value(loc=loc, ip=ip)],
142
- "sqrt.approx.ftz.f32 $0, $1;",
143
- "=f,f",
144
- has_side_effects=False,
145
- is_align_stack=False,
146
- asm_dialect=llvm.AsmDialect.AD_ATT,
147
- )
148
- )
149
-
150
-
151
- @dsl_user_op
152
- def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
153
- return Float32(
154
- llvm.inline_asm(
155
- T.f32(),
156
- [Float32(a).ir_value(loc=loc, ip=ip)],
157
- "rsqrt.approx.ftz.f32 $0, $1;",
158
- "=f,f",
159
- has_side_effects=False,
160
- is_align_stack=False,
161
- asm_dialect=llvm.AsmDialect.AD_ATT,
162
- )
163
- )
164
-
165
-
166
103
  @dsl_user_op
167
104
  def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32:
168
105
  return Int32(
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.0
4
- Requires-Python: >=3.12
3
+ Version: 0.2.1
4
+ Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
6
  Requires-Dist: nvidia-cutlass-dsl==4.2.0
7
7
  Requires-Dist: torch
File without changes
File without changes
File without changes