quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/pipeline.py CHANGED
@@ -6,9 +6,10 @@ from dataclasses import dataclass
6
6
  import cutlass.cute as cute
7
7
  from cutlass import Boolean, Int32, const_expr
8
8
  from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op
9
- from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait
10
- from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
9
+ from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp
10
+ from cutlass.pipeline import PipelineTmaAsync, PipelineState, PipelineUserType
11
11
  from cutlass.pipeline import PipelineTmaUmma
12
+ from cutlass.pipeline import Agent, agent_sync
12
13
 
13
14
 
14
15
  class PipelineStateWAdvance(PipelineState):
@@ -57,75 +58,12 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
57
58
  """
58
59
 
59
60
  @staticmethod
60
- def create(
61
- *,
62
- num_stages: int,
63
- producer_group: CooperativeGroup,
64
- consumer_group: CooperativeGroup,
65
- tx_count: int,
66
- barrier_storage: cute.Pointer = None,
67
- cta_layout_vmnk: Optional[cute.Layout] = None,
68
- tidx: Optional[Int32] = None,
69
- ):
70
- """
71
- This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync.
72
- :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
73
- :type barrier_storage: cute.Pointer
74
- :param num_stages: Number of buffer stages for this pipeline
75
- :type num_stages: Int32
76
- :param producer_group: CooperativeGroup for the producer agent
77
- :type producer_group: CooperativeGroup
78
- :param consumer_group: CooperativeGroup for the consumer agent
79
- :type consumer_group: CooperativeGroup
80
- :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
81
- :type tx_count: int
82
- :param cta_layout_vmnk: Layout of the cluster shape
83
- :type cta_layout_vmnk: cute.Layout | None
84
- :param tidx: thread index to consumer async threads
85
- :type tidx: Int32 | None
86
- """
87
- if not isinstance(barrier_storage, cute.Pointer):
88
- raise ValueError(
89
- f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
90
- )
91
-
92
- producer_type = PipelineOp.TmaLoad
93
- consumer_type = PipelineOp.AsyncThread
94
-
95
- producer = (producer_type, producer_group)
96
- consumer = (consumer_type, consumer_group)
97
-
98
- sync_object_full = PipelineAsync._make_sync_object(
99
- barrier_storage.align(min_align=8), num_stages, producer, tx_count
100
- )
101
- sync_object_empty = PipelineAsync._make_sync_object(
102
- barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
103
- )
104
- if tidx is None:
105
- tidx, _, _ = cute.arch.thread_idx()
106
- if cta_layout_vmnk is None:
107
- cta_layout_vmnk = cute.make_layout((1, 1, 1, 1))
108
- (
109
- dst_rank,
110
- is_signalling_thread,
111
- ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx)
112
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
113
- dst_rank = None
114
- else:
115
- dst_rank = dst_rank
116
-
117
- producer_mask = None
118
-
119
- pipeline_init_wait(cta_layout_vmnk)
120
-
121
- return PipelineTmaCpAsync(
122
- sync_object_full,
123
- sync_object_empty,
124
- num_stages,
125
- producer_mask,
126
- dst_rank,
127
- is_signalling_thread,
128
- )
61
+ def create(*args, **kwargs):
62
+ obj = PipelineTmaAsync.create(*args, **kwargs)
63
+ # Can't assign to __class__ directly since the dataclass is frozen
64
+ # obj.__class__ = PipelineTmaCpAsync
65
+ object.__setattr__(obj, "__class__", PipelineTmaCpAsync)
66
+ return obj
129
67
 
130
68
  @dsl_user_op
131
69
  def producer_acquire(
@@ -143,12 +81,16 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
143
81
  if_generate(
144
82
  try_acquire_token is None or try_acquire_token == 0,
145
83
  lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
84
+ loc=loc,
85
+ ip=ip,
146
86
  )
147
87
  # This is the difference between this and PipelineTmaAsync: we could have multiple
148
88
  # warps calling this, but only 1 warp should do the arrive on the full barrier
149
89
  if_generate(
150
90
  is_tma_warp,
151
91
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
92
+ loc=loc,
93
+ ip=ip,
152
94
  )
153
95
 
154
96
  @dsl_user_op
@@ -156,7 +98,9 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
156
98
  """
157
99
  We need the mbarrier to track the completion of cp.async
158
100
  """
159
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
101
+ cute.arch.cp_async_mbarrier_arrive_noinc(
102
+ self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
103
+ )
160
104
 
161
105
 
162
106
  class MbarrierArrayWDropCount(MbarrierArray):
@@ -211,6 +155,7 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
211
155
  (e.g. Blackwell mainloops)
212
156
  """
