quack-kernels 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/pipeline.py CHANGED
@@ -5,14 +5,16 @@ from dataclasses import dataclass
5
5
 
6
6
  import cutlass.cute as cute
7
7
  from cutlass import Boolean, Int32, const_expr
8
- from cutlass.cutlass_dsl import if_generate, and_
9
- from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
10
- from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
8
+ from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
9
+ from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
10
+ from cutlass.pipeline import PipelineTmaAsync, PipelineState, PipelineUserType
11
11
  from cutlass.pipeline import PipelineTmaUmma
12
+ from cutlass.pipeline import Agent, agent_sync
12
13
 
13
14
 
14
15
  class PipelineStateWAdvance(PipelineState):
15
- def advance_iters(self, num_iterations: Int32):
16
+ @dsl_user_op
17
+ def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None):
16
18
  self._count += Int32(num_iterations)
17
19
  new_index = self._index + Int32(num_iterations)
18
20
  # How many times did we cross the stages boundary
@@ -56,104 +58,53 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
56
58
  """
57
59
 
58
60
  @staticmethod
59
- def create(
60
- *,
61
- num_stages: int,
62
- producer_group: CooperativeGroup,
63
- consumer_group: CooperativeGroup,
64
- tx_count: int,
65
- barrier_storage: cute.Pointer = None,
66
- cta_layout_vmnk: Optional[cute.Layout] = None,
67
- tidx: Optional[Int32] = None,
68
- ):
69
- """
70
- This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
71
- :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
72
- :type barrier_storage: cute.Pointer
73
- :param num_stages: Number of buffer stages for this pipeline
74
- :type num_stages: Int32
75
- :param producer_group: CooperativeGroup for the producer agent
76
- :type producer_group: CooperativeGroup
77
- :param consumer_group: CooperativeGroup for the consumer agent
78
- :type consumer_group: CooperativeGroup
79
- :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
80
- :type tx_count: int
81
- :param cta_layout_vmnk: Layout of the cluster shape
82
- :type cta_layout_vmnk: cute.Layout | None
83
- :param tidx: thread index to consumer async threads
84
- :type tidx: Int32 | None
85
- """
86
- if not isinstance(barrier_storage, cute.Pointer):
87
- raise ValueError(
88
- f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
89
- )
90
-
91
- producer_type = PipelineOp.TmaLoad
92
- consumer_type = PipelineOp.AsyncThread
93
-
94
- producer = (producer_type, producer_group)
95
- consumer = (consumer_type, consumer_group)
96
-
97
- sync_object_full = PipelineAsync._make_sync_object(
98
- barrier_storage.align(min_align=8), num_stages, producer, tx_count
99
- )
100
- sync_object_empty = PipelineAsync._make_sync_object(
101
- barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
102
- )
103
- if tidx is None:
104
- tidx, _, _ = cute.arch.thread_idx()
105
- if cta_layout_vmnk is None:
106
- cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
107
- (
108
- dst_rank,
109
- is_signalling_thread,
110
- ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
111
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
112
- dst_rank = None
113
- else:
114
- dst_rank = dst_rank
115
-
116
- producer_mask = None
117
-
118
- pipeline_init_wait(cta_layout_vmnk)
119
-
120
- return PipelineTmaCpAsync(
121
- sync_object_full,
122
- sync_object_empty,
123
- num_stages,
124
- producer_mask,
125
- dst_rank,
126
- is_signalling_thread,
127
- )
128
-
61
+ def create(*args, **kwargs):
62
+ obj = PipelineTmaAsync.create(*args, **kwargs)
63
+ # Can't assign to __class__ directly since the dataclass is frozen
64
+ # obj.__class__ = PipelineTmaCpAsync
65
+ object.__setattr__(obj, "__class__", PipelineTmaCpAsync)
66
+ return obj
67
+
68
+ @dsl_user_op
129
69
  def producer_acquire(
130
70
  self,
131
71
  state: PipelineState,
132
72
  try_acquire_token: Optional[Boolean] = None,
133
73
  is_tma_warp: Optional[Boolean] = True,
74
+ *,
75
+ loc=None,
76
+ ip=None,
134
77
  ):
135
78
  """
