x-transformers 2.10.2__tar.gz → 2.11.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of x-transformers might be problematic. Click here for more details.

Files changed (68) hide show
  1. {x_transformers-2.10.2 → x_transformers-2.11.0}/PKG-INFO +10 -1
  2. {x_transformers-2.10.2 → x_transformers-2.11.0}/README.md +9 -0
  3. {x_transformers-2.10.2 → x_transformers-2.11.0}/pyproject.toml +1 -1
  4. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_copy.py +6 -5
  5. x_transformers-2.11.0/train_free.py +134 -0
  6. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/autoregressive_wrapper.py +4 -0
  7. x_transformers-2.11.0/x_transformers/free_transformer.py +330 -0
  8. {x_transformers-2.10.2 → x_transformers-2.11.0}/.github/FUNDING.yml +0 -0
  9. {x_transformers-2.10.2 → x_transformers-2.11.0}/.github/workflows/python-publish.yml +0 -0
  10. {x_transformers-2.10.2 → x_transformers-2.11.0}/.github/workflows/python-test.yaml +0 -0
  11. {x_transformers-2.10.2 → x_transformers-2.11.0}/.gitignore +0 -0
  12. {x_transformers-2.10.2 → x_transformers-2.11.0}/LICENSE +0 -0
  13. {x_transformers-2.10.2 → x_transformers-2.11.0}/data/README.md +0 -0
  14. {x_transformers-2.10.2 → x_transformers-2.11.0}/data/enwik8.gz +0 -0
  15. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/all-attention.png +0 -0
  16. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/attention-on-attention.png +0 -0
  17. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/cosine-sim-attention.png +0 -0
  18. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/deepnorm.png +0 -0
  19. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias-linear.png +0 -0
  20. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias-log.png +0 -0
  21. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  22. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias.png +0 -0
  23. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/enhanced-recurrence.png +0 -0
  24. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/fcm.png +0 -0
  25. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/ffglu.png +0 -0
  26. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/flash-attention.png +0 -0
  27. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/gate_values.png +0 -0
  28. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/gating.png +0 -0
  29. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/length-extrapolation-scale.png +0 -0
  30. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/macaron-1.png +0 -0
  31. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/macaron-2.png +0 -0
  32. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/memory-transformer.png +0 -0
  33. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/normformer.png +0 -0
  34. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/pia.png +0 -0
  35. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/qknorm-analysis.png +0 -0
  36. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/resi_dual.png +0 -0
  37. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/residual_attn.png +0 -0
  38. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/rezero.png +0 -0
  39. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/rotary.png +0 -0
  40. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/sandwich-2.png +0 -0
  41. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/sandwich.png +0 -0
  42. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/sandwich_norm.png +0 -0
  43. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/scalenorm.png +0 -0
  44. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/talking-heads.png +0 -0
  45. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/topk-attention.png +0 -0
  46. {x_transformers-2.10.2 → x_transformers-2.11.0}/images/xval.png +0 -0
  47. {x_transformers-2.10.2 → x_transformers-2.11.0}/tests/test_x_transformers.py +0 -0
  48. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_belief_state.py +0 -0
  49. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_entropy_tokenizer.py +0 -0
  50. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_enwik8.py +0 -0
  51. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_gpt_vae.py +0 -0
  52. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_length_extrapolate.py +0 -0
  53. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_parity.py +0 -0
  54. {x_transformers-2.10.2 → x_transformers-2.11.0}/train_with_muon.py +0 -0
  55. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/__init__.py +0 -0
  56. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/attend.py +0 -0
  57. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/belief_state_wrapper.py +0 -0
  58. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/continuous.py +0 -0
  59. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/dpo.py +0 -0
  60. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  61. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/gpt_vae.py +0 -0
  62. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/multi_input.py +0 -0
  63. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/neo_mlp.py +0 -0
  64. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  65. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/up_wrapper.py +0 -0
  66. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/x_transformers.py +0 -0
  67. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  68. {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.10.2
3
+ Version: 2.11.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2598,4 +2598,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2598
2598
  }
