quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,158 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Optional, Tuple, Dict, Any
3
+ from dataclasses import dataclass
4
+
5
+ from torch import Tensor
6
+
7
+ import cutlass.cute as cute
8
+ from cutlass import Int32
9
+ from cutlass.cute.runtime import from_dlpack, make_ptr
10
+
11
+ from quack.cute_dsl_utils import torch2cute_dtype_map
12
+ from quack.dense_gemm_sm90 import TileSchedulerOptions
13
+
14
+
15
+ @dataclass
16
+ class GemmTensorInfo:
17
+ tensor: Optional[Tensor]
18
+ dtype: Optional[Any] = None
19
+ major: Optional[str] = None
20
+ cute_tensor: Optional[cute.Tensor] = None
21
+
22
+
23
+ class GemmWrapperBase:
24
+ @staticmethod
25
+ def validate_tensor_3d(tensor: Tensor, name: str) -> None:
26
+ assert tensor.dim() == 3 and tensor.is_cuda, f"{name} must be a 3D CUDA tensor"
27
+ assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}"
28
+
29
+ @staticmethod
30
+ def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None:
31
+ assert tensor.shape == expected_shape, (
32
+ f"{name} must have shape {expected_shape}, got {tensor.shape}"
33
+ )
34
+
35
+ @staticmethod
36
+ def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str:
37
+ # Tensor is already permuted to (dims[0], dims[1], dims[2])
38
+ # stride(1) == 1 means dims[1] is contiguous (innermost)
39
+ return dims[1] if tensor.stride(1) == 1 else dims[0]
40
+
41
+ @staticmethod
42
+ def create_cute_tensor(
43
+ tensor: Optional[Tensor],
44
+ major: Optional[str],
45
+ dims: Tuple[str, str, str],
46
+ assumed_align: int = 16,
47
+ ) -> Optional[cute.Tensor]:
48
+ if tensor is None:
49
+ return None
50
+ # Tensor is already permuted to (dims[0], dims[1], dims[2])
51
+ # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0
52
+ leading_dim = 1 if major == dims[1] else 0
53
+ return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic(
54
+ leading_dim=leading_dim
55
+ )
56
+
57
+ @staticmethod
58
+ def validate_and_prepare_tensors(
59
+ A: Tensor,
60
+ B: Tensor,
61
+ D: Optional[Tensor] = None,
62
+ C: Optional[Tensor] = None,
63
+ additional_tensors: Optional[Dict[str, Tensor]] = None,
64
+ ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]:
65
+ GemmWrapperBase.validate_tensor_3d(A, "A")
66
+ L, M, K = A.shape
67
+ GemmWrapperBase.validate_tensor_3d(B, "B")
68
+ _, N, _ = B.shape
69
+ assert B.dtype == A.dtype, "A and B must have the same dtype"
70
+ GemmWrapperBase.validate_shape(B, (L, N, K), "B")
71
+ tensors = {
72
+ "A": GemmTensorInfo(A),
73
+ "B": GemmTensorInfo(B),
74
+ "D": GemmTensorInfo(D),
75
+ "C": GemmTensorInfo(C),
76
+ }
77
+ if D is not None:
78
+ GemmWrapperBase.validate_tensor_3d(D, "D")
79
+ GemmWrapperBase.validate_shape(D, (L, M, N), "D")
80
+ if C is not None:
81
+ GemmWrapperBase.validate_tensor_3d(C, "C")
82
+ GemmWrapperBase.validate_shape(C, (L, M, N), "C")
83
+ if additional_tensors:
84
+ for name, tensor in additional_tensors.items():
85
+ if tensor is not None:
86
+ GemmWrapperBase.validate_tensor_3d(tensor, name)
87
+ GemmWrapperBase.validate_shape(tensor, (L, M, N), name)
88
+ tensors[name] = GemmTensorInfo(tensor)
89
+
90
+ return L, M, K, N, tensors
91
+
92
+ @staticmethod
93
+ def permute_tensors(tensors: Dict[str, GemmTensorInfo]) -> None:
94
+ for info in tensors.values():
95
+ if info.tensor is not None:
96
+ info.tensor = info.tensor.permute(1, 2, 0)
97
+
98
+ @staticmethod
99
+ def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None:
100
+ for info in tensors.values():
101
+ if info.tensor is not None:
102
+ info.dtype = torch2cute_dtype_map[info.tensor.dtype]
103
+
104
+ @staticmethod
105
+ def determine_major_orders(
106
+ tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
107
+ ) -> None:
108
+ for name, dims in major_configs.items():
109
+ if name in tensors and tensors[name].tensor is not None:
110
+ tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims)
111
+
112
+ @staticmethod
113
+ def create_cute_tensors(
114
+ tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]]
115
+ ) -> None:
116
+ for name, info in tensors.items():
117
+ if info.tensor is not None and name in major_configs:
118
+ info.cute_tensor = GemmWrapperBase.create_cute_tensor(
119
+ info.tensor, info.major, major_configs[name]
120
+ )
121
+
122
+ @staticmethod
123
+ def create_scheduler_args(
124
+ max_active_clusters: int, tile_count_semaphore: Optional[Tensor] = None
125
+ ) -> TileSchedulerOptions:
126
+ return TileSchedulerOptions(
127
+ Int32(max_active_clusters),
128
+ tile_count_semaphore=make_ptr(
129
+ Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4
130
+ )
131
+ if tile_count_semaphore is not None
132
+ else None,
133
+ )
134
+
135
+ @staticmethod
136
+ def get_compile_key(
137
+ tensors: Dict[str, GemmTensorInfo],
138
+ activation: Optional[str],
139
+ tile_shape_mn: Tuple[int, int],
140
+ cluster_shape_mnk: Tuple[int, int, int],
141
+ pingpong: bool,
142
+ persistent: bool,
143
+ has_semaphore: bool,
144
+ *args,
145
+ key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"),
146
+ ) -> Tuple:
147
+ key_parts = []
148
+ for name in key_tensor_names:
149
+ if name in tensors:
150
+ key_parts.append(tensors[name].dtype)
151
+ key_parts.append(activation)
152
+ key_parts.extend([tile_shape_mn, cluster_shape_mnk])
153
+ for name in key_tensor_names:
154
+ if name in tensors:
155
+ key_parts.append(tensors[name].major)
156
+ key_parts.extend([pingpong, persistent, has_semaphore])
157
+ key_parts.extend(args)
158
+ return tuple(key_parts)
quack/layernorm.py CHANGED
@@ -10,7 +10,9 @@ import cutlass
10
10
  import cutlass.cute as cute