213
157
 
158
+ @dsl_user_op
214
159
  @staticmethod
215
160
  def create(
216
161
  *,
@@ -220,28 +165,34 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
220
165
  tx_count: int,
221
166
  barrier_storage: cute.Pointer = None,
222
167
  cta_layout_vmnk: Optional[cute.Layout] = None,
223
- producer_drop_count: Optional[Int32] = None,
224
168
  mcast_mode_mn: tuple[int, int] = (1, 1),
169
+ defer_sync: bool = False,
170
+ producer_drop_count: Optional[Int32] = None,
171
+ loc=None,
172
+ ip=None,
225
173
  ):
226
- """
227
- This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma.
228
- :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers
229
- :type barrier_storage: cute.Pointer
174
+ """Creates and initializes a new PipelineTmaUmma instance.
175
+
230
176
  :param num_stages: Number of buffer stages for this pipeline
231
- :type num_stages: Int32
232
- :param producer_group: `CooperativeGroup` for the producer agent
177
+ :type num_stages: int
178
+ :param producer_group: CooperativeGroup for the producer agent
233
179
  :type producer_group: CooperativeGroup
234
- :param consumer_group: `CooperativeGroup` for the consumer agent
180
+ :param consumer_group: CooperativeGroup for the consumer agent
235
181
  :type consumer_group: CooperativeGroup
236
182
  :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage
237
183
  :type tx_count: int
184
+ :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers
185
+ :type barrier_storage: cute.Pointer, optional
238
186
  :param cta_layout_vmnk: Layout of the cluster shape
239
- :type cta_layout_vmnk: cute.Layout | None
187
+ :type cta_layout_vmnk: cute.Layout, optional
240
188
  :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1)
241
189
  :type mcast_mode_mn: tuple[int, int], optional
190
+ :raises ValueError: If barrier_storage is not a cute.Pointer instance
191
+ :return: A new PipelineTmaUmma instance configured with the provided parameters
192
+ :rtype: PipelineTmaUmma
242
193
  """
243
194
  if not isinstance(barrier_storage, cute.Pointer):
244
- raise ValueError(
195
+ raise TypeError(
245
196
  f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}"
246
197
  )
247
198
 
@@ -257,29 +208,42 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
257
208
  producer,
258
209
  tx_count,
259
210
  drop_count=producer_drop_count,
211
+ loc=loc,
212
+ ip=ip,
260
213
  )
261
214
  sync_object_empty = PipelineTmaUmma._make_sync_object(
262
- barrier_storage.align(min_align=8) + num_stages, num_stages, consumer
215
+ barrier_storage.align(min_align=8) + num_stages,
216
+ num_stages,
217
+ consumer,
218
+ loc=loc,
219
+ ip=ip,
263
220
  )
264
221
 
265
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1:
222
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
266
223
  # No mcast mask if not using clusters
267
224
  producer_mask = None
268
225
  # All threadblocks are leaders if not using clusters
269
226
  is_leader_cta = True
270
227
  else:
271
- producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn)
272
- is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk)
228
+ producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(
229
+ cta_layout_vmnk, mcast_mode_mn, loc=loc, ip=ip
230
+ )
231
+ is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk, loc=loc, ip=ip)
273
232
 
274
233
  cta_group = (
275
234
  cute.nvgpu.tcgen05.CtaGroup.ONE
276
- if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1
235
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0], loc=loc, ip=ip) == 1
277
236
  else cute.nvgpu.tcgen05.CtaGroup.TWO
278
237
  )
279
238
 
280
239
  consumer_mask = producer_mask
281
240
 
282
- pipeline_init_wait(cta_layout_vmnk)
241
+ if not defer_sync:
242
+ cute.arch.mbarrier_init_fence()
243
+ if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, loc=loc, ip=ip) == 1:
244
+ agent_sync(Agent.ThreadBlock)
245
+ else:
246
+ agent_sync(Agent.ThreadBlockCluster, is_relaxed=True)
283
247
 
284
248
  return PipelineTmaCpAsyncUmma(
285
249
  sync_object_full,
@@ -308,12 +272,16 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
308
272
  if_generate(
309
273
  try_acquire_token is None or try_acquire_token == 0,
310
274
  lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),
275
+ loc=loc,
276
+ ip=ip,
311
277
  )
312
278
  # This is the difference between this and PipelineTmaAsync: we could have multiple
313
279
  # warps calling this, but only 1 warp should do the arrive on the full barrier
314
280
  if_generate(
315
281
  and_(self.is_leader_cta, is_tma_warp),
316
282
  lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip),
283
+ loc=loc,
284
+ ip=ip,
317
285
  )
318
286
 
319
287
  @dsl_user_op
@@ -321,4 +289,6 @@ class PipelineTmaCpAsyncUmma(PipelineTmaUmma):
321
289
  """
