quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@
2
2
  from typing import Optional, Tuple, Dict, Any
3
3
  from dataclasses import dataclass
4
4
 
5
+ import torch
5
6
  from torch import Tensor
6
7
 
7
8
  import cutlass.cute as cute
@@ -9,6 +10,7 @@ from cutlass import Int32
9
10
  from cutlass.cute.runtime import from_dlpack, make_ptr
10
11
 
11
12
  from quack.cute_dsl_utils import torch2cute_dtype_map
13
+ from quack.varlen_utils import VarlenArguments
12
14
  from quack.dense_gemm_sm90 import TileSchedulerOptions
13
15
 
14
16
 
@@ -22,8 +24,8 @@ class GemmTensorInfo:
22
24
 
23
25
  class GemmWrapperBase:
24
26
  @staticmethod
25
- def validate_tensor_3d(tensor: Tensor, name: str) -> None:
26
- assert tensor.dim() == 3 and tensor.is_cuda, f"{name} must be a 3D CUDA tensor"
27
+ def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None:
28
+ assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor"
27
29
  assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
28
30
 
29
31
  @staticmethod
@@ -47,7 +49,7 @@ class GemmWrapperBase:
47
49
  ) -> Optional[cute.Tensor]:
48
50
  if tensor is None:
49
51
  return None
50
- # Tensor is already permuted to (dims[0], dims[1], dims[2])
52
+ # Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1])
51
53
  # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
52
54
  leading_dim = 1 if major == dims[1] else 0
53
55
  return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
@@ -61,43 +63,131 @@ class GemmWrapperBase:
61
63
  D: Optional[Tensor] = None,
62
64
  C: Optional[Tensor] = None,
63
65
  additional_tensors: Optional[Dict[str, Tensor]] = None,
66
+ cu_seqlens_m: Optional[Tensor] = None,
67
+ cu_seqlens_k: Optional[Tensor] = None,
68
+ A_idx: Optional[Tensor] = None,
64
69
  ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
65
- GemmWrapperBase.validate_tensor_3d(A, "A")
66
- L, M, K = A.shape
67
- GemmWrapperBase.validate_tensor_3d(B, "B")
68
- _, N, _ = B.shape
70
+ assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), (
71
+ "Only one of cu_seqlens_m and cu_seqlens_k can be specified"
72
+ )
69
73
  assert B.dtype == A.dtype, "A and B must have the same dtype"
70
- GemmWrapperBase.validate_shape(B, (L, N, K), "B")
74
+
75
+ # Validate A_idx if provided (for gather_A case)
76
+ gather_A = A_idx is not None
77
+ if gather_A:
78
+ assert cu_seqlens_m is not None or cu_seqlens_k is not None, (
79
+ "gather_A requires either varlen_m or varlen_k"
80
+ )
81
+ assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}"
82
+ assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D"
83
+
84
+ # Determine mode and extract dimensions
85
+ if cu_seqlens_m is not None:
86
+ # varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n)
87
+ assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D"
88
+ assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D"
89
+
90
+ if gather_A:
91
+ # When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M
92
+ total_M = A_idx.shape[0]
93
+ _, K = A.shape
94
+ else:
95
+ total_M, K = A.shape
96
+
97
+ L, N, K_B = B.shape
98
+ assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
99
+ assert cu_seqlens_m.shape == (L + 1,), (
100
+ f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}"
101
+ )
102
+ M = total_M
103
+ dc_shape = (total_M, N)
104
+ dc_ndim = 2
105
+ elif cu_seqlens_k is not None:
106
+ # varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n)
107
+ assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D"
108
+ assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D"
109
+
110
+ if gather_A:
111
+ # When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K
112
+ M, _ = A.shape
113
+ total_K = A_idx.shape[0]
114
+ else:
115
+ M, total_K = A.shape
116
+
117
+ N, K_B = B.shape
118
+ assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}"
119
+ L = cu_seqlens_k.shape[0] - 1
120
+ assert cu_seqlens_k.shape == (L + 1,), (
121
+ f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}"
122
+ )
123
+ K = total_K
124
+ dc_shape = (L, M, N)
125
+ dc_ndim = 3
126
+ else:
127
+ # Normal case - all tensors must be 3D
128
+ GemmWrapperBase.validate_tensor(A, "A", 3)
129
+ GemmWrapperBase.validate_tensor(B, "B", 3)
130
+ L, M, K = A.shape
131
+ _, N, K_B = B.shape
132
+ assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}"
133
+ GemmWrapperBase.validate_shape(B, (L, N, K), "B")
134
+ dc_shape = (L, M, N)
135
+ dc_ndim = 3
136
+
137
+ # Validate D and C shapes uniformly
138
+ for tensor, name in [(D, "D"), (C, "C")]:
139
+ if tensor is not None:
140
+ assert tensor.dim() == dc_ndim, (
141
+ f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
142
+ )
143
+ assert tensor.shape == dc_shape, (
144
+ f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
145
+ )
146
+
71
147
  tensors = {
72
148
  "A": GemmTensorInfo(A),
73
149
  "B": GemmTensorInfo(B),
74
150
  "D": GemmTensorInfo(D),
75
151
  "C": GemmTensorInfo(C),
76
152
  }