2599
2599
  ```
2600
2600
 
2601
+ ```bibtex
2602
+ @inproceedings{Fleuret2025TheFT,
2603
+ title = {The Free Transformer},
2604
+ author = {Franccois Fleuret},
2605
+ year = {2025},
2606
+ url = {https://api.semanticscholar.org/CorpusID:282210283}
2607
+ }
2608
+ ```
2609
+
2601
2610
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2549,4 +2549,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2549
2549
  }
2550
2550
  ```
2551
2551
 
2552
+ ```bibtex
2553
+ @inproceedings{Fleuret2025TheFT,
2554
+ title = {The Free Transformer},
2555
+ author = {Franccois Fleuret},
2556
+ year = {2025},
2557
+ url = {https://api.semanticscholar.org/CorpusID:282210283}
2558
+ }
2559
+ ```
2560
+
2552
2561
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.10.2"
3
+ version = "2.11.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -12,15 +12,16 @@ GENERATE_EVERY = 100
12
12
  NUM_TOKENS = 16 + 2
13
13
  ENC_SEQ_LEN = 32
14
14
  DEC_SEQ_LEN = 64 + 1
15
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
16
 
16
17
  # helpers
17
18
 
18
19
  def cycle():
19
20
  while True:
20
- prefix = torch.ones((BATCH_SIZE, 1)).long()
21
- src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long()
21
+ prefix = torch.ones((BATCH_SIZE, 1)).long().to(DEVICE)
22
+ src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().to(DEVICE)
22
23
  tgt = torch.cat((prefix, src, src), 1)
23
- src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool()
24
+ src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().to(DEVICE)
24
25
  yield (src, tgt, src_mask)
25
26
 
26
27
  # instantiate model
@@ -39,7 +40,7 @@ model = XTransformer(
39
40
  dec_heads = 8,
40
41
  dec_max_seq_len = DEC_SEQ_LEN,
41
42
  dec_attn_cog_signed = True
42
- )
43
+ ).to(DEVICE)
43
44
 
44
45
  # optimizer
45
46
 
@@ -63,7 +64,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
63
64
  model.eval()
64
65
  src, _, src_mask = next(cycle())
65
66
  src, src_mask = src[:1], src_mask[:1]
66
- start_tokens = (torch.ones((1, 1)) * 1).long()
67
+ start_tokens = (torch.ones((1, 1)) * 1).long().to(DEVICE)
67
68
 
68
69
  sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
69
70
  incorrects = (src != sample).long().abs().sum()
