quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,158 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional, Tuple, Dict, Any
3
+ from dataclasses import dataclass
4
+
5
+ from torch import Tensor
6
+
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32
9
+ from cutlass.cute.runtime import from_dlpack, make_ptr
10
+
11
+ from quack.cute_dsl_utils import torch2cute_dtype_map
12
+ from quack.dense_gemm_sm90 import TileSchedulerOptions
13
+
14
+
15
+ @dataclass
16
+ class GemmTensorInfo:
17
+ tensor: Optional[Tensor]
18
+ dtype: Optional[Any] = None
19
+ major: Optional[str] = None
20
+ cute_tensor: Optional[cute.Tensor] = None
21
+
22
+
23
+ class GemmWrapperBase:
24
+ @staticmethod
25
+ def validate_tensor_3d(tensor: Tensor, name: str) -> None:
26
+ assert tensor.dim() == 3 and tensor.is_cuda, f"{name} must be a 3D CUDA tensor"
27
+ assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
28
+
29
+ @staticmethod
30
+ def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
31
+ assert tensor.shape == expected_shape, (
32
+ f"{name} must have shape {expected_shape}, got {tensor.shape}"
33
+ )
34
+
35
+ @staticmethod
36
+ def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
37
+ # Tensor is already permuted to (dims[0], dims[1], dims[2])
38
+ # stride(1) == 1 means dims[1] is contiguous (innermost)
39
+ return dims[1] if tensor.stride(1) == 1 else dims[0]
40
+
41
+ @staticmethod
42
+ def create_cute_tensor(
43
+ tensor: Optional[Tensor],
44
+ major: Optional[str],
45
+ dims: Tuple[str, str, str],
46
+ assumed_align: int = 16,
47
+ ) -> Optional[cute.Tensor]:
48
+ if tensor is None:
49
+ return None
50
+ # Tensor is already permuted to (dims[0], dims[1], dims[2])
51
+ # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
52
+ leading_dim = 1 if major == dims[1] else 0
53
+ return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
54
+ leading_dim=leading_dim
55
+ )
56
+
57
+ @staticmethod
58
+ def validate_and_prepare_tensors(
59
+ A: Tensor,
60
+ B: Tensor,
61
+ D: Optional[Tensor] = None,
62
+ C: Optional[Tensor] = None,
63
+ additional_tensors: Optional[Dict[str, Tensor]] = None,
64
+ ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
65
+ GemmWrapperBase.validate_tensor_3d(A, "A")
66
+ L, M, K = A.shape
67
+ GemmWrapperBase.validate_tensor_3d(B, "B")
68
+ _, N, _ = B.shape
69
+ assert B.dtype == A.dtype, "A and B must have the same dtype"
70
+ GemmWrapperBase.validate_shape(B, (L, N, K), "B")
71
+ tensors = {
72
+ "A": GemmTensorInfo(A),
73
+ "B": GemmTensorInfo(B),
74
+ "D": GemmTensorInfo(D),
75
+ "C": GemmTensorInfo(C),
76
+ }
77
+ if D is not None:
78
+ GemmWrapperBase.validate_tensor_3d(D, "D")
79
+ GemmWrapperBase.validate_shape(D, (L, M, N), "D")
80
+ if C is not None:
81
+ GemmWrapperBase.validate_tensor_3d(C, "C")
82
+ GemmWrapperBase.validate_shape(C, (L, M, N), "C")
83
+ if additional_tensors:
84
+ for name, tensor in additional_tensors.items():
85
+ if tensor is not None:
86
+ GemmWrapperBase.validate_tensor_3d(tensor, name)
87
+ GemmWrapperBase.validate_shape(tensor, (L, M, N), name)
88
+ tensors[name] = GemmTensorInfo(tensor)
89
+
90
+ return L, M, K, N, tensors
91
+
92
+ @staticmethod
93
+ def permute_tensors(tensors: Dict[str, GemmTensorInfo]) -> None:
94
+ for info in tensors.values():
95
+ if info.tensor is not None:
96
+ info.tensor = info.tensor.permute(1, 2, 0)
97
+
98
+ @staticmethod
99
+ def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
100
+ for info in tensors.values():
101
+ if info.tensor is not None:
102
+ info.dtype = torch2cute_dtype_map[info.tensor.dtype]
103
+
104
+ @staticmethod
105
+ def determine_major_orders(
106
+ tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
107
+ ) -> None:
108
+ for name, dims in major_configs.items():
109
+ if name in tensors and tensors[name].tensor is not None:
110
+ tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
111
+
112
+ @staticmethod
113
+ def create_cute_tensors(
114
+ tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
115
+ ) -> None:
116
+ for name, info in tensors.items():
117
+ if info.tensor is not None and name in major_configs:
118
+ info.cute_tensor = GemmWrapperBase.create_cute_tensor(
119
+ info.tensor, info.major, major_configs[name]
120
+ )
121
+
122
+ @staticmethod
123
+ def create_scheduler_args(
124
+ max_active_clusters: int, tile_count_semaphore: Optional[Tensor] = None
125
+ ) -> TileSchedulerOptions:
126
+ return TileSchedulerOptions(
127
+ Int32(max_active_clusters),
128
+ tile_count_semaphore=make_ptr(
129
+ Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
130
+ )
131
+ if tile_count_semaphore is not None
132
+ else None,
133
+ )
134
+
135
+ @staticmethod
136
+ def get_compile_key(
137
+ tensors: Dict[str, GemmTensorInfo],
138
+ activation: Optional[str],
139
+ tile_shape_mn: Tuple[int, int],
140
+ cluster_shape_mnk: Tuple[int, int, int],
141
+ pingpong: bool,
142
+ persistent: bool,
143
+ has_semaphore: bool,
144
+ *args,
145
+ key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
146
+ ) -> Tuple:
147
+ key_parts = []
148
+ for name in key_tensor_names:
149
+ if name in tensors:
150
+ key_parts.append(tensors[name].dtype)
151
+ key_parts.append(activation)
152
+ key_parts.extend([tile_shape_mn, cluster_shape_mnk])
153
+ for name in key_tensor_names:
154
+ if name in tensors:
155
+ key_parts.append(tensors[name].major)
156
+ key_parts.extend([pingpong, persistent, has_semaphore])
157
+ key_parts.extend(args)
158
+ return tuple(key_parts)
quack/layernorm.py CHANGED
@@ -10,7 +10,9 @@ import cutlass
10
10
  import cutlass.cute as cute
11
11
  from cutlass.cute.runtime import from_dlpack
12
12
  import quack.utils as utils
13
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
+ from quack.reduce import row_reduce
14
+ from quack.reduction_base import ReductionBase
15
+ from quack.cute_dsl_utils import torch2cute_dtype_map
14
16
 
15
17
 
16
18
  class LayerNorm(ReductionBase):
@@ -190,7 +192,7 @@ class LayerNorm(ReductionBase):
190
192
  cute.autovec_copy(tXsX, tXrX)
191
193
  x = tXrX.load().to(cute.Float32)
192
194
  threads_per_row = tv_layout.shape[0][0]
193
- sum_x = utils.row_reduce(
195
+ sum_x = row_reduce(
194
196
  x,
195
197
  cute.ReductionOp.ADD,
196
198
  threads_per_row,
@@ -207,7 +209,7 @@ class LayerNorm(ReductionBase):
207
209
  cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
208
210
  x = tXrX.load().to(cute.Float32)
209
211
 
210
- sum_sq_x_sub_mean = utils.row_reduce(
212
+ sum_sq_x_sub_mean = row_reduce(
211
213
  (x - mean) * (x - mean),
212
214
  cute.ReductionOp.ADD,
213
215
  threads_per_row,
quack/linear.py ADDED
@@ -0,0 +1,240 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.amp import custom_fwd, custom_bwd
9
+
10
+
11
+ from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
12
+
13
+
14
+ def linear_fwd_convert_type(*tensors):
15
+ autocast_dtype = torch.get_autocast_dtype("cuda")
16
+ if torch.is_autocast_enabled():
17
+ tensors = tuple(t.to(dtype=autocast_dtype) for t in tensors)
18
+ return tensors
19
+
20
+
21
+ def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
22
+ needs_input_grad, needs_weight_grad = needs_x_w_grad
23
+ if not needs_input_grad:
24
+ weight, weight_og = None, None
25
+ if not needs_weight_grad:
26
+ x = None
27
+ ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
28
+
29
+
30
+ def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
31
+ if ctx.needs_input_grad[0]:
32
+ assert weight is not None
33
+ return matmul_fn(dout, weight)
34
+ else:
35
+ return None
36
+
37
+
38
+ def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
39
+ if ctx.needs_input_grad[1]:
40
+ assert x is not None
41
+ x = x.reshape(-1, x.shape[-1])
42
+ # fuse_grad_accum is not compatible with torch.compile
43
+ if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
44
+ dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
45
+ else:
46
+ # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
47
+ matmul_inplace_fn(dout.T, x, weight_og.grad)
48
+ dweight = weight_og.grad
49
+ weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
50
+ else:
51
+ dweight = None
52
+ return dweight
53
+
54
+
55
+ class LinearFunc(torch.autograd.Function):
56
+ matmul_fwd_fn = gemm
57
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
58
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
59
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
60
+
61
+ # Use classmethod instead of staticmethod to allow inheritance
62
+ @classmethod
63
+ @custom_fwd(device_type="cuda")
64
+ def forward(cls, ctx, x, weight, fuse_grad_accum=False):
65
+ """
66
+ x: (..., in_features)
67
+ weight: (out_features, in_features)
68
+ out: (..., out_features)
69
+ """
70
+ ctx.weight_dtype = weight.dtype
71
+ ctx.fuse_grad_accum = fuse_grad_accum
72
+ weight_og = weight
73
+ x, weight = linear_fwd_convert_type(x, weight)
74
+ batch_shape = x.shape[:-1]
75
+ x = x.reshape(-1, x.shape[-1])
76
+ # out = F.linear(x, weight)
77
+ out = cls.matmul_fwd_fn(x, weight.T)
78
+ linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
79
+ return out.reshape(*batch_shape, out.shape[-1])
80
+
81
+ @classmethod
82
+ @custom_bwd(device_type="cuda")
83
+ def backward(cls, ctx, dout, *args):
84
+ """
85
+ dout: (..., out_features)
86
+ """
87
+ x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
88
+ batch_shape = dout.shape[:-1]
89
+ dout = dout.reshape(-1, dout.shape[-1])
90
+ dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
91
+ dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
92
+ dweight = linear_bwd_compute_weight_grad(
93
+ ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
94
+ )
95
+ # return extra Nones for other classes that inherit from LinearFunc
96
+ return dx, dweight, *([None] * 10)
97
+
98
+
99
+ class LinearUntunedFunc(LinearFunc):
100
+ # Passing in tuned=False to disable tuning at runtime
101
+ matmul_fwd_fn = partial(gemm, tuned=False)
102
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
103
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
104
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
105
+
106
+
107
+ def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
108
+ fn_cls = LinearFunc if tuned else LinearUntunedFunc
109
+ return fn_cls.apply(x, weight, fuse_grad_accum)
110
+
111
+
112
+ class LinearActFunc(LinearFunc):
113
+ matmul_fwd_fn = gemm_act
114
+
115
+ # Use classmethod instead of staticmethod to allow inheritance
116
+ @classmethod
117
+ @custom_fwd(device_type="cuda")
118
+ def forward(cls, ctx, x, weight, activation, store_preact=True, fuse_grad_accum=False):
119
+ """
120
+ x: (..., in_features)
121
+ weight: (out_features, in_features)
122
+ out: (..., out_features)
123
+ Return both out and post-activation, but only out is differentiable.
124
+ """
125
+ ctx.weight_dtype = weight.dtype
126
+ ctx.fuse_grad_accum = fuse_grad_accum
127
+ weight_og = weight
128
+ x, weight = linear_fwd_convert_type(x, weight)
129
+ batch_shape = x.shape[:-1]
130
+ x = x.reshape(-1, x.shape[-1])
131
+ out, postact = cls.matmul_fwd_fn(
132
+ x, weight.T, activation=activation, store_preact=store_preact
133
+ )
134
+ linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
135
+ if out is not None:
136
+ out = out.reshape(*batch_shape, out.shape[-1])
137
+ ctx.mark_non_differentiable(postact)
138
+ ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
139
+ return out, postact.reshape(*batch_shape, postact.shape[-1])
140
+
141
+
142
+ class LinearActUntunedFunc(LinearActFunc):
143
+ # Passing in tuned=False to disable tuning at runtime
144
+ matmul_fwd_fn = partial(gemm_act, tuned=False)
145
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
146
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
147
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
148
+
149
+
150
+ def linear_act_func(x, weight, activation, store_preact=True, fuse_grad_accum=False, tuned=True):
151
+ fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
152
+ return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
153
+
154
+
155
+ class DActLinearFunc(LinearFunc):
156
+ matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
157
+
158
+ # Use classmethod instead of staticmethod to allow inheritance
159
+ @classmethod
160
+ @custom_fwd(device_type="cuda")
161
+ def forward(cls, ctx, preact, weight, x, activation, fuse_grad_accum=False):
162
+ """
163
+ x: (..., in_features)
164
+ weight: (out_features, in_features)
165
+ out: (..., out_features)
166
+ Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
167
+ """
168
+ ctx.weight_dtype = weight.dtype
169
+ ctx.fuse_grad_accum = fuse_grad_accum
170
+ weight_og = weight
171
+ x, weight = linear_fwd_convert_type(x, weight)
172
+ batch_shape = x.shape[:-1]
173
+ x = x.reshape(-1, x.shape[-1])
174
+ out = cls.matmul_fwd_fn(x, weight.T)
175
+ # Store preact instead of x, we will recompute x in the backward pass
176
+ linear_fwd_postprocess(
177
+ ctx, preact, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
178
+ )
179
+ ctx.activation = activation
180
+ return out.reshape(*batch_shape, out.shape[-1])
181
+
182
+ @classmethod
183
+ @custom_bwd(device_type="cuda")
184
+ def backward(cls, ctx, dout):
185
+ """
186
+ dout: (..., out_features)
187
+ """
188
+ # weight_og is None if not ctx.fuse_grad_accum
189
+ preact, weight, weight_og = ctx.saved_tensors
190
+ batch_shape = dout.shape[:-1]
191
+ dout = dout.reshape(-1, dout.shape[-1])
192
+ preact = preact.reshape(-1, preact.shape[-1])
193
+ if ctx.needs_input_grad[0]:
194
+ assert weight is not None
195
+ dpreact, x = cls.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
196
+ else:
197
+ dpreact, x = None, None
198
+ dpreact = dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
199
+ dweight = linear_bwd_compute_weight_grad(
200
+ ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
201
+ )
202
+ return dpreact, dweight, *([None] * 3)
203
+
204
+
205
+ class DActLinearUntunedFunc(DActLinearFunc):
206
+ # Passing in tuned=False to disable tuning at runtime
207
+ matmul_fwd_fn = partial(gemm, tuned=False)
208
+ matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
209
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
210
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
211
+
212
+
213
+ def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
214
+ fn_cls = DActLinearFunc if tuned else DActLinearUntunedFunc
215
+ return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
216
+
217
+
218
+ class Linear(nn.Linear):
219
+ def __init__(
220
+ self,
221
+ in_features: int,
222
+ out_features: int,
223
+ bias: bool = False,
224
+ device=None,
225
+ dtype=None,
226
+ fuse_grad_accum: bool = False,
227
+ ) -> None:
228
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
229
+ self.fuse_grad_accum = fuse_grad_accum
230
+
231
+ def forward(self, input: Tensor) -> Tensor:
232
+ if (
233
+ self.bias is None
234
+ and input.is_cuda
235
+ and self.in_features % 8 == 0
236
+ and self.out_features % 8 == 0
237
+ ):
238
+ return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
239
+ else:
240
+ return F.linear(input, self.weight, self.bias)
@@ -0,0 +1,275 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Optional, Literal
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.amp import custom_fwd, custom_bwd
9
+
10
+ from quack.cross_entropy import cross_entropy, cross_entropy_fwd_out
11
+ from quack.gemm_interface import gemm, gemm_add, gemm_add_inplace
12
+ from quack.linear import linear_fwd_convert_type
13
+
14
+
15
+ def linear_cross_entropy_func(
16
+ x: Tensor, # (..., d)
17
+ weight: Tensor, # (V, d)
18
+ bias: Optional[Tensor], # (V,) or None
19
+ target: Tensor, # (...,), int or long
20
+ ignore_index: int = -100,
21
+ reduction: Literal["none", "mean", "sum"] = "mean",
22
+ inplace_backward: bool = False,
23
+ ) -> Tensor:
24
+ y = F.linear(x, weight, bias) # (..., V)
25
+ return cross_entropy(
26
+ y, target, ignore_index=ignore_index, reduction=reduction, inplace_backward=inplace_backward
27
+ )
28
+
29
+
30
+ def linear_cross_entropy_func_ref(
31
+ x: Tensor, # (..., d)
32
+ weight: Tensor, # (V, d)
33
+ bias: Optional[Tensor], # (V,) or None
34
+ target: Tensor, # (...,), int or long
35
+ ignore_index: int = -100,
36
+ reduction: Literal["none", "mean", "sum"] = "mean",
37
+ ) -> Tensor:
38
+ y = F.linear(x, weight, bias) # (..., V)
39
+ return F.cross_entropy(y, target, ignore_index=ignore_index, reduction=reduction)
40
+
41
+
42
+ def chunked_linear_cross_entropy_fwd(
43
+ x: Tensor, # (B*L, d) where B is batch, L is seqlen
44
+ weight: Tensor, # (V, d) where V is vocab size
45
+ target: Tensor, # (B*L,)
46
+ chunk_size: int = 4096,
47
+ ignore_index: int = -100,
48
+ tuned: bool = True,
49
+ ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
50
+ """
51
+ Chunked forward pass for linear cross entropy.
52
+
53
+ Splits input along batch dimension, computes matmul and cross_entropy_fwd
54
+ for each chunk, stores dx for each chunk, and accumulates dw.
55
+
56
+ Returns:
57
+ loss: (B*L,) loss values
58
+ dx: (B*L, d) gradient w.r.t. input
59
+ dw: (V, d) gradient w.r.t. weight (accumulated across chunks except last)
60
+ last_dlogits_chunk: (chunk_len, V) gradient of last chunk's logits (for deferred dw computation)
61
+ last_x_chunk: (chunk_len, d) last chunk's input (for deferred dw computation)
62
+ """
63
+ B_L, d = x.shape
64
+ V, _ = weight.shape
65
+ device = x.device
66
+ num_chunks = (B_L + chunk_size - 1) // chunk_size
67
+ # Since we use gemm with TMA we require some alignment
68
+ assert chunk_size % 8 == 0, "chunk_size must be multiple of 8"
69
+ assert B_L % 8 == 0
70
+ # Pre-allocate outputs
71
+ loss = torch.empty(B_L, device=device, dtype=torch.float32)
72
+ logits_chunk_preallocated = torch.empty((chunk_size, V), device=device, dtype=x.dtype)
73
+ dx = torch.empty_like(x)
74
+ # Last chunk of dw will be deferred to the backward pass
75
+ dw = torch.empty_like(weight, dtype=torch.float32) if num_chunks > 1 else None
76
+ last_dlogits_chunk = None
77
+ last_x_chunk = None
78
+
79
+ # Process in chunks
80
+ for i, (x_chunk, target_chunk, loss_chunk, dx_chunk) in enumerate(
81
+ zip(*(t.split(chunk_size) for t in (x, target, loss, dx)))
82
+ ):
83
+ chunk_len = x_chunk.shape[0]
84
+ logits_chunk = logits_chunk_preallocated[:chunk_len] # (chunk_len, V)
85
+ torch.mm(x_chunk, weight.mT, out=logits_chunk)
86
+ # Compute cross entropy forward with gradients
87
+ dlogits_chunk = logits_chunk # inplace_backward
88
+ cross_entropy_fwd_out(
89
+ logits_chunk,
90
+ target_chunk,
91
+ None, # target_logit
92
+ loss=loss_chunk,
93
+ lse=None, # we don't need lse here
94
+ dx=dlogits_chunk,
95
+ ignore_index=ignore_index,
96
+ )
97
+ # Compute dx for this chunk: dlogits @ weight
98
+ torch.mm(dlogits_chunk, weight, out=dx_chunk) # (chunk_len, d)
99
+ # Compute dw for all chunks except the last
100
+ if i == num_chunks - 1:
101
+ # Last chunk: save for backward pass
102
+ last_dlogits_chunk = dlogits_chunk
103
+ last_x_chunk = x_chunk
104
+ elif i == 0:
105
+ # First chunk: dw = dlogits.T @ x_chunk
106
+ gemm(dlogits_chunk.T, x_chunk, out=dw, tuned=tuned)
107
+ else:
108
+ # Middle chunks: dw += dlogits.T @ x_chunk
109
+ gemm_add_inplace(dlogits_chunk.T, x_chunk, dw, tuned=tuned)
110
+ return loss, dx, dw, last_dlogits_chunk, last_x_chunk
111
+
112
+
113
+ class ChunkedLinearCrossEntropyFunction(torch.autograd.Function):
114
+ @staticmethod
115
+ @custom_fwd(device_type="cuda")
116
+ def forward(
117
+ ctx,
118
+ x: Tensor,
119
+ weight: Tensor,
120
+ target: Tensor,
121
+ ignore_index: int = -100,
122
+ reduction: Literal["mean", "sum"] = "mean",
123
+ chunk_size: int = 4096,
124
+ tuned: bool = True,
125
+ ):
126
+ """
127
+ Forward pass computes loss and stores dx and dw for backward.
128
+ """
129
+ ctx.weight_dtype = weight.dtype
130
+ x, weight = linear_fwd_convert_type(x, weight)
131
+ batch_shape = x.shape[:-1]
132
+ x = x.reshape(-1, x.shape[-1])
133
+ # TODO: don't need to compute bwd if neither x nor weight requires grad, or not training
134
+ loss, dx, dw, last_dlogits_chunk, last_x_chunk = chunked_linear_cross_entropy_fwd(
135
+ x, weight, target, chunk_size, ignore_index, tuned=tuned
136
+ )
137
+ loss_sum = loss.sum()
138
+ loss_scale = None if reduction == "sum" else 1.0 / (target != ignore_index).sum().float()
139
+ ctx.save_for_backward(dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale)
140
+ ctx.batch_shape = batch_shape
141
+ ctx.ignore_index = ignore_index
142
+ ctx.reduction = reduction
143
+ ctx.tuned = tuned
144
+ return loss_sum if loss_scale is None else loss_sum * loss_scale
145
+
146
+ @staticmethod
147
+ @custom_bwd(device_type="cuda")
148
+ def backward(ctx, dloss):
149
+ """
150
+ Backward pass scales pre-computed gradients by dloss and completes
151
+ the last chunk's dw computation.
152
+ dloss is a scalar.
153
+ """
154
+ dx, dw, last_dlogits_chunk, last_x_chunk, loss_scale = ctx.saved_tensors
155
+ tuned = ctx.tuned
156
+ if loss_scale is not None:
157
+ dloss = dloss * loss_scale
158
+ # TODO: the case where x or weight doesn't require grad
159
+ dx.mul_(dloss)
160
+ dx = dx.reshape(*ctx.batch_shape, dx.shape[-1])
161
+ # Complete dw computation: dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
162
+ if dw is None:
163
+ # Only had one chunk, compute dw directly with dloss scaling
164
+ dw = gemm(
165
+ last_dlogits_chunk.T,
166
+ last_x_chunk,
167
+ out_dtype=ctx.weight_dtype,
168
+ alpha=dloss,
169
+ tuned=tuned,
170
+ )
171
+ else:
172
+ # Add last chunk's contribution with dloss scaling
173
+ # dw = dloss * dw + dloss * (last_dlogits_chunk.T @ last_x_chunk)
174
+ # We use alpha=dloss, beta=dloss
175
+ if ctx.weight_dtype == dw.dtype:
176
+ gemm_add_inplace(
177
+ last_dlogits_chunk.T, last_x_chunk, dw, alpha=dloss, beta=dloss, tuned=tuned
178
+ )
179
+ else:
180
+ dw = gemm_add(
181
+ last_dlogits_chunk.T,
182
+ last_x_chunk,
183
+ dw,
184
+ alpha=dloss,
185
+ beta=dloss,
186
+ out_dtype=ctx.weight_dtype,
187
+ tuned=tuned,
188
+ )
189
+ return dx, dw, None, None, None, None, None
190
+
191
+
192
+ def chunked_linear_cross_entropy(
193
+ x: Tensor,
194
+ weight: Tensor,
195
+ target: Tensor,
196
+ chunk_size: int = 4096,
197
+ ignore_index: int = -100,
198
+ reduction: Literal["mean", "sum"] = "mean",
199
+ tuned: bool = True,
200
+ ) -> Tensor:
201
+ """
202
+ Chunked linear cross entropy with automatic differentiation support.
203
+
204
+ Args:
205
+ x: Input tensor of shape (B*L, d)
206
+ weight: Weight tensor of shape (V, d)
207
+ target: Target indices of shape (B*L,)
208
+ chunk_size: Size of chunks to process
209
+ ignore_index: Index to ignore in loss computation
210
+ reduction: Type of reduction to apply
211
+ tuned: Whether to use tuned kernels
212
+
213
+ Returns:
214
+ Loss tensor with specified reduction
215
+ """
216
+ if reduction not in ["mean", "sum"]:
217
+ raise ValueError(f"Invalid reduction: {reduction}")
218
+ loss = ChunkedLinearCrossEntropyFunction.apply(
219
+ x, weight, target, ignore_index, reduction, chunk_size, tuned
220
+ )
221
+ return loss
222
+
223
+
224
+ class LinearCrossEntropy(nn.Linear):
225
+ def __init__(
226
+ self,
227
+ in_features: int,
228
+ out_features: int,
229
+ bias: bool = False,
230
+ ignore_index: int = -100,
231
+ reduction: Literal["none", "mean", "sum"] = "mean",
232
+ chunk_size: Optional[int] = None,
233
+ inplace_backward: bool = False,
234
+ tuned: bool = True,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
239
+ self.ignore_index = ignore_index
240
+ self.reduction = reduction
241
+ self.chunk_size = chunk_size
242
+ self.inplace_backward = inplace_backward
243
+ self.tuned = tuned
244
+
245
+ def forward(self, input: Tensor, target: Tensor) -> Tensor:
246
+ if (
247
+ self.bias is None
248
+ and input.is_cuda
249
+ and input.stride(-1) == 1
250
+ and self.in_features % 8 == 0
251
+ and self.out_features % 8 == 0
252
+ and input.shape[:-1].numel() % 8 == 0
253
+ and self.chunk_size is not None
254
+ and self.chunk_size % 8 == 0
255
+ and self.reduction in ["mean", "sum"]
256
+ ):
257
+ return chunked_linear_cross_entropy(
258
+ input,
259
+ self.weight,
260
+ target,
261
+ chunk_size=self.chunk_size,
262
+ ignore_index=self.ignore_index,
263
+ reduction=self.reduction,
264
+ tuned=self.tuned,
265
+ )
266
+ else:
267
+ return linear_cross_entropy_func(
268
+ input,
269
+ self.weight,
270
+ self.bias,
271
+ target,
272
+ ignore_index=self.ignore_index,
273
+ reduction=self.reduction,
274
+ inplace_backward=self.inplace_backward,
275
+ )