11
11
  from cutlass.cute.runtime import from_dlpack
12
12
  import quack.utils as utils
13
- from quack.reduction_base import ReductionBase, torch2cute_dtype_map
13
+ from quack.reduce import row_reduce
14
+ from quack.reduction_base import ReductionBase
15
+ from quack.cute_dsl_utils import torch2cute_dtype_map
14
16
 
15
17
 
16
18
  class LayerNorm(ReductionBase):
@@ -190,7 +192,7 @@ class LayerNorm(ReductionBase):
190
192
  cute.autovec_copy(tXsX, tXrX)
191
193
  x = tXrX.load().to(cute.Float32)
192
194
  threads_per_row = tv_layout.shape[0][0]
193
- sum_x = utils.row_reduce(
195
+ sum_x = row_reduce(
194
196
  x,
195
197
  cute.ReductionOp.ADD,
196
198
  threads_per_row,
@@ -207,7 +209,7 @@ class LayerNorm(ReductionBase):
207
209
  cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
208
210
  x = tXrX.load().to(cute.Float32)
209
211
 
210
- sum_sq_x_sub_mean = utils.row_reduce(
212
+ sum_sq_x_sub_mean = row_reduce(
211
213
  (x - mean) * (x - mean),
212
214
  cute.ReductionOp.ADD,
213
215
  threads_per_row,
quack/linear.py CHANGED
@@ -1,4 +1,6 @@
1
1
  # Copyright (c) 2025, Tri Dao
2
+ from functools import partial
3
+
2
4
  import torch
3
5
  import torch.nn as nn
4
6
  import torch.nn.functional as F
@@ -6,10 +8,7 @@ from torch import Tensor
6
8
  from torch.amp import custom_fwd, custom_bwd
7
9
 
8
10
 
9
- from gemm_cublas import gemm as gemm_cb, gemm_add_ as gemm_add_cb_
10
- # from gemm_cublas.interface import gemm_tuned as gemm_cb, gemm_add_tuned_ as gemm_add_cb_
11
-
12
- from quack import gemm, gemm_lse # TODO: implement these
11
+ from quack.gemm_interface import gemm, gemm_add_inplace, gemm_act, gemm_dact
13
12
 
14
13
 
15
14
  def linear_fwd_convert_type(*tensors):
@@ -19,7 +18,8 @@ def linear_fwd_convert_type(*tensors):
19
18
  return tensors
20
19
 
21
20
 
22
- def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_input_grad, needs_weight_grad):
21
+ def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad):
22
+ needs_input_grad, needs_weight_grad = needs_x_w_grad
23
23
  if not needs_input_grad:
24
24
  weight, weight_og = None, None
25
25
  if not needs_weight_grad:
@@ -27,29 +27,24 @@ def linear_fwd_postprocess(ctx, x, weight, weight_og, needs_input_grad, needs_we
27
27
  ctx.save_for_backward(x, weight, weight_og if ctx.fuse_grad_accum else None)
28
28
 
29
29
 
30
- def linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=True, sm_carveout=0):
30
+ def linear_bwd_compute_input_grad(ctx, dout, weight, matmul_fn):
31
31
  if ctx.needs_input_grad[0]:
32
32
  assert weight is not None
33
- # return gemm(dout, weight) if use_tuned_gemm else (dout @ weight)
34
- return (
35
- gemm(dout, weight, sm_carveout=sm_carveout)
36
- if use_tuned_gemm
37
- else gemm_cb(dout, weight, sm_carveout=sm_carveout)
38
- )
33
+ return matmul_fn(dout, weight)
39
34
  else:
40
35
  return None
41
36
 
42
37
 
43
- def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, sm_carveout=0):
38
+ def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, matmul_fn, matmul_inplace_fn):
44
39
  if ctx.needs_input_grad[1]:
