quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/cross_entropy.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
|
1
1
|
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
|
-
from typing import Optional, Type
|
|
4
|
+
from typing import Optional, Type, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
5
8
|
|
|
6
9
|
import cuda.bindings.driver as cuda
|
|
7
10
|
|
|
8
11
|
import cutlass
|
|
9
12
|
import cutlass.cute as cute
|
|
13
|
+
from cutlass import Int32, Float32, Boolean, const_expr
|
|
14
|
+
from cutlass.cute.runtime import from_dlpack
|
|
10
15
|
|
|
11
16
|
import quack.utils as utils
|
|
12
|
-
import
|
|
13
|
-
from
|
|
14
|
-
from quack.
|
|
17
|
+
from quack.reduce import row_reduce, online_softmax_reduce
|
|
18
|
+
from quack.reduction_base import ReductionBase
|
|
19
|
+
from quack.cute_dsl_utils import torch2cute_dtype_map
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
class CrossEntropy(ReductionBase):
|
|
@@ -21,7 +26,7 @@ class CrossEntropy(ReductionBase):
|
|
|
21
26
|
dtype,
|
|
22
27
|
N,
|
|
23
28
|
stage=2 if not online_softmax else 1,
|
|
24
|
-
reduction_dtype=
|
|
29
|
+
reduction_dtype=Float32 if not online_softmax else cutlass.Int64,
|
|
25
30
|
)
|
|
26
31
|
self.online_softmax = online_softmax
|
|
27
32
|
self.reload_from = None if N <= 16384 or online_softmax else "smem"
|
|
@@ -40,7 +45,7 @@ class CrossEntropy(ReductionBase):
|
|
|
40
45
|
|
|
41
46
|
def _set_cluster_n(self):
|
|
42
47
|
N = self.N
|
|
43
|
-
if
|
|
48
|
+
if const_expr(self.dtype.width == 16):
|
|
44
49
|
cluster_n = (
|
|
45
50
|
1
|
|
46
51
|
if N <= 16 * 1024
|
|
@@ -65,21 +70,30 @@ class CrossEntropy(ReductionBase):
|
|
|
65
70
|
@cute.jit
|
|
66
71
|
def __call__(
|
|
67
72
|
self,
|
|
68
|
-
mX: cute.Tensor,
|
|
69
|
-
mTarget: cute.Tensor,
|
|
70
|
-
|
|
71
|
-
|
|
73
|
+
mX: cute.Tensor, # (M, N)
|
|
74
|
+
mTarget: cute.Tensor, # (M,)
|
|
75
|
+
mTargetLogit: Optional[cute.Tensor], # (M, K) or (M,). If None, we use mX
|
|
76
|
+
mLoss: cute.Tensor, # (M,)
|
|
77
|
+
mLSE: Optional[cute.Tensor], # (M,)
|
|
78
|
+
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
|
79
|
+
ignore_index: Int32, # Index to ignore in loss computation
|
|
72
80
|
stream: cuda.CUstream,
|
|
73
81
|
):
|
|
74
82
|
assert mX.element_type == self.dtype
|
|
83
|
+
if const_expr(mTargetLogit is None):
|
|
84
|
+
mTargetLogit = mX
|
|
75
85
|
self._set_cluster_n()
|
|
76
|
-
|
|
86
|
+
# e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
|
|
87
|
+
num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
|
|
88
|
+
tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
|
|
77
89
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
78
90
|
num_warps = num_threads // cute.arch.WARP_SIZE
|
|
79
|
-
self.kernel(
|
|
91
|
+
self.kernel(
|
|
92
|
+
mX, mTarget, mTargetLogit, mLoss, mLSE, mdX, ignore_index, tv_layout, tiler_mn
|
|
93
|
+
).launch(
|
|
80
94
|
grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
|
|
81
95
|
block=[num_threads, 1, 1],
|
|
82
|
-
cluster=([1, self.cluster_n, 1] if
|
|
96
|
+
cluster=([1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None),
|
|
83
97
|
smem=self._smem_size_in_bytes(tiler_mn, num_warps),
|
|
84
98
|
stream=stream,
|
|
85
99
|
)
|
|
@@ -89,17 +103,20 @@ class CrossEntropy(ReductionBase):
|
|
|
89
103
|
self,
|
|
90
104
|
mX: cute.Tensor, # (M, N)
|
|
91
105
|
mTarget: cute.Tensor, # (M,)
|
|
106
|
+
mTargetLogit: cute.Tensor, # (M, K) or (M,)
|
|
92
107
|
mLoss: cute.Tensor, # (M,)
|
|
93
108
|
mLSE: Optional[cute.Tensor], # (M,)
|
|
109
|
+
mdX: Optional[cute.Tensor], # (M, N) - if provided, compute gradient
|
|
110
|
+
ignore_index: Int32, # Index to ignore in loss computation
|
|
94
111
|
tv_layout: cute.Layout,
|
|
95
112
|
tiler_mn: cute.Shape,
|
|
96
113
|
):
|
|
97
114
|
tidx, _, _ = cute.arch.thread_idx()
|
|
98
115
|
bidx, _, _ = cute.arch.block_idx()
|
|
99
|
-
if
|
|
116
|
+
if const_expr(self.cluster_n > 1):
|
|
100
117
|
cluster_y = cute.arch.block_idx()[1]
|
|
101
118
|
else:
|
|
102
|
-
cluster_y =
|
|
119
|
+
cluster_y = const_expr(0)
|
|
103
120
|
|
|
104
121
|
shape: cute.Shape = mX.shape
|
|
105
122
|
idX = cute.make_identity_tensor(shape)
|
|
@@ -118,12 +135,14 @@ class CrossEntropy(ReductionBase):
|
|
|
118
135
|
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
|
|
119
136
|
|
|
120
137
|
# declare the atoms which will be used later for memory copy
|
|
138
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
139
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
121
140
|
copy_atom_load_X = cute.make_copy_atom(
|
|
122
|
-
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=
|
|
141
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
123
142
|
)
|
|
124
143
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
125
144
|
|
|
126
|
-
####
|
|
145
|
+
#### Partition to get thread view
|
|
127
146
|
tXgX = thr_copy_X.partition_S(gX)
|
|
128
147
|
tXsX = thr_copy_X.partition_D(sX)
|
|
129
148
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
@@ -133,14 +152,14 @@ class CrossEntropy(ReductionBase):
|
|
|
133
152
|
self._initialize_cluster(tidx, mbar_ptr, num_warps)
|
|
134
153
|
|
|
135
154
|
row = tXcX[0][0]
|
|
136
|
-
target =
|
|
137
|
-
if row < shape[0]
|
|
138
|
-
target =
|
|
155
|
+
target = Int32.zero
|
|
156
|
+
if row < shape[0]:
|
|
157
|
+
target = Int32(mTarget[row])
|
|
139
158
|
|
|
140
|
-
is_even_N =
|
|
159
|
+
is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
|
|
141
160
|
tXpX = (
|
|
142
161
|
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
|
|
143
|
-
if
|
|
162
|
+
if const_expr(not is_even_N)
|
|
144
163
|
else None
|
|
145
164
|
)
|
|
146
165
|
if row < shape[0]:
|
|
@@ -148,99 +167,148 @@ class CrossEntropy(ReductionBase):
|
|
|
148
167
|
cute.arch.cp_async_commit_group()
|
|
149
168
|
cute.arch.cp_async_wait_group(0)
|
|
150
169
|
# Fill OOB values with -inf
|
|
151
|
-
if
|
|
170
|
+
if const_expr(not is_even_N):
|
|
152
171
|
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
153
172
|
cute.autovec_copy(tXsX, tXrX)
|
|
154
|
-
x = tXrX.load().to(
|
|
155
|
-
|
|
156
|
-
target_logit =
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
173
|
+
x = tXrX.load().to(Float32)
|
|
174
|
+
|
|
175
|
+
target_logit = Float32.zero
|
|
176
|
+
should_ignore = Boolean(target == ignore_index)
|
|
177
|
+
if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
|
|
178
|
+
# Only load target logit if not ignoring this index
|
|
179
|
+
if const_expr(cute.rank(mTargetLogit.shape) == 2):
|
|
180
|
+
# Use Int64 for indexing to deal with large tensors
|
|
181
|
+
mTargetLogit_off = utils.domain_offset_i64((row, 0), mTargetLogit)
|
|
182
|
+
target_logit = Float32(mTargetLogit_off[0, target])
|
|
183
|
+
else:
|
|
184
|
+
assert cute.rank(mTargetLogit.shape) == 1
|
|
185
|
+
target_logit = Float32(mTargetLogit[row])
|
|
161
186
|
|
|
162
187
|
threads_per_row = tv_layout.shape[0][0]
|
|
163
|
-
if
|
|
164
|
-
max_x =
|
|
188
|
+
if const_expr(not self.online_softmax):
|
|
189
|
+
max_x = row_reduce(
|
|
165
190
|
x,
|
|
166
191
|
cute.ReductionOp.MAX,
|
|
167
192
|
threads_per_row,
|
|
168
193
|
reduction_buffer[None, None, 0],
|
|
169
|
-
mbar_ptr + 0 if
|
|
170
|
-
init_val=-
|
|
171
|
-
hook_fn=(
|
|
172
|
-
cute.arch.cluster_wait if cutlass.const_expr(self.cluster_n > 1) else None
|
|
173
|
-
),
|
|
194
|
+
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
|
|
195
|
+
init_val=-Float32.inf,
|
|
196
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
174
197
|
)
|
|
175
|
-
if
|
|
198
|
+
if const_expr(self.reload_from == "smem"):
|
|
176
199
|
cute.autovec_copy(tXsX, tXrX)
|
|
177
|
-
x = tXrX.load().to(
|
|
200
|
+
x = tXrX.load().to(Float32)
|
|
178
201
|
log2_e = math.log2(math.e)
|
|
179
|
-
# exp_x = cute.math.exp2((x - max_x) * log2_e, fastmath=True)
|
|
180
|
-
# a bit faster, probably because it's calling ex2.approx.ftz instead of ex2.approx?
|
|
181
|
-
# exp_x = utils.exp2f((x - max_x) * log2_e)
|
|
182
202
|
# This would use ffma instead of fadd then fmul
|
|
183
|
-
exp_x =
|
|
184
|
-
denom =
|
|
203
|
+
exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=False)
|
|
204
|
+
denom = row_reduce(
|
|
185
205
|
exp_x,
|
|
186
206
|
cute.ReductionOp.ADD,
|
|
187
207
|
threads_per_row,
|
|
188
208
|
reduction_buffer[None, None, 1],
|
|
189
|
-
mbar_ptr + 1 if
|
|
209
|
+
mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
|
|
190
210
|
init_val=0.0,
|
|
191
211
|
)
|
|
192
212
|
else:
|
|
193
|
-
max_x, denom,
|
|
213
|
+
max_x, denom, exp_x = online_softmax_reduce(
|
|
194
214
|
x,
|
|
195
215
|
threads_per_row,
|
|
196
216
|
reduction_buffer[None, None, 0],
|
|
197
217
|
mbar_ptr,
|
|
198
|
-
hook_fn=(
|
|
199
|
-
|
|
200
|
-
),
|
|
218
|
+
hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
|
|
219
|
+
return_exp_x=const_expr(mdX is not None),
|
|
201
220
|
)
|
|
202
221
|
|
|
222
|
+
# Write loss and lse to gmem
|
|
203
223
|
if (
|
|
204
224
|
tXcX[0][1] == 0
|
|
205
225
|
and row < shape[0]
|
|
206
226
|
and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
|
|
207
227
|
):
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
loss_val = lse - target_logit
|
|
211
|
-
mLoss[row] =
|
|
212
|
-
if
|
|
228
|
+
lse = max_x + cute.math.log(denom, fastmath=True)
|
|
229
|
+
# Set loss to 0 if this index should be ignored, otherwise compute normally
|
|
230
|
+
loss_val = (lse - target_logit) if not should_ignore else Float32.zero
|
|
231
|
+
mLoss[row] = mLoss.element_type(loss_val)
|
|
232
|
+
if const_expr(mLSE is not None):
|
|
213
233
|
mLSE[row] = lse
|
|
214
234
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
235
|
+
# Compute gradient if mdX is provided
|
|
236
|
+
if const_expr(mdX is not None):
|
|
237
|
+
# Compute probabilities: exp(x) / sum(exp(x))
|
|
238
|
+
# If ignored, gradient should be zero
|
|
239
|
+
denom_inv = (
|
|
240
|
+
1.0 / denom
|
|
241
|
+
if not (denom == 0.0 or denom != denom or should_ignore)
|
|
242
|
+
else Float32.zero
|
|
243
|
+
)
|
|
244
|
+
probs = exp_x * denom_inv
|
|
245
|
+
mdX_off = utils.domain_offset_i64((bidx * tiler_mn[0], 0), mdX)
|
|
246
|
+
gdX = cute.local_tile(mdX_off, tiler_mn, (0, cluster_y))
|
|
247
|
+
# Setup copy atom for storing gradient
|
|
248
|
+
copy_atom_store = cute.make_copy_atom(
|
|
249
|
+
cute.nvgpu.CopyUniversalOp(), mdX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
250
|
+
)
|
|
251
|
+
thr_copy_dX = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
|
|
252
|
+
tXgdX = thr_copy_dX.partition_D(gdX)
|
|
253
|
+
tXrdX = cute.make_fragment_like(tXgdX)
|
|
254
|
+
tXcFull = thr_copy_X.partition_S(cX)
|
|
255
|
+
# Compute gradient: probs for all classes, (probs - 1) for target class
|
|
256
|
+
# If ignored, gradient is already zero
|
|
257
|
+
tXrdX_f32 = cute.make_fragment_like(tXrX, Float32)
|
|
258
|
+
tXrdX_f32.store(probs)
|
|
259
|
+
if not should_ignore:
|
|
260
|
+
for i in cutlass.range(cute.size(tXrX), unroll_full=True):
|
|
261
|
+
tXrdX_f32[i] = tXrdX_f32[i] if tXcFull[i][1] != target else tXrdX_f32[i] - 1.0
|
|
262
|
+
tXrdX.store(tXrdX_f32.load().to(tXrdX.element_type))
|
|
263
|
+
tXpdX = (
|
|
264
|
+
utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
|
|
265
|
+
if not is_even_N
|
|
266
|
+
else None
|
|
267
|
+
)
|
|
268
|
+
if row < shape[0]:
|
|
269
|
+
cute.copy(copy_atom_store, tXrdX, tXgdX, pred=tXpdX)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@torch.library.custom_op("quack::cross_entropy_fwd_out", mutates_args={"loss", "lse", "dx"})
|
|
273
|
+
def cross_entropy_fwd_out(
|
|
274
|
+
x: Tensor,
|
|
275
|
+
target: Tensor,
|
|
276
|
+
target_logit: Optional[Tensor],
|
|
277
|
+
loss: Tensor,
|
|
278
|
+
lse: Optional[Tensor],
|
|
279
|
+
dx: Optional[Tensor],
|
|
280
|
+
ignore_index: int = -100,
|
|
281
|
+
) -> None:
|
|
221
282
|
"""Cross entropy forward pass.
|
|
222
283
|
|
|
223
284
|
Args:
|
|
224
285
|
x: Input logits tensor of shape (M, N)
|
|
225
286
|
target: Target class indices tensor of shape (M,)
|
|
287
|
+
target_logit: (M, K) or (M,).
|
|
288
|
+
If provided, the target logit will be read from this tensor instead of x.
|
|
289
|
+
loss: Output loss tensor of shape (M,)
|
|
290
|
+
lse: Optional output log-sum-exp tensor of shape (M,)
|
|
291
|
+
dx: Optional output gradient tensor of shape (M, N)
|
|
292
|
+
ignore_index: Index to ignore in loss computation
|
|
226
293
|
|
|
227
294
|
Returns:
|
|
228
|
-
|
|
295
|
+
None (mutates loss, lse, and optionally dx in-place)
|
|
229
296
|
"""
|
|
230
297
|
assert x.dim() == 2, "Input must be 2D"
|
|
231
298
|
assert target.dim() == 1, "Target must be 1D"
|
|
232
299
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
233
300
|
assert x.is_cuda and target.is_cuda, "Tensors must be on CUDA device"
|
|
234
|
-
assert x.dtype in [
|
|
235
|
-
torch.float16,
|
|
236
|
-
torch.bfloat16,
|
|
237
|
-
torch.float32,
|
|
238
|
-
], "Unsupported input dtype"
|
|
301
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
239
302
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
303
|
+
if target_logit is not None:
|
|
304
|
+
assert target_logit.shape[0] == x.shape[0]
|
|
305
|
+
assert target_logit.is_cuda, "Target logits must be on CUDA device"
|
|
306
|
+
assert target_logit.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
307
|
+
if dx is not None:
|
|
308
|
+
assert dx.shape == x.shape, "dx must have same shape as x"
|
|
309
|
+
assert dx.is_cuda, "dx must be on CUDA device"
|
|
310
|
+
assert dx.dtype == x.dtype, "dx must have same dtype as x"
|
|
311
|
+
N = x.size(1)
|
|
244
312
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
245
313
|
convert_from_dlpack = lambda tensor: (
|
|
246
314
|
from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(
|
|
@@ -248,28 +316,86 @@ def _cross_entropy(
|
|
|
248
316
|
)
|
|
249
317
|
)
|
|
250
318
|
x_tensor = convert_from_dlpack(x)
|
|
251
|
-
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).
|
|
319
|
+
loss_tensor = from_dlpack(loss.detach(), assumed_align=4).mark_layout_dynamic()
|
|
252
320
|
lse_tensor = (
|
|
253
|
-
from_dlpack(lse.detach(), assumed_align=4).
|
|
321
|
+
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
|
|
254
322
|
if lse is not None
|
|
255
323
|
else None
|
|
256
324
|
)
|
|
257
|
-
target_tensor = from_dlpack(target.detach(), assumed_align=8).
|
|
325
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
|
|
326
|
+
target_logit_tensor = (
|
|
327
|
+
from_dlpack(target_logit.detach(), assumed_align=4).mark_layout_dynamic(
|
|
328
|
+
leading_dim=target_logit.ndim - 1
|
|
329
|
+
)
|
|
330
|
+
if target_logit is not None
|
|
331
|
+
else None
|
|
332
|
+
)
|
|
333
|
+
dx_tensor = convert_from_dlpack(dx) if dx is not None else None
|
|
258
334
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
259
335
|
|
|
260
|
-
compile_key = (
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
336
|
+
compile_key = (
|
|
337
|
+
dtype,
|
|
338
|
+
N,
|
|
339
|
+
target_logit.dtype if target_logit is not None else None,
|
|
340
|
+
lse.dtype if lse is not None else None,
|
|
341
|
+
dx is not None,
|
|
342
|
+
loss.stride(),
|
|
343
|
+
lse.stride() if lse is not None else None,
|
|
344
|
+
target.stride(),
|
|
345
|
+
target_logit.stride(-1) if target_logit is not None else None,
|
|
346
|
+
)
|
|
347
|
+
if compile_key not in cross_entropy_fwd_out.compile_cache:
|
|
348
|
+
# If there's dx, it's faster to not use online softmax since we want the exp(x - max)
|
|
349
|
+
cross_entropy_op = CrossEntropy(dtype, N, online_softmax=dx is None)
|
|
350
|
+
cross_entropy_fwd_out.compile_cache[compile_key] = cute.compile(
|
|
351
|
+
cross_entropy_op,
|
|
352
|
+
x_tensor,
|
|
353
|
+
target_tensor,
|
|
354
|
+
target_logit_tensor,
|
|
355
|
+
loss_tensor,
|
|
356
|
+
lse_tensor,
|
|
357
|
+
dx_tensor,
|
|
358
|
+
Int32(ignore_index),
|
|
359
|
+
stream,
|
|
265
360
|
)
|
|
266
|
-
|
|
267
|
-
x_tensor,
|
|
361
|
+
cross_entropy_fwd_out.compile_cache[compile_key](
|
|
362
|
+
x_tensor,
|
|
363
|
+
target_tensor,
|
|
364
|
+
target_logit_tensor,
|
|
365
|
+
loss_tensor,
|
|
366
|
+
lse_tensor,
|
|
367
|
+
dx_tensor,
|
|
368
|
+
Int32(ignore_index),
|
|
369
|
+
stream,
|
|
268
370
|
)
|
|
269
|
-
return loss if not return_lse else (loss, lse)
|
|
270
371
|
|
|
271
372
|
|
|
272
|
-
|
|
373
|
+
cross_entropy_fwd_out.compile_cache = {}
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def cross_entropy_fwd(
|
|
377
|
+
x: torch.Tensor,
|
|
378
|
+
target: torch.Tensor,
|
|
379
|
+
target_logit: Optional[torch.Tensor] = None,
|
|
380
|
+
ignore_index: int = -100,
|
|
381
|
+
return_lse: bool = False,
|
|
382
|
+
return_dx: bool = False,
|
|
383
|
+
inplace_backward: bool = False,
|
|
384
|
+
) -> torch.Tensor | tuple[torch.Tensor]:
|
|
385
|
+
M = x.size(0)
|
|
386
|
+
device = x.device
|
|
387
|
+
loss = torch.empty(M, device=device, dtype=torch.float32)
|
|
388
|
+
lse = torch.empty(M, device=device, dtype=torch.float32) if return_lse else None
|
|
389
|
+
dx = (torch.empty_like(x) if not inplace_backward else x) if return_dx else None
|
|
390
|
+
cross_entropy_fwd_out(x, target, target_logit, loss, lse, dx, ignore_index)
|
|
391
|
+
if return_lse and return_dx:
|
|
392
|
+
return loss, lse, dx
|
|
393
|
+
elif return_lse:
|
|
394
|
+
return loss, lse
|
|
395
|
+
elif return_dx:
|
|
396
|
+
return loss, dx
|
|
397
|
+
else:
|
|
398
|
+
return loss
|
|
273
399
|
|
|
274
400
|
|
|
275
401
|
class CrossEntropyBackward:
|
|
@@ -279,7 +405,7 @@ class CrossEntropyBackward:
|
|
|
279
405
|
self.vecsize = 128 // dtype.width
|
|
280
406
|
|
|
281
407
|
def _calculate_threads_per_row(self):
|
|
282
|
-
N = self.N
|
|
408
|
+
N = min(self.N, 16384) # We split by blocks of 16k
|
|
283
409
|
return (
|
|
284
410
|
8
|
|
285
411
|
if N <= 64
|
|
@@ -290,13 +416,14 @@ class CrossEntropyBackward:
|
|
|
290
416
|
)
|
|
291
417
|
)
|
|
292
418
|
|
|
293
|
-
def _get_tv_layout(self):
|
|
294
|
-
|
|
295
|
-
vecsize
|
|
419
|
+
def _get_tv_layout(self, num_copy_bits=128):
|
|
420
|
+
vecsize = num_copy_bits // self.dtype.width
|
|
421
|
+
assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
|
|
422
|
+
N = min(self.N, 16384)
|
|
296
423
|
num_threads = 128 if N <= 16384 else 256
|
|
297
424
|
threads_per_row = self._calculate_threads_per_row()
|
|
298
425
|
cols_per_block = num_threads // threads_per_row
|
|
299
|
-
num_blocks_N = cute.ceil_div(
|
|
426
|
+
num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
|
|
300
427
|
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
|
|
301
428
|
tv_layout = cute.make_layout(
|
|
302
429
|
((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
|
|
@@ -315,40 +442,27 @@ class CrossEntropyBackward:
|
|
|
315
442
|
mDLoss: cute.Tensor,
|
|
316
443
|
mdX: cute.Tensor,
|
|
317
444
|
mLSE: cute.Tensor,
|
|
445
|
+
ignore_index: Int32, # Index to ignore in gradient computation
|
|
318
446
|
stream: cuda.CUstream,
|
|
319
447
|
):
|
|
320
448
|
assert mX.element_type == self.dtype
|
|
321
449
|
assert mdX.element_type == self.dtype
|
|
322
|
-
|
|
323
|
-
|
|
450
|
+
# e.g. if self.N isn't divisible by 8 for bf16, we might use 64 bits (4 elements) copy
|
|
451
|
+
num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
|
|
452
|
+
tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
|
|
324
453
|
num_threads = cute.size(tv_layout, mode=[0])
|
|
325
|
-
|
|
326
|
-
mDLoss =
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
cute.append(mTarget.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
333
|
-
)
|
|
334
|
-
mLSE = cute.make_tensor(
|
|
335
|
-
mLSE.iterator,
|
|
336
|
-
cute.append(mLSE.layout, cute.make_layout((self.N,), stride=(0,))),
|
|
337
|
-
)
|
|
338
|
-
|
|
454
|
+
# (M,) -> (M, N) with stride 0 in the N dimension
|
|
455
|
+
mDLoss, mTarget, mLSE = [
|
|
456
|
+
cute.make_tensor(
|
|
457
|
+
X.iterator, cute.append(X.layout, cute.make_layout((self.N,), stride=(0,)))
|
|
458
|
+
)
|
|
459
|
+
for X in (mDLoss, mTarget, mLSE)
|
|
460
|
+
]
|
|
339
461
|
smem_size = cute.size_in_bytes(
|
|
340
462
|
mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0))
|
|
341
463
|
)
|
|
342
|
-
|
|
343
464
|
self.kernel(
|
|
344
|
-
mX,
|
|
345
|
-
mTarget,
|
|
346
|
-
mDLoss,
|
|
347
|
-
mdX,
|
|
348
|
-
mLSE,
|
|
349
|
-
mX.shape,
|
|
350
|
-
tv_layout,
|
|
351
|
-
tiler_mn,
|
|
465
|
+
mX, mTarget, mDLoss, mdX, mLSE, ignore_index, mX.shape, tv_layout, tiler_mn
|
|
352
466
|
).launch(
|
|
353
467
|
grid=[
|
|
354
468
|
cute.ceil_div(mX.shape[0], tiler_mn[0]),
|
|
@@ -368,6 +482,7 @@ class CrossEntropyBackward:
|
|
|
368
482
|
mDLoss: cute.Tensor, # (M,)
|
|
369
483
|
mdX: cute.Tensor, # (M, N)
|
|
370
484
|
mLSE: cute.Tensor, # (M,)
|
|
485
|
+
ignore_index: Int32, # Index to ignore in gradient computation
|
|
371
486
|
shape: cute.Shape,
|
|
372
487
|
tv_layout: cute.Layout,
|
|
373
488
|
tiler_mn: cute.Shape,
|
|
@@ -388,76 +503,67 @@ class CrossEntropyBackward:
|
|
|
388
503
|
gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX, mdX)]
|
|
389
504
|
cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
|
|
390
505
|
|
|
506
|
+
num_copy_elems_X = tv_layout.shape[1][0]
|
|
507
|
+
num_copy_bits_X = mX.element_type.width * num_copy_elems_X
|
|
391
508
|
copy_atom_load_X = cute.make_copy_atom(
|
|
392
|
-
cute.nvgpu.
|
|
393
|
-
)
|
|
394
|
-
copy_atom_load_X_async = cute.make_copy_atom(
|
|
395
|
-
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=128
|
|
509
|
+
cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
396
510
|
)
|
|
397
|
-
|
|
398
|
-
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=
|
|
511
|
+
copy_atom_store_dX = cute.make_copy_atom(
|
|
512
|
+
cute.nvgpu.CopyUniversalOp(), gdX.element_type, num_bits_per_copy=num_copy_bits_X
|
|
399
513
|
)
|
|
400
|
-
|
|
401
514
|
thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
|
|
402
|
-
|
|
403
|
-
copy_atom_load_X_async, tv_layout, tiler_mn
|
|
404
|
-
).get_slice(tidx)
|
|
405
|
-
thr_copy_O = cute.make_tiled_copy(copy_atom_store_O, tv_layout, tiler_mn).get_slice(tidx)
|
|
406
|
-
|
|
407
|
-
#### Thread View
|
|
408
|
-
tXgX = thr_copy_X_async.partition_S(gX)
|
|
409
|
-
tXsX = thr_copy_X_async.partition_S(sX)
|
|
515
|
+
thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
|
|
410
516
|
|
|
517
|
+
#### Partition to get thread view
|
|
518
|
+
tXgX = thr_copy_X.partition_S(gX)
|
|
519
|
+
tXsX = thr_copy_X.partition_S(sX)
|
|
411
520
|
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
|
|
412
|
-
tXcFull = thr_copy_X.partition_S(cX)
|
|
413
|
-
|
|
414
|
-
tXgO = thr_copy_O.partition_D(gdX)
|
|
415
|
-
|
|
521
|
+
tXcFull = thr_copy_X.partition_S(cX)
|
|
522
|
+
tXgdX = thr_copy_dX.partition_D(gdX)
|
|
416
523
|
# allocate fragments for gmem->rmem
|
|
417
|
-
tXrX,
|
|
524
|
+
tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
|
|
418
525
|
|
|
419
|
-
is_even_N =
|
|
526
|
+
is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
|
|
420
527
|
row = tXcX[0][0]
|
|
421
|
-
|
|
422
528
|
tXpX = (
|
|
423
|
-
utils.predicate_k(
|
|
424
|
-
if not is_even_N
|
|
425
|
-
else None
|
|
529
|
+
utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None
|
|
426
530
|
)
|
|
427
|
-
|
|
428
531
|
if row < shape[0]:
|
|
429
|
-
cute.copy(
|
|
532
|
+
cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
|
|
430
533
|
cute.arch.cp_async_commit_group()
|
|
431
534
|
cute.arch.cp_async_wait_group(0)
|
|
432
|
-
if
|
|
535
|
+
if const_expr(not is_even_N):
|
|
433
536
|
utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
|
|
434
|
-
|
|
435
537
|
cute.autovec_copy(tXsX, tXrX)
|
|
436
|
-
x = tXrX.load().to(
|
|
538
|
+
x = tXrX.load().to(Float32)
|
|
437
539
|
|
|
438
|
-
|
|
439
|
-
dloss =
|
|
440
|
-
lse =
|
|
540
|
+
target = Int32.zero
|
|
541
|
+
dloss = Float32.zero
|
|
542
|
+
lse = Float32.zero
|
|
441
543
|
if row < shape[0]:
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
544
|
+
target = Int32(mTarget[row])
|
|
545
|
+
should_ignore = Boolean(target == ignore_index)
|
|
546
|
+
# Set dloss to 0 if this index should be ignored
|
|
547
|
+
dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
|
|
548
|
+
lse = Float32(mLSE[row])
|
|
445
549
|
|
|
446
550
|
log2_e = math.log2(math.e)
|
|
447
|
-
probs =
|
|
551
|
+
probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
|
|
448
552
|
prob_shifted = probs - 1.0
|
|
449
553
|
mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
|
|
450
554
|
for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
|
|
451
|
-
mask[i] = tXcFull[i][1] ==
|
|
555
|
+
mask[i] = tXcFull[i][1] == target
|
|
452
556
|
grad = cute.where(mask.load(), prob_shifted, probs)
|
|
453
557
|
grad = grad * dloss
|
|
454
558
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
utils.predicate_k(
|
|
559
|
+
tXrdX.store(grad.to(tXrdX.element_type))
|
|
560
|
+
tXpdX = (
|
|
561
|
+
utils.predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
|
|
562
|
+
if not is_even_N
|
|
563
|
+
else None
|
|
458
564
|
)
|
|
459
565
|
if row < shape[0]:
|
|
460
|
-
cute.copy(
|
|
566
|
+
cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
|
|
461
567
|
|
|
462
568
|
|
|
463
569
|
def _cross_entropy_backward(
|
|
@@ -465,8 +571,9 @@ def _cross_entropy_backward(
|
|
|
465
571
|
target: torch.Tensor,
|
|
466
572
|
dloss: torch.Tensor,
|
|
467
573
|
lse: torch.Tensor,
|
|
468
|
-
|
|
469
|
-
|
|
574
|
+
dx: torch.Tensor,
|
|
575
|
+
ignore_index=-100,
|
|
576
|
+
) -> None:
|
|
470
577
|
"""Cross entropy backward pass.
|
|
471
578
|
Args:
|
|
472
579
|
x: Input logits tensor of shape (M, N)
|
|
@@ -483,18 +590,13 @@ def _cross_entropy_backward(
|
|
|
483
590
|
assert x.shape[0] == target.shape[0], "Batch dimensions must match"
|
|
484
591
|
assert x.shape[0] == dloss.shape[0], "Batch dimensions must match"
|
|
485
592
|
assert x.shape[0] == lse.shape[0], "Batch dimensions must match"
|
|
486
|
-
assert (
|
|
487
|
-
|
|
488
|
-
)
|
|
489
|
-
assert x.dtype in [
|
|
490
|
-
torch.float16,
|
|
491
|
-
torch.bfloat16,
|
|
492
|
-
torch.float32,
|
|
493
|
-
], "Unsupported input dtype"
|
|
593
|
+
assert x.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
|
|
594
|
+
"Tensors must be on CUDA device"
|
|
595
|
+
)
|
|
596
|
+
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype"
|
|
494
597
|
assert target.dtype in [torch.int32, torch.int64], "Target must be int32 or int64"
|
|
495
598
|
|
|
496
|
-
|
|
497
|
-
dx = torch.empty_like(x) if not inplace_backward else x
|
|
599
|
+
N = x.size(1)
|
|
498
600
|
dtype = torch2cute_dtype_map[x.dtype]
|
|
499
601
|
|
|
500
602
|
convert_from_dlpack = lambda tensor: (
|
|
@@ -504,14 +606,12 @@ def _cross_entropy_backward(
|
|
|
504
606
|
)
|
|
505
607
|
x_tensor = convert_from_dlpack(x)
|
|
506
608
|
dx_tensor = convert_from_dlpack(dx)
|
|
507
|
-
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=
|
|
508
|
-
lse_tensor = from_dlpack(lse.detach(), assumed_align=
|
|
509
|
-
target_tensor = from_dlpack(target.detach(), assumed_align=
|
|
510
|
-
mode=0
|
|
511
|
-
)
|
|
609
|
+
dloss_tensor = from_dlpack(dloss.detach(), assumed_align=4).mark_layout_dynamic()
|
|
610
|
+
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic()
|
|
611
|
+
target_tensor = from_dlpack(target.detach(), assumed_align=8).mark_layout_dynamic()
|
|
512
612
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
513
613
|
|
|
514
|
-
compile_key = (dtype, N)
|
|
614
|
+
compile_key = (dtype, N, target.dtype, dloss.stride(), lse.stride(), target.stride())
|
|
515
615
|
if compile_key not in _cross_entropy_backward.compile_cache:
|
|
516
616
|
cross_entropy_backward_op = CrossEntropyBackward(dtype, N)
|
|
517
617
|
_cross_entropy_backward.compile_cache[compile_key] = cute.compile(
|
|
@@ -521,48 +621,95 @@ def _cross_entropy_backward(
|
|
|
521
621
|
dloss_tensor,
|
|
522
622
|
dx_tensor,
|
|
523
623
|
lse_tensor,
|
|
624
|
+
Int32(ignore_index),
|
|
524
625
|
stream,
|
|
525
626
|
)
|
|
526
627
|
_cross_entropy_backward.compile_cache[compile_key](
|
|
527
|
-
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, stream
|
|
628
|
+
x_tensor, target_tensor, dloss_tensor, dx_tensor, lse_tensor, Int32(ignore_index), stream
|
|
528
629
|
)
|
|
529
|
-
return dx
|
|
530
630
|
|
|
531
631
|
|
|
532
632
|
_cross_entropy_backward.compile_cache = {}
|
|
533
633
|
|
|
534
634
|
|
|
635
|
+
@torch.library.custom_op("quack::cross_entropy_bwd_out", mutates_args={"dx"})
|
|
636
|
+
def cross_entropy_bwd_out(
|
|
637
|
+
x: torch.Tensor,
|
|
638
|
+
target: torch.Tensor,
|
|
639
|
+
dloss: torch.Tensor,
|
|
640
|
+
lse: torch.Tensor,
|
|
641
|
+
dx: torch.Tensor,
|
|
642
|
+
ignore_index: int = -100,
|
|
643
|
+
) -> None:
|
|
644
|
+
_cross_entropy_backward(x, target, dloss, lse, dx, ignore_index)
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def cross_entropy_bwd(
|
|
648
|
+
x: torch.Tensor,
|
|
649
|
+
target: torch.Tensor,
|
|
650
|
+
dloss: torch.Tensor,
|
|
651
|
+
lse: torch.Tensor,
|
|
652
|
+
ignore_index: int = -100,
|
|
653
|
+
inplace_backward: bool = False,
|
|
654
|
+
) -> None:
|
|
655
|
+
if inplace_backward and not torch.compiler.is_compiling():
|
|
656
|
+
dx = x
|
|
657
|
+
_cross_entropy_backward(
|
|
658
|
+
x=x, target=target, dloss=dloss, lse=lse, dx=x, ignore_index=ignore_index
|
|
659
|
+
)
|
|
660
|
+
else:
|
|
661
|
+
dx = torch.empty_like(x)
|
|
662
|
+
cross_entropy_bwd_out(
|
|
663
|
+
x=x, target=target, dloss=dloss, lse=lse, dx=dx, ignore_index=ignore_index
|
|
664
|
+
)
|
|
665
|
+
return dx
|
|
666
|
+
|
|
667
|
+
|
|
535
668
|
class CrossEntropyFunction(torch.autograd.Function):
|
|
536
669
|
@staticmethod
|
|
537
|
-
def forward(ctx, x, target, inplace_backward=False):
|
|
538
|
-
|
|
670
|
+
def forward(ctx, x, target, lse_partial=None, ignore_index=-100, inplace_backward=False):
|
|
671
|
+
if lse_partial is None:
|
|
672
|
+
loss, lse = cross_entropy_fwd(x, target, ignore_index=ignore_index, return_lse=True)
|
|
673
|
+
else:
|
|
674
|
+
# if we already compute partial lse, then to compute the final lse we treat
|
|
675
|
+
# @lse_partial as @x and @x as @target_logit
|
|
676
|
+
loss, lse = cross_entropy_fwd(
|
|
677
|
+
lse_partial, target, target_logit=x, ignore_index=ignore_index, return_lse=True
|
|
678
|
+
)
|
|
539
679
|
ctx.save_for_backward(x, target, lse)
|
|
680
|
+
ctx.ignore_index = ignore_index
|
|
540
681
|
ctx.inplace_backward = inplace_backward
|
|
541
682
|
return loss
|
|
542
683
|
|
|
543
684
|
@staticmethod
|
|
544
685
|
def backward(ctx, dloss):
|
|
545
686
|
x, target, lse = ctx.saved_tensors
|
|
546
|
-
dx =
|
|
547
|
-
|
|
687
|
+
dx = cross_entropy_bwd(
|
|
688
|
+
x, target, dloss, lse, ctx.ignore_index, inplace_backward=ctx.inplace_backward
|
|
689
|
+
)
|
|
690
|
+
return dx, None, None, None, None
|
|
548
691
|
|
|
549
692
|
|
|
550
693
|
def cross_entropy(
|
|
551
694
|
x: torch.Tensor,
|
|
552
695
|
target: torch.Tensor,
|
|
553
|
-
|
|
554
|
-
|
|
696
|
+
lse_partial: Optional[torch.Tensor] = None,
|
|
697
|
+
ignore_index: int = -100,
|
|
698
|
+
reduction: Literal["none", "mean", "sum"] = "mean",
|
|
699
|
+
inplace_backward: bool = False,
|
|
555
700
|
) -> torch.Tensor:
|
|
556
701
|
"""Cross entropy loss with automatic differentiation support.
|
|
557
702
|
|
|
558
703
|
Args:
|
|
559
704
|
x: Input logits tensor of shape (M, N)
|
|
560
705
|
target: Target class indices tensor of shape (M,)
|
|
561
|
-
|
|
706
|
+
lse_partial: Optional precomputed log-sum-exp partial results
|
|
562
707
|
reduction: Specifies the reduction to apply to the output:
|
|
563
708
|
'none': no reduction will be applied (default)
|
|
564
709
|
'mean': the sum of the output will be divided by the number of elements
|
|
565
710
|
'sum': the output will be summed
|
|
711
|
+
inplace_backward: Whether to perform backward pass in-place
|
|
712
|
+
ignore_index: Index to ignore in loss computation (loss will be 0 for these indices)
|
|
566
713
|
|
|
567
714
|
Returns:
|
|
568
715
|
Cross entropy loss tensor:
|
|
@@ -570,10 +717,9 @@ def cross_entropy(
|
|
|
570
717
|
- If reduction='mean': scalar tensor with mean loss
|
|
571
718
|
- If reduction='sum': scalar tensor with sum of losses
|
|
572
719
|
"""
|
|
573
|
-
loss = CrossEntropyFunction.apply(x, target, inplace_backward)
|
|
574
|
-
|
|
720
|
+
loss = CrossEntropyFunction.apply(x, target, lse_partial, ignore_index, inplace_backward)
|
|
575
721
|
if reduction == "mean":
|
|
576
|
-
return loss.
|
|
722
|
+
return loss.sum() / (target != ignore_index).sum().float()
|
|
577
723
|
elif reduction == "sum":
|
|
578
724
|
return loss.sum()
|
|
579
725
|
elif reduction == "none":
|