x-transformers 2.11.4__py3-none-any.whl → 2.11.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/free_transformer.py +32 -16
- {x_transformers-2.11.4.dist-info → x_transformers-2.11.7.dist-info}/METADATA +1 -1
- {x_transformers-2.11.4.dist-info → x_transformers-2.11.7.dist-info}/RECORD +5 -5
- {x_transformers-2.11.4.dist-info → x_transformers-2.11.7.dist-info}/WHEEL +0 -0
- {x_transformers-2.11.4.dist-info → x_transformers-2.11.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -66,7 +66,6 @@ class BinaryMapper(Module):
|
|
|
66
66
|
|
|
67
67
|
self.bits = bits
|
|
68
68
|
self.num_codes = 2 ** bits
|
|
69
|
-
self.kl_loss_threshold = kl_loss_threshold
|
|
70
69
|
|
|
71
70
|
power_two = 2 ** arange(bits)
|
|
72
71
|
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
|
|
@@ -74,13 +73,20 @@ class BinaryMapper(Module):
|
|
|
74
73
|
self.register_buffer('power_two', power_two, persistent = False)
|
|
75
74
|
self.register_buffer('codes', codes, persistent = False)
|
|
76
75
|
|
|
76
|
+
# aux loss
|
|
77
|
+
|
|
78
|
+
self.kl_loss_threshold = kl_loss_threshold
|
|
79
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
80
|
+
|
|
77
81
|
def forward(
|
|
78
82
|
self,
|
|
79
83
|
logits,
|
|
80
84
|
temperature = 1.,
|
|
81
|
-
straight_through = None
|
|
85
|
+
straight_through = None,
|
|
86
|
+
calc_aux_loss = None
|
|
82
87
|
):
|
|
83
88
|
straight_through = default(straight_through, self.training)
|
|
89
|
+
calc_aux_loss = default(calc_aux_loss, self.training)
|
|
84
90
|
|
|
85
91
|
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
|
|
86
92
|
|
|
@@ -95,26 +101,29 @@ class BinaryMapper(Module):
|
|
|
95
101
|
|
|
96
102
|
one_hot = F.one_hot(indices, self.num_codes).float()
|
|
97
103
|
|
|
98
|
-
#
|
|
104
|
+
# maybe calculate aux loss
|
|
105
|
+
|
|
106
|
+
aux_kl_loss = self.zero
|
|
99
107
|
|
|
100
|
-
if
|
|
101
|
-
|
|
108
|
+
if calc_aux_loss:
|
|
109
|
+
# calculate negative entropy
|
|
102
110
|
|
|
103
|
-
|
|
111
|
+
kl_div = self.bits * NAT - binary_entropy(logits)
|
|
112
|
+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
104
113
|
|
|
105
|
-
|
|
106
|
-
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
114
|
+
# maybe straight through
|
|
107
115
|
|
|
108
|
-
|
|
116
|
+
if straight_through:
|
|
117
|
+
# get the soft G for the gradients and do a straight through
|
|
109
118
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
119
|
+
soft_G = (
|
|
120
|
+
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
|
|
121
|
+
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
|
|
122
|
+
).exp()
|
|
114
123
|
|
|
115
|
-
|
|
124
|
+
# straight through
|
|
116
125
|
|
|
117
|
-
|
|
126
|
+
one_hot = one_hot + soft_G - soft_G.detach()
|
|
118
127
|
|
|
119
128
|
return one_hot, aux_kl_loss
|
|
120
129
|
|
|
@@ -161,6 +170,9 @@ class FreeTransformer(Module):
|
|
|
161
170
|
heads = heads,
|
|
162
171
|
only_cross = True,
|
|
163
172
|
cross_attend = True,
|
|
173
|
+
use_rmsnorm = True,
|
|
174
|
+
rotary_pos_emb = True,
|
|
175
|
+
pre_norm_has_final_norm = True,
|
|
164
176
|
**kwargs,
|
|
165
177
|
**enc_kwargs
|
|
166
178
|
)
|
|
@@ -180,6 +192,8 @@ class FreeTransformer(Module):
|
|
|
180
192
|
depth = dec_head_depth,
|
|
181
193
|
attn_dim_head = attn_dim_head,
|
|
182
194
|
heads = heads,
|
|
195
|
+
rotary_pos_emb = True,
|
|
196
|
+
use_rmsnorm = True,
|
|
183
197
|
pre_norm_has_final_norm = False,
|
|
184
198
|
**kwargs,
|
|
185
199
|
**dec_kwargs
|
|
@@ -190,6 +204,8 @@ class FreeTransformer(Module):
|
|
|
190
204
|
depth = dec_tail_depth,
|
|
191
205
|
attn_dim_head = attn_dim_head,
|
|
192
206
|
heads = heads,
|
|
207
|
+
rotary_pos_emb = True,
|
|
208
|
+
use_rmsnorm = True,
|
|
193
209
|
pre_norm_has_final_norm = True,
|
|
194
210
|
**kwargs,
|
|
195
211
|
**dec_kwargs
|
|
@@ -236,7 +252,7 @@ class FreeTransformer(Module):
|
|
|
236
252
|
|
|
237
253
|
bit_logits = self.to_latent_bit_logits(pooled)
|
|
238
254
|
|
|
239
|
-
one_hot_latents, kl_loss = self.binary_mapper(bit_logits,
|
|
255
|
+
one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
|
|
240
256
|
|
|
241
257
|
if not return_kl_loss:
|
|
242
258
|
return one_hot_latents
|
|
@@ -5,7 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
|
8
|
-
x_transformers/free_transformer.py,sha256=
|
|
8
|
+
x_transformers/free_transformer.py,sha256=RVEMvmiSVrTJeLbXVoI4YeR26kP2mlqmwFHHLw7IJ_Q,9582
|
|
9
9
|
x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
|
|
10
10
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
11
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
@@ -14,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
|
14
14
|
x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
|
|
15
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
16
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
17
|
-
x_transformers-2.11.
|
|
18
|
-
x_transformers-2.11.
|
|
19
|
-
x_transformers-2.11.
|
|
20
|
-
x_transformers-2.11.
|
|
17
|
+
x_transformers-2.11.7.dist-info/METADATA,sha256=1rStN6vwvwU-0hm17zWx8yXL99ZrZ-_KGMsFwvRIPVg,96011
|
|
18
|
+
x_transformers-2.11.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
x_transformers-2.11.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.11.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|