x-transformers 2.11.16__py3-none-any.whl → 2.11.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -149,6 +149,7 @@ class FreeTransformer(Module):
149
149
  enc_kwargs: dict = dict(),
150
150
  dec_kwargs: dict = dict(),
151
151
  kl_loss_weight = 1.,
152
+ latent_dropout_prob = 0.,
152
153
  pad_id = -1,
153
154
  **kwargs
154
155
  ):
@@ -187,6 +188,8 @@ class FreeTransformer(Module):
187
188
 
188
189
  self.from_latent_to_condition = nn.Linear(self.binary_mapper.num_codes, dim, bias = False)
189
190
 
191
+ self.latent_dropout = nn.Dropout(latent_dropout_prob)
192
+
190
193
  self.decoder_head = Decoder(
191
194
  dim = dim,
192
195
  depth = dec_head_depth,
@@ -380,6 +383,8 @@ class FreeTransformer(Module):
380
383
 
381
384
  latents, kl_loss = self.encode_to_latents(tokens_for_latents, mask = encoder_mask, per_token_latents = per_token_latents, return_kl_loss = True)
382
385
 
386
+ latents = self.latent_dropout(latents)
387
+
383
388
  condition = self.from_latent_to_condition(latents)
384
389
 
385
390
  # decoder tail
@@ -275,6 +275,10 @@ class ReluSquared(Module):
275
275
  def forward(self, x):
276
276
  return F.relu(x) ** 2
277
277
 
278
+ class SoLU(Module):
279
+ def forward(self, x):
280
+ return x.softmax(dim = -1) * x
281
+
278
282
  # embedding
279
283
 
280
284
  class TokenEmbedding(Module):
@@ -1239,6 +1243,7 @@ class FeedForward(Module):
1239
1243
  glu_mult_bias = False,
1240
1244
  swish = False,
1241
1245
  relu_squared = False,
1246
+ solu = False,
1242
1247
  custom_activation = None,
1243
1248
  post_act_ln = False,
1244
1249
  dropout = 0.,
@@ -1250,10 +1255,14 @@ class FeedForward(Module):
1250
1255
  inner_dim = int(dim * mult)
1251
1256
  dim_out = default(dim_out, dim)
1252
1257
 
1258
+ assert at_most_one_of(relu_squared, solu)
1259
+
1253
1260
  if exists(custom_activation):
1254
1261
  activation = deepcopy(custom_activation)
1255
1262
  elif relu_squared:
1256
1263
  activation = ReluSquared()
1264
+ elif solu:
1265
+ activation = SoLU()
1257
1266
  elif swish:
1258
1267
  activation = nn.SiLU()
1259
1268
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.16
3
+ Version: 2.11.18
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2607,4 +2607,14 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2607
2607
  }
2608
2608
  ```
2609
2609
 
2610
+ ```bibtex
2611
+ @article{elhage2022solu,
2612
+ title = {Softmax Linear Units},
2613
+ author = {Elhage, Nelson and Hume, Tristan and Olsson, Catherine and Nanda, Neel and Henighan, Tom and Johnston, Scott and ElShowk, Sheer and Joseph, Nicholas and DasSarma, Nova and Mann, Ben and Hernandez, Danny and Askell, Amanda and Ndousse, Kamal and Jones, Andy and Drain, Dawn and Chen, Anna and Bai, Yuntao and Ganguli, Deep and Lovitt, Liane and Hatfield-Dodds, Zac and Kernion, Jackson and Conerly, Tom and Kravec, Shauna and Fort, Stanislav and Kadavath, Saurav and Jacobson, Josh and Tran-Johnson, Eli and Kaplan, Jared and Clark, Jack and Brown, Tom and McCandlish, Sam and Amodei, Dario and Olah, Christopher},
2614
+ year = {2022},
2615
+ journal = {Transformer Circuits Thread},
2616
+ note = {https://transformer-circuits.pub/2022/solu/index.html}
2617
+ }
2618
+ ```
2619
+
2610
2620
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -5,16 +5,16 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/free_transformer.py,sha256=_hYYkaro3xei3MC3rwtuCWi9gSnciXyAT91_7SrA0nw,11396
8
+ x_transformers/free_transformer.py,sha256=F0H_rfb_8_nO4oRbaVDLdfOa8EP4YcUNCOaI2rhkLV0,11541
9
9
  x_transformers/gpt_vae.py,sha256=1zyjwgfZr6CRDsh5VMCPSdoCPg-sdX5mXmZ_mn4VyYQ,6082
10
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
12
12
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
13
13
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
14
- x_transformers/x_transformers.py,sha256=5ctPu8tvlbUMrtW360e_LPnoGv6xcgQFsyWdbvLo6Tk,127002
14
+ x_transformers/x_transformers.py,sha256=pIUxQmj_wLHIMOxyqjy4hKww6NdYtzxtRMWROovHoDA,127212
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.16.dist-info/METADATA,sha256=cvhm5LnIRCdqLuv25iSU4vj0a6Np9j2lv2O9W-V48-k,96012
18
- x_transformers-2.11.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.16.dist-info/RECORD,,
17
+ x_transformers-2.11.18.dist-info/METADATA,sha256=9VPaNWK5WVVltDqRqkb_4OtPEmJFzfkYkln3aJpKdfQ,96858
18
+ x_transformers-2.11.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.18.dist-info/RECORD,,