quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Optional, Tuple, Dict, Any
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Int32
|
|
9
|
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
10
|
+
|
|
11
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
12
|
+
from quack.dense_gemm_sm90 import TileSchedulerOptions
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class GemmTensorInfo:
|
|
17
|
+
tensor: Optional[Tensor]
|
|
18
|
+
dtype: Optional[Any] = None
|
|
19
|
+
major: Optional[str] = None
|
|
20
|
+
cute_tensor: Optional[cute.Tensor] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GemmWrapperBase:
|
|
24
|
+
@staticmethod
|
|
25
|
+
def validate_tensor_3d(tensor: Tensor, name: str) -> None:
|
|
26
|
+
assert tensor.dim() == 3 and tensor.is_cuda, f"{name} must be a 3D CUDA tensor"
|
|
27
|
+
assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
|
|
31
|
+
assert tensor.shape == expected_shape, (
|
|
32
|
+
f"{name} must have shape {expected_shape}, got {tensor.shape}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
|
|
37
|
+
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
|
38
|
+
# stride(1) == 1 means dims[1] is contiguous (innermost)
|
|
39
|
+
return dims[1] if tensor.stride(1) == 1 else dims[0]
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def create_cute_tensor(
|
|
43
|
+
tensor: Optional[Tensor],
|
|
44
|
+
major: Optional[str],
|
|
45
|
+
dims: Tuple[str, str, str],
|
|
46
|
+
assumed_align: int = 16,
|
|
47
|
+
) -> Optional[cute.Tensor]:
|
|
48
|
+
if tensor is None:
|
|
49
|
+
return None
|
|
50
|
+
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
|
51
|
+
# If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
|
|
52
|
+
leading_dim = 1 if major == dims[1] else 0
|
|
53
|
+
return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
|
|
54
|
+
leading_dim=leading_dim
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def validate_and_prepare_tensors(
|
|
59
|
+
A: Tensor,
|
|
60
|
+
B: Tensor,
|
|
61
|
+
D: Optional[Tensor] = None,
|
|
62
|
+
C: Optional[Tensor] = None,
|
|
63
|
+
additional_tensors: Optional[Dict[str, Tensor]] = None,
|
|
64
|
+
) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
|
|
65
|
+
GemmWrapperBase.validate_tensor_3d(A, "A")
|
|
66
|
+
L, M, K = A.shape
|
|
67
|
+
GemmWrapperBase.validate_tensor_3d(B, "B")
|
|
68
|
+
_, N, _ = B.shape
|
|
69
|
+
assert B.dtype == A.dtype, "A and B must have the same dtype"
|
|
70
|
+
GemmWrapperBase.validate_shape(B, (L, N, K), "B")
|
|
71
|
+
tensors = {
|
|
72
|
+
"A": GemmTensorInfo(A),
|
|
73
|
+
"B": GemmTensorInfo(B),
|
|
74
|
+
"D": GemmTensorInfo(D),
|
|
75
|
+
"C": GemmTensorInfo(C),
|
|
76
|
+
}
|
|
77
|
+
if D is not None:
|
|
78
|
+
GemmWrapperBase.validate_tensor_3d(D, "D")
|
|
79
|
+
GemmWrapperBase.validate_shape(D, (L, M, N), "D")
|
|
80
|
+
if C is not None:
|
|
81
|
+
GemmWrapperBase.validate_tensor_3d(C, "C")
|
|
82
|
+
GemmWrapperBase.validate_shape(C, (L, M, N), "C")
|
|
83
|
+
if additional_tensors:
|
|
84
|
+
for name, tensor in additional_tensors.items():
|
|
85
|
+
if tensor is not None:
|
|
86
|
+
GemmWrapperBase.validate_tensor_3d(tensor, name)
|
|
87
|
+
GemmWrapperBase.validate_shape(tensor, (L, M, N), name)
|
|
88
|
+
tensors[name] = GemmTensorInfo(tensor)
|
|
89
|
+
|
|
90
|
+
return L, M, K, N, tensors
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def permute_tensors(tensors: Dict[str, GemmTensorInfo]) -> None:
|
|
94
|
+
for info in tensors.values():
|
|
95
|
+
if info.tensor is not None:
|
|
96
|
+
info.tensor = info.tensor.permute(1, 2, 0)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
|
|
100
|
+
for info in tensors.values():
|
|
101
|
+
if info.tensor is not None:
|
|
102
|
+
info.dtype = torch2cute_dtype_map[info.tensor.dtype]
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def determine_major_orders(
|
|
106
|
+
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
|
107
|
+
) -> None:
|
|
108
|
+
for name, dims in major_configs.items():
|
|
109
|
+
if name in tensors and tensors[name].tensor is not None:
|
|
110
|
+
tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def create_cute_tensors(
|
|
114
|
+
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
|
115
|
+
) -> None:
|
|
116
|
+
for name, info in tensors.items():
|
|
117
|
+
if info.tensor is not None and name in major_configs:
|
|
118
|
+
info.cute_tensor = GemmWrapperBase.create_cute_tensor(
|
|
119
|
+
info.tensor, info.major, major_configs[name]
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def create_scheduler_args(
|
|
124
|
+
max_active_clusters: int, tile_count_semaphore: Optional[Tensor] = None
|
|
125
|
+
) -> TileSchedulerOptions:
|
|
126
|
+
return TileSchedulerOptions(
|
|
127
|
+
Int32(max_active_clusters),
|
|
128
|
+
tile_count_semaphore=make_ptr(
|
|
129
|
+
Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
|
|
130
|
+
)
|
|
131
|
+
if tile_count_semaphore is not None
|
|
132
|
+
else None,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def get_compile_key(
|
|
137
|
+
tensors: Dict[str, GemmTensorInfo],
|
|
138
|
+
activation: Optional[str],
|
|
139
|
+
tile_shape_mn: Tuple[int, int],
|
|
140
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
141
|
+
pingpong: bool,
|
|
142
|
+
persistent: bool,
|
|
143
|
+
has_semaphore: bool,
|
|
144
|
+
*args,
|
|
145
|
+
key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
|
|
146
|
+
) -> Tuple:
|
|
147
|
+
key_parts = []
|
|
148
|
+
for name in key_tensor_names:
|
|
149
|
+
if name in tensors:
|
|
150
|
+
key_parts.append(tensors[name].dtype)
|
|
151
|
+
key_parts.append(activation)
|
|
152
|
+
key_parts.extend([tile_shape_mn, cluster_shape_mnk])
|
|
153
|
+
for name in key_tensor_names:
|
|
154
|
+
if name in tensors:
|
|
155
|
+
key_parts.append(tensors[name].major)
|
|
156
|
+
key_parts.extend([pingpong, persistent, has_semaphore])
|
|
157
|
+
key_parts.extend(args)
|
|
158
|
+
return tuple(key_parts)
|
quack/layernorm.py
CHANGED
|
@@ -10,7 +10,9 @@ import cutlass
|
|
|
10
10
|
import cutlass.cute as cute
|
|
11
11
|
from cutlass.cute.runtime import from_dlpack
|
|
12
12
|
import quack.utils as utils
|
|
13
|
-
from quack.
|
|
13
|
+
from quack.reduce import row_reduce
|
|
14
|
+
from quack.reduction_base import ReductionBase
|
|
15
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class LayerNorm(ReductionBase):
|
|
@@ -190,7 +192,7 @@ class LayerNorm(ReductionBase):
|
|
|
190
192
|
cute.autovec_copy(tXsX, tXrX)
|
|
191
193
|
x = tXrX.load().to(cute.Float32)
|
|
192
194
|
threads_per_row = tv_layout.shape[0][0]
|
|
193
|
-
sum_x =
|
|
195
|
+
sum_x = row_reduce(
|
|
194
196
|
x,
|
|
195
197
|
cute.ReductionOp.ADD,
|
|
196
198
|
threads_per_row,
|
|
@@ -207,7 +209,7 @@ class LayerNorm(ReductionBase):
|
|
|
207
209
|
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
208
210
|
x = tXrX.load().to(cute.Float32)
|
|
209
211
|
|
|
210
|
-
sum_sq_x_sub_mean =
|
|
212
|
+
sum_sq_x_sub_mean = row_reduce(
|
|
211
213
|
(x - mean) * (x - mean),
|
|
212
214
|
cute.ReductionOp.ADD,
|
|
213
215
|
threads_per_row,
|
|
@@ -215,7 +217,7 @@ class LayerNorm(ReductionBase):
|
|
|
215
217
|
mbar_ptr + 1 if cutlass.const_expr(self.cluster_n > 1) else None,
|
|
216
218
|
init_val=0.0,
|
|
217
219
|
)
|
|
218
|
-
rstd =
|
|
220
|
+
rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
|
|
219
221
|
if cutlass.const_expr(mRstd is not None):
|
|
220
222
|
# Only the thread corresponding to column 0 writes out the rstd to gmem
|
|
221
223
|
if (
|
quack/linear.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
2
4
|
import torch
|
|
3
5
|
import torch.nn as nn
|
|
4
6
|
import torch.nn.functional as F
|
|
@@ -6,10 +8,7 @@ from torch import Tensor
|
|
|
6
8
|
from torch.amp import custom_fwd, custom_bwd
|
|
7
9
|
|
|
8
10
|
|
|
9
|
-
from
|
|
10
|
-
# from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
|
|
11
|
-
|
|
12
|
-
from quack import gemm, gemm_lse # TODO: implement these
|
|
11
|
+
from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
def linear_fwd_convert_type(*tensors):
|
|
@@ -19,7 +18,8 @@ def linear_fwd_convert_type(*tensors):
|
|
|
19
18
|
return tensors
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
def linear_fwd_postprocess(ctx, x, weight, weight_og,
|
|
21
|
+
def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
|
|
22
|
+
needs_input_grad, needs_weight_grad = needs_x_w_grad
|
|
23
23
|
if not needs_input_grad:
|
|
24
24
|
weight, weight_og = None, None
|
|
25
25
|
if not needs_weight_grad:
|
|
@@ -27,29 +27,24 @@ def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_input_grad, needs_we
|
|
|
27
27
|
ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
def linear_bwd_compute_input_grad(ctx, dout, weight,
|
|
30
|
+
def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
|
|
31
31
|
if ctx.needs_input_grad[0]:
|
|
32
32
|
assert weight is not None
|
|
33
|
-
|
|
34
|
-
return (
|
|
35
|
-
gemm(dout, weight, sm_carveout=sm_carveout)
|
|
36
|
-
if use_tuned_gemm
|
|
37
|
-
else gemm_cb(dout, weight, sm_carveout=sm_carveout)
|
|
38
|
-
)
|
|
33
|
+
return matmul_fn(dout, weight)
|
|
39
34
|
else:
|
|
40
35
|
return None
|
|
41
36
|
|
|
42
37
|
|
|
43
|
-
def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og,
|
|
38
|
+
def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
|
|
44
39
|
if ctx.needs_input_grad[1]:
|
|
45
40
|
assert x is not None
|
|
46
41
|
x = x.reshape(-1, x.shape[-1])
|
|
47
42
|
# fuse_grad_accum is not compatible with torch.compile
|
|
48
43
|
if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
|
|
49
|
-
dweight =
|
|
44
|
+
dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
|
|
50
45
|
else:
|
|
51
46
|
# print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
|
|
52
|
-
|
|
47
|
+
matmul_inplace_fn(dout.T, x, weight_og.grad)
|
|
53
48
|
dweight = weight_og.grad
|
|
54
49
|
weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
|
|
55
50
|
else:
|
|
@@ -58,9 +53,15 @@ def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, sm_carveout=0):
|
|
|
58
53
|
|
|
59
54
|
|
|
60
55
|
class LinearFunc(torch.autograd.Function):
|
|
61
|
-
|
|
56
|
+
matmul_fwd_fn = gemm
|
|
57
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
|
|
58
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
|
|
59
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
60
|
+
|
|
61
|
+
# Use classmethod instead of staticmethod to allow inheritance
|
|
62
|
+
@classmethod
|
|
62
63
|
@custom_fwd(device_type="cuda")
|
|
63
|
-
def forward(ctx, x, weight, fuse_grad_accum=False):
|
|
64
|
+
def forward(cls, ctx, x, weight, fuse_grad_accum=False):
|
|
64
65
|
"""
|
|
65
66
|
x: (..., in_features)
|
|
66
67
|
weight: (out_features, in_features)
|
|
@@ -73,77 +74,145 @@ class LinearFunc(torch.autograd.Function):
|
|
|
73
74
|
batch_shape = x.shape[:-1]
|
|
74
75
|
x = x.reshape(-1, x.shape[-1])
|
|
75
76
|
# out = F.linear(x, weight)
|
|
76
|
-
out =
|
|
77
|
-
linear_fwd_postprocess(
|
|
78
|
-
ctx,
|
|
79
|
-
x,
|
|
80
|
-
weight,
|
|
81
|
-
weight_og,
|
|
82
|
-
needs_input_grad=ctx.needs_input_grad[0],
|
|
83
|
-
needs_weight_grad=ctx.needs_input_grad[1],
|
|
84
|
-
)
|
|
77
|
+
out = cls.matmul_fwd_fn(x, weight.T)
|
|
78
|
+
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
85
79
|
return out.reshape(*batch_shape, out.shape[-1])
|
|
86
80
|
|
|
87
|
-
@
|
|
81
|
+
@classmethod
|
|
88
82
|
@custom_bwd(device_type="cuda")
|
|
89
|
-
def backward(ctx, dout):
|
|
83
|
+
def backward(cls, ctx, dout, *args):
|
|
90
84
|
"""
|
|
91
85
|
dout: (..., out_features)
|
|
92
86
|
"""
|
|
93
87
|
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
|
94
88
|
batch_shape = dout.shape[:-1]
|
|
95
89
|
dout = dout.reshape(-1, dout.shape[-1])
|
|
96
|
-
dx = linear_bwd_compute_input_grad(ctx, dout, weight,
|
|
90
|
+
dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
|
|
97
91
|
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
|
98
|
-
dweight = linear_bwd_compute_weight_grad(
|
|
99
|
-
|
|
92
|
+
dweight = linear_bwd_compute_weight_grad(
|
|
93
|
+
ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
|
|
94
|
+
)
|
|
95
|
+
# return extra Nones for other classes that inherit from LinearFunc
|
|
96
|
+
return dx, dweight, *([None] * 10)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class LinearUntunedFunc(LinearFunc):
|
|
100
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
101
|
+
matmul_fwd_fn = partial(gemm, tuned=False)
|
|
102
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
103
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
104
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
100
105
|
|
|
101
106
|
|
|
102
|
-
def linear_func(x, weight, fuse_grad_accum=False):
|
|
103
|
-
|
|
107
|
+
def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
|
|
108
|
+
fn_cls = LinearFunc if tuned else LinearUntunedFunc
|
|
109
|
+
return fn_cls.apply(x, weight, fuse_grad_accum)
|
|
104
110
|
|
|
105
111
|
|
|
106
|
-
class
|
|
107
|
-
|
|
112
|
+
class LinearActFunc(LinearFunc):
|
|
113
|
+
matmul_fwd_fn = gemm_act
|
|
114
|
+
|
|
115
|
+
# Use classmethod instead of staticmethod to allow inheritance
|
|
116
|
+
@classmethod
|
|
108
117
|
@custom_fwd(device_type="cuda")
|
|
109
|
-
def forward(ctx, x, weight, fuse_grad_accum=False):
|
|
118
|
+
def forward(cls, ctx, x, weight, activation, store_preact=True, fuse_grad_accum=False):
|
|
110
119
|
"""
|
|
111
120
|
x: (..., in_features)
|
|
112
121
|
weight: (out_features, in_features)
|
|
113
122
|
out: (..., out_features)
|
|
123
|
+
Return both out and post-activation, but only out is differentiable.
|
|
114
124
|
"""
|
|
115
|
-
needs_weight_grad = weight.requires_grad
|
|
116
|
-
needs_input_grad = x.requires_grad
|
|
117
125
|
ctx.weight_dtype = weight.dtype
|
|
118
126
|
ctx.fuse_grad_accum = fuse_grad_accum
|
|
119
127
|
weight_og = weight
|
|
120
128
|
x, weight = linear_fwd_convert_type(x, weight)
|
|
121
129
|
batch_shape = x.shape[:-1]
|
|
122
130
|
x = x.reshape(-1, x.shape[-1])
|
|
123
|
-
out,
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
ctx.
|
|
127
|
-
|
|
131
|
+
out, postact = cls.matmul_fwd_fn(
|
|
132
|
+
x, weight.T, activation=activation, store_preact=store_preact
|
|
133
|
+
)
|
|
134
|
+
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
135
|
+
if out is not None:
|
|
136
|
+
out = out.reshape(*batch_shape, out.shape[-1])
|
|
137
|
+
ctx.mark_non_differentiable(postact)
|
|
138
|
+
ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
|
|
139
|
+
return out, postact.reshape(*batch_shape, postact.shape[-1])
|
|
128
140
|
|
|
129
|
-
|
|
141
|
+
|
|
142
|
+
class LinearActUntunedFunc(LinearActFunc):
|
|
143
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
144
|
+
matmul_fwd_fn = partial(gemm_act, tuned=False)
|
|
145
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
146
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
147
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def linear_act_func(x, weight, activation, store_preact=True, fuse_grad_accum=False, tuned=True):
|
|
151
|
+
fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
|
|
152
|
+
return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class DActLinearFunc(LinearFunc):
|
|
156
|
+
matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
|
|
157
|
+
|
|
158
|
+
# Use classmethod instead of staticmethod to allow inheritance
|
|
159
|
+
@classmethod
|
|
160
|
+
@custom_fwd(device_type="cuda")
|
|
161
|
+
def forward(cls, ctx, preact, weight, x, activation, fuse_grad_accum=False):
|
|
162
|
+
"""
|
|
163
|
+
x: (..., in_features)
|
|
164
|
+
weight: (out_features, in_features)
|
|
165
|
+
out: (..., out_features)
|
|
166
|
+
Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
|
|
167
|
+
"""
|
|
168
|
+
ctx.weight_dtype = weight.dtype
|
|
169
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
170
|
+
weight_og = weight
|
|
171
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
172
|
+
batch_shape = x.shape[:-1]
|
|
173
|
+
x = x.reshape(-1, x.shape[-1])
|
|
174
|
+
out = cls.matmul_fwd_fn(x, weight.T)
|
|
175
|
+
# Store preact instead of x, we will recompute x in the backward pass
|
|
176
|
+
linear_fwd_postprocess(
|
|
177
|
+
ctx, preact, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
|
|
178
|
+
)
|
|
179
|
+
ctx.activation = activation
|
|
180
|
+
return out.reshape(*batch_shape, out.shape[-1])
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
130
183
|
@custom_bwd(device_type="cuda")
|
|
131
|
-
def backward(ctx, dout
|
|
184
|
+
def backward(cls, ctx, dout):
|
|
132
185
|
"""
|
|
133
186
|
dout: (..., out_features)
|
|
134
187
|
"""
|
|
135
|
-
|
|
188
|
+
# weight_og is None if not ctx.fuse_grad_accum
|
|
189
|
+
preact, weight, weight_og = ctx.saved_tensors
|
|
136
190
|
batch_shape = dout.shape[:-1]
|
|
137
191
|
dout = dout.reshape(-1, dout.shape[-1])
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
192
|
+
preact = preact.reshape(-1, preact.shape[-1])
|
|
193
|
+
if ctx.needs_input_grad[0]:
|
|
194
|
+
assert weight is not None
|
|
195
|
+
dpreact, x = cls.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
|
|
196
|
+
else:
|
|
197
|
+
dpreact, x = None, None
|
|
198
|
+
dpreact = dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
|
|
199
|
+
dweight = linear_bwd_compute_weight_grad(
|
|
200
|
+
ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
|
|
201
|
+
)
|
|
202
|
+
return dpreact, dweight, *([None] * 3)
|
|
203
|
+
|
|
143
204
|
|
|
205
|
+
class DActLinearUntunedFunc(DActLinearFunc):
|
|
206
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
207
|
+
matmul_fwd_fn = partial(gemm, tuned=False)
|
|
208
|
+
matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
|
|
209
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
210
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
144
211
|
|
|
145
|
-
|
|
146
|
-
|
|
212
|
+
|
|
213
|
+
def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
|
|
214
|
+
fn_cls = DActLinearFunc if tuned else DActLinearUntunedFunc
|
|
215
|
+
return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
|
|
147
216
|
|
|
148
217
|
|
|
149
218
|
class Linear(nn.Linear):
|
|
@@ -160,17 +229,12 @@ class Linear(nn.Linear):
|
|
|
160
229
|
self.fuse_grad_accum = fuse_grad_accum
|
|
161
230
|
|
|
162
231
|
def forward(self, input: Tensor) -> Tensor:
|
|
163
|
-
if
|
|
232
|
+
if (
|
|
233
|
+
self.bias is None
|
|
234
|
+
and input.is_cuda
|
|
235
|
+
and self.in_features % 8 == 0
|
|
236
|
+
and self.out_features % 8 == 0
|
|
237
|
+
):
|
|
164
238
|
return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
165
239
|
else:
|
|
166
240
|
return F.linear(input, self.weight, self.bias)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
class LinearLSE(Linear):
|
|
170
|
-
def forward(self, input: Tensor) -> Tensor:
|
|
171
|
-
if self.bias is None and input.is_cuda:
|
|
172
|
-
return linear_lse_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
173
|
-
else:
|
|
174
|
-
out = F.linear(input, self.weight, self.bias)
|
|
175
|
-
lse = torch.logsumexp(out, dim=-1)
|
|
176
|
-
return out, lse
|