@@ -0,0 +1,134 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "tqdm",
4
+ # "x-transformers>=2.11.0",
5
+ # ]
6
+ # ///
7
+
8
+ from x_transformers.free_transformer import FreeTransformer
9
+
10
+ from math import log
11
+ import random
12
+ import tqdm
13
+ import gzip
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.optim as optim
18
+ from torch import tensor
19
+ from torch.nn import functional as F
20
+ from torch.utils.data import DataLoader, Dataset
21
+
22
+ # constants
23
+
24
+ NUM_BATCHES = int(1e5)
25
+ BATCH_SIZE = 4
26
+ GRADIENT_ACCUMULATE_EVERY = 4
27
+ LEARNING_RATE = 1e-4
28
+ VALIDATE_EVERY = 100
29
+ GENERATE_EVERY = 250
30
+ GENERATE_LENGTH = 512
31
+ PRIME_LENGTH = 32
32
+ SEQ_LEN = 512
33
+
34
+ LATENT_BITS = 8
35
+ NAT = log(2)
36
+
37
+ # helpers
38
+
39
+ def cycle(loader):
40
+ while True:
41
+ for data in loader:
42
+ yield data
43
+
44
+ def decode_token(token):
45
+ return str(chr(max(32, token)))
46
+
47
+ def decode_tokens(tokens):
48
+ return ''.join(list(map(decode_token, tokens)))
49
+
50
+ # instantiate GPT-like decoder model
51
+
52
+ model = FreeTransformer(
53
+ num_tokens = 256,
54
+ max_seq_len = SEQ_LEN,
55
+ dim = 512,
56
+ heads = 8,
57
+ rotary_pos_emb = True,
58
+ dec_head_depth = 4,
59
+ dec_tail_depth = 4,
60
+ enc_depth = 3,
61
+ kl_loss_weight = 1.,
62
+ kl_loss_threshold = NAT,
63
+ latent_bits = LATENT_BITS
64
+ ).cuda()
65
+
66
+ rand_index = torch.randint(0, 2 ** LATENT_BITS, ())
67
+ latents = F.one_hot(rand_index, 2 ** LATENT_BITS).float().cuda()
68
+
69
+ # prepare enwik8 data
70
+
71
+ with gzip.open('./data/enwik8.gz') as file:
72
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
73
+ train_x, valid_x = np.split(data, [int(90e6)])
74
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
75
+
76
+ class TextSamplerDataset(Dataset):
77
+ def __init__(self, data, seq_len):
78
+ super().__init__()
79
+ self.data = data
80
+ self.seq_len = seq_len
81
+
82
+ def __getitem__(self, index):
83
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
84
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
85
+ return full_seq.cuda()
86
+
87
+ def __len__(self):
88
+ return self.data.size(0) // self.seq_len
89
+
90
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
91
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
92
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
93
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
94
+
95
+ # optimizer
96
+
97
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
98
+
99
+ # training
100
+
101
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
102
+ model.train()
103
+
104
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
105
+ loss, (ar_loss, vae_kl_loss) = model(next(train_loader), return_all_losses = True)
106
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
107
+
108
+ print(f'training loss: {ar_loss.item():.4f}\t| kl loss: {vae_kl_loss.item():.4f}')
109
+
110
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111
+ optim.step()
112
+ optim.zero_grad()
113
+
114
+ if i % VALIDATE_EVERY == 0:
115
+ model.eval()
116
+ with torch.no_grad():
117
+ loss, (ar_loss, _) = model(next(val_loader), return_all_losses = True)
118
+ print(f'validation loss: {ar_loss.item():.4f}')
119
+
120
+ if i % GENERATE_EVERY == 0:
121
+ model.eval()
122
+ inp = random.choice(val_dataset)[:PRIME_LENGTH]
123
+ prime = decode_tokens(inp)
124
+ print(f'%s \n\n %s', (prime, '*' * 100))
125
+
126
+ sample = model.generate(
127
+ prompts = inp,
128
+ seq_len = GENERATE_LENGTH,
129
+ latents = latents
130
+ )
131
+
132
+ output_str = decode_tokens(sample)
133
+
134
+ print(f'\n\nlatent {rand_index.tolist()} - ', output_str)
@@ -43,6 +43,10 @@ def log(t, eps = 1e-20):
43
43
  def gumbel_noise(t):
44
44
  return -log(-log(torch.rand_like(t)))
45
45
 
46
+ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
47
+ noise = gumbel_noise(logits)
48
+ return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
49
+
46
50
  # function for modifying all the cached key / values
47
51
 
48
52
  def modify_cached_kv(cache, fn):
