quack-kernels 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {quack_kernels-0.1.3/quack_kernels.egg-info → quack_kernels-0.1.4}/PKG-INFO +2 -2
  2. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/README.md +5 -1
  3. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/pyproject.toml +1 -1
  4. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack/__init__.py +1 -1
  5. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack/cross_entropy.py +12 -9
  6. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack/reduction_base.py +2 -2
  7. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack/rmsnorm.py +13 -12
  8. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack/softmax.py +25 -17
  9. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack/utils.py +15 -12
  10. {quack_kernels-0.1.3 → quack_kernels-0.1.4/quack_kernels.egg-info}/PKG-INFO +2 -2
  11. quack_kernels-0.1.4/quack_kernels.egg-info/requires.txt +6 -0
  12. quack_kernels-0.1.3/quack_kernels.egg-info/requires.txt +0 -6
  13. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/LICENSE +0 -0
  14. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack_kernels.egg-info/SOURCES.txt +0 -0
  15. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack_kernels.egg-info/dependency_links.txt +0 -0
  16. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/quack_kernels.egg-info/top_level.txt +0 -0
  17. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/setup.cfg +0 -0
  18. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/setup.py +0 -0
  19. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/tests/test_cross_entropy.py +0 -0
  20. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/tests/test_rmsnorm.py +0 -0
  21. {quack_kernels-0.1.3 → quack_kernels-0.1.4}/tests/test_softmax.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.0.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -20,6 +20,10 @@ pip install quack-kernels
20
20
  - 🦆 Softmax forward and backward
21
21
  - 🦆 Cross entropy forward
22
22
 
23
+ Upcoming:
24
+ - 🦆 Cross entropy backward
25
+ - 🦆 RMSNorm backward
26
+ - 🦆 Rotary forward + backward
23
27
 
24
28
  ## Usage
25
29
 
@@ -32,6 +36,6 @@ from quack import rmsnorm, softmax, cross_entropy
32
36
  To set up the development environment:
33
37
 
34
38
  ```bash
35
- pip install -e .[dev]
39
+ pip install -e '.[dev]'
36
40
  pre-commit install
37
41
  ```
@@ -7,7 +7,7 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.9"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.0.0",
10
+ "nvidia-cutlass-dsl==4.1.0.dev0",
11
11
  "torch",
12
12
  ]
13
13
 
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.3"
1
+ __version__ = "0.1.4"
2
2
 
3
3
  from quack.rmsnorm import rmsnorm
4
4
  from quack.softmax import softmax
@@ -77,7 +77,7 @@ class CrossEntropy(ReductionBase):
77
77
  self.kernel(mX, mTarget, mLoss, mLSE, tv_layout, tiler_mn).launch(
78
78
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
79
79
  block=[num_threads, 1, 1],
80
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
80
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
81
81
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
82
82
  stream=stream,
83
83
  )
@@ -93,15 +93,16 @@ class CrossEntropy(ReductionBase):
93
93
  tiler_mn: cute.Shape,
94
94
  ):
95
95
  tidx, _, _ = cute.arch.thread_idx()
96
- bidx, cluster_y, _ = cute.arch.block_idx()
96
+ bidx, _, _ = cute.arch.block_idx()
97
+ if cutlass.const_expr(self.cluster_n > 1):
98
+ cluster_y = cute.arch.block_idx()[1]
99
+ else:
100
+ cluster_y = cutlass.const_expr(0)
97
101
 
98
102
  shape: cute.Shape = mX.shape
99
103
  idX = cute.make_identity_tensor(shape)
100
104
  # slice for CTAs
101
- gX, cX = [
102
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
103
- for mT in (mX, idX)
104
- ]
105
+ gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, idX)]
105
106
 
106
107
  smem = cutlass.utils.SmemAllocator()
107
108
  sX = smem.allocate_tensor(
@@ -131,7 +132,9 @@ class CrossEntropy(ReductionBase):
131
132
 
132
133
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
133
134
  tXpX = (
134
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
135
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
136
+ if cutlass.const_expr(not is_even_N)
137
+ else None
135
138
  )
136
139
  if row < shape[0]:
137
140
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
@@ -154,7 +157,7 @@ class CrossEntropy(ReductionBase):
154
157
  cute.ReductionOp.MAX,
155
158
  threads_per_row,
156
159
  reduction_buffer[None, None, 0],
157
- mbar_ptr + 0 if self.cluster_n > 1 else None,
160
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
158
161
  init_val=-cutlass.Float32.inf,
159
162
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
160
163
  )
@@ -172,7 +175,7 @@ class CrossEntropy(ReductionBase):
172
175
  cute.ReductionOp.ADD,
