x-transformers 1.27.21__py3-none-any.whl → 1.27.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/xval.py CHANGED
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
260
260
  inp, target = x[:, :-1], x[:, 1:]
261
261
  x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
262
262
 
263
+ # ignore index
264
+
265
+ target_mask = target != self.ignore_index
266
+
267
+ # key padding mask
268
+
263
269
  mask = kwargs.get('mask', None)
264
- if exists(mask) and mask.shape[1] == x.shape[1]:
265
- mask = mask[:, :-1]
266
- kwargs['mask'] = mask
270
+ if exists(mask):
271
+ target_mask &= mask
272
+
273
+ if mask.shape[1] == x.shape[1]:
274
+ mask = mask[:, :-1]
275
+ kwargs['mask'] = mask
267
276
 
268
277
  logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
269
278
 
@@ -276,21 +285,18 @@ class XValAutoregressiveWrapper(nn.Module):
276
285
  target_is_number_mask = target == self.net.numerical_token_id
277
286
  x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
278
287
 
279
- # ignore index
280
-
281
- target_mask = target != self.ignore_index
282
-
283
288
  # numerical mse loss
284
289
 
285
290
  numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
286
291
 
287
292
  numerical_mse_loss = numerical_mse_loss * target_mask
293
+ numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
288
294
 
289
- loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
295
+ # combine losses
290
296
 
291
- if exists(mask):
292
- loss = loss[mask]
297
+ loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
293
298
 
299
+ loss = loss[target_mask]
294
300
  loss = loss.mean()
295
301
 
296
302
  if not return_loss_breakdown:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.21
3
+ Version: 1.27.22
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,9 +6,9 @@ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
7
  x_transformers/x_transformers.py,sha256=kQhRUMGDsinzkdYcOfE1GriJ057j7D4xSjbH79FFRSE,63574
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
- x_transformers/xval.py,sha256=to8AkYfgvR5QC-ZXDnq8PPfzOyg0yutHt_QW36jsy98,8403
10
- x_transformers-1.27.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.27.21.dist-info/METADATA,sha256=g0OEysaFW4RomYhoqjDKu7W1brHg70m_N40_nmTLRtA,662
12
- x_transformers-1.27.21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.27.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.27.21.dist-info/RECORD,,
9
+ x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
+ x_transformers-1.27.22.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.27.22.dist-info/METADATA,sha256=RTbXIIpFRnve8FVp8vLQ4LE-9x59IV6ADnu34gGAZXA,662
12
+ x_transformers-1.27.22.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.27.22.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.27.22.dist-info/RECORD,,