x-transformers 2.10.2__tar.gz → 2.11.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.10.2 → x_transformers-2.11.0}/PKG-INFO +10 -1
- {x_transformers-2.10.2 → x_transformers-2.11.0}/README.md +9 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/pyproject.toml +1 -1
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_copy.py +6 -5
- x_transformers-2.11.0/train_free.py +134 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/autoregressive_wrapper.py +4 -0
- x_transformers-2.11.0/x_transformers/free_transformer.py +330 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/.gitignore +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/LICENSE +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/data/README.md +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/data/enwik8.gz +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/all-attention.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/deepnorm.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/fcm.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/ffglu.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/flash-attention.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/gate_values.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/gating.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/macaron-1.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/macaron-2.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/normformer.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/pia.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/resi_dual.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/residual_attn.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/rezero.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/rotary.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/sandwich.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/scalenorm.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/talking-heads.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/topk-attention.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/images/xval.png +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/tests/test_x_transformers.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_belief_state.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_enwik8.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_gpt_vae.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_parity.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/train_with_muon.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/__init__.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.10.2 → x_transformers-2.11.0}/x_transformers/xval.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: x-transformers
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.11.0
|
|
4
4
|
Summary: X-Transformers
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
|
@@ -2598,4 +2598,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2598
2598
|
}
|
|
2599
2599
|
```
|
|
2600
2600
|
|
|
2601
|
+
```bibtex
|
|
2602
|
+
@inproceedings{Fleuret2025TheFT,
|
|
2603
|
+
title = {The Free Transformer},
|
|
2604
|
+
author = {Franccois Fleuret},
|
|
2605
|
+
year = {2025},
|
|
2606
|
+
url = {https://api.semanticscholar.org/CorpusID:282210283}
|
|
2607
|
+
}
|
|
2608
|
+
```
|
|
2609
|
+
|
|
2601
2610
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
|
@@ -2549,4 +2549,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
|
2549
2549
|
}
|
|
2550
2550
|
```
|
|
2551
2551
|
|
|
2552
|
+
```bibtex
|
|
2553
|
+
@inproceedings{Fleuret2025TheFT,
|
|
2554
|
+
title = {The Free Transformer},
|
|
2555
|
+
author = {Franccois Fleuret},
|
|
2556
|
+
year = {2025},
|
|
2557
|
+
url = {https://api.semanticscholar.org/CorpusID:282210283}
|
|
2558
|
+
}
|
|
2559
|
+
```
|
|
2560
|
+
|
|
2552
2561
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
|
@@ -12,15 +12,16 @@ GENERATE_EVERY = 100
|
|
|
12
12
|
NUM_TOKENS = 16 + 2
|
|
13
13
|
ENC_SEQ_LEN = 32
|
|
14
14
|
DEC_SEQ_LEN = 64 + 1
|
|
15
|
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
15
16
|
|
|
16
17
|
# helpers
|
|
17
18
|
|
|
18
19
|
def cycle():
|
|
19
20
|
while True:
|
|
20
|
-
prefix = torch.ones((BATCH_SIZE, 1)).long()
|
|
21
|
-
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long()
|
|
21
|
+
prefix = torch.ones((BATCH_SIZE, 1)).long().to(DEVICE)
|
|
22
|
+
src = torch.randint(2, NUM_TOKENS, (BATCH_SIZE, ENC_SEQ_LEN)).long().to(DEVICE)
|
|
22
23
|
tgt = torch.cat((prefix, src, src), 1)
|
|
23
|
-
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool()
|
|
24
|
+
src_mask = torch.ones(BATCH_SIZE, src.shape[1]).bool().to(DEVICE)
|
|
24
25
|
yield (src, tgt, src_mask)
|
|
25
26
|
|
|
26
27
|
# instantiate model
|
|
@@ -39,7 +40,7 @@ model = XTransformer(
|
|
|
39
40
|
dec_heads = 8,
|
|
40
41
|
dec_max_seq_len = DEC_SEQ_LEN,
|
|
41
42
|
dec_attn_cog_signed = True
|
|
42
|
-
)
|
|
43
|
+
).to(DEVICE)
|
|
43
44
|
|
|
44
45
|
# optimizer
|
|
45
46
|
|
|
@@ -63,7 +64,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
63
64
|
model.eval()
|
|
64
65
|
src, _, src_mask = next(cycle())
|
|
65
66
|
src, src_mask = src[:1], src_mask[:1]
|
|
66
|
-
start_tokens = (torch.ones((1, 1)) * 1).long()
|
|
67
|
+
start_tokens = (torch.ones((1, 1)) * 1).long().to(DEVICE)
|
|
67
68
|
|
|
68
69
|
sample = model.generate(src, start_tokens, ENC_SEQ_LEN, mask = src_mask)
|
|
69
70
|
incorrects = (src != sample).long().abs().sum()
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "tqdm",
|
|
4
|
+
# "x-transformers>=2.11.0",
|
|
5
|
+
# ]
|
|
6
|
+
# ///
|
|
7
|
+
|
|
8
|
+
from x_transformers.free_transformer import FreeTransformer
|
|
9
|
+
|
|
10
|
+
from math import log
|
|
11
|
+
import random
|
|
12
|
+
import tqdm
|
|
13
|
+
import gzip
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.optim as optim
|
|
18
|
+
from torch import tensor
|
|
19
|
+
from torch.nn import functional as F
|
|
20
|
+
from torch.utils.data import DataLoader, Dataset
|
|
21
|
+
|
|
22
|
+
# constants
|
|
23
|
+
|
|
24
|
+
NUM_BATCHES = int(1e5)
|
|
25
|
+
BATCH_SIZE = 4
|
|
26
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
|
27
|
+
LEARNING_RATE = 1e-4
|
|
28
|
+
VALIDATE_EVERY = 100
|
|
29
|
+
GENERATE_EVERY = 250
|
|
30
|
+
GENERATE_LENGTH = 512
|
|
31
|
+
PRIME_LENGTH = 32
|
|
32
|
+
SEQ_LEN = 512
|
|
33
|
+
|
|
34
|
+
LATENT_BITS = 8
|
|
35
|
+
NAT = log(2)
|
|
36
|
+
|
|
37
|
+
# helpers
|
|
38
|
+
|
|
39
|
+
def cycle(loader):
|
|
40
|
+
while True:
|
|
41
|
+
for data in loader:
|
|
42
|
+
yield data
|
|
43
|
+
|
|
44
|
+
def decode_token(token):
|
|
45
|
+
return str(chr(max(32, token)))
|
|
46
|
+
|
|
47
|
+
def decode_tokens(tokens):
|
|
48
|
+
return ''.join(list(map(decode_token, tokens)))
|
|
49
|
+
|
|
50
|
+
# instantiate GPT-like decoder model
|
|
51
|
+
|
|
52
|
+
model = FreeTransformer(
|
|
53
|
+
num_tokens = 256,
|
|
54
|
+
max_seq_len = SEQ_LEN,
|
|
55
|
+
dim = 512,
|
|
56
|
+
heads = 8,
|
|
57
|
+
rotary_pos_emb = True,
|
|
58
|
+
dec_head_depth = 4,
|
|
59
|
+
dec_tail_depth = 4,
|
|
60
|
+
enc_depth = 3,
|
|
61
|
+
kl_loss_weight = 1.,
|
|
62
|
+
kl_loss_threshold = NAT,
|
|
63
|
+
latent_bits = LATENT_BITS
|
|
64
|
+
).cuda()
|
|
65
|
+
|
|
66
|
+
rand_index = torch.randint(0, 2 ** LATENT_BITS, ())
|
|
67
|
+
latents = F.one_hot(rand_index, 2 ** LATENT_BITS).float().cuda()
|
|
68
|
+
|
|
69
|
+
# prepare enwik8 data
|
|
70
|
+
|
|
71
|
+
with gzip.open('./data/enwik8.gz') as file:
|
|
72
|
+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
|
|
73
|
+
train_x, valid_x = np.split(data, [int(90e6)])
|
|
74
|
+
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
|
|
75
|
+
|
|
76
|
+
class TextSamplerDataset(Dataset):
|
|
77
|
+
def __init__(self, data, seq_len):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.data = data
|
|
80
|
+
self.seq_len = seq_len
|
|
81
|
+
|
|
82
|
+
def __getitem__(self, index):
|
|
83
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
|
|
84
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
|
85
|
+
return full_seq.cuda()
|
|
86
|
+
|
|
87
|
+
def __len__(self):
|
|
88
|
+
return self.data.size(0) // self.seq_len
|
|
89
|
+
|
|
90
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
|
91
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
|
92
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
|
93
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
|
94
|
+
|
|
95
|
+
# optimizer
|
|
96
|
+
|
|
97
|
+
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
|
98
|
+
|
|
99
|
+
# training
|
|
100
|
+
|
|
101
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
102
|
+
model.train()
|
|
103
|
+
|
|
104
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
105
|
+
loss, (ar_loss, vae_kl_loss) = model(next(train_loader), return_all_losses = True)
|
|
106
|
+
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
|
107
|
+
|
|
108
|
+
print(f'training loss: {ar_loss.item():.4f}\t| kl loss: {vae_kl_loss.item():.4f}')
|
|
109
|
+
|
|
110
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
111
|
+
optim.step()
|
|
112
|
+
optim.zero_grad()
|
|
113
|
+
|
|
114
|
+
if i % VALIDATE_EVERY == 0:
|
|
115
|
+
model.eval()
|
|
116
|
+
with torch.no_grad():
|
|
117
|
+
loss, (ar_loss, _) = model(next(val_loader), return_all_losses = True)
|
|
118
|
+
print(f'validation loss: {ar_loss.item():.4f}')
|
|
119
|
+
|
|
120
|
+
if i % GENERATE_EVERY == 0:
|
|
121
|
+
model.eval()
|
|
122
|
+
inp = random.choice(val_dataset)[:PRIME_LENGTH]
|
|
123
|
+
prime = decode_tokens(inp)
|
|
124
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
125
|
+
|
|
126
|
+
sample = model.generate(
|
|
127
|
+
prompts = inp,
|
|
128
|
+
seq_len = GENERATE_LENGTH,
|
|
129
|
+
latents = latents
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
output_str = decode_tokens(sample)
|
|
133
|
+
|
|
134
|
+
print(f'\n\nlatent {rand_index.tolist()} - ', output_str)
|
|
@@ -43,6 +43,10 @@ def log(t, eps = 1e-20):
|
|
|
43
43
|
def gumbel_noise(t):
|
|
44
44
|
return -log(-log(torch.rand_like(t)))
|
|
45
45
|
|
|
46
|
+
def gumbel_sample(logits, temperature = 1., eps = 1e-6):
|
|
47
|
+
noise = gumbel_noise(logits)
|
|
48
|
+
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
|
|
49
|
+
|
|
46
50
|
# function for modifying all the cached key / values
|
|
47
51
|
|
|
48
52
|
def modify_cached_kv(cache, fn):
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# https://arxiv.org/abs/2510.17558
|
|
4
|
+
# François Fleuret
|
|
5
|
+
# https://www.youtube.com/watch?v=Nao16-6l6dQ
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn, Tensor, is_tensor, tensor, arange
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch.nn import Module, ModuleList
|
|
13
|
+
|
|
14
|
+
from x_transformers.x_transformers import (
|
|
15
|
+
Encoder,
|
|
16
|
+
Decoder,
|
|
17
|
+
TransformerWrapper
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from x_transformers.autoregressive_wrapper import (
|
|
21
|
+
gumbel_sample,
|
|
22
|
+
top_p,
|
|
23
|
+
top_k
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
27
|
+
from einops import rearrange, reduce, repeat, einsum, pack, unpack
|
|
28
|
+
|
|
29
|
+
# helper functions
|
|
30
|
+
|
|
31
|
+
def exists(v):
|
|
32
|
+
return v is not None
|
|
33
|
+
|
|
34
|
+
def default(v, d):
|
|
35
|
+
return v if exists(v) else d
|
|
36
|
+
|
|
37
|
+
def log(t, eps = 1e-20):
|
|
38
|
+
return t.clamp_min(eps).log()
|
|
39
|
+
|
|
40
|
+
def pack_with_inverse(t, pattern):
|
|
41
|
+
packed, ps = pack([t], pattern)
|
|
42
|
+
|
|
43
|
+
def inverse(out, inv_pattern = None):
|
|
44
|
+
inv_pattern = default(inv_pattern, pattern)
|
|
45
|
+
unpacked, = unpack(out, ps, inv_pattern)
|
|
46
|
+
return unpacked
|
|
47
|
+
|
|
48
|
+
return packed, inverse
|
|
49
|
+
|
|
50
|
+
# binary mapper
|
|
51
|
+
|
|
52
|
+
NAT = math.log(2)
|
|
53
|
+
|
|
54
|
+
def binary_entropy(logits):
|
|
55
|
+
prob = logits.sigmoid()
|
|
56
|
+
not_prob = 1. - prob
|
|
57
|
+
return -(prob * F.logsigmoid(logits) + not_prob * F.logsigmoid(-logits)).sum(dim = -1)
|
|
58
|
+
|
|
59
|
+
class BinaryMapper(Module):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
bits = 1,
|
|
63
|
+
kl_loss_threshold = NAT # 1 bit
|
|
64
|
+
):
|
|
65
|
+
super().__init__()
|
|
66
|
+
|
|
67
|
+
self.bits = bits
|
|
68
|
+
self.num_codes = 2 ** bits
|
|
69
|
+
self.kl_loss_threshold = kl_loss_threshold
|
|
70
|
+
|
|
71
|
+
power_two = 2 ** arange(bits)
|
|
72
|
+
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
|
|
73
|
+
|
|
74
|
+
self.register_buffer('power_two', power_two, persistent = False)
|
|
75
|
+
self.register_buffer('codes', codes, persistent = False)
|
|
76
|
+
|
|
77
|
+
def forward(
|
|
78
|
+
self,
|
|
79
|
+
logits,
|
|
80
|
+
temperature = 1.,
|
|
81
|
+
straight_through = None
|
|
82
|
+
):
|
|
83
|
+
straight_through = default(straight_through, self.training)
|
|
84
|
+
|
|
85
|
+
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
|
|
86
|
+
|
|
87
|
+
# temperature and prob for sampling
|
|
88
|
+
|
|
89
|
+
prob_for_sample = (logits / temperature).sigmoid()
|
|
90
|
+
|
|
91
|
+
# sampling
|
|
92
|
+
|
|
93
|
+
sampled_bits = (torch.rand_like(logits) <= prob_for_sample).long()
|
|
94
|
+
indices = (self.power_two * sampled_bits).sum(dim = -1)
|
|
95
|
+
|
|
96
|
+
one_hot = F.one_hot(indices, self.num_codes).float()
|
|
97
|
+
|
|
98
|
+
# return hard one hot if not training or overridden
|
|
99
|
+
|
|
100
|
+
if not straight_through:
|
|
101
|
+
return one_hot
|
|
102
|
+
|
|
103
|
+
# calculate negative entropy
|
|
104
|
+
|
|
105
|
+
kl_div = self.bits * NAT - binary_entropy(logits)
|
|
106
|
+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
107
|
+
|
|
108
|
+
# get the soft G for the gradients and do a straight through
|
|
109
|
+
|
|
110
|
+
soft_G = (
|
|
111
|
+
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
|
|
112
|
+
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
|
|
113
|
+
).exp()
|
|
114
|
+
|
|
115
|
+
# straight through
|
|
116
|
+
|
|
117
|
+
one_hot = one_hot + soft_G - soft_G.detach()
|
|
118
|
+
|
|
119
|
+
return one_hot, aux_kl_loss
|
|
120
|
+
|
|
121
|
+
# classes
|
|
122
|
+
|
|
123
|
+
class FreeTransformer(Module):
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
*,
|
|
127
|
+
num_tokens,
|
|
128
|
+
dim,
|
|
129
|
+
dec_head_depth,
|
|
130
|
+
dec_tail_depth,
|
|
131
|
+
enc_depth,
|
|
132
|
+
max_seq_len,
|
|
133
|
+
dim_latent = None,
|
|
134
|
+
attn_dim_head = 64,
|
|
135
|
+
heads = 8,
|
|
136
|
+
latent_bits = 16,
|
|
137
|
+
kl_loss_threshold = NAT,
|
|
138
|
+
binary_mapper_kwargs: dict = dict(),
|
|
139
|
+
enc_kwargs: dict = dict(),
|
|
140
|
+
dec_kwargs: dict = dict(),
|
|
141
|
+
kl_loss_weight = 1.,
|
|
142
|
+
pad_id = -1,
|
|
143
|
+
encoder: Module | None = None,
|
|
144
|
+
**kwargs
|
|
145
|
+
):
|
|
146
|
+
super().__init__()
|
|
147
|
+
dim_latent = default(dim_latent, dim)
|
|
148
|
+
|
|
149
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
|
150
|
+
|
|
151
|
+
self.token_unembed = nn.Linear(dim, num_tokens, bias = False)
|
|
152
|
+
|
|
153
|
+
if not exists(encoder):
|
|
154
|
+
encoder = Encoder(
|
|
155
|
+
dim = dim,
|
|
156
|
+
depth = enc_depth,
|
|
157
|
+
attn_dim_head = attn_dim_head,
|
|
158
|
+
heads = heads,
|
|
159
|
+
**kwargs,
|
|
160
|
+
**enc_kwargs
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
self.encoder = encoder
|
|
164
|
+
|
|
165
|
+
self.to_latent_bit_logits = nn.Sequential(
|
|
166
|
+
Reduce('b n d -> b d', 'mean'),
|
|
167
|
+
nn.Linear(dim, latent_bits, bias = False),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self.binary_mapper = BinaryMapper(
|
|
171
|
+
latent_bits,
|
|
172
|
+
kl_loss_threshold,
|
|
173
|
+
**binary_mapper_kwargs
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.from_latent_to_condition = nn.Sequential(
|
|
177
|
+
nn.Linear(2 ** latent_bits, dim, bias = False),
|
|
178
|
+
Rearrange('b d -> b 1 d')
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
self.decoder_head = Decoder(
|
|
182
|
+
dim = dim,
|
|
183
|
+
depth = dec_head_depth,
|
|
184
|
+
attn_dim_head = attn_dim_head,
|
|
185
|
+
heads = heads,
|
|
186
|
+
pre_norm_has_final_norm = False,
|
|
187
|
+
**kwargs,
|
|
188
|
+
**dec_kwargs
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.decoder_tail = Decoder(
|
|
192
|
+
dim = dim,
|
|
193
|
+
depth = dec_tail_depth,
|
|
194
|
+
attn_dim_head = attn_dim_head,
|
|
195
|
+
heads = heads,
|
|
196
|
+
pre_norm_has_final_norm = True,
|
|
197
|
+
**kwargs,
|
|
198
|
+
**dec_kwargs
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
self.pad_id = pad_id
|
|
202
|
+
|
|
203
|
+
self.kl_loss_weight = kl_loss_weight
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def device(self):
|
|
207
|
+
return next(self.parameters()).device
|
|
208
|
+
|
|
209
|
+
def encode_to_latents(
|
|
210
|
+
self,
|
|
211
|
+
seq,
|
|
212
|
+
mask = None,
|
|
213
|
+
return_kl_loss = False
|
|
214
|
+
):
|
|
215
|
+
pooled = self.encoder(seq, mask = mask)
|
|
216
|
+
|
|
217
|
+
bit_logits = self.to_latent_bit_logits(pooled)
|
|
218
|
+
|
|
219
|
+
one_hot_latents, kl_loss = self.binary_mapper(bit_logits, straight_through = True)
|
|
220
|
+
|
|
221
|
+
if not return_kl_loss:
|
|
222
|
+
return one_hot_latents
|
|
223
|
+
|
|
224
|
+
return one_hot_latents, kl_loss
|
|
225
|
+
|
|
226
|
+
@torch.no_grad()
|
|
227
|
+
def generate(
|
|
228
|
+
self,
|
|
229
|
+
prompts,
|
|
230
|
+
seq_len,
|
|
231
|
+
latents = None,
|
|
232
|
+
filter_logits_fn = top_p,
|
|
233
|
+
logit_filter_kwargs: dict = dict(thres = 0.9)
|
|
234
|
+
):
|
|
235
|
+
prompts, inverse_pack = pack_with_inverse(prompts, '* n')
|
|
236
|
+
|
|
237
|
+
batch = prompts.shape[0]
|
|
238
|
+
|
|
239
|
+
# prepend embeds
|
|
240
|
+
|
|
241
|
+
condition = None
|
|
242
|
+
if exists(latents):
|
|
243
|
+
if not is_tensor(latents):
|
|
244
|
+
latents = tensor(latents, device = self.device)
|
|
245
|
+
|
|
246
|
+
if latents.ndim == 1: # repeat latents
|
|
247
|
+
latents = repeat(latents, 'd -> b d', b = batch)
|
|
248
|
+
|
|
249
|
+
condition = self.from_latent_to_condition(latents)
|
|
250
|
+
|
|
251
|
+
# generated
|
|
252
|
+
|
|
253
|
+
prompt_len = prompts.shape[-1]
|
|
254
|
+
|
|
255
|
+
generated = prompts
|
|
256
|
+
|
|
257
|
+
tokens = self.token_emb(generated)
|
|
258
|
+
|
|
259
|
+
for _ in range(max(0, seq_len - prompt_len)):
|
|
260
|
+
|
|
261
|
+
head_embed = self.decoder_head(tokens)
|
|
262
|
+
|
|
263
|
+
if exists(condition):
|
|
264
|
+
head_embed = head_embed + condition
|
|
265
|
+
|
|
266
|
+
tail_embed = self.decoder_tail(head_embed)
|
|
267
|
+
|
|
268
|
+
tail_embed = tail_embed[:, -1]
|
|
269
|
+
|
|
270
|
+
logits = self.token_unembed(tail_embed)
|
|
271
|
+
|
|
272
|
+
logits = filter_logits_fn(logits, **logit_filter_kwargs)
|
|
273
|
+
|
|
274
|
+
sampled = gumbel_sample(logits)
|
|
275
|
+
|
|
276
|
+
generated, _ = pack((generated, sampled), 'b *')
|
|
277
|
+
tokens, _ = pack((tokens, self.token_emb(sampled)), 'b * d')
|
|
278
|
+
|
|
279
|
+
return inverse_pack(generated)
|
|
280
|
+
|
|
281
|
+
def forward(
|
|
282
|
+
self,
|
|
283
|
+
seq,
|
|
284
|
+
return_all_losses = False
|
|
285
|
+
):
|
|
286
|
+
batch, device = seq.shape[0], seq.device
|
|
287
|
+
|
|
288
|
+
seq, labels = seq[:, :-1], seq[:, 1:]
|
|
289
|
+
|
|
290
|
+
encoder_mask = seq != self.pad_id
|
|
291
|
+
|
|
292
|
+
tokens = self.token_emb(seq)
|
|
293
|
+
|
|
294
|
+
# decoder head
|
|
295
|
+
|
|
296
|
+
tokens = self.decoder_head(tokens)
|
|
297
|
+
|
|
298
|
+
# get latent Z
|
|
299
|
+
|
|
300
|
+
latents, kl_loss = self.encode_to_latents(tokens, mask = encoder_mask, return_kl_loss = True)
|
|
301
|
+
|
|
302
|
+
condition = self.from_latent_to_condition(latents)
|
|
303
|
+
|
|
304
|
+
# decoder tail
|
|
305
|
+
|
|
306
|
+
tokens = self.decoder_tail(tokens)
|
|
307
|
+
|
|
308
|
+
# cross entropy loss
|
|
309
|
+
|
|
310
|
+
logits = self.token_unembed(tokens)
|
|
311
|
+
|
|
312
|
+
ar_loss = F.cross_entropy(
|
|
313
|
+
rearrange(logits, 'b n l -> b l n'),
|
|
314
|
+
labels,
|
|
315
|
+
ignore_index = self.pad_id
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# return losses
|
|
319
|
+
|
|
320
|
+
total_loss = (
|
|
321
|
+
ar_loss +
|
|
322
|
+
kl_loss * self.kl_loss_weight
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
if not return_all_losses:
|
|
326
|
+
return total_loss
|
|
327
|
+
|
|
328
|
+
losses = (ar_loss, kl_loss)
|
|
329
|
+
|
|
330
|
+
return total_loss, losses
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|