x-transformers 2.4.5__tar.gz → 2.4.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.4.5 → x_transformers-2.4.7}/PKG-INFO +1 -1
- {x_transformers-2.4.5 → x_transformers-2.4.7}/pyproject.toml +1 -1
- {x_transformers-2.4.5 → x_transformers-2.4.7}/tests/test_x_transformers.py +9 -2
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/up_wrapper.py +30 -7
- {x_transformers-2.4.5 → x_transformers-2.4.7}/.github/FUNDING.yml +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/.gitignore +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/LICENSE +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/README.md +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/data/README.md +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/data/enwik8.gz +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/all-attention.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/attention-on-attention.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/deepnorm.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/fcm.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/ffglu.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/flash-attention.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/gate_values.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/gating.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/macaron-1.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/macaron-2.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/memory-transformer.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/normformer.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/pia.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/resi_dual.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/residual_attn.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/rezero.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/rotary.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/sandwich-2.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/sandwich.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/sandwich_norm.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/scalenorm.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/talking-heads.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/topk-attention.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/images/xval.png +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/train_belief_state.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/train_copy.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/train_enwik8.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/train_length_extrapolate.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/train_parity.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/__init__.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/attend.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/continuous.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/dpo.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.4.5 → x_transformers-2.4.7}/x_transformers/xval.py +0 -0
|
@@ -1100,7 +1100,10 @@ def add_attn_pool():
|
|
|
1100
1100
|
|
|
1101
1101
|
assert intermediates.attn_pooled_tokens.shape[1] == 3
|
|
1102
1102
|
|
|
1103
|
-
|
|
1103
|
+
@pytest.mark.parametrize('keep_buffer_on_cpu', (False, True))
|
|
1104
|
+
def test_up(
|
|
1105
|
+
keep_buffer_on_cpu
|
|
1106
|
+
):
|
|
1104
1107
|
from x_transformers.up_wrapper import UniversalPretrainWrapper
|
|
1105
1108
|
|
|
1106
1109
|
model = TransformerWrapper(
|
|
@@ -1115,7 +1118,11 @@ def test_up():
|
|
|
1115
1118
|
),
|
|
1116
1119
|
)
|
|
1117
1120
|
|
|
1118
|
-
up_wrapper = UniversalPretrainWrapper(
|
|
1121
|
+
up_wrapper = UniversalPretrainWrapper(
|
|
1122
|
+
model,
|
|
1123
|
+
seq_len = 16,
|
|
1124
|
+
keep_buffer_on_cpu = keep_buffer_on_cpu
|
|
1125
|
+
)
|
|
1119
1126
|
|
|
1120
1127
|
loss = up_wrapper()
|
|
1121
1128
|
loss.backward()
|
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
|
6
6
|
from random import randrange, uniform
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from torch import nn, cat, randperm
|
|
9
|
+
from torch import nn, cat, tensor, randperm
|
|
10
10
|
from torch.nn import LSTM, Module
|
|
11
11
|
|
|
12
12
|
from x_transformers.x_transformers import (
|
|
@@ -76,6 +76,13 @@ class SyntheticDataGenerator(Module):
|
|
|
76
76
|
|
|
77
77
|
self.apply(self.init_)
|
|
78
78
|
|
|
79
|
+
def reset_(self):
|
|
80
|
+
for m in self.modules():
|
|
81
|
+
if hasattr(m, 'reset_parameters'):
|
|
82
|
+
m.reset_parameters()
|
|
83
|
+
|
|
84
|
+
self.apply(self.init_)
|
|
85
|
+
|
|
79
86
|
@torch.no_grad()
|
|
80
87
|
def init_(self, m):
|
|
81
88
|
if isinstance(m, nn.Linear):
|
|
@@ -138,7 +145,9 @@ class UniversalPretrainWrapper(Module):
|
|
|
138
145
|
num_reset = 20,
|
|
139
146
|
batch_size = 32,
|
|
140
147
|
seq_len = 512,
|
|
141
|
-
seed_length = 8
|
|
148
|
+
seed_length = 8,
|
|
149
|
+
reset_turing_machine_every = 0,
|
|
150
|
+
keep_buffer_on_cpu = False
|
|
142
151
|
):
|
|
143
152
|
super().__init__()
|
|
144
153
|
|
|
@@ -157,6 +166,8 @@ class UniversalPretrainWrapper(Module):
|
|
|
157
166
|
max_seq_len = seq_len
|
|
158
167
|
)
|
|
159
168
|
|
|
169
|
+
self.reset_turing_machine_every = reset_turing_machine_every
|
|
170
|
+
|
|
160
171
|
self.seq_len = seq_len
|
|
161
172
|
self.data_generator = data_generator
|
|
162
173
|
|
|
@@ -175,11 +186,16 @@ class UniversalPretrainWrapper(Module):
|
|
|
175
186
|
|
|
176
187
|
init_data_buffer = self.random_sequences_fn(buffer_size // 2, buffer_size // 2)
|
|
177
188
|
|
|
178
|
-
|
|
189
|
+
if keep_buffer_on_cpu:
|
|
190
|
+
self.synth_data_buffer = init_data_buffer
|
|
191
|
+
else:
|
|
192
|
+
self.register_buffer('synth_data_buffer', init_data_buffer)
|
|
193
|
+
|
|
194
|
+
self.register_buffer('step', tensor(0))
|
|
179
195
|
|
|
180
196
|
@property
|
|
181
197
|
def device(self):
|
|
182
|
-
return self.
|
|
198
|
+
return self.step.device
|
|
183
199
|
|
|
184
200
|
def get_rand_sequences_from_buffer(self, size = None):
|
|
185
201
|
size = default(size, self.batch_size)
|
|
@@ -206,10 +222,17 @@ class UniversalPretrainWrapper(Module):
|
|
|
206
222
|
|
|
207
223
|
generated = self.data_generator.generate(
|
|
208
224
|
self.seq_len,
|
|
209
|
-
condition = conditions,
|
|
210
|
-
seed = seeds
|
|
225
|
+
condition = conditions.to(self.device),
|
|
226
|
+
seed = seeds.to(self.device)
|
|
211
227
|
)
|
|
212
228
|
|
|
229
|
+
self.step.add_(1)
|
|
230
|
+
|
|
231
|
+
# maybe reset turing machine
|
|
232
|
+
|
|
233
|
+
if self.reset_turing_machine_every > 0 and divisible_by(self.step.item(), self.reset_turing_machine_every):
|
|
234
|
+
self.data_generator.reset_()
|
|
235
|
+
|
|
213
236
|
# reset
|
|
214
237
|
|
|
215
238
|
if self.num_reset > 0:
|
|
@@ -226,6 +249,6 @@ class UniversalPretrainWrapper(Module):
|
|
|
226
249
|
|
|
227
250
|
# sample yet again according to pseudocode
|
|
228
251
|
|
|
229
|
-
data = self.get_rand_sequences_from_buffer()
|
|
252
|
+
data = self.get_rand_sequences_from_buffer().to(self.device)
|
|
230
253
|
|
|
231
254
|
return self.ar_wrapped(data)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|