x-transformers 2.10.1__py3-none-any.whl → 2.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

x_transformers/attend.py CHANGED
@@ -549,6 +549,11 @@ class Attend(Module):
549
549
  if self.head_learned_sink:
550
550
  # add learned attention sink
551
551
  attn_sink = repeat(self.head_attn_sink, 'h -> b h i 1', b = sim.shape[0], i = sim.shape[2])
552
+
553
+ if self.cog_signed:
554
+ attn_sink, attn_sink_sign = attn_sink.abs(), attn_sink.sign()
555
+ sim_sign = cat((attn_sink_sign, sim_sign), dim = -1)
556
+
552
557
  sim = cat((attn_sink, sim), dim = -1)
553
558
 
554
559
  pre_softmax_attn = sim
@@ -43,6 +43,10 @@ def log(t, eps = 1e-20):
43
43
  def gumbel_noise(t):
44
44
  return -log(-log(torch.rand_like(t)))
45
45
 
46
+ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
47
+ noise = gumbel_noise(logits)
48
+ return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
49
+
46
50
  # function for modifying all the cached key / values
47
51
 
48
52
  def modify_cached_kv(cache, fn):
@@ -0,0 +1,330 @@
1
+ from __future__ import annotations
2
+
3
+ # https://arxiv.org/abs/2510.17558
4
+ # François Fleuret
5
+ # https://www.youtube.com/watch?v=Nao16-6l6dQ
6
+
7
+ import math
8
+
9
+ import torch
10
+ from torch import nn, Tensor, is_tensor, tensor, arange
11
+ import torch.nn.functional as F
12
+ from torch.nn import Module, ModuleList
13
+
14
+ from x_transformers.x_transformers import (
15
+ Encoder,
16
+ Decoder,
17
+ TransformerWrapper
18
+ )
19
+
20
+ from x_transformers.autoregressive_wrapper import (
21
+ gumbel_sample,
22
+ top_p,
23
+ top_k
24
+ )
25
+
26
+ from einops.layers.torch import Rearrange, Reduce
27
+ from einops import rearrange, reduce, repeat, einsum, pack, unpack
28
+
29
+ # helper functions
30
+
31
+ def exists(v):
32
+ return v is not None
33
+
34
+ def default(v, d):
35
+ return v if exists(v) else d
36
+
37
+ def log(t, eps = 1e-20):
38
+ return t.clamp_min(eps).log()
39
+
40
+ def pack_with_inverse(t, pattern):
41
+ packed, ps = pack([t], pattern)
42
+
43
+ def inverse(out, inv_pattern = None):
44
+ inv_pattern = default(inv_pattern, pattern)
45
+ unpacked, = unpack(out, ps, inv_pattern)
46
+ return unpacked
47
+
48
+ return packed, inverse
49
+
50
+ # binary mapper
51
+
52
+ NAT = math.log(2)
53
+
54
+ def binary_entropy(logits):
55
+ prob = logits.sigmoid()
56
+ not_prob = 1. - prob
57
+ return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
58
+
59
+ class BinaryMapper(Module):
60
+ def __init__(
61
+ self,
62
+ bits = 1,
63
+ kl_loss_threshold = NAT # 1 bit
64
+ ):
65
+ super().__init__()
66
+
67
+ self.bits = bits
68
+ self.num_codes = 2 ** bits
69
+ self.kl_loss_threshold = kl_loss_threshold
70
+
71
+ power_two = 2 ** arange(bits)
72
+ codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
73
+
74
+ self.register_buffer('power_two', power_two, persistent = False)
75
+ self.register_buffer('codes', codes, persistent = False)
76
+
77
+ def forward(
78
+ self,
79
+ logits,
80
+ temperature = 1.,
81
+ straight_through = None
82
+ ):
83
+ straight_through = default(straight_through, self.training)
84
+
85
+ assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
86
+
87
+ # temperature and prob for sampling
88
+
89
+ prob_for_sample = (logits / temperature).sigmoid()
90
+
91
+ # sampling
92
+
93
+ sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
94
+ indices = (self.power_two * sampled_bits).sum(dim = -1)
95
+
96
+ one_hot = F.one_hot(indices, self.num_codes).float()
97
+
98
+ # return hard one hot if not training or overridden
99
+
100
+ if not straight_through:
101
+ return one_hot
102
+
103
+ # calculate negative entropy
104
+
105
+ kl_div = self.bits * NAT - binary_entropy(logits)
106
+ aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
107
+
108
+ # get the soft G for the gradients and do a straight through
109
+
110
+ soft_G = (
111
+ einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112
+ einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113
+ ).exp()
114
+
115
+ # straight through
116
+
117
+ one_hot = one_hot + soft_G - soft_G.detach()
118
+
119
+ return one_hot, aux_kl_loss
120
+
121
+ # classes
122
+
123
+ class FreeTransformer(Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ num_tokens,
128
+ dim,
129
+ dec_head_depth,
130
+ dec_tail_depth,
131
+ enc_depth,
132
+ max_seq_len,
133
+ dim_latent = None,
134
+ attn_dim_head = 64,
135
+ heads = 8,
136
+ latent_bits = 16,
137
+ kl_loss_threshold = NAT,
138
+ binary_mapper_kwargs: dict = dict(),
139
+ enc_kwargs: dict = dict(),
140
+ dec_kwargs: dict = dict(),
141
+ kl_loss_weight = 1.,
142
+ pad_id = -1,
143
+ encoder: Module | None = None,
144
+ **kwargs
145
+ ):
146
+ super().__init__()
147
+ dim_latent = default(dim_latent, dim)
148
+
149
+ self.token_emb = nn.Embedding(num_tokens, dim)
150
+
151
+ self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
152
+
153
+ if not exists(encoder):
154
+ encoder = Encoder(
155
+ dim = dim,
156
+ depth = enc_depth,
157
+ attn_dim_head = attn_dim_head,
158
+ heads = heads,
159
+ **kwargs,
160
+ **enc_kwargs
161
+ )
162
+
163
+ self.encoder = encoder
164
+
165
+ self.to_latent_bit_logits = nn.Sequential(
166
+ Reduce('b n d -> b d', 'mean'),
167
+ nn.Linear(dim, latent_bits, bias = False),
168
+ )
169
+
170
+ self.binary_mapper = BinaryMapper(
171
+ latent_bits,
172
+ kl_loss_threshold,
173
+ **binary_mapper_kwargs
174
+ )
175
+
176
+ self.from_latent_to_condition = nn.Sequential(
177
+ nn.Linear(2 ** latent_bits, dim, bias = False),
178
+ Rearrange('b d -> b 1 d')
179
+ )
180
+
181
+ self.decoder_head = Decoder(
182
+ dim = dim,
183
+ depth = dec_head_depth,
184
+ attn_dim_head = attn_dim_head,
185
+ heads = heads,
186
+ pre_norm_has_final_norm = False,
187
+ **kwargs,
188
+ **dec_kwargs
189
+ )
190
+
191
+ self.decoder_tail = Decoder(
192
+ dim = dim,
193
+ depth = dec_tail_depth,
194
+ attn_dim_head = attn_dim_head,
195
+ heads = heads,
196
+ pre_norm_has_final_norm = True,
197
+ **kwargs,
198
+ **dec_kwargs
199
+ )
200
+
201
+ self.pad_id = pad_id
202
+
203
+ self.kl_loss_weight = kl_loss_weight
204
+
205
+ @property
206
+ def device(self):
207
+ return next(self.parameters()).device
208
+
209
+ def encode_to_latents(
210
+ self,
211
+ seq,
212
+ mask = None,
213
+ return_kl_loss = False
214
+ ):
215
+ pooled = self.encoder(seq, mask = mask)
216
+
217
+ bit_logits = self.to_latent_bit_logits(pooled)
218
+
219
+ one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
220
+
221
+ if not return_kl_loss:
222
+ return one_hot_latents
223
+
224
+ return one_hot_latents, kl_loss
225
+
226
+ @torch.no_grad()
227
+ def generate(
228
+ self,
229
+ prompts,
230
+ seq_len,
231
+ latents = None,
232
+ filter_logits_fn = top_p,
233
+ logit_filter_kwargs: dict = dict(thres = 0.9)
234
+ ):
235
+ prompts, inverse_pack = pack_with_inverse(prompts, '* n')
236
+
237
+ batch = prompts.shape[0]
238
+
239
+ # prepend embeds
240
+
241
+ condition = None
242
+ if exists(latents):
243
+ if not is_tensor(latents):
244
+ latents = tensor(latents, device = self.device)
245
+
246
+ if latents.ndim == 1: # repeat latents
247
+ latents = repeat(latents, 'd -> b d', b = batch)
248
+
249
+ condition = self.from_latent_to_condition(latents)
250
+
251
+ # generated
252
+
253
+ prompt_len = prompts.shape[-1]
254
+
255
+ generated = prompts
256
+
257
+ tokens = self.token_emb(generated)
258
+
259
+ for _ in range(max(0, seq_len - prompt_len)):
260
+
261
+ head_embed = self.decoder_head(tokens)
262
+
263
+ if exists(condition):
264
+ head_embed = head_embed + condition
265
+
266
+ tail_embed = self.decoder_tail(head_embed)
267
+
268
+ tail_embed = tail_embed[:, -1]
269
+
270
+ logits = self.token_unembed(tail_embed)
271
+
272
+ logits = filter_logits_fn(logits, **logit_filter_kwargs)
273
+
274
+ sampled = gumbel_sample(logits)
275
+
276
+ generated, _ = pack((generated, sampled), 'b *')
277
+ tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
278
+
279
+ return inverse_pack(generated)
280
+
281
+ def forward(
282
+ self,
283
+ seq,
284
+ return_all_losses = False
285
+ ):
286
+ batch, device = seq.shape[0], seq.device
287
+
288
+ seq, labels = seq[:, :-1], seq[:, 1:]
289
+
290
+ encoder_mask = seq != self.pad_id
291
+
292
+ tokens = self.token_emb(seq)
293
+
294
+ # decoder head
295
+
296
+ tokens = self.decoder_head(tokens)
297
+
298
+ # get latent Z
299
+
300
+ latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
301
+
302
+ condition = self.from_latent_to_condition(latents)
303
+
304
+ # decoder tail
305
+
306
+ tokens = self.decoder_tail(tokens)
307
+
308
+ # cross entropy loss
309
+
310
+ logits = self.token_unembed(tokens)
311
+
312
+ ar_loss = F.cross_entropy(
313
+ rearrange(logits, 'b n l -> b l n'),
314
+ labels,
315
+ ignore_index = self.pad_id
316
+ )
317
+
318
+ # return losses
319
+
320
+ total_loss = (
321
+ ar_loss +
322
+ kl_loss * self.kl_loss_weight
323
+ )
324
+
325
+ if not return_all_losses:
326
+ return total_loss
327
+
328
+ losses = (ar_loss, kl_loss)
329
+
330
+ return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.10.1
3
+ Version: 2.11.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2598,4 +2598,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2598
2598
  }
