x-transformers 2.1.32__py3-none-any.whl → 2.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +51 -21
- x_transformers/x_transformers.py +1 -1
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.35.dist-info}/METADATA +1 -1
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.35.dist-info}/RECORD +6 -6
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.35.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.32.dist-info → x_transformers-2.1.35.dist-info}/licenses/LICENSE +0 -0
@@ -5,6 +5,7 @@
|
|
5
5
|
# https://www.youtube.com/watch?v=aqhbRtB2Fyg
|
6
6
|
|
7
7
|
from __future__ import annotations
|
8
|
+
from random import random
|
8
9
|
|
9
10
|
import torch
|
10
11
|
from torch.autograd import Function
|
@@ -53,6 +54,26 @@ def flip(x, dim = 1, lens = None):
|
|
53
54
|
|
54
55
|
return x.gather(dim, flip_indices)
|
55
56
|
|
57
|
+
# detach multiple tensors and backward the gradients once
|
58
|
+
|
59
|
+
class DetachMultiple(Function):
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def forward(self, ctx, *tensors):
|
63
|
+
detached_tensors = tuple(t.detach() for t in tensors)
|
64
|
+
|
65
|
+
for detached_tensor in detached_tensors:
|
66
|
+
detached_tensor.requires_grad_()
|
67
|
+
|
68
|
+
return detached_tensors
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def backward(self, ctx, *grads):
|
72
|
+
|
73
|
+
return grads
|
74
|
+
|
75
|
+
detach_multiple = DetachMultiple.apply
|
76
|
+
|
56
77
|
# wrappers
|
57
78
|
|
58
79
|
class BeliefStateWrapper(Module):
|
@@ -69,6 +90,8 @@ class BeliefStateWrapper(Module):
|
|
69
90
|
backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
70
91
|
pred_distance = False,
|
71
92
|
pred_distance_loss_weight: float = 1.,
|
93
|
+
cond_on_distance = False,
|
94
|
+
cond_on_distance_prob = 0.5,
|
72
95
|
max_pred_distance = None
|
73
96
|
):
|
74
97
|
super().__init__()
|
@@ -111,6 +134,21 @@ class BeliefStateWrapper(Module):
|
|
111
134
|
|
112
135
|
self.pred_distance_loss_weight = pred_distance_loss_weight
|
113
136
|
|
137
|
+
# conditioning on distance
|
138
|
+
|
139
|
+
assert 0. < cond_on_distance_prob < 1.
|
140
|
+
|
141
|
+
self.cond_on_distance = cond_on_distance
|
142
|
+
self.cond_on_distance_prob = cond_on_distance_prob
|
143
|
+
|
144
|
+
if cond_on_distance:
|
145
|
+
self.to_distance_cond = nn.Sequential(
|
146
|
+
Rearrange('... -> ... 1'),
|
147
|
+
nn.Linear(1, dim),
|
148
|
+
nn.LeakyReLU(),
|
149
|
+
nn.Linear(dim, dim * 2),
|
150
|
+
)
|
151
|
+
|
114
152
|
# the two decoders, one which is causal forward, the other causal backwards
|
115
153
|
|
116
154
|
self.forward_decoder = forward_decoder
|
@@ -250,8 +288,6 @@ class BeliefStateWrapper(Module):
|
|
250
288
|
self,
|
251
289
|
seq,
|
252
290
|
lens: Tensor | None = None, # Int['b']
|
253
|
-
return_loss_only = False,
|
254
|
-
loss_scale = 1.,
|
255
291
|
loss_weight_by_fb_indices: callable | None = None
|
256
292
|
):
|
257
293
|
batch, seq_len, device = *seq.shape, seq.device
|
@@ -284,11 +320,7 @@ class BeliefStateWrapper(Module):
|
|
284
320
|
|
285
321
|
# trick to reduce memory on backwards pass
|
286
322
|
|
287
|
-
|
288
|
-
orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
|
289
|
-
|
290
|
-
forward_embeds.requires_grad_()
|
291
|
-
backward_embeds.requires_grad_()
|
323
|
+
forward_embeds, backward_embeds = detach_multiple(forward_embeds, backward_embeds)
|
292
324
|
|
293
325
|
# belief state objective
|
294
326
|
|
@@ -344,9 +376,19 @@ class BeliefStateWrapper(Module):
|
|
344
376
|
ignore_index = -1
|
345
377
|
)
|
346
378
|
|
347
|
-
# maybe
|
379
|
+
# maybe condition on distance
|
380
|
+
|
381
|
+
cond_on_distance = self.cond_on_distance and (random() < self.cond_on_distance_prob)
|
382
|
+
|
383
|
+
if cond_on_distance:
|
384
|
+
distance = (bi - fi).float()
|
385
|
+
distance_cond = self.to_distance_cond(distance)
|
386
|
+
|
387
|
+
fb_embeds = fb_embeds * distance_cond
|
348
388
|
|
349
|
-
|
389
|
+
# maybe predict distance
|
390
|
+
|
391
|
+
if exists(self.to_distance_logits) and not cond_on_distance:
|
350
392
|
distance_logits = self.to_distance_logits(fb_embeds)
|
351
393
|
|
352
394
|
distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
|
@@ -362,11 +404,6 @@ class BeliefStateWrapper(Module):
|
|
362
404
|
pred_dist_loss * self.pred_distance_loss_weight
|
363
405
|
)
|
364
406
|
|
365
|
-
# maybe early return loss
|
366
|
-
|
367
|
-
if return_loss_only:
|
368
|
-
return loss
|
369
|
-
|
370
407
|
# maybe loss weighting
|
371
408
|
|
372
409
|
needs_loss_weight = default(self.needs_loss_weight, exists(loss_weight_by_fb_indices))
|
@@ -392,11 +429,4 @@ class BeliefStateWrapper(Module):
|
|
392
429
|
|
393
430
|
loss = loss.mean()
|
394
431
|
|
395
|
-
# backwards
|
396
|
-
|
397
|
-
(loss * loss_scale).backward()
|
398
|
-
|
399
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
400
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
401
|
-
|
402
432
|
return loss
|
x_transformers/x_transformers.py
CHANGED
@@ -2055,7 +2055,7 @@ class AttentionLayers(Module):
|
|
2055
2055
|
elif use_simple_rmsnorm:
|
2056
2056
|
norm_class = SimpleRMSNorm
|
2057
2057
|
elif use_dynamic_tanh:
|
2058
|
-
assert pre_norm, 'only tested for pre-norm'
|
2058
|
+
assert pre_norm, 'dynamic tanh norm only tested for pre-norm'
|
2059
2059
|
norm_class = partial(DynamicTanh, init_alpha = dynamic_tanh_init_alpha)
|
2060
2060
|
elif use_adaptive_layernorm:
|
2061
2061
|
norm_need_condition = True
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
8
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
9
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
10
|
+
x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.35.dist-info/METADATA,sha256=tLbl-c1QtaOphTa1DpdNfh4dXzFwTt9Fvdh94tnwdTs,88161
|
14
|
+
x_transformers-2.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.35.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.35.dist-info/RECORD,,
|
File without changes
|
File without changes
|