136
79
  TMA producer commit conditionally waits on buffer empty and sets the transaction barrier.
137
80
  """
138
81
  if_generate(
139
82
  try_acquire_token is None or try_acquire_token == 0,
140
- lambda: self.sync_object_empty.wait(state.index, state.phase),
83
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
84
+ loc=loc,
85
+ ip=ip,
141
86
  )
142
87
  # This is the difference between this and PipelineTmaAsync: we could have multiple
143
88
  # warps calling this, but only 1 warp should do the arrive on the full barrier
144
89
  if_generate(
145
90
  is_tma_warp,
146
- lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
91
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
92
+ loc=loc,
93
+ ip=ip,
147
94
  )
148
95
 
149
- def producer_cpasync_commit(self, state: PipelineState):
96
+ @dsl_user_op
97
+ def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
150
98
  """
151
99
  We need the mbarrier to track the completion of cp.async
152
100
  """
153
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
101
+ cute.arch.cp_async_mbarrier_arrive_noinc(
102
+ self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
103
+ )
154
104
 
155
105
 
156
106
  class MbarrierArrayWDropCount(MbarrierArray):
107
+ @dsl_user_op
157
108
  def __init__(
158
109
  self,
159
110
  barrier_storage: cute.Pointer,
@@ -161,6 +112,9 @@ class MbarrierArrayWDropCount(MbarrierArray):
161
112
  agent: tuple[PipelineOp, CooperativeGroup],
162
113
  tx_count: int = 0,
163
114
  drop_count: Optional[Int32] = None,
115
+ *,
116
+ loc=None,
117
+ ip=None,
164
118
  ) -> None:
165
119
  self.barrier_storage = barrier_storage
166
120
  self.tx_count = tx_count
@@ -183,7 +137,7 @@ class MbarrierArrayWDropCount(MbarrierArray):
183
137
  self.mbarrier_base = self.barrier_storage
184
138
 
185
139
  # Mbarrier initialization in constructor
186
- self.mbarrier_init()
140
+ self.mbarrier_init(loc=loc, ip=ip)
187
141
 
188
142
  def __extract_mlir_values__(self):
189
143
  return [self.barrier_storage, self.drop_count]
@@ -201,6 +155,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
201
155
  (e.g. Blackwell mainloops)
202
156
  """
203
157
 
158
+ @dsl_user_op
204
159
  @staticmethod
205
160
  def create(
206
161
  *,
@@ -210,25 +165,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
210
165
  tx_count: int,
211
166
  barrier_storage: cute.Pointer = None,
212
167
  cta_layout_vmnk: Optional[cute.Layout] = None,
168
+ mcast_mode_mn: tuple[int, int] = (1, 1),
169
+ defer_sync: bool = False,
213
170
  producer_drop_count: Optional[Int32] = None,
171
+ loc=None,
172
+ ip=None,
214
173
  ):
215
- """
216
- This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
217
- :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
218
- :type barrier_storage: cute.Pointer
174
+ """Creates and initializes a new PipelineTmaUmma instance.
175
+
219
176
  :param num_stages: Number of buffer stages for this pipeline
220
- :type num_stages: Int32
221
- :param producer_group: `CooperativeGroup` for the producer agent
177
+ :type num_stages: int
178
+ :param producer_group: CooperativeGroup for the producer agent
222
179
  :type producer_group: CooperativeGroup
223
- :param consumer_group: `CooperativeGroup` for the consumer agent
180
+ :param consumer_group: CooperativeGroup for the consumer agent
224
181
  :type consumer_group: CooperativeGroup
225
182
  :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
226
183
  :type tx_count: int
184
+ :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
185
+ :type barrier_storage: cute.Pointer, optional
227
186
  :param cta_layout_vmnk: Layout of the cluster shape
