x-transformers 2.1.32__py3-none-any.whl → 2.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,7 @@
5
5
  # https://www.youtube.com/watch?v=aqhbRtB2Fyg
6
6
 
7
7
  from __future__ import annotations
8
+ from random import random
8
9
 
9
10
  import torch
10
11
  from torch.autograd import Function
@@ -69,6 +70,8 @@ class BeliefStateWrapper(Module):
69
70
  backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
70
71
  pred_distance = False,
71
72
  pred_distance_loss_weight: float = 1.,
73
+ cond_on_distance = False,
74
+ cond_on_distance_prob = 0.5,
72
75
  max_pred_distance = None
73
76
  ):
74
77
  super().__init__()
@@ -111,6 +114,21 @@ class BeliefStateWrapper(Module):
111
114
 
112
115
  self.pred_distance_loss_weight = pred_distance_loss_weight
113
116
 
117
+ # conditioning on distance
118
+
119
+ assert 0. < cond_on_distance_prob < 1.
120
+
121
+ self.cond_on_distance = cond_on_distance
122
+ self.cond_on_distance_prob = cond_on_distance_prob
123
+
124
+ if cond_on_distance:
125
+ self.to_distance_cond = nn.Sequential(
126
+ Rearrange('... -> ... 1'),
127
+ nn.Linear(1, dim),
128
+ nn.LeakyReLU(),
129
+ nn.Linear(dim, dim * 2),
130
+ )
131
+
114
132
  # the two decoders, one which is causal forward, the other causal backwards
115
133
 
116
134
  self.forward_decoder = forward_decoder
@@ -344,9 +362,19 @@ class BeliefStateWrapper(Module):
344
362
  ignore_index = -1
345
363
  )
346
364
 
347
- # maybe predict terminal
365
+ # maybe condition on distance
366
+
367
+ cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob)
368
+
369
+ if cond_on_distance:
370
+ distance = (bi - fi).float()
371
+ distance_cond = self.to_distance_cond(distance)
372
+
373
+ fb_embeds = fb_embeds * distance_cond
374
+
375
+ # maybe predict distance
348
376
 
349
- if exists(self.to_distance_logits):
377
+ if exists(self.to_distance_logits) and not cond_on_distance:
350
378
  distance_logits = self.to_distance_logits(fb_embeds)
351
379
 
352
380
  distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
@@ -2055,7 +2055,7 @@ class AttentionLayers(Module):
2055
2055
  elif use_simple_rmsnorm:
2056
2056
  norm_class = SimpleRMSNorm
2057
2057
  elif use_dynamic_tanh:
2058
- assert pre_norm, 'only tested for pre-norm'
2058
+ assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
2059
2059
  norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
2060
2060
  elif use_adaptive_layernorm:
2061
2061
  norm_need_condition = True
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.32
3
+ Version: 2.1.34
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=PAGbBLU8xGmtiv_G6-RhX1Kb1GwxRmIxxkfHUI2l25U,12538
4
+ x_transformers/belief_state_wrapper.py,sha256=bx9AIyyYQDRfq6UxOQMBEEqSoqVm5cYwqawSmJe5bqk,13414
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=YQODc4PDB_ddgm7vi0uktV5GGetgEuwADzt3CaIdAXs,111484
10
+ x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.32.dist-info/METADATA,sha256=Jmgj9CByp1_kveLWVm5-QM_n55A9s3MFzUiV1ciD034,88161
14
- x_transformers-2.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.32.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.32.dist-info/RECORD,,
13
+ x_transformers-2.1.34.dist-info/METADATA,sha256=JXvbgbczvnKYCC0ccZxd5_pfT5U68nBBz8aXQRQtukw,88161
14
+ x_transformers-2.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.34.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.34.dist-info/RECORD,,