x-transformers 2.1.4__tar.gz → 2.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.1.4 → x_transformers-2.1.6}/PKG-INFO +1 -1
- {x_transformers-2.1.4 → x_transformers-2.1.6}/pyproject.toml +1 -1
- {x_transformers-2.1.4 → x_transformers-2.1.6}/tests/test_x_transformers.py +5 -3
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/__init__.py +1 -0
- x_transformers-2.1.4/x_transformers/belief_state.py → x_transformers-2.1.6/x_transformers/belief_state_wrapper.py +54 -8
- {x_transformers-2.1.4 → x_transformers-2.1.6}/.github/FUNDING.yml +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/.gitignore +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/LICENSE +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/README.md +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/data/README.md +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/data/enwik8.gz +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/all-attention.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/attention-on-attention.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/deepnorm.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/fcm.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/ffglu.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/flash-attention.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/gate_values.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/gating.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/macaron-1.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/macaron-2.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/memory-transformer.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/normformer.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/pia.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/resi_dual.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/residual_attn.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/rezero.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/rotary.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/sandwich-2.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/sandwich.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/sandwich_norm.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/scalenorm.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/talking-heads.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/topk-attention.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/images/xval.png +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/train_copy.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/train_enwik8.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/train_length_extrapolate.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/train_parity.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/attend.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/continuous.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/dpo.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.4 → x_transformers-2.1.6}/x_transformers/xval.py +0 -0
@@ -695,7 +695,7 @@ def test_lime(
|
|
695
695
|
model(x)
|
696
696
|
|
697
697
|
def test_belief_state_wrapper():
|
698
|
-
from x_transformers.
|
698
|
+
from x_transformers.belief_state_wrapper import BeliefStateWrapper
|
699
699
|
|
700
700
|
forward_model = TransformerWrapper(
|
701
701
|
num_tokens = 20000,
|
@@ -721,9 +721,11 @@ def test_belief_state_wrapper():
|
|
721
721
|
|
722
722
|
model = BeliefStateWrapper(
|
723
723
|
forward_decoder = forward_model,
|
724
|
-
backward_decoder = backward_model
|
724
|
+
backward_decoder = backward_model,
|
725
|
+
backward_ar_loss_weight = 0.5
|
725
726
|
)
|
726
727
|
|
727
728
|
seq = torch.randint(0, 20000, (2, 16))
|
728
729
|
|
729
|
-
loss = model(seq)
|
730
|
+
loss = model(seq, backward = False)
|
731
|
+
loss.backward()
|
@@ -14,6 +14,7 @@ from x_transformers.x_transformers import (
|
|
14
14
|
|
15
15
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
16
16
|
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
17
|
+
from x_transformers.belief_state_wrapper import BeliefStateWrapper
|
17
18
|
|
18
19
|
from x_transformers.continuous import (
|
19
20
|
ContinuousTransformerWrapper,
|
@@ -6,7 +6,7 @@
|
|
6
6
|
import torch
|
7
7
|
from torch.autograd import Function
|
8
8
|
from torch.nn import Module, ModuleList
|
9
|
-
from torch import nn, cat, stack, arange, cartesian_prod
|
9
|
+
from torch import nn, cat, stack, tensor, arange, cartesian_prod
|
10
10
|
import torch.nn.functional as F
|
11
11
|
|
12
12
|
from x_transformers.x_transformers import (
|
@@ -35,11 +35,13 @@ class BeliefStateWrapper(Module):
|
|
35
35
|
def __init__(
|
36
36
|
self,
|
37
37
|
forward_decoder: TransformerWrapper,
|
38
|
-
backward_decoder: TransformerWrapper
|
38
|
+
backward_decoder: TransformerWrapper,
|
39
|
+
train_frac_forward_backward_pairs: float = 1.,
|
40
|
+
backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
39
41
|
):
|
40
42
|
super().__init__()
|
41
43
|
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
|
42
|
-
assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same
|
44
|
+
assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
|
43
45
|
|
44
46
|
dim = forward_decoder.emb_dim
|
45
47
|
num_tokens = forward_decoder.num_tokens
|
@@ -47,6 +49,7 @@ class BeliefStateWrapper(Module):
|
|
47
49
|
# the suffix token
|
48
50
|
|
49
51
|
self.suffix_token = nn.Parameter(torch.zeros(dim))
|
52
|
+
nn.init.normal_(self.suffix_token, std = 0.02)
|
50
53
|
|
51
54
|
# the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
|
52
55
|
|
@@ -61,9 +64,24 @@ class BeliefStateWrapper(Module):
|
|
61
64
|
self.forward_decoder = forward_decoder
|
62
65
|
self.backward_decoder = backward_decoder
|
63
66
|
|
67
|
+
# what fraction of forward backward pairs to train on
|
68
|
+
# for further memory efficiency
|
69
|
+
|
70
|
+
assert 0 < train_frac_forward_backward_pairs <= 1.
|
71
|
+
self.train_frac_fb_pairs = train_frac_forward_backward_pairs
|
72
|
+
self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
|
73
|
+
|
74
|
+
# loss weighting
|
75
|
+
|
76
|
+
self.backward_ar_loss_weight = backward_ar_loss_weight
|
77
|
+
self.needs_loss_weight = backward_ar_loss_weight != 1.
|
78
|
+
|
79
|
+
self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
|
80
|
+
|
64
81
|
def forward(
|
65
82
|
self,
|
66
|
-
seq
|
83
|
+
seq,
|
84
|
+
backward = True
|
67
85
|
):
|
68
86
|
batch, seq_len, device = *seq.shape, seq.device
|
69
87
|
|
@@ -107,6 +125,17 @@ class BeliefStateWrapper(Module):
|
|
107
125
|
|
108
126
|
fb_pairs = fb_pairs[valid_mask]
|
109
127
|
|
128
|
+
# maybe subsample fb pairs
|
129
|
+
|
130
|
+
if self.needs_subsample_fb_pairs:
|
131
|
+
num_pairs = fb_pairs.shape[0]
|
132
|
+
|
133
|
+
num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1)
|
134
|
+
|
135
|
+
rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled]
|
136
|
+
|
137
|
+
fb_pairs = fb_pairs[rand_subsampled_indices]
|
138
|
+
|
110
139
|
# get labels for both
|
111
140
|
|
112
141
|
fi, bi = fb_pairs.unbind(dim = -1)
|
@@ -129,14 +158,31 @@ class BeliefStateWrapper(Module):
|
|
129
158
|
|
130
159
|
fb_loss = F.cross_entropy(
|
131
160
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
132
|
-
rearrange(labels, 'b n fb -> b (fb n)')
|
161
|
+
rearrange(labels, 'b n fb -> b (fb n)'),
|
162
|
+
reduction = 'none' if self.needs_loss_weight else 'mean'
|
133
163
|
)
|
134
164
|
|
165
|
+
# maybe loss weighting
|
166
|
+
|
167
|
+
if self.needs_loss_weight:
|
168
|
+
fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n')
|
169
|
+
fb_loss = fb_loss * self.fwd_bwd_loss_weights
|
170
|
+
fb_loss = fb_loss.mean()
|
171
|
+
|
135
172
|
# backwards
|
136
173
|
|
137
|
-
fb_loss
|
174
|
+
orig_backward = getattr(fb_loss, 'backward')
|
175
|
+
|
176
|
+
def patched_backward_fn(*args, **kwargs):
|
177
|
+
orig_backward(*args, **kwargs)
|
178
|
+
orig_forward_embeds.backward(forward_embeds.grad)
|
179
|
+
orig_backward_embeds.backward(backward_embeds.grad)
|
180
|
+
|
181
|
+
# can allow the researcher to call .backward from the outside
|
138
182
|
|
139
|
-
|
140
|
-
|
183
|
+
if backward:
|
184
|
+
patched_backward_fn()
|
185
|
+
else:
|
186
|
+
setattr(fb_loss, 'backward', patched_backward_fn)
|
141
187
|
|
142
188
|
return fb_loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|