77
- if D is not None:
78
- GemmWrapperBase.validate_tensor_3d(D, "D")
79
- GemmWrapperBase.validate_shape(D, (L, M, N), "D")
80
- if C is not None:
81
- GemmWrapperBase.validate_tensor_3d(C, "C")
82
- GemmWrapperBase.validate_shape(C, (L, M, N), "C")
153
+
83
154
  if additional_tensors:
84
155
  for name, tensor in additional_tensors.items():
85
156
  if tensor is not None:
86
- GemmWrapperBase.validate_tensor_3d(tensor, name)
87
- GemmWrapperBase.validate_shape(tensor, (L, M, N), name)
157
+ assert tensor.dim() == dc_ndim, (
158
+ f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D"
159
+ )
160
+ assert tensor.shape == dc_shape, (
161
+ f"{name} shape {tensor.shape} doesn't match expected {dc_shape}"
162
+ )
88
163
  tensors[name] = GemmTensorInfo(tensor)
89
164
 
90
165
  return L, M, K, N, tensors
91
166
 
92
167
  @staticmethod
93
- def permute_tensors(tensors: Dict[str, GemmTensorInfo]) -> None:
94
- for info in tensors.values():
95
- if info.tensor is not None:
96
- info.tensor = info.tensor.permute(1, 2, 0)
168
+ def permute_tensors(
169
+ tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False
170
+ ) -> None:
171
+ # Determine which tensors need permutation
172
+ if varlen_m:
173
+ # Only B needs permutation (3D tensor)
174
+ tensors_to_permute = ["B"]
175
+ elif varlen_k:
176
+ # Only D and C need permutation (3D tensors)
177
+ tensors_to_permute = ["D", "C"]
178
+ else:
179
+ # All tensors need permutation
180
+ tensors_to_permute = None
181
+
182
+ # Apply permutation from (L, *, *) -> (*, *, L) for selected tensors
183
+ for name, info in tensors.items():
184
+ if info.tensor is not None and info.tensor.ndim == 3:
185
+ if tensors_to_permute is None or name in tensors_to_permute:
186
+ info.tensor = info.tensor.permute(1, 2, 0)
97
187
 
98
188
  @staticmethod
99
189
  def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
100
- for info in tensors.values():
190
+ for name, info in tensors.items():
101
191
  if info.tensor is not None:
102
192
  info.dtype = torch2cute_dtype_map[info.tensor.dtype]
103
193
 
@@ -121,7 +211,9 @@ class GemmWrapperBase:
121
211
 
122
212
  @staticmethod
123
213
  def create_scheduler_args(
124
- max_active_clusters: int, tile_count_semaphore: Optional[Tensor] = None
214
+ max_active_clusters: int,
215
+ tile_count_semaphore: Optional[Tensor] = None,
216
+ batch_idx_permute: Optional[Tensor] = None,
125
217
  ) -> TileSchedulerOptions:
126
218
  return TileSchedulerOptions(
127
219
  Int32(max_active_clusters),
@@ -130,6 +222,71 @@ class GemmWrapperBase:
130
222
  )
131
223
  if tile_count_semaphore is not None
132
224
  else None,