228
- :type cta_layout_vmnk: cute.Layout | None
187
+ :type cta_layout_vmnk: cute.Layout, optional
188
+ :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
189
+ :type mcast_mode_mn: tuple[int, int], optional
190
+ :raises ValueError: If barrier_storage is not a cute.Pointer instance
191
+ :return: A new PipelineTmaUmma instance configured with the provided parameters
192
+ :rtype: PipelineTmaUmma
229
193
  """
230
194
  if not isinstance(barrier_storage, cute.Pointer):
231
- raise ValueError(
195
+ raise TypeError(
232
196
  f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
233
197
  )
234
198
 
@@ -244,29 +208,42 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
244
208
  producer,
245
209
  tx_count,
246
210
  drop_count=producer_drop_count,
211
+ loc=loc,
212
+ ip=ip,
247
213
  )
248
- sync_object_empty = PipelineAsync._make_sync_object(
249
- barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
214
+ sync_object_empty = PipelineTmaUmma._make_sync_object(
215
+ barrier_storage.align(min_align=8) + num_stages,
216
+ num_stages,
217
+ consumer,
218
+ loc=loc,
219
+ ip=ip,
250
220
  )
251
221
 
252
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
222
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
253
223
  # No mcast mask if not using clusters
254
224
  producer_mask = None
255
225
  # All threadblocks are leaders if not using clusters
256
226
  is_leader_cta = True
257
227
  else:
258
- producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk)
259
- is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
228
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
229
+ cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip
230
+ )
231
+ is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk, loc=loc, ip=ip)
260
232
 
261
233
  cta_group = (
262
234
  cute.nvgpu.tcgen05.CtaGroup.ONE
263
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
235
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1
264
236
  else cute.nvgpu.tcgen05.CtaGroup.TWO
265
237
  )
266
238
 
267
239
  consumer_mask = producer_mask
268
240
 
269
- pipeline_init_wait(cta_layout_vmnk)
241
+ if not defer_sync:
242
+ cute.arch.mbarrier_init_fence()
243
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
244
+ agent_sync(Agent.ThreadBlock)
245
+ else:
246
+ agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
270
247
 
271
248
  return PipelineTmaCpAsyncUmma(
272
249
  sync_object_full,
@@ -278,11 +255,15 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
278
255
  cta_group,
279
256
  )
280
257
 
258
+ @dsl_user_op
281
259
  def producer_acquire(
282
260
  self,
283
261
  state: PipelineState,
284
262
  try_acquire_token: Optional[Boolean] = None,
285
263
  is_tma_warp: Optional[Boolean] = True,
264
+ *,
265
+ loc=None,
266
+ ip=None,
286
267
  ):
287
268
  """
288
269
  TMA producer commit conditionally waits on buffer empty and sets the
@@ -290,17 +271,24 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
290
271
  """
291
272
  if_generate(
292
273
  try_acquire_token is None or try_acquire_token == 0,
293
- lambda: self.sync_object_empty.wait(state.index, state.phase),
274
+ lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
275
+ loc=loc,
276
+ ip=ip,
294
277
  )
295
278
  # This is the difference between this and PipelineTmaAsync: we could have multiple
296
279
  # warps calling this, but only 1 warp should do the arrive on the full barrier
297
280
  if_generate(
298
281
  and_(self.is_leader_cta, is_tma_warp),
299
- lambda: self.sync_object_full.arrive(state.index, self.producer_mask),
282
+ lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
283
+ loc=loc,
284
+ ip=ip,
300
285
  )
301
286
 
302
- def producer_cpasync_commit(self, state: PipelineState):
287
+ @dsl_user_op
288
+ def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None):
303
289
  """
304
290
  We need the mbarrier to track the completion of cp.async
