quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
from typing import Optional, Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.amp import custom_fwd, custom_bwd
|
|
9
|
+
|
|
10
|
+
from quack.cross_entropy import cross_entropy, cross_entropy_fwd_out
|
|
11
|
+
from quack.gemm_interface import gemm, gemm_add, gemm_add_inplace
|
|
12
|
+
from quack.linear import linear_fwd_convert_type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def linear_cross_entropy_func(
|
|
16
|
+
x: Tensor, # (..., d)
|
|
17
|
+
weight: Tensor, # (V, d)
|
|
18
|
+
bias: Optional[Tensor], # (V,) or None
|
|
19
|
+
target: Tensor, # (...,), int or long
|
|
20
|
+
ignore_index: int = -100,
|
|
21
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
22
|
+
inplace_backward: bool = False,
|
|
23
|
+
) -> Tensor:
|
|
24
|
+
y = F.linear(x, weight, bias) # (..., V)
|
|
25
|
+
return cross_entropy(
|
|
26
|
+
y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def linear_cross_entropy_func_ref(
|
|
31
|
+
x: Tensor, # (..., d)
|
|
32
|
+
weight: Tensor, # (V, d)
|
|
33
|
+
bias: Optional[Tensor], # (V,) or None
|
|
34
|
+
target: Tensor, # (...,), int or long
|
|
35
|
+
ignore_index: int = -100,
|
|
36
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
37
|
+
) -> Tensor:
|
|
38
|
+
y = F.linear(x, weight, bias) # (..., V)
|
|
39
|
+
return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def chunked_linear_cross_entropy_fwd(
|
|
43
|
+
x: Tensor, # (B*L, d) where B is batch, L is seqlen
|
|
44
|
+
weight: Tensor, # (V, d) where V is vocab size
|
|
45
|
+
target: Tensor, # (B*L,)
|
|
46
|
+
chunk_size: int = 4096,
|
|
47
|
+
ignore_index: int = -100,
|
|
48
|
+
tuned: bool = True,
|
|
49
|
+
) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
50
|
+
"""
|
|
51
|
+
Chunked forward pass for linear cross entropy.
|
|
52
|
+
|
|
53
|
+
Splits input along batch dimension, computes matmul and cross_entropy_fwd
|
|
54
|
+
for each chunk, stores dx for each chunk, and accumulates dw.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
loss: (B*L,) loss values
|
|
58
|
+
dx: (B*L, d) gradient w.r.t. input
|
|
59
|
+
dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
|
|
60
|
+
last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
|
|
61
|
+
last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
|
|
62
|
+
"""
|
|
63
|
+
B_L, d = x.shape
|
|
64
|
+
V, _ = weight.shape
|
|
65
|
+
device = x.device
|
|
66
|
+
num_chunks = (B_L + chunk_size - 1) // chunk_size
|
|
67
|
+
# Since we use gemm with TMA we require some alignment
|
|
68
|
+
assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
|
|
69
|
+
assert B_L % 8 == 0
|
|
70
|
+
# Pre-allocate outputs
|
|
71
|
+
loss = torch.empty(B_L, device=device, dtype=torch.float32)
|
|
72
|
+
logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
|
|
73
|
+
dx = torch.empty_like(x)
|
|
74
|
+
# Last chunk of dw will be deferred to the backward pass
|
|
75
|
+
dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
|
|
76
|
+
last_dlogits_chunk = None
|
|
77
|
+
last_x_chunk = None
|
|
78
|
+
|
|
79
|
+
# Process in chunks
|
|
80
|
+
for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
|
|
81
|
+
zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
|
|
82
|
+
):
|
|
83
|
+
chunk_len = x_chunk.shape[0]
|
|
84
|
+
logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
|
|
85
|
+
torch.mm(x_chunk, weight.mT, out=logits_chunk)
|
|
86
|
+
# Compute cross entropy forward with gradients
|
|
87
|
+
dlogits_chunk = logits_chunk # inplace_backward
|
|
88
|
+
cross_entropy_fwd_out(
|
|
89
|
+
logits_chunk,
|
|
90
|
+
target_chunk,
|
|
91
|
+
None, # target_logit
|
|
92
|
+
loss=loss_chunk,
|
|
93
|
+
lse=None, # we don't need lse here
|
|
94
|
+
dx=dlogits_chunk,
|
|
95
|
+
ignore_index=ignore_index,
|
|
96
|
+
)
|
|
97
|
+
# Compute dx for this chunk: dlogits @ weight
|
|
98
|
+
torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
|
|
99
|
+
# Compute dw for all chunks except the last
|
|
100
|
+
if i == num_chunks - 1:
|
|
101
|
+
# Last chunk: save for backward pass
|
|
102
|
+
last_dlogits_chunk = dlogits_chunk
|
|
103
|
+
last_x_chunk = x_chunk
|
|
104
|
+
elif i == 0:
|
|
105
|
+
# First chunk: dw = dlogits.T @ x_chunk
|
|
106
|
+
gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
|
|
107
|
+
else:
|
|
108
|
+
# Middle chunks: dw += dlogits.T @ x_chunk
|
|
109
|
+
gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
|
|
110
|
+
return loss, dx, dw, last_dlogits_chunk, last_x_chunk
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
114
|
+
@staticmethod
|
|
115
|
+
@custom_fwd(device_type="cuda")
|
|
116
|
+
def forward(
|
|
117
|
+
ctx,
|
|
118
|
+
x: Tensor,
|
|
119
|
+
weight: Tensor,
|
|
120
|
+
target: Tensor,
|
|
121
|
+
ignore_index: int = -100,
|
|
122
|
+
reduction: Literal["mean", "sum"] = "mean",
|
|
123
|
+
chunk_size: int = 4096,
|
|
124
|
+
tuned: bool = True,
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Forward pass computes loss and stores dx and dw for backward.
|
|
128
|
+
"""
|
|
129
|
+
ctx.weight_dtype = weight.dtype
|
|
130
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
131
|
+
batch_shape = x.shape[:-1]
|
|
132
|
+
x = x.reshape(-1, x.shape[-1])
|
|
133
|
+
# TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
|
|
134
|
+
loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
|
|
135
|
+
x, weight, target, chunk_size, ignore_index, tuned=tuned
|
|
136
|
+
)
|
|
137
|
+
loss_sum = loss.sum()
|
|
138
|
+
loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
|
|
139
|
+
ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
|
|
140
|
+
ctx.batch_shape = batch_shape
|
|
141
|
+
ctx.ignore_index = ignore_index
|
|
142
|
+
ctx.reduction = reduction
|
|
143
|
+
ctx.tuned = tuned
|
|
144
|
+
return loss_sum if loss_scale is None else loss_sum * loss_scale
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
@custom_bwd(device_type="cuda")
|
|
148
|
+
def backward(ctx, dloss):
|
|
149
|
+
"""
|
|
150
|
+
Backward pass scales pre-computed gradients by dloss and completes
|
|
151
|
+
the last chunk's dw computation.
|
|
152
|
+
dloss is a scalar.
|
|
153
|
+
"""
|
|
154
|
+
dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
|
|
155
|
+
tuned = ctx.tuned
|
|
156
|
+
if loss_scale is not None:
|
|
157
|
+
dloss = dloss * loss_scale
|
|
158
|
+
# TODO: the case where x or weight doesn't require grad
|
|
159
|
+
dx.mul_(dloss)
|
|
160
|
+
dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
|
|
161
|
+
# Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
|
|
162
|
+
if dw is None:
|
|
163
|
+
# Only had one chunk, compute dw directly with dloss scaling
|
|
164
|
+
dw = gemm(
|
|
165
|
+
last_dlogits_chunk.T,
|
|
166
|
+
last_x_chunk,
|
|
167
|
+
out_dtype=ctx.weight_dtype,
|
|
168
|
+
alpha=dloss,
|
|
169
|
+
tuned=tuned,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
# Add last chunk's contribution with dloss scaling
|
|
173
|
+
# dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
|
|
174
|
+
# We use alpha=dloss, beta=dloss
|
|
175
|
+
if ctx.weight_dtype == dw.dtype:
|
|
176
|
+
gemm_add_inplace(
|
|
177
|
+
last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
dw = gemm_add(
|
|
181
|
+
last_dlogits_chunk.T,
|
|
182
|
+
last_x_chunk,
|
|
183
|
+
dw,
|
|
184
|
+
alpha=dloss,
|
|
185
|
+
beta=dloss,
|
|
186
|
+
out_dtype=ctx.weight_dtype,
|
|
187
|
+
tuned=tuned,
|
|
188
|
+
)
|
|
189
|
+
return dx, dw, None, None, None, None, None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def chunked_linear_cross_entropy(
|
|
193
|
+
x: Tensor,
|
|
194
|
+
weight: Tensor,
|
|
195
|
+
target: Tensor,
|
|
196
|
+
chunk_size: int = 4096,
|
|
197
|
+
ignore_index: int = -100,
|
|
198
|
+
reduction: Literal["mean", "sum"] = "mean",
|
|
199
|
+
tuned: bool = True,
|
|
200
|
+
) -> Tensor:
|
|
201
|
+
"""
|
|
202
|
+
Chunked linear cross entropy with automatic differentiation support.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
x: Input tensor of shape (B*L, d)
|
|
206
|
+
weight: Weight tensor of shape (V, d)
|
|
207
|
+
target: Target indices of shape (B*L,)
|
|
208
|
+
chunk_size: Size of chunks to process
|
|
209
|
+
ignore_index: Index to ignore in loss computation
|
|
210
|
+
reduction: Type of reduction to apply
|
|
211
|
+
tuned: Whether to use tuned kernels
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Loss tensor with specified reduction
|
|
215
|
+
"""
|
|
216
|
+
if reduction not in ["mean", "sum"]:
|
|
217
|
+
raise ValueError(f"Invalid reduction: {reduction}")
|
|
218
|
+
loss = ChunkedLinearCrossEntropyFunction.apply(
|
|
219
|
+
x, weight, target, ignore_index, reduction, chunk_size, tuned
|
|
220
|
+
)
|
|
221
|
+
return loss
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class LinearCrossEntropy(nn.Linear):
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
in_features: int,
|
|
228
|
+
out_features: int,
|
|
229
|
+
bias: bool = False,
|
|
230
|
+
ignore_index: int = -100,
|
|
231
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
232
|
+
chunk_size: Optional[int] = None,
|
|
233
|
+
inplace_backward: bool = False,
|
|
234
|
+
tuned: bool = True,
|
|
235
|
+
device=None,
|
|
236
|
+
dtype=None,
|
|
237
|
+
) -> None:
|
|
238
|
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
239
|
+
self.ignore_index = ignore_index
|
|
240
|
+
self.reduction = reduction
|
|
241
|
+
self.chunk_size = chunk_size
|
|
242
|
+
self.inplace_backward = inplace_backward
|
|
243
|
+
self.tuned = tuned
|
|
244
|
+
|
|
245
|
+
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
|
246
|
+
if (
|
|
247
|
+
self.bias is None
|
|
248
|
+
and input.is_cuda
|
|
249
|
+
and input.stride(-1) == 1
|
|
250
|
+
and self.in_features % 8 == 0
|
|
251
|
+
and self.out_features % 8 == 0
|
|
252
|
+
and input.shape[:-1].numel() % 8 == 0
|
|
253
|
+
and self.chunk_size is not None
|
|
254
|
+
and self.chunk_size % 8 == 0
|
|
255
|
+
and self.reduction in ["mean", "sum"]
|
|
256
|
+
):
|
|
257
|
+
return chunked_linear_cross_entropy(
|
|
258
|
+
input,
|
|
259
|
+
self.weight,
|
|
260
|
+
target,
|
|
261
|
+
chunk_size=self.chunk_size,
|
|
262
|
+
ignore_index=self.ignore_index,
|
|
263
|
+
reduction=self.reduction,
|
|
264
|
+
tuned=self.tuned,
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
return linear_cross_entropy_func(
|
|
268
|
+
input,
|
|
269
|
+
self.weight,
|
|
270
|
+
self.bias,
|
|
271
|
+
target,
|
|
272
|
+
ignore_index=self.ignore_index,
|
|
273
|
+
reduction=self.reduction,
|
|
274
|
+
inplace_backward=self.inplace_backward,
|
|
275
|
+
)
|
quack/mlp.py
CHANGED
|
@@ -3,131 +3,31 @@ import torch
|
|
|
3
3
|
import torch.nn as nn
|
|
4
4
|
import torch.nn.functional as F
|
|
5
5
|
from torch import Tensor
|
|
6
|
-
from torch.amp import custom_fwd, custom_bwd
|
|
7
6
|
|
|
8
|
-
from
|
|
7
|
+
from quack.linear import linear_act_func, act_linear_func
|
|
9
8
|
|
|
10
|
-
from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
|
|
11
|
-
# from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
|
|
12
9
|
|
|
13
|
-
|
|
10
|
+
def mlp_func(x, weight1, weight2, activation: str, fuse_grad_accum=False, tuned=True):
|
|
11
|
+
preact, postact = linear_act_func(
|
|
12
|
+
x,
|
|
13
|
+
weight1,
|
|
14
|
+
activation,
|
|
15
|
+
store_preact=torch.is_grad_enabled(),
|
|
16
|
+
fuse_grad_accum=fuse_grad_accum,
|
|
17
|
+
tuned=tuned,
|
|
18
|
+
)
|
|
19
|
+
out = act_linear_func(
|
|
20
|
+
preact,
|
|
21
|
+
weight2,
|
|
22
|
+
postact,
|
|
23
|
+
activation=activation,
|
|
24
|
+
fuse_grad_accum=fuse_grad_accum,
|
|
25
|
+
tuned=tuned,
|
|
26
|
+
)
|
|
27
|
+
return out
|
|
14
28
|
|
|
15
29
|
|
|
16
|
-
class
|
|
17
|
-
@staticmethod
|
|
18
|
-
@custom_fwd(device_type="cuda")
|
|
19
|
-
def forward(ctx, x, weight1, weight2, fuse_grad_accum=False):
|
|
20
|
-
"""
|
|
21
|
-
x: (..., in_features)
|
|
22
|
-
weight1: (2 * intermediate_features, in_features)
|
|
23
|
-
weight2: (out_features, intermediate_features)
|
|
24
|
-
out: (..., out_features)
|
|
25
|
-
Note that we do swiglu on the even and odd indices of the intermediate output,
|
|
26
|
-
i.e. silu(y[..., ::2]) * y[..., 1::2].
|
|
27
|
-
This is different from the usual swiglu implementation that does: y1, y2 = y.chunk(2, dim=-1); silu(y1) * y2
|
|
28
|
-
"""
|
|
29
|
-
needs_weight1_grad = weight1.requires_grad
|
|
30
|
-
needs_weight2_grad = weight2.requires_grad
|
|
31
|
-
needs_input_grad = x.requires_grad
|
|
32
|
-
ctx.weight1_dtype = weight1.dtype
|
|
33
|
-
ctx.weight2_dtype = weight2.dtype
|
|
34
|
-
autocast_dtype = torch.get_autocast_dtype("cuda")
|
|
35
|
-
if torch.is_autocast_enabled():
|
|
36
|
-
x = x.to(dtype=autocast_dtype)
|
|
37
|
-
weight1_og = weight1
|
|
38
|
-
weight2_og = weight2
|
|
39
|
-
if torch.is_autocast_enabled():
|
|
40
|
-
weight1 = weight1.to(dtype=autocast_dtype)
|
|
41
|
-
weight2 = weight2.to(dtype=autocast_dtype)
|
|
42
|
-
batch_shape = x.shape[:-1]
|
|
43
|
-
x = x.reshape(-1, x.shape[-1])
|
|
44
|
-
# don't need preact if not computing gradient
|
|
45
|
-
store_preact = needs_input_grad or needs_weight1_grad or needs_weight2_grad
|
|
46
|
-
# (batch, inter_dim) & (batch, 2 * inter_dim)
|
|
47
|
-
y, preact = gemm_swiglu(x, weight1.T, store_preact=store_preact)
|
|
48
|
-
# out = F.linear(y, weight2)
|
|
49
|
-
out = gemm(y, weight2.T)
|
|
50
|
-
if not needs_input_grad:
|
|
51
|
-
weight1, weight1_og = None, None
|
|
52
|
-
if not needs_weight1_grad:
|
|
53
|
-
x = None
|
|
54
|
-
if not needs_input_grad and not needs_weight1_grad and not needs_weight2_grad:
|
|
55
|
-
weight2, weight2_og = None, None
|
|
56
|
-
preact = None
|
|
57
|
-
ctx.save_for_backward(
|
|
58
|
-
x,
|
|
59
|
-
preact,
|
|
60
|
-
weight1,
|
|
61
|
-
weight2,
|
|
62
|
-
*((weight1_og, weight2_og) if fuse_grad_accum else (None, None)),
|
|
63
|
-
)
|
|
64
|
-
ctx.fuse_grad_accum = fuse_grad_accum
|
|
65
|
-
return out.reshape(*batch_shape, out.shape[-1])
|
|
66
|
-
|
|
67
|
-
@staticmethod
|
|
68
|
-
@custom_bwd(device_type="cuda")
|
|
69
|
-
def backward(ctx, dout):
|
|
70
|
-
"""
|
|
71
|
-
dout: (..., out_features)
|
|
72
|
-
"""
|
|
73
|
-
if not torch.compiler.is_dynamo_compiling():
|
|
74
|
-
assert dout.stride(-1) == 1
|
|
75
|
-
# weight1_og and weight2_og are None if not ctx.fused_grad_accum
|
|
76
|
-
x, preact, weight1, weight2, weight1_og, weight2_og = ctx.saved_tensors
|
|
77
|
-
batch_shape = dout.shape[:-1]
|
|
78
|
-
dout = dout.reshape(-1, dout.shape[-1])
|
|
79
|
-
if (
|
|
80
|
-
not ctx.needs_input_grad[0]
|
|
81
|
-
and not ctx.needs_weight1_grad[0]
|
|
82
|
-
and not ctx.needs_weight2_grad[0]
|
|
83
|
-
):
|
|
84
|
-
return (None,) * 4
|
|
85
|
-
assert preact is not None
|
|
86
|
-
# (batch, 2 * inter_dim) and (batch, inter_dim)
|
|
87
|
-
# dpreact, y = gemm_dswiglu(dout, weight2, preact)
|
|
88
|
-
dpreact, y = gemm_dswiglu(dout, weight2, preact, sm_carveout=16)
|
|
89
|
-
if ctx.needs_input_grad[2]:
|
|
90
|
-
# fuse_grad_accum is not compatible with torch.compile
|
|
91
|
-
if not ctx.fuse_grad_accum or weight2_og.grad is None or torch.compiler.is_compiling():
|
|
92
|
-
dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype)
|
|
93
|
-
# dweight2 = gemm_cb(dout.T, y, out_dtype=ctx.weight2_dtype, sm_carveout=16)
|
|
94
|
-
else:
|
|
95
|
-
# print("Using fuse grad accum in MLP 2", dout.shape, y.shape, weight2_og.grad.shape)
|
|
96
|
-
gemm_add_cb_(dout.T, y, weight2_og.grad)
|
|
97
|
-
# gemm_add_cb_(dout.T, y, weight2_og.grad, sm_carveout=16)
|
|
98
|
-
dweight2 = weight2_og.grad
|
|
99
|
-
weight2_og.grad = (
|
|
100
|
-
None # So that pytorch doesn't add dweight to weight2_og.grad again
|
|
101
|
-
)
|
|
102
|
-
else:
|
|
103
|
-
dweight2 = None
|
|
104
|
-
if ctx.needs_input_grad[0]:
|
|
105
|
-
dx = dpreact @ weight1 # (batch, in_features)
|
|
106
|
-
# dx = gemm(dpreact, weight1) # (batch, in_features)
|
|
107
|
-
dx = dx.reshape(*batch_shape, dx.shape[-1])
|
|
108
|
-
else:
|
|
109
|
-
dx = None
|
|
110
|
-
if ctx.needs_input_grad[1]:
|
|
111
|
-
# fuse_grad_accum is not compatible with torch.compile
|
|
112
|
-
if not ctx.fuse_grad_accum or weight1_og.grad is None or torch.compiler.is_compiling():
|
|
113
|
-
dweight1 = gemm_cb(dpreact.T, x, out_dtype=ctx.weight1_dtype)
|
|
114
|
-
else:
|
|
115
|
-
# print("Using fuse grad accum in MLP 1", dpreact.shape, x.shape, weight1_og.grad.shape)
|
|
116
|
-
gemm_add_cb_(dpreact.T, x, weight1_og.grad)
|
|
117
|
-
dweight1 = weight1_og.grad
|
|
118
|
-
weight1_og.grad = (
|
|
119
|
-
None # So that pytorch doesn't add dweight to weight1_og.grad again
|
|
120
|
-
)
|
|
121
|
-
else:
|
|
122
|
-
dweight1 = None
|
|
123
|
-
return dx, dweight1, dweight2, None
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def mlp_swiglu_func(x, weight1, weight2, fuse_grad_accum=False):
|
|
127
|
-
return MLPSwiGLUFunc.apply(x, weight1, weight2, fuse_grad_accum)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class MLPSwiGLU(nn.Module):
|
|
30
|
+
class MLP(nn.Module):
|
|
131
31
|
def __init__(
|
|
132
32
|
self,
|
|
133
33
|
in_features,
|
|
@@ -135,25 +35,21 @@ class MLPSwiGLU(nn.Module):
|
|
|
135
35
|
out_features=None,
|
|
136
36
|
bias1=False,
|
|
137
37
|
bias2=False,
|
|
138
|
-
|
|
38
|
+
activation="gelu",
|
|
139
39
|
device=None,
|
|
140
40
|
dtype=None,
|
|
141
41
|
fuse_grad_accum: bool = False,
|
|
42
|
+
tuned: bool = True,
|
|
142
43
|
):
|
|
143
44
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
144
45
|
super().__init__()
|
|
145
46
|
out_features = out_features if out_features is not None else in_features
|
|
146
|
-
hidden_features =
|
|
147
|
-
|
|
148
|
-
)
|
|
149
|
-
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
150
|
-
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
|
151
|
-
self.fc1.weight._muon_reshape_functions = (
|
|
152
|
-
lambda w: rearrange(w, "(d two) e -> two d e", two=2),
|
|
153
|
-
lambda w: rearrange(w, "two d e -> (d two) e"),
|
|
154
|
-
)
|
|
47
|
+
hidden_features = hidden_features if hidden_features is not None else 4 * in_features
|
|
48
|
+
self.activation = activation
|
|
49
|
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
|
155
50
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
156
51
|
self.fuse_grad_accum = fuse_grad_accum
|
|
52
|
+
self.tuned = tuned
|
|
157
53
|
|
|
158
54
|
def forward(self, input: Tensor) -> Tensor:
|
|
159
55
|
if (
|
|
@@ -162,43 +58,17 @@ class MLPSwiGLU(nn.Module):
|
|
|
162
58
|
and input.is_cuda
|
|
163
59
|
and input.stride(-1) == 1
|
|
164
60
|
and self.fc1.in_features % 8 == 0
|
|
165
|
-
and self.fc1.out_features %
|
|
61
|
+
and self.fc1.out_features % 8 == 0
|
|
166
62
|
and self.fc2.out_features % 8 == 0
|
|
167
63
|
):
|
|
168
|
-
return
|
|
64
|
+
return mlp_func(
|
|
169
65
|
input,
|
|
170
66
|
self.fc1.weight,
|
|
171
67
|
self.fc2.weight,
|
|
68
|
+
activation=self.activation,
|
|
172
69
|
fuse_grad_accum=self.fuse_grad_accum,
|
|
70
|
+
tuned=self.tuned,
|
|
173
71
|
)
|
|
174
72
|
else:
|
|
175
73
|
y = self.fc1(input)
|
|
176
74
|
return self.fc2(F.silu(y[..., ::2]) * y[..., 1::2])
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
class MLPSwiGLURef(nn.Module):
|
|
180
|
-
def __init__(
|
|
181
|
-
self,
|
|
182
|
-
in_features,
|
|
183
|
-
hidden_features=None,
|
|
184
|
-
out_features=None,
|
|
185
|
-
bias1=False,
|
|
186
|
-
bias2=False,
|
|
187
|
-
multiple_of=128,
|
|
188
|
-
device=None,
|
|
189
|
-
dtype=None,
|
|
190
|
-
):
|
|
191
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
192
|
-
super().__init__()
|
|
193
|
-
out_features = out_features if out_features is not None else in_features
|
|
194
|
-
hidden_features = (
|
|
195
|
-
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
|
196
|
-
)
|
|
197
|
-
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
|
198
|
-
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
|
199
|
-
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
|
200
|
-
|
|
201
|
-
def forward(self, input: Tensor) -> Tensor:
|
|
202
|
-
y = self.fc1(input)
|
|
203
|
-
y1, y2 = y.chunk(2, dim=-1)
|
|
204
|
-
return self.fc2(F.silu(y1) * y2)
|
quack/pipeline.py
CHANGED
|
@@ -8,21 +8,6 @@ from cutlass.cutlass_dsl import Boolean, Int32, if_generate
|
|
|
8
8
|
from cutlass.pipeline import CooperativeGroup, PipelineOp, pipeline_init_wait
|
|
9
9
|
from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType
|
|
10
10
|
|
|
11
|
-
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
|
-
from cutlass._mlir.dialects import nvvm
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@dsl_user_op
|
|
16
|
-
def cp_async_mbarrier_arrive_shared(
|
|
17
|
-
mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None
|
|
18
|
-
) -> None:
|
|
19
|
-
nvvm.cp_async_mbarrier_arrive_shared(
|
|
20
|
-
mbar_ptr.llvm_ptr,
|
|
21
|
-
noinc=noinc,
|
|
22
|
-
loc=loc,
|
|
23
|
-
ip=ip,
|
|
24
|
-
)
|
|
25
|
-
|
|
26
11
|
|
|
27
12
|
class PipelineStateWAdvance(PipelineState):
|
|
28
13
|
def advance_iters(self, num_iterations: Int32):
|
|
@@ -65,7 +50,7 @@ def make_pipeline_state(type: PipelineUserType, stages: int):
|
|
|
65
50
|
@dataclass(frozen=True)
|
|
66
51
|
class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
67
52
|
"""
|
|
68
|
-
PipelineTmaCpAsync is used for
|
|
53
|
+
PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers
|
|
69
54
|
"""
|
|
70
55
|
|
|
71
56
|
@staticmethod
|
|
@@ -163,4 +148,4 @@ class PipelineTmaCpAsync(PipelineTmaAsync):
|
|
|
163
148
|
"""
|
|
164
149
|
We need the mbarrier to track the completion of cp.async
|
|
165
150
|
"""
|
|
166
|
-
|
|
151
|
+
cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state))
|