quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,275 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Optional, Literal
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.amp import custom_fwd, custom_bwd
9
+
10
+ from quack.cross_entropy import cross_entropy, cross_entropy_fwd_out
11
+ from quack.gemm_interface import gemm, gemm_add, gemm_add_inplace
12
+ from quack.linear import linear_fwd_convert_type
13
+
14
+
15
+ def linear_cross_entropy_func(
16
+ x: Tensor, # (..., d)
17
+ weight: Tensor, # (V, d)
18
+ bias: Optional[Tensor], # (V,) or None
19
+ target: Tensor, # (...,), int or long
20
+ ignore_index: int = -100,
21
+ reduction: Literal["none", "mean", "sum"] = "mean",
22
+ inplace_backward: bool = False,
23
+ ) -> Tensor:
24
+ y = F.linear(x, weight, bias) # (..., V)
25
+ return cross_entropy(
26
+ y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
27
+ )
28
+
29
+
30
+ def linear_cross_entropy_func_ref(
31
+ x: Tensor, # (..., d)
32
+ weight: Tensor, # (V, d)
33
+ bias: Optional[Tensor], # (V,) or None
34
+ target: Tensor, # (...,), int or long
35
+ ignore_index: int = -100,
36
+ reduction: Literal["none", "mean", "sum"] = "mean",
37
+ ) -> Tensor:
38
+ y = F.linear(x, weight, bias) # (..., V)
39
+ return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
40
+
41
+
42
+ def chunked_linear_cross_entropy_fwd(
43
+ x: Tensor, # (B*L, d) where B is batch, L is seqlen
44
+ weight: Tensor, # (V, d) where V is vocab size
45
+ target: Tensor, # (B*L,)
46
+ chunk_size: int = 4096,
47
+ ignore_index: int = -100,
48
+ tuned: bool = True,
49
+ ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
50
+ """
51
+ Chunked forward pass for linear cross entropy.
52
+
53
+ Splits input along batch dimension, computes matmul and cross_entropy_fwd
54
+ for each chunk, stores dx for each chunk, and accumulates dw.
55
+
56
+ Returns:
57
+ loss: (B*L,) loss values
58
+ dx: (B*L, d) gradient w.r.t. input
59
+ dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
60
+ last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
61
+ last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
62
+ """
63
+ B_L, d = x.shape
64
+ V, _ = weight.shape
65
+ device = x.device
66
+ num_chunks = (B_L + chunk_size - 1) // chunk_size
67
+ # Since we use gemm with TMA we require some alignment
68
+ assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
69
+ assert B_L % 8 == 0
70
+ # Pre-allocate outputs
71
+ loss = torch.empty(B_L, device=device, dtype=torch.float32)
72
+ logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
73
+ dx = torch.empty_like(x)
74
+ # Last chunk of dw will be deferred to the backward pass
75
+ dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
76
+ last_dlogits_chunk = None
77
+ last_x_chunk = None
78
+
79
+ # Process in chunks
80
+ for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
81
+ zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
82
+ ):
83
+ chunk_len = x_chunk.shape[0]
84
+ logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
85
+ torch.mm(x_chunk, weight.mT, out=logits_chunk)
86
+ # Compute cross entropy forward with gradients
87
+ dlogits_chunk = logits_chunk # inplace_backward
88
+ cross_entropy_fwd_out(
89
+ logits_chunk,
90
+ target_chunk,
91
+ None, # target_logit
92
+ loss=loss_chunk,
93
+ lse=None, # we don't need lse here
94
+ dx=dlogits_chunk,
95
+ ignore_index=ignore_index,
96
+ )
97
+ # Compute dx for this chunk: dlogits @ weight
98
+ torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
99
+ # Compute dw for all chunks except the last
100
+ if i == num_chunks - 1:
101
+ # Last chunk: save for backward pass
102
+ last_dlogits_chunk = dlogits_chunk
103
+ last_x_chunk = x_chunk
104
+ elif i == 0:
105
+ # First chunk: dw = dlogits.T @ x_chunk
106
+ gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
107
+ else:
108
+ # Middle chunks: dw += dlogits.T @ x_chunk
109
+ gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
110
+ return loss, dx, dw, last_dlogits_chunk, last_x_chunk
111
+
112
+
113
+ class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
114
+ @staticmethod
115
+ @custom_fwd(device_type="cuda")
116
+ def forward(
117
+ ctx,
118
+ x: Tensor,
119
+ weight: Tensor,
120
+ target: Tensor,
121
+ ignore_index: int = -100,
122
+ reduction: Literal["mean", "sum"] = "mean",
123
+ chunk_size: int = 4096,
124
+ tuned: bool = True,
125
+ ):
126
+ """
127
+ Forward pass computes loss and stores dx and dw for backward.
128
+ """
129
+ ctx.weight_dtype = weight.dtype
130
+ x, weight = linear_fwd_convert_type(x, weight)
131
+ batch_shape = x.shape[:-1]
132
+ x = x.reshape(-1, x.shape[-1])
133
+ # TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
134
+ loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
135
+ x, weight, target, chunk_size, ignore_index, tuned=tuned
136
+ )
137
+ loss_sum = loss.sum()
138
+ loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
139
+ ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
140
+ ctx.batch_shape = batch_shape
141
+ ctx.ignore_index = ignore_index
142
+ ctx.reduction = reduction
143
+ ctx.tuned = tuned
144
+ return loss_sum if loss_scale is None else loss_sum * loss_scale
145
+
146
+ @staticmethod
147
+ @custom_bwd(device_type="cuda")
148
+ def backward(ctx, dloss):
149
+ """
150
+ Backward pass scales pre-computed gradients by dloss and completes
151
+ the last chunk's dw computation.
152
+ dloss is a scalar.
153
+ """
154
+ dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
155
+ tuned = ctx.tuned
156
+ if loss_scale is not None:
157
+ dloss = dloss * loss_scale
158
+ # TODO: the case where x or weight doesn't require grad
159
+ dx.mul_(dloss)
160
+ dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
161
+ # Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
162
+ if dw is None:
163
+ # Only had one chunk, compute dw directly with dloss scaling
164
+ dw = gemm(
165
+ last_dlogits_chunk.T,
166
+ last_x_chunk,
167
+ out_dtype=ctx.weight_dtype,
168
+ alpha=dloss,
169
+ tuned=tuned,
170
+ )
171
+ else:
172
+ # Add last chunk's contribution with dloss scaling
173
+ # dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
174
+ # We use alpha=dloss, beta=dloss
175
+ if ctx.weight_dtype == dw.dtype:
176
+ gemm_add_inplace(
177
+ last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
178
+ )
179
+ else:
180
+ dw = gemm_add(
181
+ last_dlogits_chunk.T,
182
+ last_x_chunk,
183
+ dw,
184
+ alpha=dloss,
185
+ beta=dloss,
186
+ out_dtype=ctx.weight_dtype,
187
+ tuned=tuned,
188
+ )
189
+ return dx, dw, None, None, None, None, None
190
+
191
+
192
+ def chunked_linear_cross_entropy(
193
+ x: Tensor,
194
+ weight: Tensor,
195
+ target: Tensor,
196
+ chunk_size: int = 4096,
197
+ ignore_index: int = -100,
198
+ reduction: Literal["mean", "sum"] = "mean",
199
+ tuned: bool = True,
200
+ ) -> Tensor:
201
+ """
202
+ Chunked linear cross entropy with automatic differentiation support.
203
+
204
+ Args:
205
+ x: Input tensor of shape (B*L, d)
206
+ weight: Weight tensor of shape (V, d)
207
+ target: Target indices of shape (B*L,)
208
+ chunk_size: Size of chunks to process
209
+ ignore_index: Index to ignore in loss computation
210
+ reduction: Type of reduction to apply
211
+ tuned: Whether to use tuned kernels
212
+
213
+ Returns:
214
+ Loss tensor with specified reduction
215
+ """
216
+ if reduction not in ["mean", "sum"]:
217
+ raise ValueError(f"Invalid reduction: {reduction}")
218
+ loss = ChunkedLinearCrossEntropyFunction.apply(
219
+ x, weight, target, ignore_index, reduction, chunk_size, tuned
220
+ )
221
+ return loss
222
+
223
+
224
+ class LinearCrossEntropy(nn.Linear):
225
+ def __init__(
226
+ self,
227
+ in_features: int,
228
+ out_features: int,
229
+ bias: bool = False,
230
+ ignore_index: int = -100,
231
+ reduction: Literal["none", "mean", "sum"] = "mean",
232
+ chunk_size: Optional[int] = None,
233
+ inplace_backward: bool = False,
234
+ tuned: bool = True,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
239
+ self.ignore_index = ignore_index
240
+ self.reduction = reduction
241
+ self.chunk_size = chunk_size
242
+ self.inplace_backward = inplace_backward
243
+ self.tuned = tuned
244
+
245
+ def forward(self, input: Tensor, target: Tensor) -> Tensor:
246
+ if (
247
+ self.bias is None
248
+ and input.is_cuda
249
+ and input.stride(-1) == 1
250
+ and self.in_features % 8 == 0
251
+ and self.out_features % 8 == 0
252
+ and input.shape[:-1].numel() % 8 == 0
253
+ and self.chunk_size is not None
254
+ and self.chunk_size % 8 == 0
255
+ and self.reduction in ["mean", "sum"]
256
+ ):
257
+ return chunked_linear_cross_entropy(
258
+ input,
259
+ self.weight,
260
+ target,
261
+ chunk_size=self.chunk_size,
262
+ ignore_index=self.ignore_index,
263
+ reduction=self.reduction,
264
+ tuned=self.tuned,
265
+ )
266
+ else:
267
+ return linear_cross_entropy_func(
268
+ input,
269
+ self.weight,
270
+ self.bias,
271
+ target,
272
+ ignore_index=self.ignore_index,
273
+ reduction=self.reduction,
274
+ inplace_backward=self.inplace_backward,
275
+ )
quack/mlp.py CHANGED
@@ -3,131 +3,31 @@ import torch
3
3
  import torch.nn as nn
4
4
  import torch.nn.functional as F
5
5
  from torch import Tensor
6
- from torch.amp import custom_fwd, custom_bwd
7
6
 
8
- from einops import rearrange
7
+ from quack.linear import linear_act_func, act_linear_func
9
8
 
10
- from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
11
- # from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
12
9
 
13
- from quack import gemm, gemm_swiglu, gemm_dswiglu # TODO: implement these
10
+ def mlp_func(x, weight1, weight2, activation: str, fuse_grad_accum=False, tuned=True):
11
+ preact, postact = linear_act_func(
12
+ x,
13
+ weight1,
14
+ activation,
15
+ store_preact=torch.is_grad_enabled(),
16
+ fuse_grad_accum=fuse_grad_accum,
17
+ tuned=tuned,
18
+ )
19
+ out = act_linear_func(
20
+ preact,
21
+ weight2,
22
+ postact,
23
+ activation=activation,
24
+ fuse_grad_accum=fuse_grad_accum,
25
+ tuned=tuned,
26
+ )
27
+ return out
14
28
 
15
29
 
16
- class MLPSwiGLUFunc(torch.autograd.Function):
17
- @staticmethod
18
- @custom_fwd(device_type="cuda")
19
- def forward(ctx, x, weight1, weight2, fuse_grad_accum=False):
20
- """
21
- x: (..., in_features)
22
- weight1: (2 * intermediate_features, in_features)
23
- weight2: (out_features, intermediate_features)
24
- out: (..., out_features)
25
- Note that we do swiglu on the even and odd indices of the intermediate output,
26
- i.e. silu(y[..., ::2]) * y[..., 1::2].
27
- This is different from the usual swiglu implementation that does: y1, y2 = y.chunk(2, dim=-1); silu(y1) * y2
28
- """
29
- needs_weight1_grad = weight1.requires_grad
30
- needs_weight2_grad = weight2.requires_grad
31
- needs_input_grad = x.requires_grad
32
- ctx.weight1_dtype = weight1.dtype
33
- ctx.weight2_dtype = weight2.dtype
34
- autocast_dtype = torch.get_autocast_dtype("cuda")
35
- if torch.is_autocast_enabled():
36
- x = x.to(dtype=autocast_dtype)
37
- weight1_og = weight1
38
- weight2_og = weight2
39
- if torch.is_autocast_enabled():
40
- weight1 = weight1.to(dtype=autocast_dtype)
41
- weight2 = weight2.to(dtype=autocast_dtype)
42
- batch_shape = x.shape[:-1]
43
- x = x.reshape(-1, x.shape[-1])
44
- # don't need preact if not computing gradient
45
- store_preact = needs_input_grad or needs_weight1_grad or needs_weight2_grad
46
- # (batch, inter_dim) & (batch, 2 * inter_dim)
47
- y, preact = gemm_swiglu(x, weight1.T, store_preact=store_preact)
48
- # out = F.linear(y, weight2)
49
- out = gemm(y, weight2.T)
50
- if not needs_input_grad:
51
- weight1, weight1_og = None, None
52
- if not needs_weight1_grad:
53
- x = None
54
- if not needs_input_grad and not needs_weight1_grad and not needs_weight2_grad:
55
- weight2, weight2_og = None, None
56
- preact = None
57
- ctx.save_for_backward(
58
- x,
59
- preact,
60
- weight1,
61
- weight2,
62
- *((weight1_og, weight2_og) if fuse_grad_accum else (None, None)),
63
- )
64
- ctx.fuse_grad_accum = fuse_grad_accum
65
- return out.reshape(*batch_shape, out.shape[-1])
66
-
67
- @staticmethod
68
- @custom_bwd(device_type="cuda")
69
- def backward(ctx, dout):
70
- """
71
- dout: (..., out_features)
72
- """
73
- if not torch.compiler.is_dynamo_compiling():
74
- assert dout.stride(-1) == 1
75
- # weight1_og and weight2_og are None if not ctx.fused_grad_accum
76
- x, preact, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors
77
- batch_shape = dout.shape[:-1]
78
- dout = dout.reshape(-1, dout.shape[-1])
79
- if (
80
- not ctx.needs_input_grad[0]
81
- and not ctx.needs_weight1_grad[0]
82
- and not ctx.needs_weight2_grad[0]
83
- ):
84
- return (None,) * 4
85
- assert preact is not None
86
- # (batch, 2 * inter_dim) and (batch, inter_dim)
87
- # dpreact, y = gemm_dswiglu(dout, weight2, preact)
88
- dpreact, y = gemm_dswiglu(dout, weight2, preact, sm_carveout=16)
89
- if ctx.needs_input_grad[2]:
90
- # fuse_grad_accum is not compatible with torch.compile
91
- if not ctx.fuse_grad_accum or weight2_og.grad is None or torch.compiler.is_compiling():
92
- dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype)
93
- # dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype, sm_carveout=16)
94
- else:
95
- # print("Using fuse grad accum in MLP 2", dout.shape, y.shape, weight2_og.grad.shape)
96
- gemm_add_cb_(dout.T, y, weight2_og.grad)
97
- # gemm_add_cb_(dout.T, y, weight2_og.grad, sm_carveout=16)
98
- dweight2 = weight2_og.grad
99
- weight2_og.grad = (
100
- None # So that pytorch doesn't add dweight to weight2_og.grad again
101
- )
102
- else:
103
- dweight2 = None
104
- if ctx.needs_input_grad[0]:
105
- dx = dpreact @ weight1 # (batch, in_features)
106
- # dx = gemm(dpreact, weight1) # (batch, in_features)
107
- dx = dx.reshape(*batch_shape, dx.shape[-1])
108
- else:
109
- dx = None
110
- if ctx.needs_input_grad[1]:
111
- # fuse_grad_accum is not compatible with torch.compile
112
- if not ctx.fuse_grad_accum or weight1_og.grad is None or torch.compiler.is_compiling():
113
- dweight1 = gemm_cb(dpreact.T, x, out_dtype=ctx.weight1_dtype)
114
- else:
115
- # print("Using fuse grad accum in MLP 1", dpreact.shape, x.shape, weight1_og.grad.shape)
116
- gemm_add_cb_(dpreact.T, x, weight1_og.grad)
117
- dweight1 = weight1_og.grad
118
- weight1_og.grad = (
119
- None # So that pytorch doesn't add dweight to weight1_og.grad again
120
- )
121
- else:
122
- dweight1 = None
123
- return dx, dweight1, dweight2, None
124
-
125
-
126
- def mlp_swiglu_func(x, weight1, weight2, fuse_grad_accum=False):
127
- return MLPSwiGLUFunc.apply(x, weight1, weight2, fuse_grad_accum)
128
-
129
-
130
- class MLPSwiGLU(nn.Module):
30
+ class MLP(nn.Module):
131
31
  def __init__(
132
32
  self,
133
33
  in_features,
@@ -135,25 +35,21 @@ class MLPSwiGLU(nn.Module):
135
35
  out_features=None,
136
36
  bias1=False,
137
37
  bias2=False,
138
- multiple_of=128,
38
+ activation="gelu",
139
39
  device=None,
140
40
  dtype=None,
141
41
  fuse_grad_accum: bool = False,
42
+ tuned: bool = True,
142
43
  ):