305
291
  """
306
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
292
+ cute.arch.cp_async_mbarrier_arrive_noinc(
293
+ self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
294
+ )
quack/reduce.py CHANGED
@@ -196,9 +196,9 @@ def online_softmax_reduce(
196
196
  )
197
197
  cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
198
198
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
199
- max_x_single_warp = cute.make_fragment(num_iter, Float32)
199
+ max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
200
200
  max_x_single_warp.fill(-Float32.inf)
201
- sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
201
+ sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
202
202
  sum_exp_x_single_warp.fill(0.0)
203
203
  for i in cutlass.range_constexpr(num_iter):
204
204
  idx = lane_idx + i * cute.arch.WARP_SIZE
quack/rmsnorm.py CHANGED
@@ -686,9 +686,7 @@ class RMSNormBackward(ReductionBase):
686
686
 
687
687
  if const_expr(self.cluster_n > 1):
688
688
  # Need this fence since the STAS from the producer is using the async proxy.
689
- cute.arch.fence_proxy(
690
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
691
- )
689
+ cute.arch.fence_view_async_shared()
692
690
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
693
691
  # Requires adjusting the thread_count when initializing the mbar
694
692
  cute.arch.sync_warp()
quack/sm90_utils.py CHANGED
@@ -27,10 +27,11 @@ def make_smem_layout(
27
27
  sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size),
28
28
  dtype,
29
29
  )
30
+ order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2)
30
31
  smem_layout_staged = cute.tile_to_shape(
31
32
  smem_layout_atom,
32
33
  cute.append(shape, stage) if const_expr(stage is not None) else shape,
33
- order=(1, 0, 2) if layout.is_m_major_c() else (0, 1, 2),
34
+ order=order if const_expr(stage is not None) else order[:2],
34
35
  )
35
36
  return smem_layout_staged
36
37
 
@@ -101,7 +102,7 @@ def gemm_zero_init(
101
102
  tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
102
103
  )
103
104
  else:
104
- acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
105
+ acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32)
105
106
  rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
106
107
  rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
107
108
  gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
@@ -125,3 +126,34 @@ def gemm_w_idx(
125
126
  rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
126
127
  rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
127
128
  gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait)
129
+
130
+
131
+ def partition_fragment_ABC(
132
+ thr_mma: cute.ThrMma,
133
+ shape_mnk: cute.Shape,
134
+ sA: Optional[cute.Tensor],
135
+ sB: Optional[cute.Tensor],
136
+ swap_AB: bool = False,
137
+ ):
138
+ is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
139
+ if const_expr(not swap_AB):
140
+ acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
141
+ if const_expr(not is_rs):
142
+ assert sA is not None
143
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
144
+ else:
145
+ tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2])))
146
+ assert sB is not None
147
+ tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
148
+ else:
149
+ acc = cute.make_rmem_tensor(
150
+ thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32
151
+ )
152
+ if const_expr(not is_rs):
153
+ assert sB is not None
154
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
155
+ else: # B in rmem
156
+ tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2])))
157
+ assert sA is not None
158
+ tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA))
159
+ return acc, tCrA, tCrB
@@ -83,7 +83,7 @@ def bitonic_topk_merge(
83
83
  else:
84
84
  minmax_fn = min if ascending else max
85
85
  # Write the top k elements to the first half of the array
86
- for i in cutlass.range(k, unfoll_full=True):
86
+ for i in cutlass.range(k, unroll_full=True):
87
87
  arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i])
88
88
  # Now the 1st half is bitonic, we just need to merge it
89
89
  bitonic_merge(arr0, k, start0, ascending)
@@ -108,12 +108,12 @@ def bitonic_topk(
108
108
  n = cute.size(arr.shape)
109
109
  assert k == 1 << int(math.log2(k)), "k must be a power of 2"
110
110
  assert n % k == 0, "n must be divisible by k"
111
- topk_vals = cute.make_fragment(k, arr.element_type)
111
+ topk_vals = cute.make_rmem_tensor(k, arr.element_type)
112
112
  for v in cutlass.range(k, unroll_full=True):
113
113
  topk_vals[v] = arr[v]
114
114
  bitonic_sort(topk_vals, ascending=ascending)
115
115
  for i in cutlass.range(1, n // k, unroll_full=True):
116
- other_vals = cute.make_fragment(k, arr.element_type)
116
+ other_vals = cute.make_rmem_tensor(k, arr.element_type)
117
117
  for v in cutlass.range(k, unroll_full=True):
118
118
  other_vals[v] = arr[i * k + v]
119
119
  bitonic_sort(other_vals, ascending=ascending)
@@ -122,7 +122,7 @@ def bitonic_topk(
122
122
  # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
123
123
  # do duplicate work.
124
124
  for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
125
- other_vals = cute.make_fragment(k, arr.element_type)
125
+ other_vals = cute.make_rmem_tensor(k, arr.element_type)
126
126
  for v in cutlass.range(k, unroll_full=True):
127
127
  other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
128
128
  bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)