x-transformers 1.27.20__tar.gz → 1.27.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {x_transformers-1.27.20/x_transformers.egg-info → x_transformers-1.27.22}/PKG-INFO +1 -1
  2. {x_transformers-1.27.20 → x_transformers-1.27.22}/README.md +1 -2
  3. {x_transformers-1.27.20 → x_transformers-1.27.22}/setup.py +1 -1
  4. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/xval.py +22 -7
  5. {x_transformers-1.27.20 → x_transformers-1.27.22/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.27.20 → x_transformers-1.27.22}/LICENSE +0 -0
  7. {x_transformers-1.27.20 → x_transformers-1.27.22}/setup.cfg +0 -0
  8. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/__init__.py +0 -0
  9. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/attend.py +0 -0
  10. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/autoregressive_wrapper.py +0 -0
  11. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/continuous.py +0 -0
  12. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/dpo.py +0 -0
  13. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/nonautoregressive_wrapper.py +0 -0
  14. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/x_transformers.py +0 -0
  15. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  16. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/SOURCES.txt +0 -0
  17. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/dependency_links.txt +0 -0
  18. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/requires.txt +0 -0
  19. {x_transformers-1.27.20 → x_transformers-1.27.22}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.20
3
+ Version: 1.27.22
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1437,11 +1437,10 @@ model = XValAutoregressiveWrapper(model)
1437
1437
 
1438
1438
  ids = torch.randint(0, 4, (1, 777))
1439
1439
  nums = torch.randn(1, 777)
1440
- mask = torch.ones(1, 777).bool()
1441
1440
 
1442
1441
  # train on a lot of data above
1443
1442
 
1444
- loss = model(ids, nums, mask = mask)
1443
+ loss = model(ids, nums)
1445
1444
  loss.backward()
1446
1445
 
1447
1446
  # then generate
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.27.20',
6
+ version = '1.27.22',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -260,10 +260,19 @@ class XValAutoregressiveWrapper(nn.Module):
260
260
  inp, target = x[:, :-1], x[:, 1:]
261
261
  x_num_inp, x_num_target = x_num[:, :-1], x_num[:, 1:]
262
262
 
263
+ # ignore index
264
+
265
+ target_mask = target != self.ignore_index
266
+
267
+ # key padding mask
268
+
263
269
  mask = kwargs.get('mask', None)
264
- if exists(mask) and mask.shape[1] == x.shape[1]:
265
- mask = mask[:, :-1]
266
- kwargs['mask'] = mask
270
+ if exists(mask):
271
+ target_mask &= mask
272
+
273
+ if mask.shape[1] == x.shape[1]:
274
+ mask = mask[:, :-1]
275
+ kwargs['mask'] = mask
267
276
 
268
277
  logits, numerical_pred = self.net(inp, x_num_inp, **kwargs)
269
278
 
@@ -271,17 +280,23 @@ class XValAutoregressiveWrapper(nn.Module):
271
280
 
272
281
  cross_entropy_loss = F.cross_entropy(logits, target, reduction = 'none', ignore_index = self.ignore_index)
273
282
 
274
- target_mask = target != self.ignore_index
283
+ # protect against nan in `x_num` input tensor
284
+
285
+ target_is_number_mask = target == self.net.numerical_token_id
286
+ x_num_target = x_num_target.masked_fill(~target_is_number_mask, 0.)
287
+
288
+ # numerical mse loss
275
289
 
276
290
  numerical_mse_loss = F.mse_loss(numerical_pred, x_num_target, reduction = 'none')
277
291
 
278
292
  numerical_mse_loss = numerical_mse_loss * target_mask
293
+ numerical_mse_loss = numerical_mse_loss.masked_fill(~target_is_number_mask, 0.)
279
294
 
280
- loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
295
+ # combine losses
281
296
 
282
- if exists(mask):
283
- loss = loss[mask]
297
+ loss = cross_entropy_loss + numerical_mse_loss * self.numerical_loss_weight
284
298
 
299
+ loss = loss[target_mask]
285
300
  loss = loss.mean()
286
301
 
287
302
  if not return_loss_breakdown:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.20
3
+ Version: 1.27.22
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang