x-transformers 2.11.1__py3-none-any.whl → 2.11.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

@@ -128,19 +128,19 @@ class FreeTransformer(Module):
128
128
  dim,
129
129
  dec_head_depth,
130
130
  dec_tail_depth,
131
- enc_depth,
132
131
  max_seq_len,
132
+ enc_depth = 1,
133
133
  dim_latent = None,
134
134
  attn_dim_head = 64,
135
135
  heads = 8,
136
136
  latent_bits = 16,
137
+ per_token_latents = True, # they use a latent per token in the sequence, instead of one for entire sequence, iiuc
137
138
  kl_loss_threshold = NAT,
138
139
  binary_mapper_kwargs: dict = dict(),
139
140
  enc_kwargs: dict = dict(),
140
141
  dec_kwargs: dict = dict(),
141
142
  kl_loss_weight = 1.,
142
143
  pad_id = -1,
143
- encoder: Module | None = None,
144
144
  **kwargs
145
145
  ):
146
146
  super().__init__()
@@ -150,33 +150,30 @@ class FreeTransformer(Module):
150
150
 
151
151
  self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
152
152
 
153
- if not exists(encoder):
154
- encoder = Encoder(
155
- dim = dim,
156
- depth = enc_depth,
157
- attn_dim_head = attn_dim_head,
158
- heads = heads,
159
- **kwargs,
160
- **enc_kwargs
161
- )
153
+ self.query_token_for_latents = nn.Parameter(torch.randn(dim) * 1e-2)
162
154
 
163
- self.encoder = encoder
155
+ self.per_token_latents = per_token_latents
164
156
 
165
- self.to_latent_bit_logits = nn.Sequential(
166
- Reduce('b n d -> b d', 'mean'),
167
- nn.Linear(dim, latent_bits, bias = False),
157
+ self.encoder = Encoder(
158
+ dim = dim,
159
+ depth = enc_depth,
160
+ attn_dim_head = attn_dim_head,
161
+ heads = heads,
162
+ only_cross = True,
163
+ cross_attend = True,
164
+ **kwargs,
165
+ **enc_kwargs
168
166
  )
169
167
 
168
+ self.to_latent_bit_logits = nn.Linear(dim, latent_bits, bias = False)
169
+
170
170
  self.binary_mapper = BinaryMapper(
171
171
  latent_bits,
172
172
  kl_loss_threshold,
173
173
  **binary_mapper_kwargs
174
174
  )
175
175
 
176
- self.from_latent_to_condition = nn.Sequential(
177
- nn.Linear(2 ** latent_bits, dim, bias = False),
178
- Rearrange('b d -> b 1 d')
179
- )
176
+ self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
180
177
 
181
178
  self.decoder_head = Decoder(
182
179
  dim = dim,
@@ -208,11 +205,34 @@ class FreeTransformer(Module):
208
205
 
209
206
  def encode_to_latents(
210
207
  self,
211
- seq,
208
+ decoder_head_embeds,
212
209
  mask = None,
213
210
  return_kl_loss = False
214
211
  ):
215
- pooled = self.encoder(seq, mask = mask)
212
+ batch, seq_len, device = *decoder_head_embeds.shape[:2], decoder_head_embeds.device
213
+
214
+ query_tokens = repeat(self.query_token_for_latents, 'd -> b 1 d', b = batch)
215
+
216
+ encoder_kwargs = dict()
217
+
218
+ # handle the interesting per query token latents, as in the paper
219
+
220
+ if self.per_token_latents:
221
+ query_tokens = repeat(query_tokens, 'b 1 d -> b n d', n = seq_len)
222
+
223
+ rotary_pos = torch.arange(seq_len, device = device)
224
+
225
+ encoder_kwargs.update(
226
+ pos = rotary_pos,
227
+ context_pos = rotary_pos
228
+ )
229
+
230
+ pooled = self.encoder(
231
+ query_tokens,
232
+ context = decoder_head_embeds,
233
+ context_mask = mask,
234
+ **encoder_kwargs
235
+ )
216
236
 
217
237
  bit_logits = self.to_latent_bit_logits(pooled)
218
238
 
x_transformers/gpt_vae.py CHANGED
@@ -44,6 +44,7 @@ class GPTVAE(Module):
44
44
  enc_kwargs: dict = dict(),
45
45
  dec_kwargs: dict = dict(),
46
46
  vae_kl_loss_weight = 1.,
47
+ vae_kl_div_floor = 0., # what was done in free transformer, which in turn came from Kingma 2016
47
48
  latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
48
49
  pad_id = -1,
49
50
  encoder: Module | None = None,
@@ -99,6 +100,7 @@ class GPTVAE(Module):
99
100
 
100
101
  # loss weights - vae kl loss
101
102
 
103
+ self.vae_kl_div_floor = vae_kl_div_floor
102
104
  self.vae_kl_loss_weight = vae_kl_loss_weight
103
105
 
104
106
  self.latents_dropout = nn.Dropout(latents_dropout_prob)
@@ -190,12 +192,16 @@ class GPTVAE(Module):
190
192
 
191
193
  # vae kl loss
192
194
 
193
- vae_kl_loss = (
195
+ vae_kl_loss = 0.5 * (
194
196
  latents_log_var.exp()
195
197
  + latents_mean.square()
196
198
  - latents_log_var
197
199
  - 1.
198
- ).sum(dim = -1).mean()
200
+ )
201
+
202
+ vae_kl_loss = F.relu(vae_kl_loss - self.vae_kl_div_floor)
203
+
204
+ vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
199
205
 
200
206
  # return losses
201
207
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.1
3
+ Version: 2.11.4
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -5,8 +5,8 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/free_transformer.py,sha256=K9dX9xj0pJgiOi9jlOLCD9Nn9eYNNgTWj9YvQLhexHw,8295
9
- x_transformers/gpt_vae.py,sha256=myYSgcx66V0M4zeEGKyhY1P2HlPDHcezhaZEoo_uMdo,5715
8
+ x_transformers/free_transformer.py,sha256=t-PyeiqNq-B8BZUGICECyWCeU5XLPSGbFsQ3ICxTtsM,9072
9
+ x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
10
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
12
12
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
14
14
  x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.1.dist-info/METADATA,sha256=Rj4l6-wbfFsC7wbSWfUFQyTNWQE5EXu952aSE3B8uas,96011
18
- x_transformers-2.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.1.dist-info/RECORD,,
17
+ x_transformers-2.11.4.dist-info/METADATA,sha256=18u2dW5d3aJeH3LaVsRRA_vf1gZzbSjS0Fbd5XM4dJc,96011
18
+ x_transformers-2.11.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.4.dist-info/RECORD,,