@@ -0,0 +1,330 @@
1
+ from __future__ import annotations
2
+
3
+ # https://arxiv.org/abs/2510.17558
4
+ # François Fleuret
5
+ # https://www.youtube.com/watch?v=Nao16-6l6dQ
6
+
7
+ import math
8
+
9
+ import torch
10
+ from torch import nn, Tensor, is_tensor, tensor, arange
11
+ import torch.nn.functional as F
12
+ from torch.nn import Module, ModuleList
13
+
14
+ from x_transformers.x_transformers import (
15
+ Encoder,
16
+ Decoder,
17
+ TransformerWrapper
18
+ )
19
+
20
+ from x_transformers.autoregressive_wrapper import (
21
+ gumbel_sample,
22
+ top_p,
23
+ top_k
24
+ )
25
+
26
+ from einops.layers.torch import Rearrange, Reduce
27
+ from einops import rearrange, reduce, repeat, einsum, pack, unpack
28
+
29
+ # helper functions
30
+
31
+ def exists(v):
32
+ return v is not None
33
+
34
+ def default(v, d):
35
+ return v if exists(v) else d
36
+
37
+ def log(t, eps = 1e-20):
38
+ return t.clamp_min(eps).log()
39
+
40
+ def pack_with_inverse(t, pattern):
41
+ packed, ps = pack([t], pattern)
42
+
43
+ def inverse(out, inv_pattern = None):
44
+ inv_pattern = default(inv_pattern, pattern)
45
+ unpacked, = unpack(out, ps, inv_pattern)
46
+ return unpacked
47
+
48
+ return packed, inverse
49
+
50
+ # binary mapper
51
+
52
+ NAT = math.log(2)
53
+
54
+ def binary_entropy(logits):
55
+ prob = logits.sigmoid()
56
+ not_prob = 1. - prob
57
+ return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
58
+
59
+ class BinaryMapper(Module):
60
+ def __init__(
61
+ self,
62
+ bits = 1,
63
+ kl_loss_threshold = NAT # 1 bit
64
+ ):
65
+ super().__init__()
66
+
67
+ self.bits = bits
68
+ self.num_codes = 2 ** bits
69
+ self.kl_loss_threshold = kl_loss_threshold
70
+
71
+ power_two = 2 ** arange(bits)
72
+ codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
73
+
74
+ self.register_buffer('power_two', power_two, persistent = False)
75
+ self.register_buffer('codes', codes, persistent = False)
76
+
77
+ def forward(
78
+ self,
79
+ logits,
80
+ temperature = 1.,
81
+ straight_through = None
82
+ ):
83
+ straight_through = default(straight_through, self.training)
84
+
85
+ assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
86
+
87
+ # temperature and prob for sampling
88
+
89
+ prob_for_sample = (logits / temperature).sigmoid()
90
+
91
+ # sampling
92
+
93
+ sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
94
+ indices = (self.power_two * sampled_bits).sum(dim = -1)
95
+
96
+ one_hot = F.one_hot(indices, self.num_codes).float()
97
+
98
+ # return hard one hot if not training or overridden
99
+
100
+ if not straight_through:
101
+ return one_hot
102
+
103
+ # calculate negative entropy
104
+
105
+ kl_div = self.bits * NAT - binary_entropy(logits)
106
+ aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
107
+
108
+ # get the soft G for the gradients and do a straight through
109
+
110
+ soft_G = (
111
+ einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
112
+ einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
113
+ ).exp()
114
+
115
+ # straight through
116
+
117
+ one_hot = one_hot + soft_G - soft_G.detach()
118
+
119
+ return one_hot, aux_kl_loss
120
+
121
+ # classes
122
+
123
+ class FreeTransformer(Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ num_tokens,
128
+ dim,
129
+ dec_head_depth,
130
+ dec_tail_depth,
131
+ enc_depth,
132
+ max_seq_len,
133
+ dim_latent = None,
134
+ attn_dim_head = 64,
135
+ heads = 8,
136
+ latent_bits = 16,
137
+ kl_loss_threshold = NAT,
138
+ binary_mapper_kwargs: dict = dict(),
139
+ enc_kwargs: dict = dict(),
140
+ dec_kwargs: dict = dict(),
141
+ kl_loss_weight = 1.,
142
+ pad_id = -1,
143
+ encoder: Module | None = None,
144
+ **kwargs
145
+ ):
146
+ super().__init__()
147
+ dim_latent = default(dim_latent, dim)
148
+
149
+ self.token_emb = nn.Embedding(num_tokens, dim)
150
+
151
+ self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
152
+
153
+ if not exists(encoder):
154
+ encoder = Encoder(
155
+ dim = dim,
156
+ depth = enc_depth,
157
+ attn_dim_head = attn_dim_head,
158
+ heads = heads,
159
+ **kwargs,
160
+ **enc_kwargs
161
+ )
162
+
163
+ self.encoder = encoder
164
+
165
+ self.to_latent_bit_logits = nn.Sequential(
166
+ Reduce('b n d -> b d', 'mean'),
167
+ nn.Linear(dim, latent_bits, bias = False),
168
+ )
169
+
170
+ self.binary_mapper = BinaryMapper(
171
+ latent_bits,
172
+ kl_loss_threshold,
173
+ **binary_mapper_kwargs
174
+ )
175
+
176
+ self.from_latent_to_condition = nn.Sequential(
177
+ nn.Linear(2 ** latent_bits, dim, bias = False),
178
+ Rearrange('b d -> b 1 d')
179
+ )
180
+
181
+ self.decoder_head = Decoder(
182
+ dim = dim,
183
+ depth = dec_head_depth,
184
+ attn_dim_head = attn_dim_head,
185
+ heads = heads,
186
+ pre_norm_has_final_norm = False,
187
+ **kwargs,
188
+ **dec_kwargs
189
+ )
190
+
191
+ self.decoder_tail = Decoder(
192
+ dim = dim,
193
+ depth = dec_tail_depth,
194
+ attn_dim_head = attn_dim_head,
195
+ heads = heads,
196
+ pre_norm_has_final_norm = True,
197
+ **kwargs,
198
+ **dec_kwargs
199
+ )
200
+
201
+ self.pad_id = pad_id
202
+
203
+ self.kl_loss_weight = kl_loss_weight
204
+
205
+ @property
206
+ def device(self):
207
+ return next(self.parameters()).device
208
+
209
+ def encode_to_latents(
210
+ self,
211
+ seq,
212
+ mask = None,
213
+ return_kl_loss = False
214
+ ):
215
+ pooled = self.encoder(seq, mask = mask)
216
+
217
+ bit_logits = self.to_latent_bit_logits(pooled)
218
+
219
+ one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
220
+
221
+ if not return_kl_loss:
222
+ return one_hot_latents
223
+
224
+ return one_hot_latents, kl_loss
225
+
226
+ @torch.no_grad()
227
+ def generate(
228
+ self,
229
+ prompts,
230
+ seq_len,
231
+ latents = None,
232
+ filter_logits_fn = top_p,
233
+ logit_filter_kwargs: dict = dict(thres = 0.9)
234
+ ):
235
+ prompts, inverse_pack = pack_with_inverse(prompts, '* n')
236
+
237
+ batch = prompts.shape[0]
238
+
239
+ # prepend embeds
240
+
241
+ condition = None
242
+ if exists(latents):
243
+ if not is_tensor(latents):
244
+ latents = tensor(latents, device = self.device)
245
+
246
+ if latents.ndim == 1: # repeat latents
247
+ latents = repeat(latents, 'd -> b d', b = batch)
248
+
249
+ condition = self.from_latent_to_condition(latents)
250
+
251
+ # generated
252
+
253
+ prompt_len = prompts.shape[-1]
254
+
255
+ generated = prompts
256
+
257
+ tokens = self.token_emb(generated)
258
+
259
+ for _ in range(max(0, seq_len - prompt_len)):
260
+
261
+ head_embed = self.decoder_head(tokens)
262
+
263
+ if exists(condition):
264
+ head_embed = head_embed + condition
265
+
266
+ tail_embed = self.decoder_tail(head_embed)
267
+
268
+ tail_embed = tail_embed[:, -1]
269
+
270
+ logits = self.token_unembed(tail_embed)
271
+
272
+ logits = filter_logits_fn(logits, **logit_filter_kwargs)
273
+
274
+ sampled = gumbel_sample(logits)
275
+
276
+ generated, _ = pack((generated, sampled), 'b *')
277
+ tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
278
+
279
+ return inverse_pack(generated)
280
+
281
+ def forward(
282
+ self,
283
+ seq,
284
+ return_all_losses = False
285
+ ):
286
+ batch, device = seq.shape[0], seq.device
287
+
288
+ seq, labels = seq[:, :-1], seq[:, 1:]
289
+
290
+ encoder_mask = seq != self.pad_id
291
+
292
+ tokens = self.token_emb(seq)
293
+
294
+ # decoder head
295
+
296
+ tokens = self.decoder_head(tokens)
297
+
298
+ # get latent Z
299
+
300
+ latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
301
+
302
+ condition = self.from_latent_to_condition(latents)
303
+
304
+ # decoder tail
305
+
306
+ tokens = self.decoder_tail(tokens)
307
+
308
+ # cross entropy loss
309
+
310
+ logits = self.token_unembed(tokens)
311
+
312
+ ar_loss = F.cross_entropy(
313
+ rearrange(logits, 'b n l -> b l n'),
314
+ labels,
315
+ ignore_index = self.pad_id
316
+ )
317
+
318
+ # return losses
319
+
320
+ total_loss = (
321
+ ar_loss +
322
+ kl_loss * self.kl_loss_weight
323
+ )
324
+
325
+ if not return_all_losses:
326
+ return total_loss
327
+
328
+ losses = (ar_loss, kl_loss)
329
+
330
+ return total_loss, losses
File without changes