quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/cross_entropy.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
from typing import Optional, Type
|
|
4
|
+
from typing import Optional, Type, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
5
8
|
|
|
6
9
|
import cuda.bindings.driver as cuda
|
|
7
10
|
|
|
8
11
|
import cutlass
|
|
9
12
|
import cutlass.cute as cute
|
|
13
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
14
|
+
from cutlass.cute.runtime import from_dlpack
|
|
10
15
|
|
|
11
16
|
import quack.utils as utils
|
|
12
|
-
import
|
|
13
|
-
from
|
|
14
|
-
from quack.
|
|
17
|
+
from quack.reduce import row_reduce, online_softmax_reduce
|
|
18
|
+
from quack.reduction_base import ReductionBase
|
|
19
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
class CrossEntropy(ReductionBase):
|
|
@@ -21,7 +26,7 @@ class CrossEntropy(ReductionBase):
|
|
|
21
26
|
dtype,
|
|
22
27
|
N,
|
|
23
28
|
stage=2 if not online_softmax else 1,
|
|
24
|
-
reduction_dtype=
|
|
29
|
+
reduction_dtype=Float32 if not online_softmax else cutlass.Int64,
|
|
25
30
|
)
|
|
26
31
|
self.online_softmax = online_softmax
|
|
27
32
|
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
@@ -40,7 +45,7 @@ class CrossEntropy(ReductionBase):
|
|
|
40
45
|
|
|
41
46
|
def _set_cluster_n(self):
|
|
42
47
|
N = self.N
|
|
43
|
-
if
|
|
48
|
+
if const_expr(self.dtype.width == 16):
|
|
44
49
|
cluster_n = (
|
|
45
50
|
1
|
|
46
51
|
if N <= 16 * 1024
|
|
@@ -65,21 +70,30 @@ class CrossEntropy(ReductionBase):
|
|
|
65
70
|
@cute.jit
|
|
66
71
|
def __call__(
|
|
67
72
|
self,
|
|
68
|
-
mX: cute.Tensor,
|
|
69
|
-
mTarget: cute.Tensor,
|
|
70
|
-
|
|
71
|
-
|
|
73
|
+
mX: cute.Tensor, # (M, N)
|
|
74
|
+
mTarget: cute.Tensor, # (M,)
|
|
75
|
+
mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX
|
|
76
|
+
mLoss: cute.Tensor, # (M,)
|
|
77
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
78
|
+
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
|
79
|
+
ignore_index: Int32, # Index to ignore in loss computation
|
|
72
80
|
stream: cuda.CUstream,
|
|
73
81
|
):
|
|
74
82
|
assert mX.element_type == self.dtype
|
|
83
|
+
if const_expr(mTargetLogit is None):
|
|
84
|
+
mTargetLogit = mX
|
|
75
85
|
self._set_cluster_n()
|
|
76
|
-
|
|
86
|
+
# e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
|
|
87
|
+
num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
|
|
88
|
+
tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
|
|
77
89
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
78
90
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
79
|
-
self.kernel(
|
|
91
|
+
self.kernel(
|
|
92
|
+
mX, mTarget, mTargetLogit, mLoss, mLSE, mdX, ignore_index, tv_layout, tiler_mn
|
|
93
|
+
).launch(
|
|
80
94
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
95
|
block=[num_threads, 1, 1],
|
|
82
|
-
cluster=([1, self.cluster_n, 1] if
|
|
96
|
+
cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
|
|
83
97
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
98
|
stream=stream,
|
|
85
99
|
)
|
|
@@ -89,17 +103,20 @@ class CrossEntropy(ReductionBase):
|
|
|
89
103
|
self,
|
|
90
104
|
mX: cute.Tensor, # (M, N)
|
|
91
105
|
mTarget: cute.Tensor, # (M,)
|
|
106
|
+
mTargetLogit: cute.Tensor, # (M, K) or (M,)
|
|
92
107
|
mLoss: cute.Tensor, # (M,)
|
|
93
108
|
mLSE: Optional[cute.Tensor], # (M,)
|
|
109
|
+
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
|
110
|
+
ignore_index: Int32, # Index to ignore in loss computation
|
|
94
111
|
tv_layout: cute.Layout,
|
|
95
112
|
tiler_mn: cute.Shape,
|
|
96
113
|
):
|
|
97
114
|
tidx, _, _ = cute.arch.thread_idx()
|
|
98
115
|
bidx, _, _ = cute.arch.block_idx()
|
|
99
|
-
if
|
|
116
|
+
if const_expr(self.cluster_n > 1):
|
|
100
117
|
cluster_y = cute.arch.block_idx()[1]
|
|
101
118
|
else:
|
|
102
|
-
cluster_y =
|
|
119
|
+
cluster_y = const_expr(0)
|
|
103
120
|
|
|
104
121
|
shape: cute.Shape = mX.shape
|
|
105
122
|
idX = cute.make_identity_tensor(shape)
|
|
@@ -118,12 +135,14 @@ class CrossEntropy(ReductionBase):
|
|
|
118
135
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
119
136
|
|
|
120
137
|
# declare the atoms which will be used later for memory copy
|
|
138
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
139
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
121
140
|
copy_atom_load_X = cute.make_copy_atom(
|
|
122
|
-
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=
|
|
141
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
123
142
|
)
|
|
124
143
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
125
144
|
|
|
126
|
-
####
|
|
145
|
+
#### Partition to get thread view
|
|
127
146
|
tXgX = thr_copy_X.partition_S(gX)
|
|
128
147
|
tXsX = thr_copy_X.partition_D(sX)
|
|
129
148
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
@@ -133,14 +152,14 @@ class CrossEntropy(ReductionBase):
|
|
|
133
152
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
134
153
|
|
|
135
154
|
row = tXcX[0][0]
|
|
136
|
-
target =
|
|
137
|
-
if row < shape[0]
|
|
138
|
-
target =
|
|
155
|
+
target = Int32.zero
|
|
156
|
+
if row < shape[0]:
|
|
157
|
+
target = Int32(mTarget[row])
|
|
139
158
|
|
|
140
|
-
is_even_N =
|
|
159
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
141
160
|
tXpX = (
|
|
142
161
|
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
143
|
-
if
|
|
162
|
+
if const_expr(not is_even_N)
|
|
144
163
|
else None
|
|
145
164
|
)
|
|
146
165
|
if row < shape[0]:
|
|
@@ -148,58 +167,62 @@ class CrossEntropy(ReductionBase):
|
|
|
148
167
|
cute.arch.cp_async_commit_group()
|
|
149
168
|
cute.arch.cp_async_wait_group(0)
|
|
150
169
|
# Fill OOB values with -inf
|
|
151
|
-
if
|
|
170
|
+
if const_expr(not is_even_N):
|
|
152
171
|
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
153
172
|
cute.autovec_copy(tXsX, tXrX)
|
|
154
|
-
x = tXrX.load().to(
|
|
155
|
-
|
|
156
|
-
target_logit =
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
173
|
+
x = tXrX.load().to(Float32)
|
|
174
|
+
|
|
175
|
+
target_logit = Float32.zero
|
|
176
|
+
should_ignore = Boolean(target == ignore_index)
|
|
177
|
+
if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
|
|
178
|
+
# Only load target logit if not ignoring this index
|
|
179
|
+
if const_expr(cute.rank(mTargetLogit.shape) == 2):
|
|
180
|
+
# Use Int64 for indexing to deal with large tensors
|
|
181
|
+
mTargetLogit_off = utils.domain_offset_i64((row, 0), mTargetLogit)
|
|
182
|
+
target_logit = Float32(mTargetLogit_off[0, target])
|
|
183
|
+
else:
|
|
184
|
+
assert cute.rank(mTargetLogit.shape) == 1
|
|
185
|
+
target_logit = Float32(mTargetLogit[row])
|
|
161
186
|
|
|
162
187
|
threads_per_row = tv_layout.shape[0][0]
|
|
163
|
-
if
|
|
164
|
-
max_x =
|
|
188
|
+
if const_expr(not self.online_softmax):
|
|
189
|
+
max_x = row_reduce(
|
|
165
190
|
x,
|
|
166
191
|
cute.ReductionOp.MAX,
|
|
167
192
|
threads_per_row,
|
|
168
193
|
reduction_buffer[None, None, 0],
|
|
169
|
-
mbar_ptr + 0 if
|
|
170
|
-
init_val=-
|
|
171
|
-
hook_fn=(
|
|
172
|
-
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
173
|
-
),
|
|
194
|
+
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
|
|
195
|
+
init_val=-Float32.inf,
|
|
196
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
174
197
|
)
|
|
175
|
-
if
|
|
198
|
+
if const_expr(self.reload_from == "smem"):
|
|
176
199
|
cute.autovec_copy(tXsX, tXrX)
|
|
177
|
-
x = tXrX.load().to(
|
|
200
|
+
x = tXrX.load().to(Float32)
|
|
178
201
|
log2_e = math.log2(math.e)
|
|
179
202
|
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
180
203
|
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
181
204
|
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
182
205
|
# This would use ffma instead of fadd then fmul
|
|
183
206
|
exp_x = utils.exp2f(x * log2_e - (max_x * log2_e))
|
|
184
|
-
denom =
|
|
207
|
+
denom = row_reduce(
|
|
185
208
|
exp_x,
|
|
186
209
|
cute.ReductionOp.ADD,
|
|
187
210
|
threads_per_row,
|
|
188
211
|
reduction_buffer[None, None, 1],
|
|
189
|
-
mbar_ptr + 1 if
|
|
212
|
+
mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
|
|
190
213
|
init_val=0.0,
|
|
191
214
|
)
|
|
192
215
|
else:
|
|
193
|
-
max_x, denom,
|
|
216
|
+
max_x, denom, exp_x = online_softmax_reduce(
|
|
194
217
|
x,
|
|
195
218
|
threads_per_row,
|
|
196
219
|
reduction_buffer[None, None, 0],
|
|
197
220
|
mbar_ptr,
|
|
198
|
-
hook_fn=(
|
|
199
|
-
|
|
200
|
-
),
|
|
221
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
222
|
+
return_exp_x=const_expr(mdX is not None),
|
|
201
223
|
)
|
|
202
224
|
|
|
225
|
+
# Write loss and lse to gmem
|
|
203
226
|
if (
|
|
204
227
|
tXcX[0][1] == 0
|
|
205
228
|
and row < shape[0]
|
|
@@ -207,40 +230,89 @@ class CrossEntropy(ReductionBase):
|
|
|
207
230
|
):
|
|
208
231
|
ln_2 = math.log(2.0)
|
|
209
232
|
lse = max_x + utils.log2f(denom) * ln_2
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
233
|
+
# Set loss to 0 if this index should be ignored, otherwise compute normally
|
|
234
|
+
loss_val = (lse - target_logit) if not should_ignore else Float32.zero
|
|
235
|
+
mLoss[row] = mLoss.element_type(loss_val)
|
|
236
|
+
if const_expr(mLSE is not None):
|
|
213
237
|
mLSE[row] = lse
|
|
214
238
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
239
|
+
# Compute gradient if mdX is provided
|
|
240
|
+
if const_expr(mdX is not None):
|
|
241
|
+
# Compute probabilities: exp(x) / sum(exp(x))
|
|
242
|
+
# If ignored, gradient should be zero
|
|
243
|
+
denom_inv = (
|
|
244
|
+
1.0 / denom
|
|
245
|
+
if not (denom == 0.0 or denom != denom or should_ignore)
|
|
246
|
+
else Float32.zero
|
|
247
|
+
)
|
|
248
|
+
probs = exp_x * denom_inv
|
|
249
|
+
mdX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mdX)
|
|
250
|
+
gdX = cute.local_tile(mdX_off, tiler_mn, (0, cluster_y))
|
|
251
|
+
# Setup copy atom for storing gradient
|
|
252
|
+
copy_atom_store = cute.make_copy_atom(
|
|
253
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
254
|
+
)
|
|
255
|
+
thr_copy_dX = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
|
|
256
|
+
tXgdX = thr_copy_dX.partition_D(gdX)
|
|
257
|
+
tXrdX = cute.make_fragment_like(tXgdX)
|
|
258
|
+
tXcFull = thr_copy_X.partition_S(cX)
|
|
259
|
+
# Compute gradient: probs for all classes, (probs - 1) for target class
|
|
260
|
+
# If ignored, gradient is already zero
|
|
261
|
+
tXrdX_f32 = cute.make_fragment_like(tXrX, Float32)
|
|
262
|
+
tXrdX_f32.store(probs)
|
|
263
|
+
if not should_ignore:
|
|
264
|
+
for i in cutlass.range(cute.size(tXrX), unroll_full=True):
|
|
265
|
+
tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
|
|
266
|
+
tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
|
|
267
|
+
tXpdX = (
|
|
268
|
+
utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
|
|
269
|
+
if not is_even_N
|
|
270
|
+
else None
|
|
271
|
+
)
|
|
272
|
+
if row < shape[0]:
|
|
273
|
+
cute.copy(copy_atom_store, tXrdX, tXgdX, pred=tXpdX)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
|
|
277
|
+
def cross_entropy_fwd_out(
|
|
278
|
+
x: Tensor,
|
|
279
|
+
target: Tensor,
|
|
280
|
+
target_logit: Optional[Tensor],
|
|
281
|
+
loss: Tensor,
|
|
282
|
+
lse: Optional[Tensor],
|
|
283
|
+
dx: Optional[Tensor],
|
|
284
|
+
ignore_index: int = -100,
|
|
285
|
+
) -> None:
|
|
221
286
|
"""Cross entropy forward pass.
|
|
222
287
|
|
|
223
288
|
Args:
|
|
224
289
|
x: Input logits tensor of shape (M, N)
|
|
225
290
|
target: Target class indices tensor of shape (M,)
|
|
291
|
+
target_logit: (M, K) or (M,).
|
|
292
|
+
If provided, the target logit will be read from this tensor instead of x.
|
|
293
|
+
loss: Output loss tensor of shape (M,)
|
|
294
|
+
lse: Optional output log-sum-exp tensor of shape (M,)
|
|
295
|
+
dx: Optional output gradient tensor of shape (M, N)
|
|
296
|
+
ignore_index: Index to ignore in loss computation
|
|
226
297
|
|
|
227
298
|
Returns:
|
|
228
|
-
|
|
299
|
+
None (mutates loss, lse, and optionally dx in-place)
|
|
229
300
|
"""
|
|
230
301
|
assert x.dim() == 2, "Input must be 2D"
|
|
231
302
|
assert target.dim() == 1, "Target must be 1D"
|
|
232
303
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
233
304
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
234
|
-
assert x.dtype in [
|
|
235
|
-
torch.float16,
|
|
236
|
-
torch.bfloat16,
|
|
237
|
-
torch.float32,
|
|
238
|
-
], "Unsupported input dtype"
|
|
305
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
239
306
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
307
|
+
if target_logit is not None:
|
|
308
|
+
assert target_logit.shape[0] == x.shape[0]
|
|
309
|
+
assert target_logit.is_cuda, "Target logits must be on CUDA device"
|
|
310
|
+
assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
311
|
+
if dx is not None:
|
|
312
|
+
assert dx.shape == x.shape, "dx must have same shape as x"
|
|
313
|
+
assert dx.is_cuda, "dx must be on CUDA device"
|
|
314
|
+
assert dx.dtype == x.dtype, "dx must have same dtype as x"
|
|
315
|
+
N = x.size(1)
|
|
244
316
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
245
317
|
convert_from_dlpack = lambda tensor: (
|
|
246
318
|
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -248,28 +320,86 @@ def _cross_entropy(
|
|
|
248
320
|
)
|
|
249
321
|
)
|
|
250
322
|
x_tensor = convert_from_dlpack(x)
|
|
251
|
-
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).
|
|
323
|
+
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_layout_dynamic()
|
|
252
324
|
lse_tensor = (
|
|
253
|
-
from_dlpack(lse.detach(), assumed_align=4).
|
|
325
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
|
|
254
326
|
if lse is not None
|
|
255
327
|
else None
|
|
256
328
|
)
|
|
257
|
-
target_tensor = from_dlpack(target.detach(), assumed_align=8).
|
|
329
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
|
|
330
|
+
target_logit_tensor = (
|
|
331
|
+
from_dlpack(target_logit.detach(), assumed_align=4).mark_layout_dynamic(
|
|
332
|
+
leading_dim=target_logit.ndim - 1
|
|
333
|
+
)
|
|
334
|
+
if target_logit is not None
|
|
335
|
+
else None
|
|
336
|
+
)
|
|
337
|
+
dx_tensor = convert_from_dlpack(dx) if dx is not None else None
|
|
258
338
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
259
339
|
|
|
260
|
-
compile_key = (
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
340
|
+
compile_key = (
|
|
341
|
+
dtype,
|
|
342
|
+
N,
|
|
343
|
+
target_logit.dtype if target_logit is not None else None,
|
|
344
|
+
lse.dtype if lse is not None else None,
|
|
345
|
+
dx is not None,
|
|
346
|
+
loss.stride(),
|
|
347
|
+
lse.stride() if lse is not None else None,
|
|
348
|
+
target.stride(),
|
|
349
|
+
target_logit.stride(-1) if target_logit is not None else None,
|
|
350
|
+
)
|
|
351
|
+
if compile_key not in cross_entropy_fwd_out.compile_cache:
|
|
352
|
+
# If there's dx, it's faster to not use online softmax since we want the exp(x - max)
|
|
353
|
+
cross_entropy_op = CrossEntropy(dtype, N, online_softmax=dx is None)
|
|
354
|
+
cross_entropy_fwd_out.compile_cache[compile_key] = cute.compile(
|
|
355
|
+
cross_entropy_op,
|
|
356
|
+
x_tensor,
|
|
357
|
+
target_tensor,
|
|
358
|
+
target_logit_tensor,
|
|
359
|
+
loss_tensor,
|
|
360
|
+
lse_tensor,
|
|
361
|
+
dx_tensor,
|
|
362
|
+
Int32(ignore_index),
|
|
363
|
+
stream,
|
|
265
364
|
)
|
|
266
|
-
|
|
267
|
-
x_tensor,
|
|
365
|
+
cross_entropy_fwd_out.compile_cache[compile_key](
|
|
366
|
+
x_tensor,
|
|
367
|
+
target_tensor,
|
|
368
|
+
target_logit_tensor,
|
|
369
|
+
loss_tensor,
|
|
370
|
+
lse_tensor,
|
|
371
|
+
dx_tensor,
|
|
372
|
+
Int32(ignore_index),
|
|
373
|
+
stream,
|
|
268
374
|
)
|
|
269
|
-
return loss if not return_lse else (loss, lse)
|
|
270
375
|
|
|
271
376
|
|
|
272
|
-
|
|
377
|
+
cross_entropy_fwd_out.compile_cache = {}
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def cross_entropy_fwd(
|
|
381
|
+
x: torch.Tensor,
|
|
382
|
+
target: torch.Tensor,
|
|
383
|
+
target_logit: Optional[torch.Tensor] = None,
|
|
384
|
+
ignore_index: int = -100,
|
|
385
|
+
return_lse: bool = False,
|
|
386
|
+
return_dx: bool = False,
|
|
387
|
+
inplace_backward: bool = False,
|
|
388
|
+
) -> torch.Tensor | tuple[torch.Tensor]:
|
|
389
|
+
M = x.size(0)
|
|
390
|
+
device = x.device
|
|
391
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
392
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
393
|
+
dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None
|
|
394
|
+
cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index)
|
|
395
|
+
if return_lse and return_dx:
|
|
396
|
+
return loss, lse, dx
|
|
397
|
+
elif return_lse:
|
|
398
|
+
return loss, lse
|
|
399
|
+
elif return_dx:
|
|
400
|
+
return loss, dx
|
|
401
|
+
else:
|
|
402
|
+
return loss
|
|
273
403
|
|
|
274
404
|
|
|
275
405
|
class CrossEntropyBackward:
|
|
@@ -279,7 +409,7 @@ class CrossEntropyBackward:
|
|
|
279
409
|
self.vecsize = 128 // dtype.width
|
|
280
410
|
|
|
281
411
|
def _calculate_threads_per_row(self):
|
|
282
|
-
N = self.N
|
|
412
|
+
N = min(self.N, 16384) # We split by blocks of 16k
|
|
283
413
|
return (
|
|
284
414
|
8
|
|
285
415
|
if N <= 64
|
|
@@ -290,13 +420,14 @@ class CrossEntropyBackward:
|
|
|
290
420
|
)
|
|
291
421
|
)
|
|
292
422
|
|
|
293
|
-
def _get_tv_layout(self):
|
|
294
|
-
|
|
295
|
-
vecsize
|
|
423
|
+
def _get_tv_layout(self, num_copy_bits=128):
|
|
424
|
+
vecsize = num_copy_bits // self.dtype.width
|
|
425
|
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
426
|
+
N = min(self.N, 16384)
|
|
296
427
|
num_threads = 128 if N <= 16384 else 256
|
|
297
428
|
threads_per_row = self._calculate_threads_per_row()
|
|
298
429
|
cols_per_block = num_threads // threads_per_row
|
|
299
|
-
num_blocks_N = cute.ceil_div(
|
|
430
|
+
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
|
|
300
431
|
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
301
432
|
tv_layout = cute.make_layout(
|
|
302
433
|
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
@@ -315,40 +446,27 @@ class CrossEntropyBackward:
|
|
|
315
446
|
mDLoss: cute.Tensor,
|
|
316
447
|
mdX: cute.Tensor,
|
|
317
448
|
mLSE: cute.Tensor,
|
|
449
|
+
ignore_index: Int32, # Index to ignore in gradient computation
|
|
318
450
|
stream: cuda.CUstream,
|
|
319
451
|
):
|
|
320
452
|
assert mX.element_type == self.dtype
|
|
321
453
|
assert mdX.element_type == self.dtype
|
|
322
|
-
|
|
323
|
-
|
|
454
|
+
# e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
|
|
455
|
+
num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
|
|
456
|
+
tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
|
|
324
457
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
325
|
-
|
|
326
|
-
mDLoss =
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
333
|
-
)
|
|
334
|
-
mLSE = cute.make_tensor(
|
|
335
|
-
mLSE.iterator,
|
|
336
|
-
cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
337
|
-
)
|
|
338
|
-
|
|
458
|
+
# (M,) -> (M, N) with stride 0 in the N dimension
|
|
459
|
+
mDLoss, mTarget, mLSE = [
|
|
460
|
+
cute.make_tensor(
|
|
461
|
+
X.iterator, cute.append(X.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
462
|
+
)
|
|
463
|
+
for X in (mDLoss, mTarget, mLSE)
|
|
464
|
+
]
|
|
339
465
|
smem_size = cute.size_in_bytes(
|
|
340
466
|
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
341
467
|
)
|
|
342
|
-
|
|
343
468
|
self.kernel(
|
|
344
|
-
mX,
|
|
345
|
-
mTarget,
|
|
346
|
-
mDLoss,
|
|
347
|
-
mdX,
|
|
348
|
-
mLSE,
|
|
349
|
-
mX.shape,
|
|
350
|
-
tv_layout,
|
|
351
|
-
tiler_mn,
|
|
469
|
+
mX, mTarget, mDLoss, mdX, mLSE, ignore_index, mX.shape, tv_layout, tiler_mn
|
|
352
470
|
).launch(
|
|
353
471
|
grid=[
|
|
354
472
|
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
@@ -368,6 +486,7 @@ class CrossEntropyBackward:
|
|
|
368
486
|
mDLoss: cute.Tensor, # (M,)
|
|
369
487
|
mdX: cute.Tensor, # (M, N)
|
|
370
488
|
mLSE: cute.Tensor, # (M,)
|
|
489
|
+
ignore_index: Int32, # Index to ignore in gradient computation
|
|
371
490
|
shape: cute.Shape,
|
|
372
491
|
tv_layout: cute.Layout,
|
|
373
492
|
tiler_mn: cute.Shape,
|
|
@@ -388,76 +507,67 @@ class CrossEntropyBackward:
|
|
|
388
507
|
gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
|
|
389
508
|
cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
|
|
390
509
|
|
|
510
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
511
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
391
512
|
copy_atom_load_X = cute.make_copy_atom(
|
|
392
|
-
cute.nvgpu.
|
|
513
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
393
514
|
)
|
|
394
|
-
|
|
395
|
-
cute.nvgpu.
|
|
515
|
+
copy_atom_store_dX = cute.make_copy_atom(
|
|
516
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
396
517
|
)
|
|
397
|
-
copy_atom_store_O = cute.make_copy_atom(
|
|
398
|
-
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=128
|
|
399
|
-
)
|
|
400
|
-
|
|
401
518
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
402
|
-
|
|
403
|
-
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
404
|
-
).get_slice(tidx)
|
|
405
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
406
|
-
|
|
407
|
-
#### Thread View
|
|
408
|
-
tXgX = thr_copy_X_async.partition_S(gX)
|
|
409
|
-
tXsX = thr_copy_X_async.partition_S(sX)
|
|
519
|
+
thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
410
520
|
|
|
521
|
+
#### Partition to get thread view
|
|
522
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
523
|
+
tXsX = thr_copy_X.partition_S(sX)
|
|
411
524
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
412
|
-
tXcFull = thr_copy_X.partition_S(cX)
|
|
413
|
-
|
|
414
|
-
tXgO = thr_copy_O.partition_D(gdX)
|
|
415
|
-
|
|
525
|
+
tXcFull = thr_copy_X.partition_S(cX)
|
|
526
|
+
tXgdX = thr_copy_dX.partition_D(gdX)
|
|
416
527
|
# allocate fragments for gmem->rmem
|
|
417
|
-
tXrX,
|
|
528
|
+
tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
|
|
418
529
|
|
|
419
|
-
is_even_N =
|
|
530
|
+
is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
|
|
420
531
|
row = tXcX[0][0]
|
|
421
|
-
|
|
422
532
|
tXpX = (
|
|
423
|
-
utils.predicate_k(
|
|
424
|
-
if not is_even_N
|
|
425
|
-
else None
|
|
533
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
426
534
|
)
|
|
427
|
-
|
|
428
535
|
if row < shape[0]:
|
|
429
|
-
cute.copy(
|
|
536
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
430
537
|
cute.arch.cp_async_commit_group()
|
|
431
538
|
cute.arch.cp_async_wait_group(0)
|
|
432
|
-
if
|
|
539
|
+
if const_expr(not is_even_N):
|
|
433
540
|
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
434
|
-
|
|
435
541
|
cute.autovec_copy(tXsX, tXrX)
|
|
436
|
-
x = tXrX.load().to(
|
|
542
|
+
x = tXrX.load().to(Float32)
|
|
437
543
|
|
|
438
|
-
|
|
439
|
-
dloss =
|
|
440
|
-
lse =
|
|
544
|
+
target = Int32.zero
|
|
545
|
+
dloss = Float32.zero
|
|
546
|
+
lse = Float32.zero
|
|
441
547
|
if row < shape[0]:
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
548
|
+
target = Int32(mTarget[row])
|
|
549
|
+
should_ignore = Boolean(target == ignore_index)
|
|
550
|
+
# Set dloss to 0 if this index should be ignored
|
|
551
|
+
dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
|
|
552
|
+
lse = Float32(mLSE[row])
|
|
445
553
|
|
|
446
554
|
log2_e = math.log2(math.e)
|
|
447
|
-
probs = utils.exp2f(
|
|
555
|
+
probs = utils.exp2f(x * log2_e - lse * log2_e)
|
|
448
556
|
prob_shifted = probs - 1.0
|
|
449
557
|
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
450
558
|
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
|
451
|
-
mask[i] = tXcFull[i][1] ==
|
|
559
|
+
mask[i] = tXcFull[i][1] == target
|
|
452
560
|
grad = cute.where(mask.load(), prob_shifted, probs)
|
|
453
561
|
grad = grad * dloss
|
|
454
562
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
utils.predicate_k(
|
|
563
|
+
tXrdX.store(grad.to(tXrdX.element_type))
|
|
564
|
+
tXpdX = (
|
|
565
|
+
utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
|
|
566
|
+
if not is_even_N
|
|
567
|
+
else None
|
|
458
568
|
)
|
|
459
569
|
if row < shape[0]:
|
|
460
|
-
cute.copy(
|
|
570
|
+
cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
|
|
461
571
|
|
|
462
572
|
|
|
463
573
|
def _cross_entropy_backward(
|
|
@@ -465,8 +575,9 @@ def _cross_entropy_backward(
|
|
|
465
575
|
target: torch.Tensor,
|
|
466
576
|
dloss: torch.Tensor,
|
|
467
577
|
lse: torch.Tensor,
|
|
468
|
-
|
|
469
|
-
|
|
578
|
+
dx: torch.Tensor,
|
|
579
|
+
ignore_index=-100,
|
|
580
|
+
) -> None:
|
|
470
581
|
"""Cross entropy backward pass.
|
|
471
582
|
Args:
|
|
472
583
|
x: Input logits tensor of shape (M, N)
|
|
@@ -486,15 +597,10 @@ def _cross_entropy_backward(
|
|
|
486
597
|
assert (
|
|
487
598
|
x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda
|
|
488
599
|
), "Tensors must be on CUDA device"
|
|
489
|
-
assert x.dtype in [
|
|
490
|
-
torch.float16,
|
|
491
|
-
torch.bfloat16,
|
|
492
|
-
torch.float32,
|
|
493
|
-
], "Unsupported input dtype"
|
|
600
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
494
601
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
495
602
|
|
|
496
|
-
|
|
497
|
-
dx = torch.empty_like(x) if not inplace_backward else x
|
|
603
|
+
N = x.size(1)
|
|
498
604
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
499
605
|
|
|
500
606
|
convert_from_dlpack = lambda tensor: (
|
|
@@ -504,14 +610,12 @@ def _cross_entropy_backward(
|
|
|
504
610
|
)
|
|
505
611
|
x_tensor = convert_from_dlpack(x)
|
|
506
612
|
dx_tensor = convert_from_dlpack(dx)
|
|
507
|
-
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=
|
|
508
|
-
lse_tensor = from_dlpack(lse.detach(), assumed_align=
|
|
509
|
-
target_tensor = from_dlpack(target.detach(), assumed_align=
|
|
510
|
-
mode=0
|
|
511
|
-
)
|
|
613
|
+
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=4).mark_layout_dynamic()
|
|
614
|
+
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
|
|
615
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
|
|
512
616
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
513
617
|
|
|
514
|
-
compile_key = (dtype, N)
|
|
618
|
+
compile_key = (dtype, N, target.dtype, dloss.stride(), lse.stride(), target.stride())
|
|
515
619
|
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
516
620
|
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
517
621
|
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
@@ -521,48 +625,95 @@ def _cross_entropy_backward(
|
|
|
521
625
|
dloss_tensor,
|
|
522
626
|
dx_tensor,
|
|
523
627
|
lse_tensor,
|
|
628
|
+
Int32(ignore_index),
|
|
524
629
|
stream,
|
|
525
630
|
)
|
|
526
631
|
_cross_entropy_backward.compile_cache[compile_key](
|
|
527
|
-
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
|
|
632
|
+
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, Int32(ignore_index), stream
|
|
528
633
|
)
|
|
529
|
-
return dx
|
|
530
634
|
|
|
531
635
|
|
|
532
636
|
_cross_entropy_backward.compile_cache = {}
|
|
533
637
|
|
|
534
638
|
|
|
639
|
+
@torch.library.custom_op("quack::cross_entropy_bwd_out", mutates_args={"dx"})
|
|
640
|
+
def cross_entropy_bwd_out(
|
|
641
|
+
x: torch.Tensor,
|
|
642
|
+
target: torch.Tensor,
|
|
643
|
+
dloss: torch.Tensor,
|
|
644
|
+
lse: torch.Tensor,
|
|
645
|
+
dx: torch.Tensor,
|
|
646
|
+
ignore_index: int = -100,
|
|
647
|
+
) -> None:
|
|
648
|
+
_cross_entropy_backward(x, target, dloss, lse, dx, ignore_index)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def cross_entropy_bwd(
|
|
652
|
+
x: torch.Tensor,
|
|
653
|
+
target: torch.Tensor,
|
|
654
|
+
dloss: torch.Tensor,
|
|
655
|
+
lse: torch.Tensor,
|
|
656
|
+
ignore_index: int = -100,
|
|
657
|
+
inplace_backward: bool = False,
|
|
658
|
+
) -> None:
|
|
659
|
+
if inplace_backward and not torch.compiler.is_compiling():
|
|
660
|
+
dx = x
|
|
661
|
+
_cross_entropy_backward(
|
|
662
|
+
x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index
|
|
663
|
+
)
|
|
664
|
+
else:
|
|
665
|
+
dx = torch.empty_like(x)
|
|
666
|
+
cross_entropy_bwd_out(
|
|
667
|
+
x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index
|
|
668
|
+
)
|
|
669
|
+
return dx
|
|
670
|
+
|
|
671
|
+
|
|
535
672
|
class CrossEntropyFunction(torch.autograd.Function):
|
|
536
673
|
@staticmethod
|
|
537
|
-
def forward(ctx, x, target, inplace_backward=False):
|
|
538
|
-
|
|
674
|
+
def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False):
|
|
675
|
+
if lse_partial is None:
|
|
676
|
+
loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True)
|
|
677
|
+
else:
|
|
678
|
+
# if we already compute partial lse, then to compute the final lse we treat
|
|
679
|
+
# @lse_partial as @x and @x as @target_logit
|
|
680
|
+
loss, lse = cross_entropy_fwd(
|
|
681
|
+
lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True
|
|
682
|
+
)
|
|
539
683
|
ctx.save_for_backward(x, target, lse)
|
|
684
|
+
ctx.ignore_index = ignore_index
|
|
540
685
|
ctx.inplace_backward = inplace_backward
|
|
541
686
|
return loss
|
|
542
687
|
|
|
543
688
|
@staticmethod
|
|
544
689
|
def backward(ctx, dloss):
|
|
545
690
|
x, target, lse = ctx.saved_tensors
|
|
546
|
-
dx =
|
|
547
|
-
|
|
691
|
+
dx = cross_entropy_bwd(
|
|
692
|
+
x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward
|
|
693
|
+
)
|
|
694
|
+
return dx, None, None, None, None
|
|
548
695
|
|
|
549
696
|
|
|
550
697
|
def cross_entropy(
|
|
551
698
|
x: torch.Tensor,
|
|
552
699
|
target: torch.Tensor,
|
|
553
|
-
|
|
554
|
-
|
|
700
|
+
lse_partial: Optional[torch.Tensor] = None,
|
|
701
|
+
ignore_index: int = -100,
|
|
702
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
703
|
+
inplace_backward: bool = False,
|
|
555
704
|
) -> torch.Tensor:
|
|
556
705
|
"""Cross entropy loss with automatic differentiation support.
|
|
557
706
|
|
|
558
707
|
Args:
|
|
559
708
|
x: Input logits tensor of shape (M, N)
|
|
560
709
|
target: Target class indices tensor of shape (M,)
|
|
561
|
-
|
|
710
|
+
lse_partial: Optional precomputed log-sum-exp partial results
|
|
562
711
|
reduction: Specifies the reduction to apply to the output:
|
|
563
712
|
'none': no reduction will be applied (default)
|
|
564
713
|
'mean': the sum of the output will be divided by the number of elements
|
|
565
714
|
'sum': the output will be summed
|
|
715
|
+
inplace_backward: Whether to perform backward pass in-place
|
|
716
|
+
ignore_index: Index to ignore in loss computation (loss will be 0 for these indices)
|
|
566
717
|
|
|
567
718
|
Returns:
|
|
568
719
|
Cross entropy loss tensor:
|
|
@@ -570,10 +721,9 @@ def cross_entropy(
|
|
|
570
721
|
- If reduction='mean': scalar tensor with mean loss
|
|
571
722
|
- If reduction='sum': scalar tensor with sum of losses
|
|
572
723
|
"""
|
|
573
|
-
loss = CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
574
|
-
|
|
724
|
+
loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward)
|
|
575
725
|
if reduction == "mean":
|
|
576
|
-
return loss.
|
|
726
|
+
return loss.sum() / (target != ignore_index).sum().float()
|
|
577
727
|
elif reduction == "sum":
|
|
578
728
|
return loss.sum()
|
|
579
729
|
elif reduction == "none":
|