173
176
  threads_per_row,
174
177
  reduction_buffer[None, None, 1],
175
- mbar_ptr + 1 if self.cluster_n > 1 else None,
178
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
176
179
  init_val=0.0,
177
180
  )
178
181
  else:
@@ -88,10 +88,10 @@ class ReductionBase:
88
88
  def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int):
89
89
  if cutlass.const_expr(self.cluster_n > 1):
90
90
  if tidx < self.stage:
91
- cute.arch.mbarrier_init_arrive_cnt(mbar_ptr + tidx, 1)
91
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
92
92
  cute.arch.mbarrier_init_fence()
93
93
  if tidx < self.stage:
94
- cute.arch.mbarrier_init_tx_bytes(
94
+ cute.arch.mbarrier_arrive_and_expect_tx(
95
95
  mbar_ptr + tidx, num_warps * self.cluster_n * self.reduction_dtype.width // 8
96
96
  )
97
97
  # Cluster arrive after barrier init
@@ -84,7 +84,7 @@ class RMSNorm(ReductionBase):
84
84
  self.kernel(mX, mW, mO, mRstd, eps, tv_layout, tiler_mn, self.reload_from).launch(
85
85
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
86
86
  block=[num_threads, 1, 1],
87
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
87
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
88
88
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
89
89
  stream=stream,
90
90
  )
@@ -103,7 +103,11 @@ class RMSNorm(ReductionBase):
103
103
  delay_w_load: cutlass.Constexpr = False,
104
104
  ):
105
105
  tidx, _, _ = cute.arch.thread_idx()
106
- bidx, cluster_y, _ = cute.arch.block_idx()
106
+ bidx, _, _ = cute.arch.block_idx()
107
+ if cutlass.const_expr(self.cluster_n > 1):
108
+ cluster_y = cute.arch.block_idx()[1]
109
+ else:
110
+ cluster_y = cutlass.const_expr(0)
107
111
 
108
112
  smem = cutlass.utils.SmemAllocator()
109
113
  sX = smem.allocate_tensor(
@@ -114,13 +118,10 @@ class RMSNorm(ReductionBase):
114
118
  shape = mX.shape
115
119
  idX = cute.make_identity_tensor(shape)
116
120
  # slice for CTAs
117
- gX, gO, cX = [
118
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
119
- for mT in (mX, mO, idX)
120
- ]
121
- gW = cute.local_tile(mW, tiler_mn, (0, 0 if self.cluster_n == 1 else cluster_y))
121
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
122
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
122
123
  gRstd = (
123
- cute.local_tile(mRstd, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
124
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
124
125
  if cutlass.const_expr(mRstd is not None)
125
126
  else None
126
127
  )
@@ -167,7 +168,7 @@ class RMSNorm(ReductionBase):
167
168
  cute.arch.cp_async_commit_group()
168
169
 
169
170
  tWpW = utils.predicate_k(thr_copy_W.partition_S(cX), limit=shape[1])
170
- if not delay_w_load:
171
+ if cutlass.const_expr(not delay_w_load):
171
172
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
172
173
 
173
174
  cute.arch.cp_async_wait_group(0)
@@ -192,12 +193,12 @@ class RMSNorm(ReductionBase):
192
193
  and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
193
194
  ):
194
195
  tXrRstd[0] = rstd
195
- if delay_w_load:
196
+ if cutlass.const_expr(delay_w_load):
196
197
  cute.copy(copy_atom_load_W, tWgW, tWrW, pred=tWpW)
197
- if reload_from == "smem":
198
+ if cutlass.const_expr(reload_from == "smem"):
198
199
  cute.autovec_copy(tXsX, tXrX)
199
200
  x = tXrX.load().to(cute.Float32)
200
- elif reload_from == "gmem":
201
+ elif cutlass.const_expr(reload_from == "gmem"):
201
202
  cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
202
203
  x = tXrX.load().to(cute.Float32)
203
204
  x_hat = x * rstd
@@ -75,7 +75,7 @@ class Softmax(ReductionBase):
75
75
  self.kernel(mX, mO, tv_layout, tiler_mn).launch(
76
76
  grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
77
77
  block=[num_threads, 1, 1],
78
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
78
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
79
79
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
80
80
  stream=stream,
81
81
  )
@@ -89,15 +89,16 @@ class Softmax(ReductionBase):
89
89
  tiler_mn: cute.Shape,
90
90
  ):
91
91
  tidx, _, _ = cute.arch.thread_idx()
92
- bidx, cluster_y, _ = cute.arch.block_idx()
92
+ bidx, _, _ = cute.arch.block_idx()
93
+ if cutlass.const_expr(self.cluster_n > 1):
94
+ cluster_y = cute.arch.block_idx()[1]
95
+ else:
96
+ cluster_y = cutlass.const_expr(0)
93
97
 