2599
2599
  ```
2600
2600
 
2601
+ ```bibtex
2602
+ @inproceedings{Fleuret2025TheFT,
2603
+ title = {The Free Transformer},
2604
+ author = {Franccois Fleuret},
2605
+ year = {2025},
2606
+ url = {https://api.semanticscholar.org/CorpusID:282210283}
2607
+ }
2608
+ ```
2609
+
2601
2610
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,10 +1,11 @@
1
1
  x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
- x_transformers/attend.py,sha256=uu4lIEfiwzZLeuBY2dJLG9709DZbWK8-on4ds8SCCJ0,19207
3
- x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
2
+ x_transformers/attend.py,sha256=vrFPCfr3WwsyMZJxn1Pox_8VHZVLVSMuXThW3eZmd5Q,19388
3
+ x_transformers/autoregressive_wrapper.py,sha256=T4PUpOndC_67pxp6-rL6g3CWNq2DWdpn3UViu9rlk7Y,18563
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
+ x_transformers/free_transformer.py,sha256=SvltfYD3K4kgy0hbdsy0S8NQ24x1wXVF1GvkF5O9GGU,8283
8
9
  x_transformers/gpt_vae.py,sha256=myYSgcx66V0M4zeEGKyhY1P2HlPDHcezhaZEoo_uMdo,5715
9
10
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
10
11
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
@@ -13,7 +14,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
13
14
  x_transformers/x_transformers.py,sha256=ADr83Fz2cehj_F7N1bMwxhAg-r48fGhlaZqw3hxoxMQ,125765
14
15
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
15
16
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
16
- x_transformers-2.10.1.dist-info/METADATA,sha256=sEfcxJr3l0W4Yga0NLHq1sMk90Zr5-Lpr-9fIlmG9H4,95799
17
- x_transformers-2.10.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- x_transformers-2.10.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
- x_transformers-2.10.1.dist-info/RECORD,,
17
+ x_transformers-2.11.0.dist-info/METADATA,sha256=22nan4aatJ8_zRwYeORRCCA733rUqInhp3oATAABLMw,96011
18
+ x_transformers-2.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ x_transformers-2.11.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
20
+ x_transformers-2.11.0.dist-info/RECORD,,