quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Optional, Tuple, Dict, Any
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Int32
|
|
9
|
+
from cutlass.cute.runtime import from_dlpack, make_ptr
|
|
10
|
+
|
|
11
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
12
|
+
from quack.dense_gemm_sm90 import TileSchedulerOptions
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class GemmTensorInfo:
|
|
17
|
+
tensor: Optional[Tensor]
|
|
18
|
+
dtype: Optional[Any] = None
|
|
19
|
+
major: Optional[str] = None
|
|
20
|
+
cute_tensor: Optional[cute.Tensor] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GemmWrapperBase:
|
|
24
|
+
@staticmethod
|
|
25
|
+
def validate_tensor_3d(tensor: Tensor, name: str) -> None:
|
|
26
|
+
assert tensor.dim() == 3 and tensor.is_cuda, f"{name} must be a 3D CUDA tensor"
|
|
27
|
+
assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
|
|
31
|
+
assert tensor.shape == expected_shape, (
|
|
32
|
+
f"{name} must have shape {expected_shape}, got {tensor.shape}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
|
|
37
|
+
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
|
38
|
+
# stride(1) == 1 means dims[1] is contiguous (innermost)
|
|
39
|
+
return dims[1] if tensor.stride(1) == 1 else dims[0]
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def create_cute_tensor(
|
|
43
|
+
tensor: Optional[Tensor],
|
|
44
|
+
major: Optional[str],
|
|
45
|
+
dims: Tuple[str, str, str],
|
|
46
|
+
assumed_align: int = 16,
|
|
47
|
+
) -> Optional[cute.Tensor]:
|
|
48
|
+
if tensor is None:
|
|
49
|
+
return None
|
|
50
|
+
# Tensor is already permuted to (dims[0], dims[1], dims[2])
|
|
51
|
+
# If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
|
|
52
|
+
leading_dim = 1 if major == dims[1] else 0
|
|
53
|
+
return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
|
|
54
|
+
leading_dim=leading_dim
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def validate_and_prepare_tensors(
|
|
59
|
+
A: Tensor,
|
|
60
|
+
B: Tensor,
|
|
61
|
+
D: Optional[Tensor] = None,
|
|
62
|
+
C: Optional[Tensor] = None,
|
|
63
|
+
additional_tensors: Optional[Dict[str, Tensor]] = None,
|
|
64
|
+
) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
|
|
65
|
+
GemmWrapperBase.validate_tensor_3d(A, "A")
|
|
66
|
+
L, M, K = A.shape
|
|
67
|
+
GemmWrapperBase.validate_tensor_3d(B, "B")
|
|
68
|
+
_, N, _ = B.shape
|
|
69
|
+
assert B.dtype == A.dtype, "A and B must have the same dtype"
|
|
70
|
+
GemmWrapperBase.validate_shape(B, (L, N, K), "B")
|
|
71
|
+
tensors = {
|
|
72
|
+
"A": GemmTensorInfo(A),
|
|
73
|
+
"B": GemmTensorInfo(B),
|
|
74
|
+
"D": GemmTensorInfo(D),
|
|
75
|
+
"C": GemmTensorInfo(C),
|
|
76
|
+
}
|
|
77
|
+
if D is not None:
|
|
78
|
+
GemmWrapperBase.validate_tensor_3d(D, "D")
|
|
79
|
+
GemmWrapperBase.validate_shape(D, (L, M, N), "D")
|
|
80
|
+
if C is not None:
|
|
81
|
+
GemmWrapperBase.validate_tensor_3d(C, "C")
|
|
82
|
+
GemmWrapperBase.validate_shape(C, (L, M, N), "C")
|
|
83
|
+
if additional_tensors:
|
|
84
|
+
for name, tensor in additional_tensors.items():
|
|
85
|
+
if tensor is not None:
|
|
86
|
+
GemmWrapperBase.validate_tensor_3d(tensor, name)
|
|
87
|
+
GemmWrapperBase.validate_shape(tensor, (L, M, N), name)
|
|
88
|
+
tensors[name] = GemmTensorInfo(tensor)
|
|
89
|
+
|
|
90
|
+
return L, M, K, N, tensors
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def permute_tensors(tensors: Dict[str, GemmTensorInfo]) -> None:
|
|
94
|
+
for info in tensors.values():
|
|
95
|
+
if info.tensor is not None:
|
|
96
|
+
info.tensor = info.tensor.permute(1, 2, 0)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
|
|
100
|
+
for info in tensors.values():
|
|
101
|
+
if info.tensor is not None:
|
|
102
|
+
info.dtype = torch2cute_dtype_map[info.tensor.dtype]
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def determine_major_orders(
|
|
106
|
+
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
|
107
|
+
) -> None:
|
|
108
|
+
for name, dims in major_configs.items():
|
|
109
|
+
if name in tensors and tensors[name].tensor is not None:
|
|
110
|
+
tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def create_cute_tensors(
|
|
114
|
+
tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
|
|
115
|
+
) -> None:
|
|
116
|
+
for name, info in tensors.items():
|
|
117
|
+
if info.tensor is not None and name in major_configs:
|
|
118
|
+
info.cute_tensor = GemmWrapperBase.create_cute_tensor(
|
|
119
|
+
info.tensor, info.major, major_configs[name]
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def create_scheduler_args(
|
|
124
|
+
max_active_clusters: int, tile_count_semaphore: Optional[Tensor] = None
|
|
125
|
+
) -> TileSchedulerOptions:
|
|
126
|
+
return TileSchedulerOptions(
|
|
127
|
+
Int32(max_active_clusters),
|
|
128
|
+
tile_count_semaphore=make_ptr(
|
|
129
|
+
Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
|
|
130
|
+
)
|
|
131
|
+
if tile_count_semaphore is not None
|
|
132
|
+
else None,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def get_compile_key(
|
|
137
|
+
tensors: Dict[str, GemmTensorInfo],
|
|
138
|
+
activation: Optional[str],
|
|
139
|
+
tile_shape_mn: Tuple[int, int],
|
|
140
|
+
cluster_shape_mnk: Tuple[int, int, int],
|
|
141
|
+
pingpong: bool,
|
|
142
|
+
persistent: bool,
|
|
143
|
+
has_semaphore: bool,
|
|
144
|
+
*args,
|
|
145
|
+
key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
|
|
146
|
+
) -> Tuple:
|
|
147
|
+
key_parts = []
|
|
148
|
+
for name in key_tensor_names:
|
|
149
|
+
if name in tensors:
|
|
150
|
+
key_parts.append(tensors[name].dtype)
|
|
151
|
+
key_parts.append(activation)
|
|
152
|
+
key_parts.extend([tile_shape_mn, cluster_shape_mnk])
|
|
153
|
+
for name in key_tensor_names:
|
|
154
|
+
if name in tensors:
|
|
155
|
+
key_parts.append(tensors[name].major)
|
|
156
|
+
key_parts.extend([pingpong, persistent, has_semaphore])
|
|
157
|
+
key_parts.extend(args)
|
|
158
|
+
return tuple(key_parts)
|
quack/layernorm.py
CHANGED
|
@@ -10,7 +10,9 @@ import cutlass
|
|
|
10
10
|
import cutlass.cute as cute
|
|
11
11
|
from cutlass.cute.runtime import from_dlpack
|
|
12
12
|
import quack.utils as utils
|
|
13
|
-
from quack.
|
|
13
|
+
from quack.reduce import row_reduce
|
|
14
|
+
from quack.reduction_base import ReductionBase
|
|
15
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
class LayerNorm(ReductionBase):
|
|
@@ -190,7 +192,7 @@ class LayerNorm(ReductionBase):
|
|
|
190
192
|
cute.autovec_copy(tXsX, tXrX)
|
|
191
193
|
x = tXrX.load().to(cute.Float32)
|
|
192
194
|
threads_per_row = tv_layout.shape[0][0]
|
|
193
|
-
sum_x =
|
|
195
|
+
sum_x = row_reduce(
|
|
194
196
|
x,
|
|
195
197
|
cute.ReductionOp.ADD,
|
|
196
198
|
threads_per_row,
|
|
@@ -207,7 +209,7 @@ class LayerNorm(ReductionBase):
|
|
|
207
209
|
cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
|
|
208
210
|
x = tXrX.load().to(cute.Float32)
|
|
209
211
|
|
|
210
|
-
sum_sq_x_sub_mean =
|
|
212
|
+
sum_sq_x_sub_mean = row_reduce(
|
|
211
213
|
(x - mean) * (x - mean),
|
|
212
214
|
cute.ReductionOp.ADD,
|
|
213
215
|
threads_per_row,
|
quack/linear.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.amp import custom_fwd, custom_bwd
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def linear_fwd_convert_type(*tensors):
|
|
15
|
+
autocast_dtype = torch.get_autocast_dtype("cuda")
|
|
16
|
+
if torch.is_autocast_enabled():
|
|
17
|
+
tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
|
|
18
|
+
return tensors
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
|
|
22
|
+
needs_input_grad, needs_weight_grad = needs_x_w_grad
|
|
23
|
+
if not needs_input_grad:
|
|
24
|
+
weight, weight_og = None, None
|
|
25
|
+
if not needs_weight_grad:
|
|
26
|
+
x = None
|
|
27
|
+
ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
|
|
31
|
+
if ctx.needs_input_grad[0]:
|
|
32
|
+
assert weight is not None
|
|
33
|
+
return matmul_fn(dout, weight)
|
|
34
|
+
else:
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
|
|
39
|
+
if ctx.needs_input_grad[1]:
|
|
40
|
+
assert x is not None
|
|
41
|
+
x = x.reshape(-1, x.shape[-1])
|
|
42
|
+
# fuse_grad_accum is not compatible with torch.compile
|
|
43
|
+
if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
|
|
44
|
+
dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
|
|
45
|
+
else:
|
|
46
|
+
# print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
|
|
47
|
+
matmul_inplace_fn(dout.T, x, weight_og.grad)
|
|
48
|
+
dweight = weight_og.grad
|
|
49
|
+
weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
|
|
50
|
+
else:
|
|
51
|
+
dweight = None
|
|
52
|
+
return dweight
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class LinearFunc(torch.autograd.Function):
|
|
56
|
+
matmul_fwd_fn = gemm
|
|
57
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
|
|
58
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
|
|
59
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
60
|
+
|
|
61
|
+
# Use classmethod instead of staticmethod to allow inheritance
|
|
62
|
+
@classmethod
|
|
63
|
+
@custom_fwd(device_type="cuda")
|
|
64
|
+
def forward(cls, ctx, x, weight, fuse_grad_accum=False):
|
|
65
|
+
"""
|
|
66
|
+
x: (..., in_features)
|
|
67
|
+
weight: (out_features, in_features)
|
|
68
|
+
out: (..., out_features)
|
|
69
|
+
"""
|
|
70
|
+
ctx.weight_dtype = weight.dtype
|
|
71
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
72
|
+
weight_og = weight
|
|
73
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
74
|
+
batch_shape = x.shape[:-1]
|
|
75
|
+
x = x.reshape(-1, x.shape[-1])
|
|
76
|
+
# out = F.linear(x, weight)
|
|
77
|
+
out = cls.matmul_fwd_fn(x, weight.T)
|
|
78
|
+
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
79
|
+
return out.reshape(*batch_shape, out.shape[-1])
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
@custom_bwd(device_type="cuda")
|
|
83
|
+
def backward(cls, ctx, dout, *args):
|
|
84
|
+
"""
|
|
85
|
+
dout: (..., out_features)
|
|
86
|
+
"""
|
|
87
|
+
x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
|
|
88
|
+
batch_shape = dout.shape[:-1]
|
|
89
|
+
dout = dout.reshape(-1, dout.shape[-1])
|
|
90
|
+
dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
|
|
91
|
+
dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
|
|
92
|
+
dweight = linear_bwd_compute_weight_grad(
|
|
93
|
+
ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
|
|
94
|
+
)
|
|
95
|
+
# return extra Nones for other classes that inherit from LinearFunc
|
|
96
|
+
return dx, dweight, *([None] * 10)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class LinearUntunedFunc(LinearFunc):
|
|
100
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
101
|
+
matmul_fwd_fn = partial(gemm, tuned=False)
|
|
102
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
103
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
104
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
|
|
108
|
+
fn_cls = LinearFunc if tuned else LinearUntunedFunc
|
|
109
|
+
return fn_cls.apply(x, weight, fuse_grad_accum)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class LinearActFunc(LinearFunc):
|
|
113
|
+
matmul_fwd_fn = gemm_act
|
|
114
|
+
|
|
115
|
+
# Use classmethod instead of staticmethod to allow inheritance
|
|
116
|
+
@classmethod
|
|
117
|
+
@custom_fwd(device_type="cuda")
|
|
118
|
+
def forward(cls, ctx, x, weight, activation, store_preact=True, fuse_grad_accum=False):
|
|
119
|
+
"""
|
|
120
|
+
x: (..., in_features)
|
|
121
|
+
weight: (out_features, in_features)
|
|
122
|
+
out: (..., out_features)
|
|
123
|
+
Return both out and post-activation, but only out is differentiable.
|
|
124
|
+
"""
|
|
125
|
+
ctx.weight_dtype = weight.dtype
|
|
126
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
127
|
+
weight_og = weight
|
|
128
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
129
|
+
batch_shape = x.shape[:-1]
|
|
130
|
+
x = x.reshape(-1, x.shape[-1])
|
|
131
|
+
out, postact = cls.matmul_fwd_fn(
|
|
132
|
+
x, weight.T, activation=activation, store_preact=store_preact
|
|
133
|
+
)
|
|
134
|
+
linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
|
|
135
|
+
if out is not None:
|
|
136
|
+
out = out.reshape(*batch_shape, out.shape[-1])
|
|
137
|
+
ctx.mark_non_differentiable(postact)
|
|
138
|
+
ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
|
|
139
|
+
return out, postact.reshape(*batch_shape, postact.shape[-1])
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class LinearActUntunedFunc(LinearActFunc):
|
|
143
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
144
|
+
matmul_fwd_fn = partial(gemm_act, tuned=False)
|
|
145
|
+
matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
146
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
147
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def linear_act_func(x, weight, activation, store_preact=True, fuse_grad_accum=False, tuned=True):
|
|
151
|
+
fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
|
|
152
|
+
return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class DActLinearFunc(LinearFunc):
|
|
156
|
+
matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
|
|
157
|
+
|
|
158
|
+
# Use classmethod instead of staticmethod to allow inheritance
|
|
159
|
+
@classmethod
|
|
160
|
+
@custom_fwd(device_type="cuda")
|
|
161
|
+
def forward(cls, ctx, preact, weight, x, activation, fuse_grad_accum=False):
|
|
162
|
+
"""
|
|
163
|
+
x: (..., in_features)
|
|
164
|
+
weight: (out_features, in_features)
|
|
165
|
+
out: (..., out_features)
|
|
166
|
+
Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
|
|
167
|
+
"""
|
|
168
|
+
ctx.weight_dtype = weight.dtype
|
|
169
|
+
ctx.fuse_grad_accum = fuse_grad_accum
|
|
170
|
+
weight_og = weight
|
|
171
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
172
|
+
batch_shape = x.shape[:-1]
|
|
173
|
+
x = x.reshape(-1, x.shape[-1])
|
|
174
|
+
out = cls.matmul_fwd_fn(x, weight.T)
|
|
175
|
+
# Store preact instead of x, we will recompute x in the backward pass
|
|
176
|
+
linear_fwd_postprocess(
|
|
177
|
+
ctx, preact, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
|
|
178
|
+
)
|
|
179
|
+
ctx.activation = activation
|
|
180
|
+
return out.reshape(*batch_shape, out.shape[-1])
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
@custom_bwd(device_type="cuda")
|
|
184
|
+
def backward(cls, ctx, dout):
|
|
185
|
+
"""
|
|
186
|
+
dout: (..., out_features)
|
|
187
|
+
"""
|
|
188
|
+
# weight_og is None if not ctx.fuse_grad_accum
|
|
189
|
+
preact, weight, weight_og = ctx.saved_tensors
|
|
190
|
+
batch_shape = dout.shape[:-1]
|
|
191
|
+
dout = dout.reshape(-1, dout.shape[-1])
|
|
192
|
+
preact = preact.reshape(-1, preact.shape[-1])
|
|
193
|
+
if ctx.needs_input_grad[0]:
|
|
194
|
+
assert weight is not None
|
|
195
|
+
dpreact, x = cls.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
|
|
196
|
+
else:
|
|
197
|
+
dpreact, x = None, None
|
|
198
|
+
dpreact = dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
|
|
199
|
+
dweight = linear_bwd_compute_weight_grad(
|
|
200
|
+
ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
|
|
201
|
+
)
|
|
202
|
+
return dpreact, dweight, *([None] * 3)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class DActLinearUntunedFunc(DActLinearFunc):
|
|
206
|
+
# Passing in tuned=False to disable tuning at runtime
|
|
207
|
+
matmul_fwd_fn = partial(gemm, tuned=False)
|
|
208
|
+
matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
|
|
209
|
+
matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
|
|
210
|
+
matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
|
|
214
|
+
fn_cls = DActLinearFunc if tuned else DActLinearUntunedFunc
|
|
215
|
+
return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class Linear(nn.Linear):
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
in_features: int,
|
|
222
|
+
out_features: int,
|
|
223
|
+
bias: bool = False,
|
|
224
|
+
device=None,
|
|
225
|
+
dtype=None,
|
|
226
|
+
fuse_grad_accum: bool = False,
|
|
227
|
+
) -> None:
|
|
228
|
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
229
|
+
self.fuse_grad_accum = fuse_grad_accum
|
|
230
|
+
|
|
231
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
232
|
+
if (
|
|
233
|
+
self.bias is None
|
|
234
|
+
and input.is_cuda
|
|
235
|
+
and self.in_features % 8 == 0
|
|
236
|
+
and self.out_features % 8 == 0
|
|
237
|
+
):
|
|
238
|
+
return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
|
|
239
|
+
else:
|
|
240
|
+
return F.linear(input, self.weight, self.bias)
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
from typing import Optional, Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.amp import custom_fwd, custom_bwd
|
|
9
|
+
|
|
10
|
+
from quack.cross_entropy import cross_entropy, cross_entropy_fwd_out
|
|
11
|
+
from quack.gemm_interface import gemm, gemm_add, gemm_add_inplace
|
|
12
|
+
from quack.linear import linear_fwd_convert_type
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def linear_cross_entropy_func(
|
|
16
|
+
x: Tensor, # (..., d)
|
|
17
|
+
weight: Tensor, # (V, d)
|
|
18
|
+
bias: Optional[Tensor], # (V,) or None
|
|
19
|
+
target: Tensor, # (...,), int or long
|
|
20
|
+
ignore_index: int = -100,
|
|
21
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
22
|
+
inplace_backward: bool = False,
|
|
23
|
+
) -> Tensor:
|
|
24
|
+
y = F.linear(x, weight, bias) # (..., V)
|
|
25
|
+
return cross_entropy(
|
|
26
|
+
y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def linear_cross_entropy_func_ref(
|
|
31
|
+
x: Tensor, # (..., d)
|
|
32
|
+
weight: Tensor, # (V, d)
|
|
33
|
+
bias: Optional[Tensor], # (V,) or None
|
|
34
|
+
target: Tensor, # (...,), int or long
|
|
35
|
+
ignore_index: int = -100,
|
|
36
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
37
|
+
) -> Tensor:
|
|
38
|
+
y = F.linear(x, weight, bias) # (..., V)
|
|
39
|
+
return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def chunked_linear_cross_entropy_fwd(
|
|
43
|
+
x: Tensor, # (B*L, d) where B is batch, L is seqlen
|
|
44
|
+
weight: Tensor, # (V, d) where V is vocab size
|
|
45
|
+
target: Tensor, # (B*L,)
|
|
46
|
+
chunk_size: int = 4096,
|
|
47
|
+
ignore_index: int = -100,
|
|
48
|
+
tuned: bool = True,
|
|
49
|
+
) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
|
50
|
+
"""
|
|
51
|
+
Chunked forward pass for linear cross entropy.
|
|
52
|
+
|
|
53
|
+
Splits input along batch dimension, computes matmul and cross_entropy_fwd
|
|
54
|
+
for each chunk, stores dx for each chunk, and accumulates dw.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
loss: (B*L,) loss values
|
|
58
|
+
dx: (B*L, d) gradient w.r.t. input
|
|
59
|
+
dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
|
|
60
|
+
last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
|
|
61
|
+
last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
|
|
62
|
+
"""
|
|
63
|
+
B_L, d = x.shape
|
|
64
|
+
V, _ = weight.shape
|
|
65
|
+
device = x.device
|
|
66
|
+
num_chunks = (B_L + chunk_size - 1) // chunk_size
|
|
67
|
+
# Since we use gemm with TMA we require some alignment
|
|
68
|
+
assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
|
|
69
|
+
assert B_L % 8 == 0
|
|
70
|
+
# Pre-allocate outputs
|
|
71
|
+
loss = torch.empty(B_L, device=device, dtype=torch.float32)
|
|
72
|
+
logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
|
|
73
|
+
dx = torch.empty_like(x)
|
|
74
|
+
# Last chunk of dw will be deferred to the backward pass
|
|
75
|
+
dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
|
|
76
|
+
last_dlogits_chunk = None
|
|
77
|
+
last_x_chunk = None
|
|
78
|
+
|
|
79
|
+
# Process in chunks
|
|
80
|
+
for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
|
|
81
|
+
zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
|
|
82
|
+
):
|
|
83
|
+
chunk_len = x_chunk.shape[0]
|
|
84
|
+
logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
|
|
85
|
+
torch.mm(x_chunk, weight.mT, out=logits_chunk)
|
|
86
|
+
# Compute cross entropy forward with gradients
|
|
87
|
+
dlogits_chunk = logits_chunk # inplace_backward
|
|
88
|
+
cross_entropy_fwd_out(
|
|
89
|
+
logits_chunk,
|
|
90
|
+
target_chunk,
|
|
91
|
+
None, # target_logit
|
|
92
|
+
loss=loss_chunk,
|
|
93
|
+
lse=None, # we don't need lse here
|
|
94
|
+
dx=dlogits_chunk,
|
|
95
|
+
ignore_index=ignore_index,
|
|
96
|
+
)
|
|
97
|
+
# Compute dx for this chunk: dlogits @ weight
|
|
98
|
+
torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
|
|
99
|
+
# Compute dw for all chunks except the last
|
|
100
|
+
if i == num_chunks - 1:
|
|
101
|
+
# Last chunk: save for backward pass
|
|
102
|
+
last_dlogits_chunk = dlogits_chunk
|
|
103
|
+
last_x_chunk = x_chunk
|
|
104
|
+
elif i == 0:
|
|
105
|
+
# First chunk: dw = dlogits.T @ x_chunk
|
|
106
|
+
gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
|
|
107
|
+
else:
|
|
108
|
+
# Middle chunks: dw += dlogits.T @ x_chunk
|
|
109
|
+
gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
|
|
110
|
+
return loss, dx, dw, last_dlogits_chunk, last_x_chunk
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
114
|
+
@staticmethod
|
|
115
|
+
@custom_fwd(device_type="cuda")
|
|
116
|
+
def forward(
|
|
117
|
+
ctx,
|
|
118
|
+
x: Tensor,
|
|
119
|
+
weight: Tensor,
|
|
120
|
+
target: Tensor,
|
|
121
|
+
ignore_index: int = -100,
|
|
122
|
+
reduction: Literal["mean", "sum"] = "mean",
|
|
123
|
+
chunk_size: int = 4096,
|
|
124
|
+
tuned: bool = True,
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Forward pass computes loss and stores dx and dw for backward.
|
|
128
|
+
"""
|
|
129
|
+
ctx.weight_dtype = weight.dtype
|
|
130
|
+
x, weight = linear_fwd_convert_type(x, weight)
|
|
131
|
+
batch_shape = x.shape[:-1]
|
|
132
|
+
x = x.reshape(-1, x.shape[-1])
|
|
133
|
+
# TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
|
|
134
|
+
loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
|
|
135
|
+
x, weight, target, chunk_size, ignore_index, tuned=tuned
|
|
136
|
+
)
|
|
137
|
+
loss_sum = loss.sum()
|
|
138
|
+
loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
|
|
139
|
+
ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
|
|
140
|
+
ctx.batch_shape = batch_shape
|
|
141
|
+
ctx.ignore_index = ignore_index
|
|
142
|
+
ctx.reduction = reduction
|
|
143
|
+
ctx.tuned = tuned
|
|
144
|
+
return loss_sum if loss_scale is None else loss_sum * loss_scale
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
@custom_bwd(device_type="cuda")
|
|
148
|
+
def backward(ctx, dloss):
|
|
149
|
+
"""
|
|
150
|
+
Backward pass scales pre-computed gradients by dloss and completes
|
|
151
|
+
the last chunk's dw computation.
|
|
152
|
+
dloss is a scalar.
|
|
153
|
+
"""
|
|
154
|
+
dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
|
|
155
|
+
tuned = ctx.tuned
|
|
156
|
+
if loss_scale is not None:
|
|
157
|
+
dloss = dloss * loss_scale
|
|
158
|
+
# TODO: the case where x or weight doesn't require grad
|
|
159
|
+
dx.mul_(dloss)
|
|
160
|
+
dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
|
|
161
|
+
# Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
|
|
162
|
+
if dw is None:
|
|
163
|
+
# Only had one chunk, compute dw directly with dloss scaling
|
|
164
|
+
dw = gemm(
|
|
165
|
+
last_dlogits_chunk.T,
|
|
166
|
+
last_x_chunk,
|
|
167
|
+
out_dtype=ctx.weight_dtype,
|
|
168
|
+
alpha=dloss,
|
|
169
|
+
tuned=tuned,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
# Add last chunk's contribution with dloss scaling
|
|
173
|
+
# dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
|
|
174
|
+
# We use alpha=dloss, beta=dloss
|
|
175
|
+
if ctx.weight_dtype == dw.dtype:
|
|
176
|
+
gemm_add_inplace(
|
|
177
|
+
last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
dw = gemm_add(
|
|
181
|
+
last_dlogits_chunk.T,
|
|
182
|
+
last_x_chunk,
|
|
183
|
+
dw,
|
|
184
|
+
alpha=dloss,
|
|
185
|
+
beta=dloss,
|
|
186
|
+
out_dtype=ctx.weight_dtype,
|
|
187
|
+
tuned=tuned,
|
|
188
|
+
)
|
|
189
|
+
return dx, dw, None, None, None, None, None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def chunked_linear_cross_entropy(
|
|
193
|
+
x: Tensor,
|
|
194
|
+
weight: Tensor,
|
|
195
|
+
target: Tensor,
|
|
196
|
+
chunk_size: int = 4096,
|
|
197
|
+
ignore_index: int = -100,
|
|
198
|
+
reduction: Literal["mean", "sum"] = "mean",
|
|
199
|
+
tuned: bool = True,
|
|
200
|
+
) -> Tensor:
|
|
201
|
+
"""
|
|
202
|
+
Chunked linear cross entropy with automatic differentiation support.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
x: Input tensor of shape (B*L, d)
|
|
206
|
+
weight: Weight tensor of shape (V, d)
|
|
207
|
+
target: Target indices of shape (B*L,)
|
|
208
|
+
chunk_size: Size of chunks to process
|
|
209
|
+
ignore_index: Index to ignore in loss computation
|
|
210
|
+
reduction: Type of reduction to apply
|
|
211
|
+
tuned: Whether to use tuned kernels
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Loss tensor with specified reduction
|
|
215
|
+
"""
|
|
216
|
+
if reduction not in ["mean", "sum"]:
|
|
217
|
+
raise ValueError(f"Invalid reduction: {reduction}")
|
|
218
|
+
loss = ChunkedLinearCrossEntropyFunction.apply(
|
|
219
|
+
x, weight, target, ignore_index, reduction, chunk_size, tuned
|
|
220
|
+
)
|
|
221
|
+
return loss
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class LinearCrossEntropy(nn.Linear):
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
in_features: int,
|
|
228
|
+
out_features: int,
|
|
229
|
+
bias: bool = False,
|
|
230
|
+
ignore_index: int = -100,
|
|
231
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
232
|
+
chunk_size: Optional[int] = None,
|
|
233
|
+
inplace_backward: bool = False,
|
|
234
|
+
tuned: bool = True,
|
|
235
|
+
device=None,
|
|
236
|
+
dtype=None,
|
|
237
|
+
) -> None:
|
|
238
|
+
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
239
|
+
self.ignore_index = ignore_index
|
|
240
|
+
self.reduction = reduction
|
|
241
|
+
self.chunk_size = chunk_size
|
|
242
|
+
self.inplace_backward = inplace_backward
|
|
243
|
+
self.tuned = tuned
|
|
244
|
+
|
|
245
|
+
def forward(self, input: Tensor, target: Tensor) -> Tensor:
|
|
246
|
+
if (
|
|
247
|
+
self.bias is None
|
|
248
|
+
and input.is_cuda
|
|
249
|
+
and input.stride(-1) == 1
|
|
250
|
+
and self.in_features % 8 == 0
|
|
251
|
+
and self.out_features % 8 == 0
|
|
252
|
+
and input.shape[:-1].numel() % 8 == 0
|
|
253
|
+
and self.chunk_size is not None
|
|
254
|
+
and self.chunk_size % 8 == 0
|
|
255
|
+
and self.reduction in ["mean", "sum"]
|
|
256
|
+
):
|
|
257
|
+
return chunked_linear_cross_entropy(
|
|
258
|
+
input,
|
|
259
|
+
self.weight,
|
|
260
|
+
target,
|
|
261
|
+
chunk_size=self.chunk_size,
|
|
262
|
+
ignore_index=self.ignore_index,
|
|
263
|
+
reduction=self.reduction,
|
|
264
|
+
tuned=self.tuned,
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
return linear_cross_entropy_func(
|
|
268
|
+
input,
|
|
269
|
+
self.weight,
|
|
270
|
+
self.bias,
|
|
271
|
+
target,
|
|
272
|
+
ignore_index=self.ignore_index,
|
|
273
|
+
reduction=self.reduction,
|
|
274
|
+
inplace_backward=self.inplace_backward,
|
|
275
|
+
)
|