143
44
  factory_kwargs = {"device": device, "dtype": dtype}
144
45
  super().__init__()
145
46
  out_features = out_features if out_features is not None else in_features
146
- hidden_features = (
147
- hidden_features if hidden_features is not None else int(8 * in_features / 3)
148
- )
149
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
150
- self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
151
- self.fc1.weight._muon_reshape_functions = (
152
- lambda w: rearrange(w, "(d two) e -> two d e", two=2),
153
- lambda w: rearrange(w, "two d e -> (d two) e"),
154
- )
47
+ hidden_features = hidden_features if hidden_features is not None else 4 * in_features
48
+ self.activation = activation
49
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
155
50
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
156
51
  self.fuse_grad_accum = fuse_grad_accum
52
+ self.tuned = tuned
157
53
 
158
54
  def forward(self, input: Tensor) -> Tensor:
159
55
  if (
@@ -162,43 +58,17 @@ class MLPSwiGLU(nn.Module):
162
58
  and input.is_cuda
163
59
  and input.stride(-1) == 1
164
60
  and self.fc1.in_features % 8 == 0
165
- and self.fc1.out_features % 16 == 0
61
+ and self.fc1.out_features % 8 == 0
166
62
  and self.fc2.out_features % 8 == 0
167
63
  ):
168
- return mlp_swiglu_func(
64
+ return mlp_func(
169
65
  input,
170
66
  self.fc1.weight,
171
67
  self.fc2.weight,
68
+ activation=self.activation,
172
69
  fuse_grad_accum=self.fuse_grad_accum,
70
+ tuned=self.tuned,
173
71
  )
174
72
  else:
175
73
  y = self.fc1(input)
176
74
  return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
177
-
178
-
179
- class MLPSwiGLURef(nn.Module):
180
- def __init__(
181
- self,
182
- in_features,
183
- hidden_features=None,
184
- out_features=None,
185
- bias1=False,
186
- bias2=False,
187
- multiple_of=128,
188
- device=None,
189
- dtype=None,
190
- ):
191
- factory_kwargs = {"device": device, "dtype": dtype}
192
- super().__init__()
193
- out_features = out_features if out_features is not None else in_features
194
- hidden_features = (
195
- hidden_features if hidden_features is not None else int(8 * in_features / 3)
196
- )
197
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
198
- self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
199
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
200
-
201
- def forward(self, input: Tensor) -> Tensor:
202
- y = self.fc1(input)
203
- y1, y2 = y.chunk(2, dim=-1)
204
- return self.fc2(F.silu(y1) * y2)
quack/pipeline.py CHANGED
@@ -8,21 +8,6 @@ from cutlass.cutlass_dsl import Boolean, Int32, if_generate
8
8
  from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
9
9
  from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
10
10
 
11
- from cutlass.cutlass_dsl import dsl_user_op
12
- from cutlass._mlir.dialects import nvvm
13
-
14
-
15
- @dsl_user_op
16
- def cp_async_mbarrier_arrive_shared(
17
- mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None
18
- ) -> None:
19
- nvvm.cp_async_mbarrier_arrive_shared(
20
- mbar_ptr.llvm_ptr,
21
- noinc=noinc,
22
- loc=loc,
23
- ip=ip,
24
- )
25
-
26
11
 
27
12
  class PipelineStateWAdvance(PipelineState):
28
13
  def advance_iters(self, num_iterations: Int32):
@@ -65,7 +50,7 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
65
50
  @dataclass(frozen=True)
66
51
  class PipelineTmaCpAsync(PipelineTmaAsync):
67
52
  """
68
- PipelineTmaCpAsync is used for CpAync + TMA producers and AsyncThread consumers
53
+ PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
69
54
  """
70
55
 
71
56
  @staticmethod
@@ -163,4 +148,4 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
163
148
  """
164
149
  We need the mbarrier to track the completion of cp.async
165
150
  """
166
- cp_async_mbarrier_arrive_shared(self.producer_get_barrier(state), noinc=True)
151
+ cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))