x-transformers 1.27.21__py3-none-any.whl → 1.27.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/xval.py +16 -10
- {x_transformers-1.27.21.dist-info → x_transformers-1.27.22.dist-info}/METADATA +1 -1
- {x_transformers-1.27.21.dist-info → x_transformers-1.27.22.dist-info}/RECORD +6 -6
- {x_transformers-1.27.21.dist-info → x_transformers-1.27.22.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.21.dist-info → x_transformers-1.27.22.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.21.dist-info → x_transformers-1.27.22.dist-info}/top_level.txt +0 -0
x_transformers/xval.py
CHANGED
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
260
260
|
inp, target = x[:, :-1], x[:, 1:]
|
261
261
|
x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
|
262
262
|
|
263
|
+
# ignore index
|
264
|
+
|
265
|
+
target_mask = target != self.ignore_index
|
266
|
+
|
267
|
+
# key padding mask
|
268
|
+
|
263
269
|
mask = kwargs.get('mask', None)
|
264
|
-
if exists(mask)
|
265
|
-
|
266
|
-
|
270
|
+
if exists(mask):
|
271
|
+
target_mask &= mask
|
272
|
+
|
273
|
+
if mask.shape[1] == x.shape[1]:
|
274
|
+
mask = mask[:, :-1]
|
275
|
+
kwargs['mask'] = mask
|
267
276
|
|
268
277
|
logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
|
269
278
|
|
@@ -276,21 +285,18 @@ class XValAutoregressiveWrapper(nn.Module):
|
|
276
285
|
target_is_number_mask = target == self.net.numerical_token_id
|
277
286
|
x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
|
278
287
|
|
279
|
-
# ignore index
|
280
|
-
|
281
|
-
target_mask = target != self.ignore_index
|
282
|
-
|
283
288
|
# numerical mse loss
|
284
289
|
|
285
290
|
numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
|
286
291
|
|
287
292
|
numerical_mse_loss = numerical_mse_loss * target_mask
|
293
|
+
numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
|
288
294
|
|
289
|
-
|
295
|
+
# combine losses
|
290
296
|
|
291
|
-
|
292
|
-
loss = loss[mask]
|
297
|
+
loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
|
293
298
|
|
299
|
+
loss = loss[target_mask]
|
294
300
|
loss = loss.mean()
|
295
301
|
|
296
302
|
if not return_loss_breakdown:
|
@@ -6,9 +6,9 @@ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
7
|
x_transformers/x_transformers.py,sha256=kQhRUMGDsinzkdYcOfE1GriJ057j7D4xSjbH79FFRSE,63574
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
-
x_transformers/xval.py,sha256=
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
14
|
-
x_transformers-1.27.
|
9
|
+
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
+
x_transformers-1.27.22.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.22.dist-info/METADATA,sha256=RTbXIIpFRnve8FVp8vLQ4LE-9x59IV6ADnu34gGAZXA,662
|
12
|
+
x_transformers-1.27.22.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.27.22.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|