x-transformers 1.27.20__py3-none-any.whl → 1.27.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/xval.py +22 -7
- {x_transformers-1.27.20.dist-info → x_transformers-1.27.22.dist-info}/METADATA +1 -1
- {x_transformers-1.27.20.dist-info → x_transformers-1.27.22.dist-info}/RECORD +6 -6
- {x_transformers-1.27.20.dist-info → x_transformers-1.27.22.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.20.dist-info → x_transformers-1.27.22.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.20.dist-info → x_transformers-1.27.22.dist-info}/top_level.txt +0 -0
x_transformers/xval.py
CHANGED
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
260
260
|
inp, target = x[:, :-1], x[:, 1:]
|
261
261
|
x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
|
262
262
|
|
263
|
+
# ignore index
|
264
|
+
|
265
|
+
target_mask = target != self.ignore_index
|
266
|
+
|
267
|
+
# key padding mask
|
268
|
+
|
263
269
|
mask = kwargs.get('mask', None)
|
264
|
-
if exists(mask)
|
265
|
-
|
266
|
-
|
270
|
+
if exists(mask):
|
271
|
+
target_mask &= mask
|
272
|
+
|
273
|
+
if mask.shape[1] == x.shape[1]:
|
274
|
+
mask = mask[:, :-1]
|
275
|
+
kwargs['mask'] = mask
|
267
276
|
|
268
277
|
logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
|
269
278
|
|
@@ -271,17 +280,23 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
271
280
|
|
272
281
|
cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
|
273
282
|
|
274
|
-
|
283
|
+
# protect against nan in `x_num` input tensor
|
284
|
+
|
285
|
+
target_is_number_mask = target == self.net.numerical_token_id
|
286
|
+
x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
|
287
|
+
|
288
|
+
# numerical mse loss
|
275
289
|
|
276
290
|
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
|
277
291
|
|
278
292
|
numerical_mse_loss = numerical_mse_loss * target_mask
|
293
|
+
numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
|
279
294
|
|
280
|
-
|
295
|
+
# combine losses
|
281
296
|
|
282
|
-
|
283
|
-
loss = loss[mask]
|
297
|
+
loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
|
284
298
|
|
299
|
+
loss = loss[target_mask]
|
285
300
|
loss = loss.mean()
|
286
301
|
|
287
302
|
if not return_loss_breakdown:
|
@@ -6,9 +6,9 @@ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
7
|
x_transformers/x_transformers.py,sha256=kQhRUMGDsinzkdYcOfE1GriJ057j7D4xSjbH79FFRSE,63574
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
-
x_transformers/xval.py,sha256=
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
14
|
-
x_transformers-1.27.
|
9
|
+
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
+
x_transformers-1.27.22.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.22.dist-info/METADATA,sha256=RTbXIIpFRnve8FVp8vLQ4LE-9x59IV6ADnu34gGAZXA,662
|
12
|
+
x_transformers-1.27.22.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.27.22.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|