quack-kernels 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/fast_math.py ADDED
@@ -0,0 +1,97 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ from typing import Tuple
4
+
5
+ import cutlass
6
+ import cutlass.cute as cute
7
+ from cutlass import Int32, Uint32
8
+ from cutlass.cutlass_dsl import T, dsl_user_op
9
+ from cutlass._mlir.dialects import llvm
10
+
11
+
12
+ @cute.jit
13
+ def clz(x: Int32) -> Int32:
14
+ # for i in cutlass.range_constexpr(32):
15
+ # if (1 << (31 - i)) & x:
16
+ # return Int32(i)
17
+ # return Int32(32)
18
+ # Early exit is not supported yet
19
+ res = Int32(32)
20
+ done = False
21
+ for i in cutlass.range(32):
22
+ if ((1 << (31 - i)) & x) and not done:
23
+ res = Int32(i)
24
+ done = True
25
+ return res
26
+
27
+
28
+ def find_log2(x: Int32) -> Int32:
29
+ a: Int32 = Int32(31 - clz(x))
30
+ return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2.
31
+
32
+
33
+ @dsl_user_op
34
+ def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
35
+ return Uint32(
36
+ llvm.inline_asm(
37
+ T.i32(),
38
+ [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
39
+ "mul.hi.u32 $0, $1, $2;",
40
+ "=r,r,r",
41
+ has_side_effects=False,
42
+ is_align_stack=False,
43
+ asm_dialect=llvm.AsmDialect.AD_ATT,
44
+ )
45
+ )
46
+
47
+
48
+ class FastDivmod:
49
+ def __init__(
50
+ self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None
51
+ ):
52
+ self.divisor = divisor
53
+ self.multiplier = multipler
54
+ self.shift_right = shift_right
55
+ self._loc = loc
56
+
57
+ # called by host
58
+ @staticmethod
59
+ def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod":
60
+ """Construct the FastDivmod object, in host code.
61
+ This precomputes some values based on the divisor and is computationally expensive.
62
+ """
63
+ p = Uint32(31 + find_log2(divisor))
64
+ divisor_u32 = Uint32(divisor)
65
+ multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
66
+ shift_right = Uint32(p - 32)
67
+ return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip)
68
+
69
+ @cute.jit
70
+ def div(self, dividend: Int32) -> Int32:
71
+ return (
72
+ Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
73
+ if self.divisor != 1
74
+ else dividend
75
+ )
76
+
77
+ def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
78
+ quotient = self.div(dividend)
79
+ remainder = dividend - quotient * self.divisor
80
+ return quotient, remainder
81
+
82
+ def __extract_mlir_values__(self):
83
+ values, self._values_pos = [], []
84
+ for obj in [self.divisor, self.multiplier, self.shift_right]:
85
+ obj_values = cutlass.extract_mlir_values(obj)
86
+ values += obj_values
87
+ self._values_pos.append(len(obj_values))
88
+ return values
89
+
90
+ def __new_from_mlir_values__(self, values):
91
+ obj_list = []
92
+ for obj, n_items in zip(
93
+ [self.divisor, self.multiplier, self.shift_right], self._values_pos
94
+ ):
95
+ obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
96
+ values = values[n_items:]
97
+ return FastDivmod(*(tuple(obj_list)), loc=self._loc)
quack/gemm_config.py ADDED
@@ -0,0 +1,61 @@
1
+ # Copyright (C) 2025, Tri Dao.
2
+ import itertools
3
+ from typing import Optional
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class GemmConfig(BaseModel, frozen=True):
8
+ tile_m: int = 256
9
+ tile_n: int = 128
10
+ cluster_m: int = 2
11
+ cluster_n: int = 1
12
+ swap_ab: bool = False
13
+ pingpong: bool = False
14
+ raster_order: int = 2
15
+ max_swizzle_size: int = 1
16
+
17
+
18
+ def get_all_configs(
19
+ epilogue: Optional[str],
20
+ tune_pingpong=True,
21
+ tune_raster_order=True,
22
+ ) -> list[GemmConfig]:
23
+ tile_n_vals = [128, 144, 160, 176, 192, 208]
24
+ tile_mn_vals = [(256, tile_n) for tile_n in tile_n_vals]
25
+ if epilogue in ["swiglu"]:
26
+ tile_mn_vals = [(m, n) for m, n in tile_mn_vals if n % 32 == 0]
27
+ cluster = [(1, 1), (1, 2), (2, 1)]
28
+ # cluster = [(1, 2), (2, 1)]
29
+ if epilogue in ["lse"]:
30
+ cluster = [(1, 2), (2, 1)]
31
+ swap_ab_vals = [False, True]
32
+ if epilogue in ["lse", "swiglu"]:
33
+ swap_ab_vals = [False]
34
+ pingpong_vals = [False, True] if tune_pingpong else [False]
35
+ raster_swizzle = (
36
+ [(0, 1)]
37
+ if not tune_raster_order
38
+ else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)]
39
+ )
40
+ return [
41
+ GemmConfig(
42
+ tile_m=tile_m if not pingpong else 128,
43
+ tile_n=tile_n,
44
+ cluster_m=cluster_m,
45
+ cluster_n=cluster_n,
46
+ swap_ab=swap_ab,
47
+ pingpong=pingpong,
48
+ raster_order=raster_order,
49
+ max_swizzle_size=max_swizzle_size,
50
+ )
51
+ for (tile_m, tile_n), (cluster_m, cluster_n), swap_ab, pingpong, (
52
+ raster_order,
53
+ max_swizzle_size,
54
+ ) in itertools.product(
55
+ tile_mn_vals,
56
+ cluster,
57
+ swap_ab_vals,
58
+ pingpong_vals,
59
+ raster_swizzle,
60
+ )
61
+ ]
@@ -0,0 +1,321 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+ from quack.gemm_config import GemmConfig, get_all_configs
9
+
10
+ from quack.autotuner import autotune, AutotuneConfig
11
+ from quack.lse import logsumexp
12
+
13
+
14
+ def gemm_swiglu_out_ref(
15
+ A: Tensor, B: Tensor, out: Optional[Tensor], store_preact: bool
16
+ ) -> (Tensor, Tensor):
17
+ preact = torch.mm(A, B)
18
+ out_ = F.silu(preact[..., ::2]) * preact[..., 1::2]
19
+ if out is not None:
20
+ out.copy_(out_)
21
+ else:
22
+ out = out_
23
+ if not store_preact:
24
+ preact = None
25
+ return out, preact
26
+
27
+
28
+ @autotune(
29
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(epilogue=None)], key=["sm_carveout"]
30
+ )
31
+ def gemm_tuned(
32
+ A: Tensor,
33
+ B: Tensor,
34
+ sm_carveout: int = 0,
35
+ config: Optional[GemmConfig] = None,
36
+ ) -> (Tensor, Optional[Tensor]):
37
+ if config is None:
38
+ config = GemmConfig(
39
+ tile_m=256,
40
+ tile_n=192,
41
+ cluster_m=2,
42
+ cluster_n=1,
43
+ pingpong=False,
44
+ raster_order=2,
45
+ max_swizzle_size=1,
46
+ )
47
+ out = torch.ops.quack.gemm_impl.default(
48
+ A if not config.swap_ab else B.T,
49
+ B if not config.swap_ab else A.T,
50
+ sm_carveout,
51
+ config.tile_m,
52
+ config.tile_n,
53
+ config.cluster_m,
54
+ config.cluster_n,
55
+ not config.swap_ab, # C_rowmajor
56
+ config.pingpong,
57
+ config.raster_order,
58
+ config.max_swizzle_size,
59
+ )
60
+ return out if not config.swap_ab else out.T
61
+
62
+
63
+ @torch.library.custom_op("quack::gemm", mutates_args=(), device_types="cuda")
64
+ def gemm(A: Tensor, B: Tensor, sm_carveout: int = 0) -> Tensor:
65
+ return gemm_tuned(A, B, sm_carveout)
66
+
67
+
68
+ @torch.library.register_fake("quack::gemm")
69
+ def gemm_ref(A: Tensor, B: Tensor, sm_carveout: int = 0) -> Tensor:
70
+ return torch.mm(A, B)
71
+
72
+
73
+ @autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("add")])
74
+ def gemm_add_tuned(
75
+ A: Tensor,
76
+ B: Tensor,
77
+ C: Tensor,
78
+ config: Optional[GemmConfig] = None,
79
+ ) -> (Tensor, Optional[Tensor]):
80
+ if config is None:
81
+ config = GemmConfig(
82
+ tile_m=256,
83
+ tile_n=192,
84
+ cluster_m=2,
85
+ cluster_n=1,
86
+ pingpong=False,
87
+ raster_order=2,
88
+ max_swizzle_size=1,
89
+ )
90
+ out = torch.ops.quack.gemm_add_impl.default(
91
+ A if not config.swap_ab else B.T,
92
+ B if not config.swap_ab else A.T,
93
+ C if not config.swap_ab else C.T,
94
+ config.tile_m,
95
+ config.tile_n,
96
+ config.cluster_m,
97
+ config.cluster_n,
98
+ config.pingpong,
99
+ config.raster_order,
100
+ config.max_swizzle_size,
101
+ )
102
+ return out if not config.swap_ab else out.T
103
+
104
+
105
+ @torch.library.custom_op("quack::gemm_add", mutates_args=(), device_types="cuda")
106
+ def gemm_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
107
+ return gemm_add_tuned(A, B, C)
108
+
109
+
110
+ @torch.library.register_fake("quack::gemm_add")
111
+ def gemm_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
112
+ return C + torch.mm(A, B)
113
+
114
+
115
+ @torch.library.custom_op("quack::gemm_add_t", mutates_args=(), device_types="cuda")
116
+ def gemm_t_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
117
+ return gemm_add_tuned(A, B.T, C)
118
+
119
+
120
+ @torch.library.register_fake("quack::gemm_add_t")
121
+ def gemm_t_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
122
+ return gemm_add_ref(A, B.T, C)
123
+
124
+
125
+ @autotune(
126
+ configs=[AutotuneConfig(config=c) for c in get_all_configs("swiglu")], key=["store_preact"]
127
+ )
128
+ def gemm_swiglu_tuned(
129
+ A: Tensor,
130
+ B: Tensor,
131
+ store_preact: bool = True,
132
+ config: Optional[GemmConfig] = None,
133
+ ) -> (Tensor, Optional[Tensor]):
134
+ if config is None:
135
+ config = GemmConfig(
136
+ tile_m=256,
137
+ tile_n=192,
138
+ cluster_m=2,
139
+ cluster_n=1,
140
+ pingpong=False,
141
+ raster_order=2,
142
+ max_swizzle_size=1,
143
+ )
144
+ # out, preact
145
+ return torch.ops.quack.gemm_swiglu_impl.default(
146
+ A,
147
+ B,
148
+ store_preact,
149
+ config.tile_m,
150
+ config.tile_n,
151
+ config.cluster_m,
152
+ config.cluster_n,
153
+ config.pingpong,
154
+ config.raster_order,
155
+ config.max_swizzle_size,
156
+ )
157
+
158
+
159
+ # Specifying the schema manually here since torch.library._infer_schema doesn't work when return
160
+ # type is a tuple of Tensor
161
+ @torch.library.custom_op(
162
+ "quack::gemm_swiglu",
163
+ mutates_args=(),
164
+ device_types="cuda",
165
+ schema="(Tensor A, Tensor B, bool store_preact) -> (Tensor, Tensor)",
166
+ )
167
+ def gemm_swiglu(A: Tensor, B: Tensor, store_preact: bool = True) -> (Tensor, Tensor):
168
+ return gemm_swiglu_tuned(A, B, store_preact=store_preact)
169
+
170
+
171
+ @torch.library.register_fake("quack::gemm_swiglu")
172
+ def gemm_swiglu_ref(A: Tensor, B: Tensor, store_preact: bool) -> (Tensor, Tensor):
173
+ return gemm_swiglu_out_ref(A, B, None, store_preact)
174
+
175
+
176
+ # @torch.library.custom_op("quack::gemm_swiglu_t", mutates_args=(), device_types="cuda",
177
+ # schema="(Tensor A, Tensor B, bool store_preact) -> (Tensor, Tensor)")
178
+ # def gemm_swiglu_t(A: Tensor, B: Tensor, store_preact: bool = True) -> (Tensor, Tensor):
179
+ # return gemm_swiglu_tuned(A, B.T, store_preact=store_preact)
180
+
181
+
182
+ # @torch.library.register_fake("quack::gemm_swiglu_t")
183
+ # def gemm_swiglu_t_ref(A: Tensor, B: Tensor, store_preact: bool) -> (Tensor, Tensor):
184
+ # return gemm_swiglu_ref(A, B.T, store_preact)
185
+
186
+
187
+ @autotune(
188
+ configs=[AutotuneConfig(config=c) for c in get_all_configs("dswiglu")], key=["sm_carveout"]
189
+ )
190
+ def gemm_dswiglu_tuned(
191
+ A: Tensor,
192
+ B: Tensor,
193
+ preact: Tensor,
194
+ sm_carveout: int = 0,
195
+ config: Optional[GemmConfig] = None,
196
+ ) -> (Tensor, Tensor):
197
+ if config is None:
198
+ config = GemmConfig(
199
+ tile_m=128,
200
+ tile_n=192,
201
+ cluster_m=2,
202
+ cluster_n=1,
203
+ pingpong=True,
204
+ raster_order=2,
205
+ max_swizzle_size=1,
206
+ )
207
+ out, postact = torch.ops.quack.gemm_dswiglu_impl.default(
208
+ A if not config.swap_ab else B.T,
209
+ B if not config.swap_ab else A.T,
210
+ preact if not config.swap_ab else preact.T,
211
+ sm_carveout,
212
+ config.tile_m,
213
+ config.tile_n,
214
+ config.cluster_m,
215
+ config.cluster_n,
216
+ not config.swap_ab, # C_rowmajor
217
+ config.pingpong,
218
+ config.raster_order,
219
+ config.max_swizzle_size,
220
+ )
221
+ return (out, postact) if not config.swap_ab else (out.T, postact.T)
222
+
223
+
224
+ # Specifying the schema manually here since torch.library._infer_schema doesn't work when return
225
+ # type is a tuple of Tensor
226
+ @torch.library.custom_op(
227
+ "quack::gemm_dswiglu",
228
+ mutates_args=(),
229
+ device_types="cuda",
230
+ schema="(Tensor A, Tensor B, Tensor preact, int sm_carveout=0) -> (Tensor, Tensor)",
231
+ )
232
+ def gemm_dswiglu(A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0) -> (Tensor, Tensor):
233
+ return gemm_dswiglu_tuned(A, B, preact, sm_carveout)
234
+
235
+
236
+ @torch.library.register_fake("quack::gemm_dswiglu")
237
+ def gemm_dswiglu_ref(
238
+ A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0
239
+ ) -> (Tensor, Tensor):
240
+ # A: (M, K), B: (K, N), preact: (M, 2 * N)
241
+ dout = torch.mm(A, B)
242
+ p0, p1 = preact[..., ::2], preact[..., 1::2]
243
+ sigmoid = torch.sigmoid(p0)
244
+ silu = F.silu(p0)
245
+ postact = silu * p1
246
+ d0 = sigmoid * (1 + p0 * (1 - sigmoid)) * p1 * dout
247
+ d1 = F.silu(p0) * dout
248
+ out = torch.stack([d0, d1], dim=-1).reshape(d0.shape[:-1] + (2 * d0.shape[-1],))
249
+ return out, postact
250
+
251
+
252
+ @autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("lse")])
253
+ def gemm_lse_tuned(
254
+ A: Tensor,
255
+ B: Tensor,
256
+ softcap: float = 0.0,
257
+ config: Optional[GemmConfig] = None,
258
+ ) -> (Tensor, Tensor):
259
+ if config is None:
260
+ config = GemmConfig(
261
+ tile_m=256,
262
+ tile_n=192,
263
+ cluster_m=2,
264
+ cluster_n=1,
265
+ pingpong=False,
266
+ raster_order=2,
267
+ max_swizzle_size=1,
268
+ )
269
+ out, lse_partial = torch.ops.quack.gemm_lse_impl.default(
270
+ A,
271
+ B,
272
+ None, # bias
273
+ softcap,
274
+ config.tile_m,
275
+ config.tile_n,
276
+ config.cluster_m,
277
+ config.cluster_n,
278
+ config.pingpong,
279
+ config.raster_order,
280
+ config.max_swizzle_size,
281
+ )
282
+ lse = logsumexp(lse_partial)
283
+ return out, lse
284
+
285
+
286
+ @torch.library.custom_op(
287
+ "quack::gemm_lse",
288
+ mutates_args=(),
289
+ device_types="cuda",
290
+ schema="(Tensor A, Tensor B, float softcap=0.0) -> (Tensor, Tensor)",
291
+ )
292
+ def gemm_lse(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
293
+ return gemm_lse_tuned(A, B, softcap)
294
+
295
+
296
+ @torch.library.register_fake("quack::gemm_lse")
297
+ def gemm_lse_ref(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
298
+ # A: (M, K), B: (K, N)
299
+ out = torch.mm(A, B)
300
+ if softcap > 0:
301
+ out_fp32 = torch.tanh(out.to(torch.float32) / softcap) * softcap
302
+ out = out_fp32.to(out.dtype)
303
+ else:
304
+ out_fp32 = out.to(torch.float32)
305
+ lse = torch.logsumexp(out_fp32, dim=-1)
306
+ return out, lse
307
+
308
+
309
+ @torch.library.custom_op(
310
+ "quack::gemm_lse_t",
311
+ mutates_args=(),
312
+ device_types="cuda",
313
+ schema="(Tensor A, Tensor B, float softcap=0.0) -> (Tensor, Tensor)",
314
+ )
315
+ def gemm_lse_t(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
316
+ return gemm_lse_tuned(A, B.T, softcap)
317
+
318
+
319
+ @torch.library.register_fake("quack::gemm_lse_t")
320
+ def gemm_lse_t_ref(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
321
+ return gemm_lse_ref(A, B.T, softcap)
quack/linear.py ADDED
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.amp import custom_fwd, custom_bwd
7
+
8
+
9
+ from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
10
+ # from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
11
+
12
+ from quack import gemm, gemm_lse # TODO: implement these
13
+
14
+
15
+ def linear_fwd_convert_type(*tensors):
16
+ autocast_dtype = torch.get_autocast_dtype("cuda")
17
+ if torch.is_autocast_enabled():
18
+ tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
19
+ return tensors
20
+
21
+
22
+ def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_input_grad, needs_weight_grad):
23
+ if not needs_input_grad:
24
+ weight, weight_og = None, None
25
+ if not needs_weight_grad:
26
+ x = None
27
+ ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
28
+
29
+
30
+ def linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=True, sm_carveout=0):
31
+ if ctx.needs_input_grad[0]:
32
+ assert weight is not None
33
+ # return gemm(dout, weight) if use_tuned_gemm else (dout @ weight)
34
+ return (
35
+ gemm(dout, weight, sm_carveout=sm_carveout)
36
+ if use_tuned_gemm
37
+ else gemm_cb(dout, weight, sm_carveout=sm_carveout)
38
+ )
39
+ else:
40
+ return None
41
+
42
+
43
+ def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, sm_carveout=0):
44
+ if ctx.needs_input_grad[1]:
45
+ assert x is not None
46
+ x = x.reshape(-1, x.shape[-1])
47
+ # fuse_grad_accum is not compatible with torch.compile
48
+ if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
49
+ dweight = gemm_cb(dout.T, x, out_dtype=ctx.weight_dtype, sm_carveout=sm_carveout)
50
+ else:
51
+ # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
52
+ gemm_add_cb_(dout.T, x, weight_og.grad, sm_carveout=sm_carveout)
53
+ dweight = weight_og.grad
54
+ weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
55
+ else:
56
+ dweight = None
57
+ return dweight
58
+
59
+
60
+ class LinearFunc(torch.autograd.Function):
61
+ @staticmethod
62
+ @custom_fwd(device_type="cuda")
63
+ def forward(ctx, x, weight, fuse_grad_accum=False):
64
+ """
65
+ x: (..., in_features)
66
+ weight: (out_features, in_features)
67
+ out: (..., out_features)
68
+ """
69
+ ctx.weight_dtype = weight.dtype
70
+ ctx.fuse_grad_accum = fuse_grad_accum
71
+ weight_og = weight
72
+ x, weight = linear_fwd_convert_type(x, weight)
73
+ batch_shape = x.shape[:-1]
74
+ x = x.reshape(-1, x.shape[-1])
75
+ # out = F.linear(x, weight)
76
+ out = gemm(x, weight.T)
77
+ linear_fwd_postprocess(
78
+ ctx,
79
+ x,
80
+ weight,
81
+ weight_og,
82
+ needs_input_grad=ctx.needs_input_grad[0],
83
+ needs_weight_grad=ctx.needs_input_grad[1],
84
+ )
85
+ return out.reshape(*batch_shape, out.shape[-1])
86
+
87
+ @staticmethod
88
+ @custom_bwd(device_type="cuda")
89
+ def backward(ctx, dout):
90
+ """
91
+ dout: (..., out_features)
92
+ """
93
+ x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
94
+ batch_shape = dout.shape[:-1]
95
+ dout = dout.reshape(-1, dout.shape[-1])
96
+ dx = linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=True)
97
+ dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
98
+ dweight = linear_bwd_compute_weight_grad(ctx, dout, x, weight_og)
99
+ return dx, dweight, None
100
+
101
+
102
+ def linear_func(x, weight, fuse_grad_accum=False):
103
+ return LinearFunc.apply(x, weight, fuse_grad_accum)
104
+
105
+
106
+ class LinearLSEFunc(torch.autograd.Function):
107
+ @staticmethod
108
+ @custom_fwd(device_type="cuda")
109
+ def forward(ctx, x, weight, fuse_grad_accum=False):
110
+ """
111
+ x: (..., in_features)
112
+ weight: (out_features, in_features)
113
+ out: (..., out_features)
114
+ """
115
+ needs_weight_grad = weight.requires_grad
116
+ needs_input_grad = x.requires_grad
117
+ ctx.weight_dtype = weight.dtype
118
+ ctx.fuse_grad_accum = fuse_grad_accum
119
+ weight_og = weight
120
+ x, weight = linear_fwd_convert_type(x, weight)
121
+ batch_shape = x.shape[:-1]
122
+ x = x.reshape(-1, x.shape[-1])
123
+ out, lse = gemm_lse(x, weight.T)
124
+ lse = lse.reshape(*batch_shape)
125
+ linear_fwd_postprocess(ctx, x, weight, weight_og, needs_weight_grad, needs_input_grad)
126
+ ctx.mark_non_differentiable(lse)
127
+ return out.reshape(*batch_shape, out.shape[-1]), lse
128
+
129
+ @staticmethod
130
+ @custom_bwd(device_type="cuda")
131
+ def backward(ctx, dout, dlse_ignored):
132
+ """
133
+ dout: (..., out_features)
134
+ """
135
+ x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
136
+ batch_shape = dout.shape[:-1]
137
+ dout = dout.reshape(-1, dout.shape[-1])
138
+ # cuBLAS seems faster for this so we just use it instead of cutlass gemm
139
+ dx = linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=False)
140
+ dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
141
+ dweight = linear_bwd_compute_weight_grad(ctx, dout, x, weight_og)
142
+ return dx, dweight, None
143
+
144
+
145
+ def linear_lse_func(x, weight, fuse_grad_accum=False):
146
+ return LinearLSEFunc.apply(x, weight, fuse_grad_accum)
147
+
148
+
149
+ class Linear(nn.Linear):
150
+ def __init__(
151
+ self,
152
+ in_features: int,
153
+ out_features: int,
154
+ bias: bool = False,
155
+ device=None,
156
+ dtype=None,
157
+ fuse_grad_accum: bool = False,
158
+ ) -> None:
159
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
160
+ self.fuse_grad_accum = fuse_grad_accum
161
+
162
+ def forward(self, input: Tensor) -> Tensor:
163
+ if self.bias is None and input.is_cuda:
164
+ return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
165
+ else:
166
+ return F.linear(input, self.weight, self.bias)
167
+
168
+
169
+ class LinearLSE(Linear):
170
+ def forward(self, input: Tensor) -> Tensor:
171
+ if self.bias is None and input.is_cuda:
172
+ return linear_lse_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
173
+ else:
174
+ out = F.linear(input, self.weight, self.bias)
175
+ lse = torch.logsumexp(out, dim=-1)
176
+ return out, lse