quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +4 -1
- quack/autotuner.py +309 -0
- quack/cross_entropy.py +2 -5
- quack/cute_dsl_utils.py +40 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +2474 -0
- quack/fast_math.py +97 -0
- quack/gemm_config.py +61 -0
- quack/gemm_interface.py +321 -0
- quack/linear.py +176 -0
- quack/lse.py +62 -0
- quack/mlp.py +204 -0
- quack/pipeline.py +166 -0
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2088 -0
- quack/tensormap_manager.py +114 -0
- quack/tile_scheduler.py +935 -0
- quack/topk.py +221 -0
- quack/utils.py +237 -19
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/METADATA +3 -3
- quack_kernels-0.1.11.dist-info/RECORD +31 -0
- quack_kernels-0.1.9.dist-info/RECORD +0 -12
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.9.dist-info → quack_kernels-0.1.11.dist-info}/top_level.txt +0 -0
quack/fast_math.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
import cutlass
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass import Int32, Uint32
|
|
8
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
|
+
from cutlass._mlir.dialects import llvm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@cute.jit
|
|
13
|
+
def clz(x: Int32) -> Int32:
|
|
14
|
+
# for i in cutlass.range_constexpr(32):
|
|
15
|
+
# if (1 << (31 - i)) & x:
|
|
16
|
+
# return Int32(i)
|
|
17
|
+
# return Int32(32)
|
|
18
|
+
# Early exit is not supported yet
|
|
19
|
+
res = Int32(32)
|
|
20
|
+
done = False
|
|
21
|
+
for i in cutlass.range(32):
|
|
22
|
+
if ((1 << (31 - i)) & x) and not done:
|
|
23
|
+
res = Int32(i)
|
|
24
|
+
done = True
|
|
25
|
+
return res
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def find_log2(x: Int32) -> Int32:
|
|
29
|
+
a: Int32 = Int32(31 - clz(x))
|
|
30
|
+
return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dsl_user_op
|
|
34
|
+
def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
|
|
35
|
+
return Uint32(
|
|
36
|
+
llvm.inline_asm(
|
|
37
|
+
T.i32(),
|
|
38
|
+
[Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
|
|
39
|
+
"mul.hi.u32 $0, $1, $2;",
|
|
40
|
+
"=r,r,r",
|
|
41
|
+
has_side_effects=False,
|
|
42
|
+
is_align_stack=False,
|
|
43
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
44
|
+
)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FastDivmod:
|
|
49
|
+
def __init__(
|
|
50
|
+
self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None
|
|
51
|
+
):
|
|
52
|
+
self.divisor = divisor
|
|
53
|
+
self.multiplier = multipler
|
|
54
|
+
self.shift_right = shift_right
|
|
55
|
+
self._loc = loc
|
|
56
|
+
|
|
57
|
+
# called by host
|
|
58
|
+
@staticmethod
|
|
59
|
+
def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod":
|
|
60
|
+
"""Construct the FastDivmod object, in host code.
|
|
61
|
+
This precomputes some values based on the divisor and is computationally expensive.
|
|
62
|
+
"""
|
|
63
|
+
p = Uint32(31 + find_log2(divisor))
|
|
64
|
+
divisor_u32 = Uint32(divisor)
|
|
65
|
+
multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
|
|
66
|
+
shift_right = Uint32(p - 32)
|
|
67
|
+
return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip)
|
|
68
|
+
|
|
69
|
+
@cute.jit
|
|
70
|
+
def div(self, dividend: Int32) -> Int32:
|
|
71
|
+
return (
|
|
72
|
+
Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
|
|
73
|
+
if self.divisor != 1
|
|
74
|
+
else dividend
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
|
|
78
|
+
quotient = self.div(dividend)
|
|
79
|
+
remainder = dividend - quotient * self.divisor
|
|
80
|
+
return quotient, remainder
|
|
81
|
+
|
|
82
|
+
def __extract_mlir_values__(self):
|
|
83
|
+
values, self._values_pos = [], []
|
|
84
|
+
for obj in [self.divisor, self.multiplier, self.shift_right]:
|
|
85
|
+
obj_values = cutlass.extract_mlir_values(obj)
|
|
86
|
+
values += obj_values
|
|
87
|
+
self._values_pos.append(len(obj_values))
|
|
88
|
+
return values
|
|
89
|
+
|
|
90
|
+
def __new_from_mlir_values__(self, values):
|
|
91
|
+
obj_list = []
|
|
92
|
+
for obj, n_items in zip(
|
|
93
|
+
[self.divisor, self.multiplier, self.shift_right], self._values_pos
|
|
94
|
+
):
|
|
95
|
+
obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
|
|
96
|
+
values = values[n_items:]
|
|
97
|
+
return FastDivmod(*(tuple(obj_list)), loc=self._loc)
|
quack/gemm_config.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright (C) 2025, Tri Dao.
|
|
2
|
+
import itertools
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GemmConfig(BaseModel, frozen=True):
|
|
8
|
+
tile_m: int = 256
|
|
9
|
+
tile_n: int = 128
|
|
10
|
+
cluster_m: int = 2
|
|
11
|
+
cluster_n: int = 1
|
|
12
|
+
swap_ab: bool = False
|
|
13
|
+
pingpong: bool = False
|
|
14
|
+
raster_order: int = 2
|
|
15
|
+
max_swizzle_size: int = 1
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_all_configs(
|
|
19
|
+
epilogue: Optional[str],
|
|
20
|
+
tune_pingpong=True,
|
|
21
|
+
tune_raster_order=True,
|
|
22
|
+
) -> list[GemmConfig]:
|
|
23
|
+
tile_n_vals = [128, 144, 160, 176, 192, 208]
|
|
24
|
+
tile_mn_vals = [(256, tile_n) for tile_n in tile_n_vals]
|
|
25
|
+
if epilogue in ["swiglu"]:
|
|
26
|
+
tile_mn_vals = [(m, n) for m, n in tile_mn_vals if n % 32 == 0]
|
|
27
|
+
cluster = [(1, 1), (1, 2), (2, 1)]
|
|
28
|
+
# cluster = [(1, 2), (2, 1)]
|
|
29
|
+
if epilogue in ["lse"]:
|
|
30
|
+
cluster = [(1, 2), (2, 1)]
|
|
31
|
+
swap_ab_vals = [False, True]
|
|
32
|
+
if epilogue in ["lse", "swiglu"]:
|
|
33
|
+
swap_ab_vals = [False]
|
|
34
|
+
pingpong_vals = [False, True] if tune_pingpong else [False]
|
|
35
|
+
raster_swizzle = (
|
|
36
|
+
[(0, 1)]
|
|
37
|
+
if not tune_raster_order
|
|
38
|
+
else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
|
|
39
|
+
)
|
|
40
|
+
return [
|
|
41
|
+
GemmConfig(
|
|
42
|
+
tile_m=tile_m if not pingpong else 128,
|
|
43
|
+
tile_n=tile_n,
|
|
44
|
+
cluster_m=cluster_m,
|
|
45
|
+
cluster_n=cluster_n,
|
|
46
|
+
swap_ab=swap_ab,
|
|
47
|
+
pingpong=pingpong,
|
|
48
|
+
raster_order=raster_order,
|
|
49
|
+
max_swizzle_size=max_swizzle_size,
|
|
50
|
+
)
|
|
51
|
+
for (tile_m, tile_n), (cluster_m, cluster_n), swap_ab, pingpong, (
|
|
52
|
+
raster_order,
|
|
53
|
+
max_swizzle_size,
|
|
54
|
+
) in itertools.product(
|
|
55
|
+
tile_mn_vals,
|
|
56
|
+
cluster,
|
|
57
|
+
swap_ab_vals,
|
|
58
|
+
pingpong_vals,
|
|
59
|
+
raster_swizzle,
|
|
60
|
+
)
|
|
61
|
+
]
|
quack/gemm_interface.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
|
|
8
|
+
from quack.gemm_config import GemmConfig, get_all_configs
|
|
9
|
+
|
|
10
|
+
from quack.autotuner import autotune, AutotuneConfig
|
|
11
|
+
from quack.lse import logsumexp
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def gemm_swiglu_out_ref(
|
|
15
|
+
A: Tensor, B: Tensor, out: Optional[Tensor], store_preact: bool
|
|
16
|
+
) -> (Tensor, Tensor):
|
|
17
|
+
preact = torch.mm(A, B)
|
|
18
|
+
out_ = F.silu(preact[..., ::2]) * preact[..., 1::2]
|
|
19
|
+
if out is not None:
|
|
20
|
+
out.copy_(out_)
|
|
21
|
+
else:
|
|
22
|
+
out = out_
|
|
23
|
+
if not store_preact:
|
|
24
|
+
preact = None
|
|
25
|
+
return out, preact
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@autotune(
|
|
29
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(epilogue=None)], key=["sm_carveout"]
|
|
30
|
+
)
|
|
31
|
+
def gemm_tuned(
|
|
32
|
+
A: Tensor,
|
|
33
|
+
B: Tensor,
|
|
34
|
+
sm_carveout: int = 0,
|
|
35
|
+
config: Optional[GemmConfig] = None,
|
|
36
|
+
) -> (Tensor, Optional[Tensor]):
|
|
37
|
+
if config is None:
|
|
38
|
+
config = GemmConfig(
|
|
39
|
+
tile_m=256,
|
|
40
|
+
tile_n=192,
|
|
41
|
+
cluster_m=2,
|
|
42
|
+
cluster_n=1,
|
|
43
|
+
pingpong=False,
|
|
44
|
+
raster_order=2,
|
|
45
|
+
max_swizzle_size=1,
|
|
46
|
+
)
|
|
47
|
+
out = torch.ops.quack.gemm_impl.default(
|
|
48
|
+
A if not config.swap_ab else B.T,
|
|
49
|
+
B if not config.swap_ab else A.T,
|
|
50
|
+
sm_carveout,
|
|
51
|
+
config.tile_m,
|
|
52
|
+
config.tile_n,
|
|
53
|
+
config.cluster_m,
|
|
54
|
+
config.cluster_n,
|
|
55
|
+
not config.swap_ab, # C_rowmajor
|
|
56
|
+
config.pingpong,
|
|
57
|
+
config.raster_order,
|
|
58
|
+
config.max_swizzle_size,
|
|
59
|
+
)
|
|
60
|
+
return out if not config.swap_ab else out.T
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@torch.library.custom_op("quack::gemm", mutates_args=(), device_types="cuda")
|
|
64
|
+
def gemm(A: Tensor, B: Tensor, sm_carveout: int = 0) -> Tensor:
|
|
65
|
+
return gemm_tuned(A, B, sm_carveout)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@torch.library.register_fake("quack::gemm")
|
|
69
|
+
def gemm_ref(A: Tensor, B: Tensor, sm_carveout: int = 0) -> Tensor:
|
|
70
|
+
return torch.mm(A, B)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("add")])
|
|
74
|
+
def gemm_add_tuned(
|
|
75
|
+
A: Tensor,
|
|
76
|
+
B: Tensor,
|
|
77
|
+
C: Tensor,
|
|
78
|
+
config: Optional[GemmConfig] = None,
|
|
79
|
+
) -> (Tensor, Optional[Tensor]):
|
|
80
|
+
if config is None:
|
|
81
|
+
config = GemmConfig(
|
|
82
|
+
tile_m=256,
|
|
83
|
+
tile_n=192,
|
|
84
|
+
cluster_m=2,
|
|
85
|
+
cluster_n=1,
|
|
86
|
+
pingpong=False,
|
|
87
|
+
raster_order=2,
|
|
88
|
+
max_swizzle_size=1,
|
|
89
|
+
)
|
|
90
|
+
out = torch.ops.quack.gemm_add_impl.default(
|
|
91
|
+
A if not config.swap_ab else B.T,
|
|
92
|
+
B if not config.swap_ab else A.T,
|
|
93
|
+
C if not config.swap_ab else C.T,
|
|
94
|
+
config.tile_m,
|
|
95
|
+
config.tile_n,
|
|
96
|
+
config.cluster_m,
|
|
97
|
+
config.cluster_n,
|
|
98
|
+
config.pingpong,
|
|
99
|
+
config.raster_order,
|
|
100
|
+
config.max_swizzle_size,
|
|
101
|
+
)
|
|
102
|
+
return out if not config.swap_ab else out.T
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@torch.library.custom_op("quack::gemm_add", mutates_args=(), device_types="cuda")
|
|
106
|
+
def gemm_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
107
|
+
return gemm_add_tuned(A, B, C)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@torch.library.register_fake("quack::gemm_add")
|
|
111
|
+
def gemm_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
112
|
+
return C + torch.mm(A, B)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@torch.library.custom_op("quack::gemm_add_t", mutates_args=(), device_types="cuda")
|
|
116
|
+
def gemm_t_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
117
|
+
return gemm_add_tuned(A, B.T, C)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@torch.library.register_fake("quack::gemm_add_t")
|
|
121
|
+
def gemm_t_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
122
|
+
return gemm_add_ref(A, B.T, C)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@autotune(
|
|
126
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs("swiglu")], key=["store_preact"]
|
|
127
|
+
)
|
|
128
|
+
def gemm_swiglu_tuned(
|
|
129
|
+
A: Tensor,
|
|
130
|
+
B: Tensor,
|
|
131
|
+
store_preact: bool = True,
|
|
132
|
+
config: Optional[GemmConfig] = None,
|
|
133
|
+
) -> (Tensor, Optional[Tensor]):
|
|
134
|
+
if config is None:
|
|
135
|
+
config = GemmConfig(
|
|
136
|
+
tile_m=256,
|
|
137
|
+
tile_n=192,
|
|
138
|
+
cluster_m=2,
|
|
139
|
+
cluster_n=1,
|
|
140
|
+
pingpong=False,
|
|
141
|
+
raster_order=2,
|
|
142
|
+
max_swizzle_size=1,
|
|
143
|
+
)
|
|
144
|
+
# out, preact
|
|
145
|
+
return torch.ops.quack.gemm_swiglu_impl.default(
|
|
146
|
+
A,
|
|
147
|
+
B,
|
|
148
|
+
store_preact,
|
|
149
|
+
config.tile_m,
|
|
150
|
+
config.tile_n,
|
|
151
|
+
config.cluster_m,
|
|
152
|
+
config.cluster_n,
|
|
153
|
+
config.pingpong,
|
|
154
|
+
config.raster_order,
|
|
155
|
+
config.max_swizzle_size,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# Specifying the schema manually here since torch.library._infer_schema doesn't work when return
|
|
160
|
+
# type is a tuple of Tensor
|
|
161
|
+
@torch.library.custom_op(
|
|
162
|
+
"quack::gemm_swiglu",
|
|
163
|
+
mutates_args=(),
|
|
164
|
+
device_types="cuda",
|
|
165
|
+
schema="(Tensor A, Tensor B, bool store_preact) -> (Tensor, Tensor)",
|
|
166
|
+
)
|
|
167
|
+
def gemm_swiglu(A: Tensor, B: Tensor, store_preact: bool = True) -> (Tensor, Tensor):
|
|
168
|
+
return gemm_swiglu_tuned(A, B, store_preact=store_preact)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@torch.library.register_fake("quack::gemm_swiglu")
|
|
172
|
+
def gemm_swiglu_ref(A: Tensor, B: Tensor, store_preact: bool) -> (Tensor, Tensor):
|
|
173
|
+
return gemm_swiglu_out_ref(A, B, None, store_preact)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# @torch.library.custom_op("quack::gemm_swiglu_t", mutates_args=(), device_types="cuda",
|
|
177
|
+
# schema="(Tensor A, Tensor B, bool store_preact) -> (Tensor, Tensor)")
|
|
178
|
+
# def gemm_swiglu_t(A: Tensor, B: Tensor, store_preact: bool = True) -> (Tensor, Tensor):
|
|
179
|
+
# return gemm_swiglu_tuned(A, B.T, store_preact=store_preact)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# @torch.library.register_fake("quack::gemm_swiglu_t")
|
|
183
|
+
# def gemm_swiglu_t_ref(A: Tensor, B: Tensor, store_preact: bool) -> (Tensor, Tensor):
|
|
184
|
+
# return gemm_swiglu_ref(A, B.T, store_preact)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@autotune(
|
|
188
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs("dswiglu")], key=["sm_carveout"]
|
|
189
|
+
)
|
|
190
|
+
def gemm_dswiglu_tuned(
|
|
191
|
+
A: Tensor,
|
|
192
|
+
B: Tensor,
|
|
193
|
+
preact: Tensor,
|
|
194
|
+
sm_carveout: int = 0,
|
|
195
|
+
config: Optional[GemmConfig] = None,
|
|
196
|
+
) -> (Tensor, Tensor):
|
|
197
|
+
if config is None:
|
|
198
|
+
config = GemmConfig(
|
|
199
|
+
tile_m=128,
|
|
200
|
+
tile_n=192,
|
|
201
|
+
cluster_m=2,
|
|
202
|
+
cluster_n=1,
|
|
203
|
+
pingpong=True,
|
|
204
|
+
raster_order=2,
|
|
205
|
+
max_swizzle_size=1,
|
|
206
|
+
)
|
|
207
|
+
out, postact = torch.ops.quack.gemm_dswiglu_impl.default(
|
|
208
|
+
A if not config.swap_ab else B.T,
|
|
209
|
+
B if not config.swap_ab else A.T,
|
|
210
|
+
preact if not config.swap_ab else preact.T,
|
|
211
|
+
sm_carveout,
|
|
212
|
+
config.tile_m,
|
|
213
|
+
config.tile_n,
|
|
214
|
+
config.cluster_m,
|
|
215
|
+
config.cluster_n,
|
|
216
|
+
not config.swap_ab, # C_rowmajor
|
|
217
|
+
config.pingpong,
|
|
218
|
+
config.raster_order,
|
|
219
|
+
config.max_swizzle_size,
|
|
220
|
+
)
|
|
221
|
+
return (out, postact) if not config.swap_ab else (out.T, postact.T)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
# Specifying the schema manually here since torch.library._infer_schema doesn't work when return
|
|
225
|
+
# type is a tuple of Tensor
|
|
226
|
+
@torch.library.custom_op(
|
|
227
|
+
"quack::gemm_dswiglu",
|
|
228
|
+
mutates_args=(),
|
|
229
|
+
device_types="cuda",
|
|
230
|
+
schema="(Tensor A, Tensor B, Tensor preact, int sm_carveout=0) -> (Tensor, Tensor)",
|
|
231
|
+
)
|
|
232
|
+
def gemm_dswiglu(A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0) -> (Tensor, Tensor):
|
|
233
|
+
return gemm_dswiglu_tuned(A, B, preact, sm_carveout)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@torch.library.register_fake("quack::gemm_dswiglu")
|
|
237
|
+
def gemm_dswiglu_ref(
|
|
238
|
+
A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0
|
|
239
|
+
) -> (Tensor, Tensor):
|
|
240
|
+
# A: (M, K), B: (K, N), preact: (M, 2 * N)
|
|
241
|
+
dout = torch.mm(A, B)
|
|
242
|
+
p0, p1 = preact[..., ::2], preact[..., 1::2]
|
|
243
|
+
sigmoid = torch.sigmoid(p0)
|
|
244
|
+
silu = F.silu(p0)
|
|
245
|
+
postact = silu * p1
|
|
246
|
+
d0 = sigmoid * (1 + p0 * (1 - sigmoid)) * p1 * dout
|
|
247
|
+
d1 = F.silu(p0) * dout
|
|
248
|
+
out = torch.stack([d0, d1], dim=-1).reshape(d0.shape[:-1] + (2 * d0.shape[-1],))
|
|
249
|
+
return out, postact
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("lse")])
|
|
253
|
+
def gemm_lse_tuned(
|
|
254
|
+
A: Tensor,
|
|
255
|
+
B: Tensor,
|
|
256
|
+
softcap: float = 0.0,
|
|
257
|
+
config: Optional[GemmConfig] = None,
|
|
258
|
+
) -> (Tensor, Tensor):
|
|
259
|
+
if config is None:
|
|
260
|
+
config = GemmConfig(
|
|
261
|
+
tile_m=256,
|
|
262
|
+
tile_n=192,
|
|
263
|
+
cluster_m=2,
|
|
264
|
+
cluster_n=1,
|
|
265
|
+
pingpong=False,
|
|
266
|
+
raster_order=2,
|
|
267
|
+
max_swizzle_size=1,
|
|
268
|
+
)
|
|
269
|
+
out, lse_partial = torch.ops.quack.gemm_lse_impl.default(
|
|
270
|
+
A,
|
|
271
|
+
B,
|
|
272
|
+
None, # bias
|
|
273
|
+
softcap,
|
|
274
|
+
config.tile_m,
|
|
275
|
+
config.tile_n,
|
|
276
|
+
config.cluster_m,
|
|
277
|
+
config.cluster_n,
|
|
278
|
+
config.pingpong,
|
|
279
|
+
config.raster_order,
|
|
280
|
+
config.max_swizzle_size,
|
|
281
|
+
)
|
|
282
|
+
lse = logsumexp(lse_partial)
|
|
283
|
+
return out, lse
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
@torch.library.custom_op(
|
|
287
|
+
"quack::gemm_lse",
|
|
288
|
+
mutates_args=(),
|
|
289
|
+
device_types="cuda",
|
|
290
|
+
schema="(Tensor A, Tensor B, float softcap=0.0) -> (Tensor, Tensor)",
|
|
291
|
+
)
|
|
292
|
+
def gemm_lse(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
|
|
293
|
+
return gemm_lse_tuned(A, B, softcap)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@torch.library.register_fake("quack::gemm_lse")
|
|
297
|
+
def gemm_lse_ref(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
|
|
298
|
+
# A: (M, K), B: (K, N)
|
|
299
|
+
out = torch.mm(A, B)
|
|
300
|
+
if softcap > 0:
|
|
301
|
+
out_fp32 = torch.tanh(out.to(torch.float32) / softcap) * softcap
|
|
302
|
+
out = out_fp32.to(out.dtype)
|
|
303
|
+
else:
|
|
304
|
+
out_fp32 = out.to(torch.float32)
|
|
305
|
+
lse = torch.logsumexp(out_fp32, dim=-1)
|
|
306
|
+
return out, lse
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@torch.library.custom_op(
|
|
310
|
+
"quack::gemm_lse_t",
|
|
311
|
+
mutates_args=(),
|
|
312
|
+
device_types="cuda",
|
|
313
|
+
schema="(Tensor A, Tensor B, float softcap=0.0) -> (Tensor, Tensor)",
|
|
314
|
+
)
|
|
315
|
+
def gemm_lse_t(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
|
|
316
|
+
return gemm_lse_tuned(A, B.T, softcap)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@torch.library.register_fake("quack::gemm_lse_t")
|
|
320
|
+
def gemm_lse_t_ref(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
|
|
321
|
+
return gemm_lse_ref(A, B.T, softcap)
|
quack/linear.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
from torch.amp import custom_fwd, custom_bwd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
|
|
10
|
+
# from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
|
|
11
|
+
|
|
12
|
+
from quack import gemm, gemm_lse # TODO: implement these
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def linear_fwd_convert_type(*tensors):
|
|
16
|
+
autocast_dtype = torch.get_autocast_dtype("cuda")
|
|
17
|
+
if torch.is_autocast_enabled():
|
|
18
|
+
tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
|
|
19
|
+
return tensors
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_input_grad, needs_weight_grad):
|
|
23
|
+
if not needs_input_grad:
|
|
24
|
+
weight, weight_og = None, None
|
|
25
|
+
if not needs_weight_grad:
|
|
26
|
+
x = None
|
|
27
|
+
ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=True, sm_carveout=0):
|
|
31
|
+
if ctx.needs_input_grad[0]:
|
|
32
|
+
assert weight is not None
|
|
33
|
+
# return gemm(dout, weight) if use_tuned_gemm else (dout @ weight)
|
|
34
|
+
return (
|
|
35
|
+
gemm(dout, weight, sm_carveout=sm_carveout)
|
|
36
|
+
if use_tuned_gemm
|
|
37
|
+
else gemm_cb(dout, weight, sm_carveout=sm_carveout)
|
|
38
|
+
)
|
|
39
|
+
else:
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, sm_carveout=0):
|
|
44
|
+
if ctx.needs_input_grad[1]:
|
|
45
|
+
assert x is not None
|
|
46
|
+
x = x.reshape(-1, x.shape[-1])
|
|
47
|
+
# fuse_grad_accum is not compatible with torch.compile
|
|
48
|
+
if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
|
|
49
|
+
dweight = gemm_cb(dout.T, x, out_dtype=ctx.weight_dtype, sm_carveout=sm_carveout)
|
|
50
|
+
else:
|
|
51
|
+
# print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
|
|
52
|
+
gemm_add_cb_(dout.T, x, weight_og.grad, sm_carveout=sm_carveout)
|
|
53
|
+
dweight = weight_og.grad
|
|
54
|
+
weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
|
|
55
|
+
else:
|
|
56
|
+
dweight = None
|
|
57
|
+
return dweight
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LinearFunc(torch.autograd.Function):
|
|
61
|
+
@staticmethod
|
|
62
|
+
@custom_fwd(device_type="cuda")
|
|
63
|
+
def forward(ctx, x, weight, fuse_grad_accum=False):
|
|
64
|
+
"""
|
|
65
|
+
x: (..., in_features)
|
|
66
|
+
weight: (out_features, in_features)
|
|
67
|
+
out: (..., out_features)
|
|
68
|
+
"""
|
|
69
|
+
ctx.weight_dtype = weight.dtype
|
|
70
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
71
|
+
weight_og = weight
|
|
72
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
73
|
+
batch_shape = x.shape[:-1]
|
|
74
|
+
x = x.reshape(-1, x.shape[-1])
|
|
75
|
+
# out = F.linear(x, weight)
|
|
76
|
+
out = gemm(x, weight.T)
|
|
77
|
+
linear_fwd_postprocess(
|
|
78
|
+
ctx,
|
|
79
|
+
x,
|
|
80
|
+
weight,
|
|
81
|
+
weight_og,
|
|
82
|
+
needs_input_grad=ctx.needs_input_grad[0],
|
|
83
|
+
needs_weight_grad=ctx.needs_input_grad[1],
|
|
84
|
+
)
|
|
85
|
+
return out.reshape(*batch_shape, out.shape[-1])
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
@custom_bwd(device_type="cuda")
|
|
89
|
+
def backward(ctx, dout):
|
|
90
|
+
"""
|
|
91
|
+
dout: (..., out_features)
|
|
92
|
+
"""
|
|
93
|
+
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
|
94
|
+
batch_shape = dout.shape[:-1]
|
|
95
|
+
dout = dout.reshape(-1, dout.shape[-1])
|
|
96
|
+
dx = linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=True)
|
|
97
|
+
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
|
98
|
+
dweight = linear_bwd_compute_weight_grad(ctx, dout, x, weight_og)
|
|
99
|
+
return dx, dweight, None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def linear_func(x, weight, fuse_grad_accum=False):
|
|
103
|
+
return LinearFunc.apply(x, weight, fuse_grad_accum)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class LinearLSEFunc(torch.autograd.Function):
|
|
107
|
+
@staticmethod
|
|
108
|
+
@custom_fwd(device_type="cuda")
|
|
109
|
+
def forward(ctx, x, weight, fuse_grad_accum=False):
|
|
110
|
+
"""
|
|
111
|
+
x: (..., in_features)
|
|
112
|
+
weight: (out_features, in_features)
|
|
113
|
+
out: (..., out_features)
|
|
114
|
+
"""
|
|
115
|
+
needs_weight_grad = weight.requires_grad
|
|
116
|
+
needs_input_grad = x.requires_grad
|
|
117
|
+
ctx.weight_dtype = weight.dtype
|
|
118
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
119
|
+
weight_og = weight
|
|
120
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
121
|
+
batch_shape = x.shape[:-1]
|
|
122
|
+
x = x.reshape(-1, x.shape[-1])
|
|
123
|
+
out, lse = gemm_lse(x, weight.T)
|
|
124
|
+
lse = lse.reshape(*batch_shape)
|
|
125
|
+
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_weight_grad, needs_input_grad)
|
|
126
|
+
ctx.mark_non_differentiable(lse)
|
|
127
|
+
return out.reshape(*batch_shape, out.shape[-1]), lse
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
@custom_bwd(device_type="cuda")
|
|
131
|
+
def backward(ctx, dout, dlse_ignored):
|
|
132
|
+
"""
|
|
133
|
+
dout: (..., out_features)
|
|
134
|
+
"""
|
|
135
|
+
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
|
136
|
+
batch_shape = dout.shape[:-1]
|
|
137
|
+
dout = dout.reshape(-1, dout.shape[-1])
|
|
138
|
+
# cuBLAS seems faster for this so we just use it instead of cutlass gemm
|
|
139
|
+
dx = linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=False)
|
|
140
|
+
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
|
141
|
+
dweight = linear_bwd_compute_weight_grad(ctx, dout, x, weight_og)
|
|
142
|
+
return dx, dweight, None
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def linear_lse_func(x, weight, fuse_grad_accum=False):
|
|
146
|
+
return LinearLSEFunc.apply(x, weight, fuse_grad_accum)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class Linear(nn.Linear):
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
in_features: int,
|
|
153
|
+
out_features: int,
|
|
154
|
+
bias: bool = False,
|
|
155
|
+
device=None,
|
|
156
|
+
dtype=None,
|
|
157
|
+
fuse_grad_accum: bool = False,
|
|
158
|
+
) -> None:
|
|
159
|
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
160
|
+
self.fuse_grad_accum = fuse_grad_accum
|
|
161
|
+
|
|
162
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
163
|
+
if self.bias is None and input.is_cuda:
|
|
164
|
+
return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
165
|
+
else:
|
|
166
|
+
return F.linear(input, self.weight, self.bias)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class LinearLSE(Linear):
|
|
170
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
171
|
+
if self.bias is None and input.is_cuda:
|
|
172
|
+
return linear_lse_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
173
|
+
else:
|
|
174
|
+
out = F.linear(input, self.weight, self.bias)
|
|
175
|
+
lse = torch.logsumexp(out, dim=-1)
|
|
176
|
+
return out, lse
|