45
40
  assert x is not None
46
41
  x = x.reshape(-1, x.shape[-1])
47
42
  # fuse_grad_accum is not compatible with torch.compile
48
43
  if not ctx.fuse_grad_accum or weight_og.grad is None or torch.compiler.is_compiling():
49
- dweight = gemm_cb(dout.T, x, out_dtype=ctx.weight_dtype, sm_carveout=sm_carveout)
44
+ dweight = matmul_fn(dout.T, x, out_dtype=ctx.weight_dtype)
50
45
  else:
51
46
  # print("Using fuse grad accum in Linear", dout.shape, x.shape, weight_og.grad.shape)
52
- gemm_add_cb_(dout.T, x, weight_og.grad, sm_carveout=sm_carveout)
47
+ matmul_inplace_fn(dout.T, x, weight_og.grad)
53
48
  dweight = weight_og.grad
54
49
  weight_og.grad = None # So that pytorch doesn't add dweight to weight_og.grad again
55
50
  else:
@@ -58,9 +53,15 @@ def linear_bwd_compute_weight_grad(ctx, dout, x, weight_og, sm_carveout=0):
58
53
 
59
54
 
60
55
  class LinearFunc(torch.autograd.Function):
61
- @staticmethod
56
+ matmul_fwd_fn = gemm
57
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True)
58
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True)
59
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
60
+
61
+ # Use classmethod instead of staticmethod to allow inheritance
62
+ @classmethod
62
63
  @custom_fwd(device_type="cuda")
63
- def forward(ctx, x, weight, fuse_grad_accum=False):
64
+ def forward(cls, ctx, x, weight, fuse_grad_accum=False):
64
65
  """
65
66
  x: (..., in_features)
66
67
  weight: (out_features, in_features)
@@ -73,77 +74,145 @@ class LinearFunc(torch.autograd.Function):
73
74
  batch_shape = x.shape[:-1]
74
75
  x = x.reshape(-1, x.shape[-1])
75
76
  # out = F.linear(x, weight)
76
- out = gemm(x, weight.T)
77
- linear_fwd_postprocess(
78
- ctx,
79
- x,
80
- weight,
81
- weight_og,
82
- needs_input_grad=ctx.needs_input_grad[0],
83
- needs_weight_grad=ctx.needs_input_grad[1],
84
- )
77
+ out = cls.matmul_fwd_fn(x, weight.T)
78
+ linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
85
79
  return out.reshape(*batch_shape, out.shape[-1])
86
80
 
87
- @staticmethod
81
+ @classmethod
88
82
  @custom_bwd(device_type="cuda")
89
- def backward(ctx, dout):
83
+ def backward(cls, ctx, dout, *args):
90
84
  """
91
85
  dout: (..., out_features)
