x-transformers 2.8.0__py3-none-any.whl → 2.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/gpt_vae.py CHANGED
@@ -68,10 +68,13 @@ class GPTVAE(Module):
68
68
 
69
69
  self.to_latent_mean_log_variance = nn.Sequential(
70
70
  nn.Linear(dim, dim_latent * 2),
71
- Rearrange('b (two d) -> two b 1 d', two = 2)
71
+ Rearrange('b (two d) -> two b d', two = 2)
72
72
  )
73
73
 
74
- self.from_latent_to_prepend_token = nn.Linear(dim_latent, dim)
74
+ self.from_latent_to_prepend_token = nn.Sequential(
75
+ nn.Linear(dim_latent, dim),
76
+ Rearrange('b d -> b 1 d')
77
+ )
75
78
 
76
79
  self.decoder = TransformerWrapper(
77
80
  num_tokens = num_tokens,
@@ -126,11 +129,19 @@ class GPTVAE(Module):
126
129
  prompts,
127
130
  seq_len,
128
131
  latents = None,
132
+ seq_for_latents = None,
129
133
  **generate_kwargs
130
134
  ):
131
135
  assert prompts.ndim in {1, 2}
132
136
  batch = prompts.shape[0] if prompts.ndim == 2 else 1
133
137
 
138
+ # if seq_for_latents passed in, derive latents from it
139
+
140
+ if exists(seq_for_latents):
141
+ assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
142
+
143
+ latents = self.encode_to_latents(seq_for_latents)
144
+
134
145
  # prepend embeds
135
146
 
136
147
  prepend_embeds = None
@@ -143,9 +154,6 @@ class GPTVAE(Module):
143
154
 
144
155
  prepend_embeds = self.from_latent_to_prepend_token(latents)
145
156
 
146
- if exists(prepend_embeds):
147
- prepend_embeds = rearrange(prepend_embeds, 'b d -> b 1 d')
148
-
149
157
  # generated
150
158
 
151
159
  generated = self.ar_wrapped_decoder.generate(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.8.0
3
+ Version: 2.8.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/gpt_vae.py,sha256=yqL1K2yJ6RSP_MC6XSHI3hjiUnaptddg6CUnbEX4Bsk,5281
8
+ x_transformers/gpt_vae.py,sha256=Q2pzQ6iXRnP2Bfa6g-fs4US-JTouXB5-MfKw3sTwWmU,5561
9
9
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
10
10
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
11
11
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
@@ -13,7 +13,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
13
13
  x_transformers/x_transformers.py,sha256=odnCZAKZKrQLXmpaWhiPVB5elGjt8kerDbO3-yeC-60,124764
14
14
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
15
15
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
16
- x_transformers-2.8.0.dist-info/METADATA,sha256=jPo0ZPhD1d_aocaDqFYWXA7EXPAcxWeUYNDzKpY1yi8,94136
17
- x_transformers-2.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- x_transformers-2.8.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
- x_transformers-2.8.0.dist-info/RECORD,,
16
+ x_transformers-2.8.1.dist-info/METADATA,sha256=_PnvoOSFJAgrpEfpNNljxdeYQ3BhDYJvVOp7yjaF-iM,94136
17
+ x_transformers-2.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ x_transformers-2.8.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
+ x_transformers-2.8.1.dist-info/RECORD,,