x-transformers 2.1.32__py3-none-any.whl → 2.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +30 -2
- x_transformers/x_transformers.py +1 -1
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.34.dist-info}/METADATA +1 -1
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.34.dist-info}/RECORD +6 -6
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.34.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.34.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@
|
|
5
5
|
# https://www.youtube.com/watch?v=aqhbRtB2Fyg
|
6
6
|
|
7
7
|
from __future__ import annotations
|
8
|
+
from random import random
|
8
9
|
|
9
10
|
import torch
|
10
11
|
from torch.autograd import Function
|
@@ -69,6 +70,8 @@ class BeliefStateWrapper(Module):
|
|
69
70
|
backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
70
71
|
pred_distance = False,
|
71
72
|
pred_distance_loss_weight: float = 1.,
|
73
|
+
cond_on_distance = False,
|
74
|
+
cond_on_distance_prob = 0.5,
|
72
75
|
max_pred_distance = None
|
73
76
|
):
|
74
77
|
super().__init__()
|
@@ -111,6 +114,21 @@ class BeliefStateWrapper(Module):
|
|
111
114
|
|
112
115
|
self.pred_distance_loss_weight = pred_distance_loss_weight
|
113
116
|
|
117
|
+
# conditioning on distance
|
118
|
+
|
119
|
+
assert 0. < cond_on_distance_prob < 1.
|
120
|
+
|
121
|
+
self.cond_on_distance = cond_on_distance
|
122
|
+
self.cond_on_distance_prob = cond_on_distance_prob
|
123
|
+
|
124
|
+
if cond_on_distance:
|
125
|
+
self.to_distance_cond = nn.Sequential(
|
126
|
+
Rearrange('... -> ... 1'),
|
127
|
+
nn.Linear(1, dim),
|
128
|
+
nn.LeakyReLU(),
|
129
|
+
nn.Linear(dim, dim * 2),
|
130
|
+
)
|
131
|
+
|
114
132
|
# the two decoders, one which is causal forward, the other causal backwards
|
115
133
|
|
116
134
|
self.forward_decoder = forward_decoder
|
@@ -344,9 +362,19 @@ class BeliefStateWrapper(Module):
|
|
344
362
|
ignore_index = -1
|
345
363
|
)
|
346
364
|
|
347
|
-
# maybe
|
365
|
+
# maybe condition on distance
|
366
|
+
|
367
|
+
cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob)
|
368
|
+
|
369
|
+
if cond_on_distance:
|
370
|
+
distance = (bi - fi).float()
|
371
|
+
distance_cond = self.to_distance_cond(distance)
|
372
|
+
|
373
|
+
fb_embeds = fb_embeds * distance_cond
|
374
|
+
|
375
|
+
# maybe predict distance
|
348
376
|
|
349
|
-
if exists(self.to_distance_logits):
|
377
|
+
if exists(self.to_distance_logits) and not cond_on_distance:
|
350
378
|
distance_logits = self.to_distance_logits(fb_embeds)
|
351
379
|
|
352
380
|
distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
|
x_transformers/x_transformers.py
CHANGED
@@ -2055,7 +2055,7 @@ class AttentionLayers(Module):
|
|
2055
2055
|
elif use_simple_rmsnorm:
|
2056
2056
|
norm_class = SimpleRMSNorm
|
2057
2057
|
elif use_dynamic_tanh:
|
2058
|
-
assert pre_norm, 'only tested for pre-norm'
|
2058
|
+
assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
|
2059
2059
|
norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
|
2060
2060
|
elif use_adaptive_layernorm:
|
2061
2061
|
norm_need_condition = True
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=bx9AIyyYQDRfq6UxOQMBEEqSoqVm5cYwqawSmJe5bqk,13414
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
8
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
9
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
10
|
+
x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.34.dist-info/METADATA,sha256=JXvbgbczvnKYCC0ccZxd5_pfT5U68nBBz8aXQRQtukw,88161
|
14
|
+
x_transformers-2.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.34.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.34.dist-info/RECORD,,
|
File without changes
|
File without changes
|