x-transformers 2.1.34__py3-none-any.whl → 2.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,6 +54,26 @@ def flip(x, dim = 1, lens = None):
54
54
 
55
55
  return x.gather(dim, flip_indices)
56
56
 
57
+ # detach multiple tensors and backward the gradients once
58
+
59
+ class DetachMultiple(Function):
60
+
61
+ @classmethod
62
+ def forward(self, ctx, *tensors):
63
+ detached_tensors = tuple(t.detach() for t in tensors)
64
+
65
+ for detached_tensor in detached_tensors:
66
+ detached_tensor.requires_grad_()
67
+
68
+ return detached_tensors
69
+
70
+ @classmethod
71
+ def backward(self, ctx, *grads):
72
+
73
+ return grads
74
+
75
+ detach_multiple = DetachMultiple.apply
76
+
57
77
  # wrappers
58
78
 
59
79
  class BeliefStateWrapper(Module):
@@ -268,8 +288,6 @@ class BeliefStateWrapper(Module):
268
288
  self,
269
289
  seq,
270
290
  lens: Tensor | None = None, # Int['b']
271
- return_loss_only = False,
272
- loss_scale = 1.,
273
291
  loss_weight_by_fb_indices: callable | None = None
274
292
  ):
275
293
  batch, seq_len, device = *seq.shape, seq.device
@@ -302,11 +320,7 @@ class BeliefStateWrapper(Module):
302
320
 
303
321
  # trick to reduce memory on backwards pass
304
322
 
305
- orig_forward_embeds, forward_embeds = forward_embeds, forward_embeds.detach()
306
- orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
307
-
308
- forward_embeds.requires_grad_()
309
- backward_embeds.requires_grad_()
323
+ forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
310
324
 
311
325
  # belief state objective
312
326
 
@@ -390,11 +404,6 @@ class BeliefStateWrapper(Module):
390
404
  pred_dist_loss * self.pred_distance_loss_weight
391
405
  )
392
406
 
393
- # maybe early return loss
394
-
395
- if return_loss_only:
396
- return loss
397
-
398
407
  # maybe loss weighting
399
408
 
400
409
  needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
@@ -420,11 +429,4 @@ class BeliefStateWrapper(Module):
420
429
 
421
430
  loss = loss.mean()
422
431
 
423
- # backwards
424
-
425
- (loss * loss_scale).backward()
426
-
427
- orig_forward_embeds.backward(forward_embeds.grad)
428
- orig_backward_embeds.backward(backward_embeds.grad)
429
-
430
432
  return loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.34
3
+ Version: 2.1.35
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=bx9AIyyYQDRfq6UxOQMBEEqSoqVm5cYwqawSmJe5bqk,13414
4
+ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.34.dist-info/METADATA,sha256=JXvbgbczvnKYCC0ccZxd5_pfT5U68nBBz8aXQRQtukw,88161
14
- x_transformers-2.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.34.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.34.dist-info/RECORD,,
13
+ x_transformers-2.1.35.dist-info/METADATA,sha256=tLbl-c1QtaOphTa1DpdNfh4dXzFwTt9Fvdh94tnwdTs,88161
14
+ x_transformers-2.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.35.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.35.dist-info/RECORD,,