x-transformers 2.1.5__tar.gz → 2.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.1.5 → x_transformers-2.1.7}/PKG-INFO +1 -1
- {x_transformers-2.1.5 → x_transformers-2.1.7}/pyproject.toml +1 -1
- {x_transformers-2.1.5 → x_transformers-2.1.7}/tests/test_x_transformers.py +5 -3
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/__init__.py +1 -0
- x_transformers-2.1.5/x_transformers/belief_state.py → x_transformers-2.1.7/x_transformers/belief_state_wrapper.py +33 -7
- {x_transformers-2.1.5 → x_transformers-2.1.7}/.github/FUNDING.yml +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/.gitignore +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/LICENSE +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/README.md +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/data/README.md +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/data/enwik8.gz +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/all-attention.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/attention-on-attention.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/deepnorm.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/fcm.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/ffglu.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/flash-attention.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/gate_values.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/gating.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/macaron-1.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/macaron-2.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/memory-transformer.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/normformer.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/pia.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/resi_dual.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/residual_attn.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/rezero.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/rotary.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/sandwich-2.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/sandwich.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/sandwich_norm.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/scalenorm.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/talking-heads.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/topk-attention.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/images/xval.png +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/train_copy.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/train_enwik8.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/train_length_extrapolate.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/train_parity.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/attend.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/continuous.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/dpo.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.5 → x_transformers-2.1.7}/x_transformers/xval.py +0 -0
@@ -695,7 +695,7 @@ def test_lime(
|
|
695
695
|
model(x)
|
696
696
|
|
697
697
|
def test_belief_state_wrapper():
|
698
|
-
from x_transformers.
|
698
|
+
from x_transformers.belief_state_wrapper import BeliefStateWrapper
|
699
699
|
|
700
700
|
forward_model = TransformerWrapper(
|
701
701
|
num_tokens = 20000,
|
@@ -721,9 +721,11 @@ def test_belief_state_wrapper():
|
|
721
721
|
|
722
722
|
model = BeliefStateWrapper(
|
723
723
|
forward_decoder = forward_model,
|
724
|
-
backward_decoder = backward_model
|
724
|
+
backward_decoder = backward_model,
|
725
|
+
backward_ar_loss_weight = 0.5
|
725
726
|
)
|
726
727
|
|
727
728
|
seq = torch.randint(0, 20000, (2, 16))
|
728
729
|
|
729
|
-
loss = model(seq)
|
730
|
+
loss = model(seq, backward = False)
|
731
|
+
loss.backward()
|
@@ -14,6 +14,7 @@ from x_transformers.x_transformers import (
|
|
14
14
|
|
15
15
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
16
16
|
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
17
|
+
from x_transformers.belief_state_wrapper import BeliefStateWrapper
|
17
18
|
|
18
19
|
from x_transformers.continuous import (
|
19
20
|
ContinuousTransformerWrapper,
|
@@ -6,7 +6,7 @@
|
|
6
6
|
import torch
|
7
7
|
from torch.autograd import Function
|
8
8
|
from torch.nn import Module, ModuleList
|
9
|
-
from torch import nn, cat, stack, arange, cartesian_prod
|
9
|
+
from torch import nn, cat, stack, tensor, arange, cartesian_prod
|
10
10
|
import torch.nn.functional as F
|
11
11
|
|
12
12
|
from x_transformers.x_transformers import (
|
@@ -36,7 +36,8 @@ class BeliefStateWrapper(Module):
|
|
36
36
|
self,
|
37
37
|
forward_decoder: TransformerWrapper,
|
38
38
|
backward_decoder: TransformerWrapper,
|
39
|
-
train_frac_forward_backward_pairs: float = 1
|
39
|
+
train_frac_forward_backward_pairs: float = 1.,
|
40
|
+
backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
40
41
|
):
|
41
42
|
super().__init__()
|
42
43
|
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
|
@@ -70,9 +71,17 @@ class BeliefStateWrapper(Module):
|
|
70
71
|
self.train_frac_fb_pairs = train_frac_forward_backward_pairs
|
71
72
|
self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
|
72
73
|
|
74
|
+
# loss weighting
|
75
|
+
|
76
|
+
self.backward_ar_loss_weight = backward_ar_loss_weight
|
77
|
+
self.needs_loss_weight = backward_ar_loss_weight != 1.
|
78
|
+
|
79
|
+
self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
|
80
|
+
|
73
81
|
def forward(
|
74
82
|
self,
|
75
|
-
seq
|
83
|
+
seq,
|
84
|
+
backward = True
|
76
85
|
):
|
77
86
|
batch, seq_len, device = *seq.shape, seq.device
|
78
87
|
|
@@ -149,14 +158,31 @@ class BeliefStateWrapper(Module):
|
|
149
158
|
|
150
159
|
fb_loss = F.cross_entropy(
|
151
160
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
152
|
-
rearrange(labels, 'b n fb -> b (fb n)')
|
161
|
+
rearrange(labels, 'b n fb -> b (fb n)'),
|
162
|
+
reduction = 'none' if self.needs_loss_weight else 'mean'
|
153
163
|
)
|
154
164
|
|
165
|
+
# maybe loss weighting
|
166
|
+
|
167
|
+
if self.needs_loss_weight:
|
168
|
+
fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n', fb = 2)
|
169
|
+
fb_loss = einx.multiply('b fb n, fb', fb_loss, self.loss_weights)
|
170
|
+
fb_loss = fb_loss.mean()
|
171
|
+
|
155
172
|
# backwards
|
156
173
|
|
157
|
-
fb_loss
|
174
|
+
orig_backward = getattr(fb_loss, 'backward')
|
175
|
+
|
176
|
+
def patched_backward_fn(*args, **kwargs):
|
177
|
+
orig_backward(*args, **kwargs)
|
178
|
+
orig_forward_embeds.backward(forward_embeds.grad)
|
179
|
+
orig_backward_embeds.backward(backward_embeds.grad)
|
180
|
+
|
181
|
+
# can allow the researcher to call .backward from the outside
|
158
182
|
|
159
|
-
|
160
|
-
|
183
|
+
if backward:
|
184
|
+
patched_backward_fn()
|
185
|
+
else:
|
186
|
+
setattr(fb_loss, 'backward', patched_backward_fn)
|
161
187
|
|
162
188
|
return fb_loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|