x-transformers 2.8.0__tar.gz → 2.8.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.8.0 → x_transformers-2.8.1}/PKG-INFO +1 -1
- {x_transformers-2.8.0 → x_transformers-2.8.1}/pyproject.toml +1 -1
- {x_transformers-2.8.0 → x_transformers-2.8.1}/tests/test_x_transformers.py +20 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/gpt_vae.py +13 -5
- {x_transformers-2.8.0 → x_transformers-2.8.1}/.github/FUNDING.yml +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/.gitignore +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/LICENSE +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/README.md +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/data/README.md +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/data/enwik8.gz +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/all-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/attention-on-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/deepnorm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/fcm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/ffglu.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/flash-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/gate_values.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/gating.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/macaron-1.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/macaron-2.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/memory-transformer.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/normformer.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/pia.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/resi_dual.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/residual_attn.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/rezero.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/rotary.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/sandwich-2.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/sandwich.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/sandwich_norm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/scalenorm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/talking-heads.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/topk-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/images/xval.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_belief_state.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_copy.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_enwik8.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_gpt_vae.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_length_extrapolate.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/train_parity.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/__init__.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/attend.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/continuous.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/dpo.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.1}/x_transformers/xval.py +0 -0
@@ -1340,3 +1340,23 @@ def test_qk_clip_attn_layers():
|
|
1340
1340
|
out, intermediates = model(seq, return_intermediates = True)
|
1341
1341
|
|
1342
1342
|
model.attn_qk_clip_(intermediates)
|
1343
|
+
|
1344
|
+
def test_vae():
|
1345
|
+
from x_transformers.gpt_vae import GPTVAE
|
1346
|
+
|
1347
|
+
model = GPTVAE(
|
1348
|
+
num_tokens = 256,
|
1349
|
+
max_seq_len = 1024,
|
1350
|
+
dim = 512,
|
1351
|
+
depth = 4,
|
1352
|
+
enc_depth = 2
|
1353
|
+
)
|
1354
|
+
|
1355
|
+
seq = torch.randint(0, 256, (1, 1024))
|
1356
|
+
|
1357
|
+
loss = model(seq)
|
1358
|
+
loss.backward()
|
1359
|
+
|
1360
|
+
style = torch.randint(0, 256, (1, 1024))
|
1361
|
+
|
1362
|
+
out = model.generate(seq[:, :512], 512, seq_for_latents = style)
|
@@ -68,10 +68,13 @@ class GPTVAE(Module):
|
|
68
68
|
|
69
69
|
self.to_latent_mean_log_variance = nn.Sequential(
|
70
70
|
nn.Linear(dim, dim_latent * 2),
|
71
|
-
Rearrange('b (two d) -> two b
|
71
|
+
Rearrange('b (two d) -> two b d', two = 2)
|
72
72
|
)
|
73
73
|
|
74
|
-
self.from_latent_to_prepend_token = nn.
|
74
|
+
self.from_latent_to_prepend_token = nn.Sequential(
|
75
|
+
nn.Linear(dim_latent, dim),
|
76
|
+
Rearrange('b d -> b 1 d')
|
77
|
+
)
|
75
78
|
|
76
79
|
self.decoder = TransformerWrapper(
|
77
80
|
num_tokens = num_tokens,
|
@@ -126,11 +129,19 @@ class GPTVAE(Module):
|
|
126
129
|
prompts,
|
127
130
|
seq_len,
|
128
131
|
latents = None,
|
132
|
+
seq_for_latents = None,
|
129
133
|
**generate_kwargs
|
130
134
|
):
|
131
135
|
assert prompts.ndim in {1, 2}
|
132
136
|
batch = prompts.shape[0] if prompts.ndim == 2 else 1
|
133
137
|
|
138
|
+
# if seq_for_latents passed in, derive latents from it
|
139
|
+
|
140
|
+
if exists(seq_for_latents):
|
141
|
+
assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
|
142
|
+
|
143
|
+
latents = self.encode_to_latents(seq_for_latents)
|
144
|
+
|
134
145
|
# prepend embeds
|
135
146
|
|
136
147
|
prepend_embeds = None
|
@@ -143,9 +154,6 @@ class GPTVAE(Module):
|
|
143
154
|
|
144
155
|
prepend_embeds = self.from_latent_to_prepend_token(latents)
|
145
156
|
|
146
|
-
if exists(prepend_embeds):
|
147
|
-
prepend_embeds = rearrange(prepend_embeds, 'b d -> b 1 d')
|
148
|
-
|
149
157
|
# generated
|
150
158
|
|
151
159
|
generated = self.ar_wrapped_decoder.generate(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|