x-transformers 2.8.0__tar.gz → 2.8.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.8.0 → x_transformers-2.8.2}/PKG-INFO +1 -1
- {x_transformers-2.8.0 → x_transformers-2.8.2}/pyproject.toml +1 -1
- {x_transformers-2.8.0 → x_transformers-2.8.2}/tests/test_x_transformers.py +20 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/gpt_vae.py +31 -19
- {x_transformers-2.8.0 → x_transformers-2.8.2}/.github/FUNDING.yml +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/.gitignore +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/LICENSE +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/README.md +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/data/README.md +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/data/enwik8.gz +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/all-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/attention-on-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/deepnorm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/fcm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/ffglu.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/flash-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/gate_values.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/gating.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/macaron-1.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/macaron-2.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/memory-transformer.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/normformer.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/pia.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/resi_dual.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/residual_attn.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/rezero.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/rotary.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/sandwich-2.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/sandwich.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/sandwich_norm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/scalenorm.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/talking-heads.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/topk-attention.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/images/xval.png +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_belief_state.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_copy.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_enwik8.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_gpt_vae.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_length_extrapolate.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/train_parity.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/__init__.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/attend.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/continuous.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/dpo.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.0 → x_transformers-2.8.2}/x_transformers/xval.py +0 -0
@@ -1340,3 +1340,23 @@ def test_qk_clip_attn_layers():
|
|
1340
1340
|
out, intermediates = model(seq, return_intermediates = True)
|
1341
1341
|
|
1342
1342
|
model.attn_qk_clip_(intermediates)
|
1343
|
+
|
1344
|
+
def test_vae():
|
1345
|
+
from x_transformers.gpt_vae import GPTVAE
|
1346
|
+
|
1347
|
+
model = GPTVAE(
|
1348
|
+
num_tokens = 256,
|
1349
|
+
max_seq_len = 1024,
|
1350
|
+
dim = 512,
|
1351
|
+
depth = 4,
|
1352
|
+
enc_depth = 2
|
1353
|
+
)
|
1354
|
+
|
1355
|
+
seq = torch.randint(0, 256, (1, 1024))
|
1356
|
+
|
1357
|
+
loss = model(seq)
|
1358
|
+
loss.backward()
|
1359
|
+
|
1360
|
+
style = torch.randint(0, 256, (1, 1024))
|
1361
|
+
|
1362
|
+
out = model.generate(seq[:, :512], 512, seq_for_latents = style)
|
@@ -46,32 +46,39 @@ class GPTVAE(Module):
|
|
46
46
|
vae_kl_loss_weight = 1.,
|
47
47
|
latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
|
48
48
|
pad_id = -1,
|
49
|
+
encoder: Module | None = None,
|
49
50
|
**kwargs
|
50
51
|
):
|
51
52
|
super().__init__()
|
52
53
|
dim_latent = default(dim_latent, dim)
|
53
54
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
55
|
+
if not exists(encoder):
|
56
|
+
encoder = TransformerWrapper(
|
57
|
+
num_tokens = num_tokens,
|
58
|
+
max_seq_len = max_seq_len + 1,
|
59
|
+
return_only_embed = True,
|
60
|
+
average_pool_embed = True,
|
61
|
+
attn_layers = Encoder(
|
62
|
+
dim = dim,
|
63
|
+
depth = enc_depth,
|
64
|
+
attn_dim_head = attn_dim_head,
|
65
|
+
heads = heads,
|
66
|
+
**kwargs,
|
67
|
+
**enc_kwargs
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
self.encoder = encoder
|
68
72
|
|
69
73
|
self.to_latent_mean_log_variance = nn.Sequential(
|
70
74
|
nn.Linear(dim, dim_latent * 2),
|
71
|
-
Rearrange('b (two d) -> two b
|
75
|
+
Rearrange('b (two d) -> two b d', two = 2)
|
72
76
|
)
|
73
77
|
|
74
|
-
self.from_latent_to_prepend_token = nn.
|
78
|
+
self.from_latent_to_prepend_token = nn.Sequential(
|
79
|
+
nn.Linear(dim_latent, dim),
|
80
|
+
Rearrange('b d -> b 1 d')
|
81
|
+
)
|
75
82
|
|
76
83
|
self.decoder = TransformerWrapper(
|
77
84
|
num_tokens = num_tokens,
|
@@ -126,11 +133,19 @@ class GPTVAE(Module):
|
|
126
133
|
prompts,
|
127
134
|
seq_len,
|
128
135
|
latents = None,
|
136
|
+
seq_for_latents = None,
|
129
137
|
**generate_kwargs
|
130
138
|
):
|
131
139
|
assert prompts.ndim in {1, 2}
|
132
140
|
batch = prompts.shape[0] if prompts.ndim == 2 else 1
|
133
141
|
|
142
|
+
# if seq_for_latents passed in, derive latents from it
|
143
|
+
|
144
|
+
if exists(seq_for_latents):
|
145
|
+
assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
|
146
|
+
|
147
|
+
latents = self.encode_to_latents(seq_for_latents)
|
148
|
+
|
134
149
|
# prepend embeds
|
135
150
|
|
136
151
|
prepend_embeds = None
|
@@ -143,9 +158,6 @@ class GPTVAE(Module):
|
|
143
158
|
|
144
159
|
prepend_embeds = self.from_latent_to_prepend_token(latents)
|
145
160
|
|
146
|
-
if exists(prepend_embeds):
|
147
|
-
prepend_embeds = rearrange(prepend_embeds, 'b d -> b 1 d')
|
148
|
-
|
149
161
|
# generated
|
150
162
|
|
151
163
|
generated = self.ar_wrapped_decoder.generate(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|