x-transformers 2.11.4__py3-none-any.whl → 2.11.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -66,7 +66,6 @@ class BinaryMapper(Module):
66
66
 
67
67
  self.bits = bits
68
68
  self.num_codes = 2 ** bits
69
- self.kl_loss_threshold = kl_loss_threshold
70
69
 
71
70
  power_two = 2 ** arange(bits)
72
71
  codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
@@ -74,13 +73,20 @@ class BinaryMapper(Module):
74
73
  self.register_buffer('power_two', power_two, persistent = False)
75
74
  self.register_buffer('codes', codes, persistent = False)
76
75
 
76
+ # aux loss
77
+
78
+ self.kl_loss_threshold = kl_loss_threshold
79
+ self.register_buffer('zero', tensor(0.), persistent = False)
80
+
77
81
  def forward(
78
82
  self,
79
83
  logits,
80
84
  temperature = 1.,
81
- straight_through = None
85
+ straight_through = None,
86
+ calc_aux_loss = None
82
87
  ):
83
88
  straight_through = default(straight_through, self.training)
89
+ calc_aux_loss = default(calc_aux_loss, self.training)
84
90
 
85
91
  assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
86
92
 
@@ -95,26 +101,29 @@ class BinaryMapper(Module):
95
101
 
96
102
  one_hot = F.one_hot(indices, self.num_codes).float()
97
103
 
98
- # return hard one hot if not training or overridden
104
+ # maybe calculate aux loss
105
+
106
+ aux_kl_loss = self.zero
99
107
 
100
- if not straight_through:
101
- return one_hot
108
+ if calc_aux_loss:
109
+ # calculate negative entropy
102
110
 
103
- # calculate negative entropy
111
+ kl_div = self.bits * NAT - binary_entropy(logits)
112
+ aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
104
113
 
105
- kl_div = self.bits * NAT - binary_entropy(logits)
106
- aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
114
+ # maybe straight through
107
115
 
108
- # get the soft G for the gradients and do a straight through
116
+ if straight_through:
117
+ # get the soft G for the gradients and do a straight through
109
118
 
110
- soft_G = (
111
- einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112
- einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113
- ).exp()
119
+ soft_G = (
120
+ einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
121
+ einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
122
+ ).exp()
114
123
 
115
- # straight through
124
+ # straight through
116
125
 
117
- one_hot = one_hot + soft_G - soft_G.detach()
126
+ one_hot = one_hot + soft_G - soft_G.detach()
118
127
 
119
128
  return one_hot, aux_kl_loss
120
129
 
@@ -161,6 +170,9 @@ class FreeTransformer(Module):
161
170
  heads = heads,
162
171
  only_cross = True,
163
172
  cross_attend = True,
173
+ use_rmsnorm = True,
174
+ rotary_pos_emb = True,
175
+ pre_norm_has_final_norm = True,
164
176
  **kwargs,
165
177
  **enc_kwargs
166
178
  )
@@ -180,6 +192,8 @@ class FreeTransformer(Module):
180
192
  depth = dec_head_depth,
181
193
  attn_dim_head = attn_dim_head,
182
194
  heads = heads,
195
+ rotary_pos_emb = True,
196
+ use_rmsnorm = True,
183
197
  pre_norm_has_final_norm = False,
184
198
  **kwargs,
185
199
  **dec_kwargs
@@ -190,6 +204,8 @@ class FreeTransformer(Module):
190
204
  depth = dec_tail_depth,
191
205
  attn_dim_head = attn_dim_head,
192
206
  heads = heads,
207
+ rotary_pos_emb = True,
208
+ use_rmsnorm = True,
193
209
  pre_norm_has_final_norm = True,
194
210
  **kwargs,
195
211
  **dec_kwargs
@@ -236,7 +252,7 @@ class FreeTransformer(Module):
236
252
 
237
253
  bit_logits = self.to_latent_bit_logits(pooled)
238
254
 
239
- one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
255
+ one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
240
256
 
241
257
  if not return_kl_loss:
242
258
  return one_hot_latents
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.11.4
3
+ Version: 2.11.7
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
- x_transformers/free_transformer.py,sha256=t-PyeiqNq-B8BZUGICECyWCeU5XLPSGbFsQ3ICxTtsM,9072
8
+ x_transformers/free_transformer.py,sha256=RVEMvmiSVrTJeLbXVoI4YeR26kP2mlqmwFHHLw7IJ_Q,9582
9
9
  x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
10
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
11
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
14
14
  x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
15
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
16
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
17
- x_transformers-2.11.4.dist-info/METADATA,sha256=18u2dW5d3aJeH3LaVsRRA_vf1gZzbSjS0Fbd5XM4dJc,96011
18
- x_transformers-2.11.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- x_transformers-2.11.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
- x_transformers-2.11.4.dist-info/RECORD,,
17
+ x_transformers-2.11.7.dist-info/METADATA,sha256=1rStN6vwvwU-0hm17zWx8yXL99ZrZ-_KGMsFwvRIPVg,96011
18
+ x_transformers-2.11.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.7.dist-info/RECORD,,