225
+ batch_idx_permute=(
226
+ from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0)
227
+ )
228
+ if batch_idx_permute is not None
229
+ else None,
230
+ )
231
+
232
+ @staticmethod
233
+ def create_varlen_args(
234
+ cu_seqlens_m: Optional[Tensor],
235
+ cu_seqlens_k: Optional[Tensor],
236
+ A_idx: Optional[Tensor],
237
+ max_active_clusters: int,
238
+ cluster_shape_mnk: Tuple[int, int, int],
239
+ tensors: Dict[str, GemmTensorInfo],
240
+ num_epi_tensormaps: int = 0,
241
+ pingpong: bool = False,
242
+ ) -> Optional[Any]:
243
+ if cu_seqlens_m is None and cu_seqlens_k is None:
244
+ return None
245
+ # When varlen_m, we assume persistent=True
246
+ # Grid size depends on num_active_clusters and cluster size
247
+ cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1]
248
+ num_blocks = max_active_clusters * cluster_size
249
+ # Calculate number of tensormaps needed
250
+ if cu_seqlens_m is not None:
251
+ # For varlen_m: need tensormaps for D and epilogue tensors
252
+ num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2)
253
+ if tensors["D"].tensor is not None:
254
+ num_tensormaps += 1 if not pingpong else 2 # D tensormap
255
+ else:
256
+ # For varlen_k: need tensormaps for A & B
257
+ num_tensormaps = 2 if A_idx is None else 1
258
+ # Create tensormap buffer (each tensormap is 128 bytes = 16 int64s)
259
+ tensormap_size = 128 // 8 # 16 int64s
260
+ if num_tensormaps > 0:
261
+ device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device
262
+ tensormaps = torch.empty(
263
+ (num_blocks, num_tensormaps, tensormap_size),
264
+ dtype=torch.int64,
265
+ device=device,
266
+ )
267
+ tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic(
268
+ mode=0, stride_order=(0, 1, 2)
269
+ )
270
+ else:
271
+ tensormaps_cute = None
272
+
273
+ return VarlenArguments(
274
+ mCuSeqlensM=(
275
+ from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0)
276
+ if cu_seqlens_m is not None
277
+ else None
278
+ ),
279
+ mCuSeqlensK=(
280
+ from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0)
281
+ if cu_seqlens_k is not None
282
+ else None
283
+ ),
284
+ mTensormaps=tensormaps_cute,
285
+ mAIdx=(
286
+ from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0)
287
+ if A_idx is not None
288
+ else None
289
+ ),
133
290
  )
134
291
 
135
292
  @staticmethod
quack/layernorm.py CHANGED
@@ -217,7 +217,7 @@ class LayerNorm(ReductionBase):
217
217
  mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
218
218
  init_val=0.0,
219
219
  )
220
- rstd = utils.rsqrt(sum_sq_x_sub_mean / shape[1] + eps)
220
+ rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
221
221
  if cutlass.const_expr(mRstd is not None):
222
222
  # Only the thread corresponding to column 0 writes out the rstd to gmem
223
223
  if (
quack/reduce.py CHANGED
@@ -159,8 +159,7 @@ def online_softmax_reduce(
159
159
  width=min(threads_per_row, cute.arch.WARP_SIZE),
160
160
  )
161
161
  log2_e = math.log2(math.e)
162
- exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
163
- # exp_x = exp2f((x - max_x) * log2_e)
162
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
164
163
  sum_exp_x = warp_reduce(
165
164
  exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
166
165
  operator.add,
@@ -190,10 +189,10 @@ def online_softmax_reduce(
190
189
  reduction_buffer[row_idx, lane_idx]
191
190
  )
192
191
  max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
193
- sum_exp_x *= utils.exp2f((max_x_single_warp - max_x_final) * log2_e)
192
+ sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
194
193
  sum_exp_x = warp_reduce(sum_exp_x, operator.add)
195
194
  if cutlass.const_expr(return_exp_x):
196
- exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
195
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
197
196
  max_x = max_x_final
198
197
  else:
199
198
  cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
@@ -231,11 +230,11 @@ def online_softmax_reduce(
231
230
  max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
232
231
  sum_exp_x = 0.0
233
232
  for i in cutlass.range_constexpr(num_iter):
234
- sum_exp_x += sum_exp_x_single_warp[i] * utils.exp2f(
235
- (max_x_single_warp[i] - max_x_final) * log2_e
233
+ sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
234
+ max_x_single_warp[i] - max_x_final, fastmath=True
236
235
  )
237
236
  sum_exp_x = warp_reduce(sum_exp_x, operator.add)
238
237
  if cutlass.const_expr(return_exp_x):
239
- exp_x *= utils.exp2f((max_x - max_x_final) * log2_e)
238
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
240
239
  max_x = max_x_final
241
240
  return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)