x-transformers 2.10.2__py3-none-any.whl → 2.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- x_transformers/autoregressive_wrapper.py +4 -0
- x_transformers/free_transformer.py +330 -0
- {x_transformers-2.10.2.dist-info → x_transformers-2.11.1.dist-info}/METADATA +10 -1
- {x_transformers-2.10.2.dist-info → x_transformers-2.11.1.dist-info}/RECORD +6 -5
- {x_transformers-2.10.2.dist-info → x_transformers-2.11.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.10.2.dist-info → x_transformers-2.11.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -43,6 +43,10 @@ def log(t, eps = 1e-20):
|
|
|
43
43
|
def gumbel_noise(t):
|
|
44
44
|
return -log(-log(torch.rand_like(t)))
|
|
45
45
|
|
|
46
|
+
def gumbel_sample(logits, temperature = 1., eps = 1e-6):
|
|
47
|
+
noise = gumbel_noise(logits)
|
|
48
|
+
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
|
|
49
|
+
|
|
46
50
|
# function for modifying all the cached key / values
|
|
47
51
|
|
|
48
52
|
def modify_cached_kv(cache, fn):
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# https://arxiv.org/abs/2510.17558
|
|
4
|
+
# François Fleuret
|
|
5
|
+
# https://www.youtube.com/watch?v=Nao16-6l6dQ
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn, Tensor, is_tensor, tensor, arange
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch.nn import Module, ModuleList
|
|
13
|
+
|
|
14
|
+
from x_transformers.x_transformers import (
|
|
15
|
+
Encoder,
|
|
16
|
+
Decoder,
|
|
17
|
+
TransformerWrapper
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from x_transformers.autoregressive_wrapper import (
|
|
21
|
+
gumbel_sample,
|
|
22
|
+
top_p,
|
|
23
|
+
top_k
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
27
|
+
from einops import rearrange, reduce, repeat, einsum, pack, unpack
|
|
28
|
+
|
|
29
|
+
# helper functions
|
|
30
|
+
|
|
31
|
+
def exists(v):
|
|
32
|
+
return v is not None
|
|
33
|
+
|
|
34
|
+
def default(v, d):
|
|
35
|
+
return v if exists(v) else d
|
|
36
|
+
|
|
37
|
+
def log(t, eps = 1e-20):
|
|
38
|
+
return t.clamp_min(eps).log()
|
|
39
|
+
|
|
40
|
+
def pack_with_inverse(t, pattern):
|
|
41
|
+
packed, ps = pack([t], pattern)
|
|
42
|
+
|
|
43
|
+
def inverse(out, inv_pattern = None):
|
|
44
|
+
inv_pattern = default(inv_pattern, pattern)
|
|
45
|
+
unpacked, = unpack(out, ps, inv_pattern)
|
|
46
|
+
return unpacked
|
|
47
|
+
|
|
48
|
+
return packed, inverse
|
|
49
|
+
|
|
50
|
+
# binary mapper
|
|
51
|
+
|
|
52
|
+
NAT = math.log(2)
|
|
53
|
+
|
|
54
|
+
def binary_entropy(logits):
|
|
55
|
+
prob = logits.sigmoid()
|
|
56
|
+
not_prob = 1. - prob
|
|
57
|
+
return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
|
|
58
|
+
|
|
59
|
+
class BinaryMapper(Module):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
bits = 1,
|
|
63
|
+
kl_loss_threshold = NAT # 1 bit
|
|
64
|
+
):
|
|
65
|
+
super().__init__()
|
|
66
|
+
|
|
67
|
+
self.bits = bits
|
|
68
|
+
self.num_codes = 2 ** bits
|
|
69
|
+
self.kl_loss_threshold = kl_loss_threshold
|
|
70
|
+
|
|
71
|
+
power_two = 2 ** arange(bits)
|
|
72
|
+
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
|
|
73
|
+
|
|
74
|
+
self.register_buffer('power_two', power_two, persistent = False)
|
|
75
|
+
self.register_buffer('codes', codes, persistent = False)
|
|
76
|
+
|
|
77
|
+
def forward(
|
|
78
|
+
self,
|
|
79
|
+
logits,
|
|
80
|
+
temperature = 1.,
|
|
81
|
+
straight_through = None
|
|
82
|
+
):
|
|
83
|
+
straight_through = default(straight_through, self.training)
|
|
84
|
+
|
|
85
|
+
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
|
|
86
|
+
|
|
87
|
+
# temperature and prob for sampling
|
|
88
|
+
|
|
89
|
+
prob_for_sample = (logits / temperature).sigmoid()
|
|
90
|
+
|
|
91
|
+
# sampling
|
|
92
|
+
|
|
93
|
+
sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
|
|
94
|
+
indices = (self.power_two * sampled_bits).sum(dim = -1)
|
|
95
|
+
|
|
96
|
+
one_hot = F.one_hot(indices, self.num_codes).float()
|
|
97
|
+
|
|
98
|
+
# return hard one hot if not training or overridden
|
|
99
|
+
|
|
100
|
+
if not straight_through:
|
|
101
|
+
return one_hot
|
|
102
|
+
|
|
103
|
+
# calculate negative entropy
|
|
104
|
+
|
|
105
|
+
kl_div = self.bits * NAT - binary_entropy(logits)
|
|
106
|
+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
107
|
+
|
|
108
|
+
# get the soft G for the gradients and do a straight through
|
|
109
|
+
|
|
110
|
+
soft_G = (
|
|
111
|
+
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
|
|
112
|
+
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
|
|
113
|
+
).exp()
|
|
114
|
+
|
|
115
|
+
# straight through
|
|
116
|
+
|
|
117
|
+
one_hot = one_hot + soft_G - soft_G.detach()
|
|
118
|
+
|
|
119
|
+
return one_hot, aux_kl_loss
|
|
120
|
+
|
|
121
|
+
# classes
|
|
122
|
+
|
|
123
|
+
class FreeTransformer(Module):
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
*,
|
|
127
|
+
num_tokens,
|
|
128
|
+
dim,
|
|
129
|
+
dec_head_depth,
|
|
130
|
+
dec_tail_depth,
|
|
131
|
+
enc_depth,
|
|
132
|
+
max_seq_len,
|
|
133
|
+
dim_latent = None,
|
|
134
|
+
attn_dim_head = 64,
|
|
135
|
+
heads = 8,
|
|
136
|
+
latent_bits = 16,
|
|
137
|
+
kl_loss_threshold = NAT,
|
|
138
|
+
binary_mapper_kwargs: dict = dict(),
|
|
139
|
+
enc_kwargs: dict = dict(),
|
|
140
|
+
dec_kwargs: dict = dict(),
|
|
141
|
+
kl_loss_weight = 1.,
|
|
142
|
+
pad_id = -1,
|
|
143
|
+
encoder: Module | None = None,
|
|
144
|
+
**kwargs
|
|
145
|
+
):
|
|
146
|
+
super().__init__()
|
|
147
|
+
dim_latent = default(dim_latent, dim)
|
|
148
|
+
|
|
149
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
|
150
|
+
|
|
151
|
+
self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
|
|
152
|
+
|
|
153
|
+
if not exists(encoder):
|
|
154
|
+
encoder = Encoder(
|
|
155
|
+
dim = dim,
|
|
156
|
+
depth = enc_depth,
|
|
157
|
+
attn_dim_head = attn_dim_head,
|
|
158
|
+
heads = heads,
|
|
159
|
+
**kwargs,
|
|
160
|
+
**enc_kwargs
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
self.encoder = encoder
|
|
164
|
+
|
|
165
|
+
self.to_latent_bit_logits = nn.Sequential(
|
|
166
|
+
Reduce('b n d -> b d', 'mean'),
|
|
167
|
+
nn.Linear(dim, latent_bits, bias = False),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self.binary_mapper = BinaryMapper(
|
|
171
|
+
latent_bits,
|
|
172
|
+
kl_loss_threshold,
|
|
173
|
+
**binary_mapper_kwargs
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.from_latent_to_condition = nn.Sequential(
|
|
177
|
+
nn.Linear(2 ** latent_bits, dim, bias = False),
|
|
178
|
+
Rearrange('b d -> b 1 d')
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self.decoder_head = Decoder(
|
|
182
|
+
dim = dim,
|
|
183
|
+
depth = dec_head_depth,
|
|
184
|
+
attn_dim_head = attn_dim_head,
|
|
185
|
+
heads = heads,
|
|
186
|
+
pre_norm_has_final_norm = False,
|
|
187
|
+
**kwargs,
|
|
188
|
+
**dec_kwargs
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.decoder_tail = Decoder(
|
|
192
|
+
dim = dim,
|
|
193
|
+
depth = dec_tail_depth,
|
|
194
|
+
attn_dim_head = attn_dim_head,
|
|
195
|
+
heads = heads,
|
|
196
|
+
pre_norm_has_final_norm = True,
|
|
197
|
+
**kwargs,
|
|
198
|
+
**dec_kwargs
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self.pad_id = pad_id
|
|
202
|
+
|
|
203
|
+
self.kl_loss_weight = kl_loss_weight
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def device(self):
|
|
207
|
+
return next(self.parameters()).device
|
|
208
|
+
|
|
209
|
+
def encode_to_latents(
|
|
210
|
+
self,
|
|
211
|
+
seq,
|
|
212
|
+
mask = None,
|
|
213
|
+
return_kl_loss = False
|
|
214
|
+
):
|
|
215
|
+
pooled = self.encoder(seq, mask = mask)
|
|
216
|
+
|
|
217
|
+
bit_logits = self.to_latent_bit_logits(pooled)
|
|
218
|
+
|
|
219
|
+
one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
|
|
220
|
+
|
|
221
|
+
if not return_kl_loss:
|
|
222
|
+
return one_hot_latents
|
|
223
|
+
|
|
224
|
+
return one_hot_latents, kl_loss
|
|
225
|
+
|
|
226
|
+
@torch.no_grad()
|
|
227
|
+
def generate(
|
|
228
|
+
self,
|
|
229
|
+
prompts,
|
|
230
|
+
seq_len,
|
|
231
|
+
latents = None,
|
|
232
|
+
filter_logits_fn = top_p,
|
|
233
|
+
logit_filter_kwargs: dict = dict(thres = 0.9)
|
|
234
|
+
):
|
|
235
|
+
prompts, inverse_pack = pack_with_inverse(prompts, '* n')
|
|
236
|
+
|
|
237
|
+
batch = prompts.shape[0]
|
|
238
|
+
|
|
239
|
+
# prepend embeds
|
|
240
|
+
|
|
241
|
+
condition = None
|
|
242
|
+
if exists(latents):
|
|
243
|
+
if not is_tensor(latents):
|
|
244
|
+
latents = tensor(latents, device = self.device)
|
|
245
|
+
|
|
246
|
+
if latents.ndim == 1: # repeat latents
|
|
247
|
+
latents = repeat(latents, 'd -> b d', b = batch)
|
|
248
|
+
|
|
249
|
+
condition = self.from_latent_to_condition(latents)
|
|
250
|
+
|
|
251
|
+
# generated
|
|
252
|
+
|
|
253
|
+
prompt_len = prompts.shape[-1]
|
|
254
|
+
|
|
255
|
+
generated = prompts
|
|
256
|
+
|
|
257
|
+
tokens = self.token_emb(generated)
|
|
258
|
+
|
|
259
|
+
for _ in range(max(0, seq_len - prompt_len)):
|
|
260
|
+
|
|
261
|
+
head_embed = self.decoder_head(tokens)
|
|
262
|
+
|
|
263
|
+
if exists(condition):
|
|
264
|
+
head_embed = head_embed + condition
|
|
265
|
+
|
|
266
|
+
tail_embed = self.decoder_tail(head_embed)
|
|
267
|
+
|
|
268
|
+
tail_embed = tail_embed[:, -1]
|
|
269
|
+
|
|
270
|
+
logits = self.token_unembed(tail_embed)
|
|
271
|
+
|
|
272
|
+
logits = filter_logits_fn(logits, **logit_filter_kwargs)
|
|
273
|
+
|
|
274
|
+
sampled = gumbel_sample(logits)
|
|
275
|
+
|
|
276
|
+
generated, _ = pack((generated, sampled), 'b *')
|
|
277
|
+
tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
|
|
278
|
+
|
|
279
|
+
return inverse_pack(generated)
|
|
280
|
+
|
|
281
|
+
def forward(
|
|
282
|
+
self,
|
|
283
|
+
seq,
|
|
284
|
+
return_all_losses = False
|
|
285
|
+
):
|
|
286
|
+
batch, device = seq.shape[0], seq.device
|
|
287
|
+
|
|
288
|
+
seq, labels = seq[:, :-1], seq[:, 1:]
|
|
289
|
+
|
|
290
|
+
encoder_mask = seq != self.pad_id
|
|
291
|
+
|
|
292
|
+
tokens = self.token_emb(seq)
|
|
293
|
+
|
|
294
|
+
# decoder head
|
|
295
|
+
|
|
296
|
+
tokens = self.decoder_head(tokens)
|
|
297
|
+
|
|
298
|
+
# get latent Z
|
|
299
|
+
|
|
300
|
+
latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
|
|
301
|
+
|
|
302
|
+
condition = self.from_latent_to_condition(latents)
|
|
303
|
+
|
|
304
|
+
# decoder tail
|
|
305
|
+
|
|
306
|
+
tokens = self.decoder_tail(tokens + condition)
|
|
307
|
+
|
|
308
|
+
# cross entropy loss
|
|
309
|
+
|
|
310
|
+
logits = self.token_unembed(tokens)
|
|
311
|
+
|
|
312
|
+
ar_loss = F.cross_entropy(
|
|
313
|
+
rearrange(logits, 'b n l -> b l n'),
|
|
314
|
+
labels,
|
|
315
|
+
ignore_index = self.pad_id
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# return losses
|
|
319
|
+
|
|
320
|
+
total_loss = (
|
|
321
|
+
ar_loss +
|
|
322
|
+
kl_loss * self.kl_loss_weight
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if not return_all_losses:
|
|
326
|
+
return total_loss
|
|
327
|
+
|
|
328
|
+
losses = (ar_loss, kl_loss)
|
|
329
|
+
|
|
330
|
+
return total_loss, losses
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.11.1
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2598,4 +2598,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2598
2598
|
}
|
|
2599
2599
|
```
|
|
2600
2600
|
|
|
2601
|
+
```bibtex
|
|
2602
|
+
@inproceedings{Fleuret2025TheFT,
|
|
2603
|
+
title = {The Free Transformer},
|
|
2604
|
+
author = {Franccois Fleuret},
|
|
2605
|
+
year = {2025},
|
|
2606
|
+
url = {https://api.semanticscholar.org/CorpusID:282210283}
|
|
2607
|
+
}
|
|
2608
|
+
```
|
|
2609
|
+
|
|
2601
2610
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
|
2
2
|
x_transformers/attend.py,sha256=vrFPCfr3WwsyMZJxn1Pox_8VHZVLVSMuXThW3eZmd5Q,19388
|
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=T4PUpOndC_67pxp6-rL6g3CWNq2DWdpn3UViu9rlk7Y,18563
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
|
8
|
+
x_transformers/free_transformer.py,sha256=K9dX9xj0pJgiOi9jlOLCD9Nn9eYNNgTWj9YvQLhexHw,8295
|
|
8
9
|
x_transformers/gpt_vae.py,sha256=myYSgcx66V0M4zeEGKyhY1P2HlPDHcezhaZEoo_uMdo,5715
|
|
9
10
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
10
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
@@ -13,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
|
13
14
|
x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
|
|
14
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
15
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
16
|
-
x_transformers-2.
|
|
17
|
-
x_transformers-2.
|
|
18
|
-
x_transformers-2.
|
|
19
|
-
x_transformers-2.
|
|
17
|
+
x_transformers-2.11.1.dist-info/METADATA,sha256=Rj4l6-wbfFsC7wbSWfUFQyTNWQE5EXu952aSE3B8uas,96011
|
|
18
|
+
x_transformers-2.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
x_transformers-2.11.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.11.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|