94
98
  shape = mX.shape
95
99
  idX = cute.make_identity_tensor(shape)
96
100
  # slice for CTAs
97
- gX, gO, cX = [
98
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
99
- for mT in (mX, mO, idX)
100
- ]
101
+ gX, gO, cX = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX)]
101
102
 
102
103
  smem = cutlass.utils.SmemAllocator()
103
104
  sX = smem.allocate_tensor(
@@ -129,7 +130,9 @@ class Softmax(ReductionBase):
129
130
 
130
131
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
131
132
  tXpX = (
132
- utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
133
+ utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
134
+ if cutlass.const_expr(not is_even_N)
135
+ else None
133
136
  )
134
137
  if tXcX[0][0] < shape[0]:
135
138
  cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
@@ -148,7 +151,7 @@ class Softmax(ReductionBase):
148
151
  cute.ReductionOp.MAX,
149
152
  threads_per_row,
150
153
  reduction_buffer[None, None, 0],
151
- mbar_ptr + 0 if self.cluster_n > 1 else None,
154
+ mbar_ptr + 0 if cutlass.const_expr(self.cluster_n > 1) else None,
152
155
  init_val=-cutlass.Float32.inf,
153
156
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
154
157
  )
@@ -159,7 +162,7 @@ class Softmax(ReductionBase):
159
162
  cute.ReductionOp.ADD,
160
163
  threads_per_row,
161
164
  reduction_buffer[None, None, 1],
162
- mbar_ptr + 1 if self.cluster_n > 1 else None,
165
+ mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
163
166
  init_val=0.0,
164
167
  )
165
168
  else:
@@ -174,7 +177,9 @@ class Softmax(ReductionBase):
174
177
  y = exp_x * (1.0 / denom)
175
178
  tXrO.store(y.to(tXrO.element_type))
176
179
  tOpO = (
177
- utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1]) if not is_even_N else None
180
+ utils.predicate_k(thr_copy_O.partition_S(cX), limit=shape[1])
181
+ if cutlass.const_expr(not is_even_N)
182
+ else None
178
183
  )
179
184
  if tXcX[0][0] < shape[0]:
180
185
  cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
@@ -283,7 +288,7 @@ class SoftmaxBackward(ReductionBase):
283
288
  self.kernel(mdY, mY, mdX, tv_layout, tiler_mn).launch(
284
289
  grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
285
290
  block=[num_threads, 1, 1],
286
- cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
291
+ cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None,
287
292
  smem=self._smem_size_in_bytes(tiler_mn, num_warps),
288
293
  stream=stream,
289
294
  )
@@ -298,14 +303,17 @@ class SoftmaxBackward(ReductionBase):
298
303
  tiler_mn: cute.Shape,
299
304
  ):
300
305
  tidx, _, _ = cute.arch.thread_idx()
301
- bidx, cluster_y, _ = cute.arch.block_idx()
306
+ bidx, _, _ = cute.arch.block_idx()
307
+ if cutlass.const_expr(self.cluster_n > 1):
308
+ cluster_y = cute.arch.block_idx()[1]
309
+ else:
310
+ cluster_y = cutlass.const_expr(0)
302
311
 
303
312
  shape = mdY.shape
304
313
  idX = cute.make_identity_tensor(shape)
305
314
  # slice for CTAs
306
315
  gdY, gY, gdX, cX = [
307
- cute.local_tile(mT, tiler_mn, (bidx, 0 if self.cluster_n == 1 else cluster_y))
308
- for mT in (mdY, mY, mdX, idX)
316
+ cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mdY, mY, mdX, idX)
309
317
  ]
310
318
 
311
319
  smem = cutlass.utils.SmemAllocator()
@@ -344,7 +352,7 @@ class SoftmaxBackward(ReductionBase):
344
352
  is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
345
353
  tdYpdY = (
346
354
  utils.predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
347
- if not is_even_N
355
+ if cutlass.const_expr(not is_even_N)
348
356
  else None
349
357
  )
350
358
 
@@ -366,7 +374,7 @@ class SoftmaxBackward(ReductionBase):
366
374
  cute.ReductionOp.ADD,
367
375
  threads_per_row,
368
376
  reduction_buffer[None, None, 0],
369
- mbar_ptr if self.cluster_n > 1 else None,
377
+ mbar_ptr if cutlass.const_expr(self.cluster_n > 1) else None,
370
378
  init_val=0.0,
371
379
  hook_fn=cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None,
372
380
  )
@@ -376,7 +384,7 @@ class SoftmaxBackward(ReductionBase):
376
384
  tdXrdX.store(dx.to(tdXrdX.element_type))