92
86
  """
93
87
  x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
94
88
  batch_shape = dout.shape[:-1]
95
89
  dout = dout.reshape(-1, dout.shape[-1])
96
- dx = linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=True)
90
+ dx = linear_bwd_compute_input_grad(ctx, dout, weight, cls.matmul_bwd_dx)
97
91
  dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
98
- dweight = linear_bwd_compute_weight_grad(ctx, dout, x, weight_og)
99
- return dx, dweight, None
92
+ dweight = linear_bwd_compute_weight_grad(
93
+ ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
94
+ )
95
+ # return extra Nones for other classes that inherit from LinearFunc
96
+ return dx, dweight, *([None] * 10)
97
+
98
+
99
+ class LinearUntunedFunc(LinearFunc):
100
+ # Passing in tuned=False to disable tuning at runtime
101
+ matmul_fwd_fn = partial(gemm, tuned=False)
102
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
103
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
104
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
100
105
 
101
106
 
102
- def linear_func(x, weight, fuse_grad_accum=False):
103
- return LinearFunc.apply(x, weight, fuse_grad_accum)
107
+ def linear_func(x, weight, fuse_grad_accum=False, tuned=True):
108
+ fn_cls = LinearFunc if tuned else LinearUntunedFunc
109
+ return fn_cls.apply(x, weight, fuse_grad_accum)
104
110
 
105
111
 
106
- class LinearLSEFunc(torch.autograd.Function):
107
- @staticmethod
112
+ class LinearActFunc(LinearFunc):
113
+ matmul_fwd_fn = gemm_act
114
+
115
+ # Use classmethod instead of staticmethod to allow inheritance
116
+ @classmethod
108
117
  @custom_fwd(device_type="cuda")
109
- def forward(ctx, x, weight, fuse_grad_accum=False):
118
+ def forward(cls, ctx, x, weight, activation, store_preact=True, fuse_grad_accum=False):
110
119
  """
111
120
  x: (..., in_features)
112
121
  weight: (out_features, in_features)
113
122
  out: (..., out_features)
123
+ Return both out and post-activation, but only out is differentiable.
114
124
  """
115
- needs_weight_grad = weight.requires_grad
116
- needs_input_grad = x.requires_grad
117
125
  ctx.weight_dtype = weight.dtype
118
126
  ctx.fuse_grad_accum = fuse_grad_accum
119
127
  weight_og = weight
120
128
  x, weight = linear_fwd_convert_type(x, weight)
121
129
  batch_shape = x.shape[:-1]
122
130
  x = x.reshape(-1, x.shape[-1])
123
- out, lse = gemm_lse(x, weight.T)
124
- lse = lse.reshape(*batch_shape)
125
- linear_fwd_postprocess(ctx, x, weight, weight_og, needs_weight_grad, needs_input_grad)
126
- ctx.mark_non_differentiable(lse)
127
- return out.reshape(*batch_shape, out.shape[-1]), lse
131
+ out, postact = cls.matmul_fwd_fn(
132
+ x, weight.T, activation=activation, store_preact=store_preact
133
+ )
134
+ linear_fwd_postprocess(ctx, x, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2])
135
+ if out is not None:
136
+ out = out.reshape(*batch_shape, out.shape[-1])
137
+ ctx.mark_non_differentiable(postact)
138
+ ctx.set_materialize_grads(False) # We don't want to materialize grads for postact
139
+ return out, postact.reshape(*batch_shape, postact.shape[-1])
128
140
 
129
- @staticmethod
141
+
142
+ class LinearActUntunedFunc(LinearActFunc):
143
+ # Passing in tuned=False to disable tuning at runtime
144
+ matmul_fwd_fn = partial(gemm_act, tuned=False)
145
+ matmul_bwd_dx = partial(gemm, dynamic_scheduler=True, tuned=False)
146
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
147
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
148
+
149
+
150
+ def linear_act_func(x, weight, activation, store_preact=True, fuse_grad_accum=False, tuned=True):
151
+ fn_cls = LinearActFunc if tuned else LinearActUntunedFunc
152
+ return fn_cls.apply(x, weight, activation, store_preact, fuse_grad_accum)
153
+
154
+
155
+ class DActLinearFunc(LinearFunc):
156
+ matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True)
157
+
158
+ # Use classmethod instead of staticmethod to allow inheritance
159
+ @classmethod
160
+ @custom_fwd(device_type="cuda")
161
+ def forward(cls, ctx, preact, weight, x, activation, fuse_grad_accum=False):
162
+ """
163
+ x: (..., in_features)
164
+ weight: (out_features, in_features)
165
+ out: (..., out_features)
166
+ Takes in an extra preact argument which is the pre-activation, to be used in the backward pass.
167
+ """
168
+ ctx.weight_dtype = weight.dtype
169
+ ctx.fuse_grad_accum = fuse_grad_accum
170
+ weight_og = weight
171
+ x, weight = linear_fwd_convert_type(x, weight)
172
+ batch_shape = x.shape[:-1]
173
+ x = x.reshape(-1, x.shape[-1])
174
+ out = cls.matmul_fwd_fn(x, weight.T)
175
+ # Store preact instead of x, we will recompute x in the backward pass
176
+ linear_fwd_postprocess(
177
+ ctx, preact, weight, weight_og, needs_x_w_grad=ctx.needs_input_grad[:2]
178
+ )
179
+ ctx.activation = activation
180
+ return out.reshape(*batch_shape, out.shape[-1])
181
+
182
+ @classmethod
130
183
  @custom_bwd(device_type="cuda")
131
- def backward(ctx, dout, dlse_ignored):
184
+ def backward(cls, ctx, dout):
132
185
  """
