x-transformers 2.11.2__py3-none-any.whl → 2.11.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- x_transformers/free_transformer.py +47 -21
- {x_transformers-2.11.2.dist-info → x_transformers-2.11.5.dist-info}/METADATA +1 -1
- {x_transformers-2.11.2.dist-info → x_transformers-2.11.5.dist-info}/RECORD +5 -5
- {x_transformers-2.11.2.dist-info → x_transformers-2.11.5.dist-info}/WHEEL +0 -0
- {x_transformers-2.11.2.dist-info → x_transformers-2.11.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -128,19 +128,19 @@ class FreeTransformer(Module):
|
|
|
128
128
|
dim,
|
|
129
129
|
dec_head_depth,
|
|
130
130
|
dec_tail_depth,
|
|
131
|
-
enc_depth,
|
|
132
131
|
max_seq_len,
|
|
132
|
+
enc_depth = 1,
|
|
133
133
|
dim_latent = None,
|
|
134
134
|
attn_dim_head = 64,
|
|
135
135
|
heads = 8,
|
|
136
136
|
latent_bits = 16,
|
|
137
|
+
per_token_latents = True, # they use a latent per token in the sequence, instead of one for entire sequence, iiuc
|
|
137
138
|
kl_loss_threshold = NAT,
|
|
138
139
|
binary_mapper_kwargs: dict = dict(),
|
|
139
140
|
enc_kwargs: dict = dict(),
|
|
140
141
|
dec_kwargs: dict = dict(),
|
|
141
142
|
kl_loss_weight = 1.,
|
|
142
143
|
pad_id = -1,
|
|
143
|
-
encoder: Module | None = None,
|
|
144
144
|
**kwargs
|
|
145
145
|
):
|
|
146
146
|
super().__init__()
|
|
@@ -150,39 +150,40 @@ class FreeTransformer(Module):
|
|
|
150
150
|
|
|
151
151
|
self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
|
|
152
152
|
|
|
153
|
-
|
|
154
|
-
encoder = Encoder(
|
|
155
|
-
dim = dim,
|
|
156
|
-
depth = enc_depth,
|
|
157
|
-
attn_dim_head = attn_dim_head,
|
|
158
|
-
heads = heads,
|
|
159
|
-
**kwargs,
|
|
160
|
-
**enc_kwargs
|
|
161
|
-
)
|
|
153
|
+
self.query_token_for_latents = nn.Parameter(torch.randn(dim) * 1e-2)
|
|
162
154
|
|
|
163
|
-
self.
|
|
155
|
+
self.per_token_latents = per_token_latents
|
|
164
156
|
|
|
165
|
-
self.
|
|
166
|
-
|
|
167
|
-
|
|
157
|
+
self.encoder = Encoder(
|
|
158
|
+
dim = dim,
|
|
159
|
+
depth = enc_depth,
|
|
160
|
+
attn_dim_head = attn_dim_head,
|
|
161
|
+
heads = heads,
|
|
162
|
+
only_cross = True,
|
|
163
|
+
cross_attend = True,
|
|
164
|
+
use_rmsnorm = True,
|
|
165
|
+
rotary_pos_emb = True,
|
|
166
|
+
**kwargs,
|
|
167
|
+
**enc_kwargs
|
|
168
168
|
)
|
|
169
169
|
|
|
170
|
+
self.to_latent_bit_logits = nn.Linear(dim, latent_bits, bias = False)
|
|
171
|
+
|
|
170
172
|
self.binary_mapper = BinaryMapper(
|
|
171
173
|
latent_bits,
|
|
172
174
|
kl_loss_threshold,
|
|
173
175
|
**binary_mapper_kwargs
|
|
174
176
|
)
|
|
175
177
|
|
|
176
|
-
self.from_latent_to_condition = nn.
|
|
177
|
-
nn.Linear(2 ** latent_bits, dim, bias = False),
|
|
178
|
-
Rearrange('b d -> b 1 d')
|
|
179
|
-
)
|
|
178
|
+
self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
|
|
180
179
|
|
|
181
180
|
self.decoder_head = Decoder(
|
|
182
181
|
dim = dim,
|
|
183
182
|
depth = dec_head_depth,
|
|
184
183
|
attn_dim_head = attn_dim_head,
|
|
185
184
|
heads = heads,
|
|
185
|
+
rotary_pos_emb = True,
|
|
186
|
+
use_rmsnorm = True,
|
|
186
187
|
pre_norm_has_final_norm = False,
|
|
187
188
|
**kwargs,
|
|
188
189
|
**dec_kwargs
|
|
@@ -193,6 +194,8 @@ class FreeTransformer(Module):
|
|
|
193
194
|
depth = dec_tail_depth,
|
|
194
195
|
attn_dim_head = attn_dim_head,
|
|
195
196
|
heads = heads,
|
|
197
|
+
rotary_pos_emb = True,
|
|
198
|
+
use_rmsnorm = True,
|
|
196
199
|
pre_norm_has_final_norm = True,
|
|
197
200
|
**kwargs,
|
|
198
201
|
**dec_kwargs
|
|
@@ -208,11 +211,34 @@ class FreeTransformer(Module):
|
|
|
208
211
|
|
|
209
212
|
def encode_to_latents(
|
|
210
213
|
self,
|
|
211
|
-
|
|
214
|
+
decoder_head_embeds,
|
|
212
215
|
mask = None,
|
|
213
216
|
return_kl_loss = False
|
|
214
217
|
):
|
|
215
|
-
|
|
218
|
+
batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
|
|
219
|
+
|
|
220
|
+
query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
|
|
221
|
+
|
|
222
|
+
encoder_kwargs = dict()
|
|
223
|
+
|
|
224
|
+
# handle the interesting per query token latents, as in the paper
|
|
225
|
+
|
|
226
|
+
if self.per_token_latents:
|
|
227
|
+
query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
|
|
228
|
+
|
|
229
|
+
rotary_pos = torch.arange(seq_len, device = device)
|
|
230
|
+
|
|
231
|
+
encoder_kwargs.update(
|
|
232
|
+
pos = rotary_pos,
|
|
233
|
+
context_pos = rotary_pos
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
pooled = self.encoder(
|
|
237
|
+
query_tokens,
|
|
238
|
+
context = decoder_head_embeds,
|
|
239
|
+
context_mask = mask,
|
|
240
|
+
**encoder_kwargs
|
|
241
|
+
)
|
|
216
242
|
|
|
217
243
|
bit_logits = self.to_latent_bit_logits(pooled)
|
|
218
244
|
|
|
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
|
8
|
-
x_transformers/free_transformer.py,sha256=
|
|
8
|
+
x_transformers/free_transformer.py,sha256=oUrTl3cNbal9-c2wU7Q3B8vruWg_FIbFNBhsgkBryc4,9273
|
|
9
9
|
x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
|
|
10
10
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
11
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
|
14
14
|
x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
|
|
15
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
16
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
17
|
-
x_transformers-2.11.
|
|
18
|
-
x_transformers-2.11.
|
|
19
|
-
x_transformers-2.11.
|
|
20
|
-
x_transformers-2.11.
|
|
17
|
+
x_transformers-2.11.5.dist-info/METADATA,sha256=6lQ4z5dCmDbPAc1Vhl1ObzSqPUneWTqwYKJKaHhbiNY,96011
|
|
18
|
+
x_transformers-2.11.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
x_transformers-2.11.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.11.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|