377
385
  tdXpdX = (
378
386
  utils.predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
379
- if not is_even_N
387
+ if cutlass.const_expr(not is_even_N)
380
388
  else None
381
389
  )
382
390
  if tXcX[0][0] < shape[0]:
@@ -37,19 +37,20 @@ def min_constexpr(
37
37
  return a if a < b else b
38
38
 
39
39
 
40
+ @cute.jit
40
41
  def warp_reduce(
41
42
  val: cute.TensorSSA | cute.Numeric,
42
43
  op: Callable,
43
44
  width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
44
45
  ) -> cute.TensorSSA | cute.Numeric:
45
- if isinstance(val, cute.TensorSSA):
46
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
46
47
  res = cute.make_fragment(val.shape, val.dtype)
47
48
  res.store(val)
48
- for i in range(cute.size(val.shape)):
49
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
49
50
  res[i] = warp_reduce(res[i], op, width)
50
51
  return res.load()
51
52
  else:
52
- for i in range(int(math.log2(width))):
53
+ for i in cutlass.range_constexpr(int(math.log2(width))):
53
54
  val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
54
55
  return val
55
56
 
@@ -111,15 +112,15 @@ def store_shared_remote(
111
112
  remote_mbar_ptr_i32 = set_block_rank(
112
113
  mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
113
114
  ).ir_value()
114
- if isinstance(val, float):
115
+ if cutlass.const_expr(isinstance(val, float)):
115
116
  val = Float32(val)
116
117
  assert isinstance(val, (Float32, cutlass.Int64)), "val must be Float32 or Int64"
117
- suffix = "f32" if isinstance(val, Float32) else "s64"
118
+ suffix = "f32" if cutlass.const_expr(isinstance(val, Float32)) else "s64"
118
119
  llvm.inline_asm(
119
120
  None,
120
121
  [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
121
122
  f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
122
- f"r,{'f' if isinstance(val, Float32) else 'l'},r",
123
+ f"r,{'f' if cutlass.const_expr(isinstance(val, Float32)) else 'l'},r",
123
124
  has_side_effects=True,
124
125
  is_align_stack=False,
125
126
  asm_dialect=llvm.AsmDialect.AD_ATT,
@@ -299,6 +300,7 @@ def online_softmax_reduce(
299
300
  return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
300
301
 
301
302
 
303
+ @cute.jit
302
304
  def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
303
305
  """exp2f calculation for both vector and scalar.
304
306
 
@@ -307,10 +309,10 @@ def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
307
309
  :return: exp2 value
308
310
  :rtype: cute.TensorSSA or Float32
309
311
  """
310
- if isinstance(x, cute.TensorSSA):
312
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
311
313
  res = cute.make_fragment(x.shape, Float32)
312
314
  res.store(x)
313
- for i in range(cute.size(x.shape)):
315
+ for i in cutlass.range_constexpr(cute.size(x.shape)):
314
316
  res[i] = cute.arch.exp2(res[i])
315
317
  return res.load()
316
318
  else:
@@ -347,6 +349,7 @@ def rsqrt(a: float | Float32, *, loc=None, ip=None) -> Float32:
347
349
  )
348
350
 
349
351
 
352
+ @cute.jit
350
353
  def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
351
354
  # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
352
355
  tApA = cute.make_fragment(
@@ -356,8 +359,8 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
356
359
  ),
357
360
  cutlass.Boolean,
358
361
  )
359
- for rest_v in range(tApA.shape[0]):
360
- for rest_k in range(tApA.shape[2]):
362
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
363
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
361
364
  tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
362
365
  return tApA
363
366
 
@@ -373,8 +376,8 @@ def fill_oob(tXsX: cute.Tensor, tXpX: cute.Tensor, fill_value: cute.Numeric) ->
373
376
  """
374
377
  tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), 0, 0])
375
378
  tXrX_fill.fill(fill_value)
376
- for rest_v in range(tXpX.shape[0]):
377
- for rest_k in range(tXpX.shape[2]):
379
+ for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
380
+ for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
378
381
  if not tXpX[rest_v, 0, rest_k]:
379
382
  cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
380
383
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Requires-Python: >=3.9
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.0.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.1.0.dev0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -0,0 +1,6 @@
1
+ nvidia-cutlass-dsl==4.1.0.dev0
2
+ torch
3
+
4
+ [dev]
5
+ pre-commit
6
+ ruff
@@ -1,6 +0,0 @@
1
- nvidia-cutlass-dsl==4.0.0
2
- torch
3
-
4
- [dev]
5
- pre-commit
6
- ruff
File without changes
File without changes
File without changes