322
290
  We need the mbarrier to track the completion of cp.async
323
291
  """
324
- cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip)
292
+ cute.arch.cp_async_mbarrier_arrive_noinc(
293
+ self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip
294
+ )
quack/reduce.py CHANGED
@@ -196,9 +196,9 @@ def online_softmax_reduce(
196
196
  )
197
197
  cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
198
198
  num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
199
- max_x_single_warp = cute.make_fragment(num_iter, Float32)
199
+ max_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
200
200
  max_x_single_warp.fill(-Float32.inf)
201
- sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
201
+ sum_exp_x_single_warp = cute.make_rmem_tensor(num_iter, Float32)
202
202
  sum_exp_x_single_warp.fill(0.0)
203
203
  for i in cutlass.range_constexpr(num_iter):
204
204
  idx = lane_idx + i * cute.arch.WARP_SIZE
quack/rmsnorm.py CHANGED
@@ -686,9 +686,7 @@ class RMSNormBackward(ReductionBase):
686
686
 
687
687
  if const_expr(self.cluster_n > 1):
688
688
  # Need this fence since the STAS from the producer is using the async proxy.
689
- cute.arch.fence_proxy(
690
- cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
691
- )
689
+ cute.arch.fence_view_async_shared()
692
690
  # It's faster to have 1 lane per warp to signal the mbar, rather than all lanes
693
691
  # Requires adjusting the thread_count when initializing the mbar
694
692
  cute.arch.sync_warp()
quack/sm90_utils.py CHANGED
@@ -102,7 +102,7 @@ def gemm_zero_init(
102
102
  tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False
103
103
  )
104
104
  else:
105
- acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32)
105
+ acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C(shape), Float32)
106
106
  rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
107
107
  rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
108
108
  gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait)
@@ -137,7 +137,7 @@ def partition_fragment_ABC(
137
137
  ):
138
138
  is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM
139
139
  if const_expr(not swap_AB):
140
- acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
140
+ acc = cute.make_rmem_tensor(thr_mma.partition_shape_C(shape_mnk[:2]), Float32)
141
141
  if const_expr(not is_rs):
142
142
  assert sA is not None
143
143
  tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA))
@@ -146,7 +146,9 @@ def partition_fragment_ABC(
146
146
  assert sB is not None
147
147
  tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB))
148
148
  else:
149
- acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32)
149
+ acc = cute.make_rmem_tensor(
150
+ thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32
151
+ )
150
152
  if const_expr(not is_rs):
151
153
  assert sB is not None
152
154
  tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB))
@@ -108,12 +108,12 @@ def bitonic_topk(
108
108
  n = cute.size(arr.shape)
109
109
  assert k == 1 << int(math.log2(k)), "k must be a power of 2"
110
110
  assert n % k == 0, "n must be divisible by k"
111
- topk_vals = cute.make_fragment(k, arr.element_type)
111
+ topk_vals = cute.make_rmem_tensor(k, arr.element_type)
112
112
  for v in cutlass.range(k, unroll_full=True):
113
113
  topk_vals[v] = arr[v]
114
114
  bitonic_sort(topk_vals, ascending=ascending)
115
115
  for i in cutlass.range(1, n // k, unroll_full=True):
116
- other_vals = cute.make_fragment(k, arr.element_type)
116
+ other_vals = cute.make_rmem_tensor(k, arr.element_type)
117
117
  for v in cutlass.range(k, unroll_full=True):
118
118
  other_vals[v] = arr[i * k + v]
119
119
  bitonic_sort(other_vals, ascending=ascending)
@@ -122,7 +122,7 @@ def bitonic_topk(
122
122
  # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps
123
123
  # do duplicate work.
124
124
  for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True):
125
- other_vals = cute.make_fragment(k, arr.element_type)
125
+ other_vals = cute.make_rmem_tensor(k, arr.element_type)
126
126
  for v in cutlass.range(k, unroll_full=True):
127
127
  other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i)
128
128
  bitonic_topk_merge(topk_vals, other_vals, ascending=ascending)