133
186
  dout: (..., out_features)
134
187
  """
135
- x, weight, weight_og = ctx.saved_tensors # weight_og is None if not ctx.fuse_grad_accum
188
+ # weight_og is None if not ctx.fuse_grad_accum
189
+ preact, weight, weight_og = ctx.saved_tensors
136
190
  batch_shape = dout.shape[:-1]
137
191
  dout = dout.reshape(-1, dout.shape[-1])
138
- # cuBLAS seems faster for this so we just use it instead of cutlass gemm
139
- dx = linear_bwd_compute_input_grad(ctx, dout, weight, use_tuned_gemm=False)
140
- dx = dx.reshape(*batch_shape, dx.shape[-1]) if dx is not None else None
141
- dweight = linear_bwd_compute_weight_grad(ctx, dout, x, weight_og)
142
- return dx, dweight, None
192
+ preact = preact.reshape(-1, preact.shape[-1])
193
+ if ctx.needs_input_grad[0]:
194
+ assert weight is not None
195
+ dpreact, x = cls.matmul_bwd_dx(dout, weight, preact, activation=ctx.activation)
196
+ else:
197
+ dpreact, x = None, None
198
+ dpreact = dpreact.reshape(*batch_shape, dpreact.shape[-1]) if dpreact is not None else None
199
+ dweight = linear_bwd_compute_weight_grad(
200
+ ctx, dout, x, weight_og, cls.matmul_bwd_dw, cls.matmul_bwd_dw_inplace
201
+ )
202
+ return dpreact, dweight, *([None] * 3)
203
+
143
204
 
205
+ class DActLinearUntunedFunc(DActLinearFunc):
206
+ # Passing in tuned=False to disable tuning at runtime
207
+ matmul_fwd_fn = partial(gemm, tuned=False)
208
+ matmul_bwd_dx = partial(gemm_dact, dynamic_scheduler=True, tuned=False)
209
+ matmul_bwd_dw = partial(gemm, dynamic_scheduler=True, tuned=False)
210
+ matmul_bwd_dw_inplace = partial(gemm_add_inplace, dynamic_scheduler=True)
144
211
 
145
- def linear_lse_func(x, weight, fuse_grad_accum=False):
146
- return LinearLSEFunc.apply(x, weight, fuse_grad_accum)
212
+
213
+ def act_linear_func(preact, weight, x, activation, fuse_grad_accum=False, tuned=True):
214
+ fn_cls = DActLinearFunc if tuned else DActLinearUntunedFunc
215
+ return fn_cls.apply(preact, weight, x, activation, fuse_grad_accum)
147
216
 
148
217
 
149
218
  class Linear(nn.Linear):
@@ -160,17 +229,12 @@ class Linear(nn.Linear):
160
229
  self.fuse_grad_accum = fuse_grad_accum
161
230
 
162
231
  def forward(self, input: Tensor) -> Tensor:
163
- if self.bias is None and input.is_cuda:
232
+ if (
233
+ self.bias is None
234
+ and input.is_cuda
235
+ and self.in_features % 8 == 0
236
+ and self.out_features % 8 == 0
237
+ ):
164
238
  return linear_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
165
239
  else:
166
240
  return F.linear(input, self.weight, self.bias)
167
-
168
-
169
- class LinearLSE(Linear):
170
- def forward(self, input: Tensor) -> Tensor:
171
- if self.bias is None and input.is_cuda:
172
- return linear_lse_func(input, self.weight, fuse_grad_accum=self.fuse_grad_accum)
173
- else:
174
- out = F.linear(input, self.weight, self.bias)
175
- lse = torch.logsumexp(out, dim=-1)
176
- return out, lse