x-transformers 2.1.32__py3-none-any.whl → 2.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,7 @@
5
5
  # https://www.youtube.com/watch?v=aqhbRtB2Fyg
6
6
 
7
7
  from __future__ import annotations
8
+ from random import random
8
9
 
9
10
  import torch
10
11
  from torch.autograd import Function
@@ -53,6 +54,26 @@ def flip(x, dim = 1, lens = None):
53
54
 
54
55
  return x.gather(dim, flip_indices)
55
56
 
57
+ # detach multiple tensors and backward the gradients once
58
+
59
+ class DetachMultiple(Function):
60
+
61
+ @classmethod
62
+ def forward(self, ctx, *tensors):
63
+ detached_tensors = tuple(t.detach() for t in tensors)
64
+
65
+ for detached_tensor in detached_tensors:
66
+ detached_tensor.requires_grad_()
67
+
68
+ return detached_tensors
69
+
70
+ @classmethod
71
+ def backward(self, ctx, *grads):
72
+
73
+ return grads
74
+
75
+ detach_multiple = DetachMultiple.apply
76
+
56
77
  # wrappers
57
78
 
58
79
  class BeliefStateWrapper(Module):
@@ -69,6 +90,8 @@ class BeliefStateWrapper(Module):
69
90
  backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
70
91
  pred_distance = False,
71
92
  pred_distance_loss_weight: float = 1.,
93
+ cond_on_distance = False,
94
+ cond_on_distance_prob = 0.5,
72
95
  max_pred_distance = None
73
96
  ):
74
97
  super().__init__()
@@ -111,6 +134,21 @@ class BeliefStateWrapper(Module):
111
134
 
112
135
  self.pred_distance_loss_weight = pred_distance_loss_weight
113
136
 
137
+ # conditioning on distance
138
+
139
+ assert 0. < cond_on_distance_prob < 1.
140
+
141
+ self.cond_on_distance = cond_on_distance
142
+ self.cond_on_distance_prob = cond_on_distance_prob
143
+
144
+ if cond_on_distance:
145
+ self.to_distance_cond = nn.Sequential(
146
+ Rearrange('... -> ... 1'),
147
+ nn.Linear(1, dim),
148
+ nn.LeakyReLU(),
149
+ nn.Linear(dim, dim * 2),
150
+ )
151
+
114
152
  # the two decoders, one which is causal forward, the other causal backwards
115
153
 
116
154
  self.forward_decoder = forward_decoder
@@ -250,8 +288,6 @@ class BeliefStateWrapper(Module):
250
288
  self,
251
289
  seq,
252
290
  lens: Tensor | None = None, # Int['b']
253
- return_loss_only = False,
254
- loss_scale = 1.,
255
291
  loss_weight_by_fb_indices: callable | None = None
256
292
  ):
257
293
  batch, seq_len, device = *seq.shape, seq.device
@@ -284,11 +320,7 @@ class BeliefStateWrapper(Module):
284
320
 
285
321
  # trick to reduce memory on backwards pass
286
322
 
287
- orig_forward_embeds, forward_embeds = forward_embeds, forward_embeds.detach()
288
- orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
289
-
290
- forward_embeds.requires_grad_()
291
- backward_embeds.requires_grad_()
323
+ forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
292
324
 
293
325
  # belief state objective
294
326
 
@@ -344,9 +376,19 @@ class BeliefStateWrapper(Module):
344
376
  ignore_index = -1
345
377
  )
346
378
 
347
- # maybe predict terminal
379
+ # maybe condition on distance
380
+
381
+ cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob)
382
+
383
+ if cond_on_distance:
384
+ distance = (bi - fi).float()
385
+ distance_cond = self.to_distance_cond(distance)
386
+
387
+ fb_embeds = fb_embeds * distance_cond
348
388
 
349
- if exists(self.to_distance_logits):
389
+ # maybe predict distance
390
+
391
+ if exists(self.to_distance_logits) and not cond_on_distance:
350
392
  distance_logits = self.to_distance_logits(fb_embeds)
351
393
 
352
394
  distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
@@ -362,11 +404,6 @@ class BeliefStateWrapper(Module):
362
404
  pred_dist_loss * self.pred_distance_loss_weight
363
405
  )
364
406
 
365
- # maybe early return loss
366
-
367
- if return_loss_only:
368
- return loss
369
-
370
407
  # maybe loss weighting
371
408
 
372
409
  needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
@@ -392,11 +429,4 @@ class BeliefStateWrapper(Module):
392
429
 
393
430
  loss = loss.mean()
394
431
 
395
- # backwards
396
-
397
- (loss * loss_scale).backward()
398
-
399
- orig_forward_embeds.backward(forward_embeds.grad)
400
- orig_backward_embeds.backward(backward_embeds.grad)
401
-
402
432
  return loss
@@ -2055,7 +2055,7 @@ class AttentionLayers(Module):
2055
2055
  elif use_simple_rmsnorm:
2056
2056
  norm_class = SimpleRMSNorm
2057
2057
  elif use_dynamic_tanh:
2058
- assert pre_norm, 'only tested for pre-norm'
2058
+ assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
2059
2059
  norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2060
2060
  elif use_adaptive_layernorm:
2061
2061
  norm_need_condition = True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.32
3
+ Version: 2.1.35
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=PAGbBLU8xGmtiv_G6-RhX1Kb1GwxRmIxxkfHUI2l25U,12538
4
+ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=YQODc4PDB_ddgm7vi0uktV5GGetgEuwADzt3CaIdAXs,111484
10
+ x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.32.dist-info/METADATA,sha256=Jmgj9CByp1_kveLWVm5-QM_n55A9s3MFzUiV1ciD034,88161
14
- x_transformers-2.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.32.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.32.dist-info/RECORD,,
13
+ x_transformers-2.1.35.dist-info/METADATA,sha256=tLbl-c1QtaOphTa1DpdNfh4dXzFwTt9Fvdh94tnwdTs,88161
14
+ x_transformers-2.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.35.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.35.dist-info/RECORD,,