x-transformers 2.0.3__tar.gz → 2.0.5__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. {x_transformers-2.0.3 → x_transformers-2.0.5}/PKG-INFO +1 -1
  2. x_transformers-2.0.5/data/README.md +3 -0
  3. x_transformers-2.0.5/data/enwik8.gz +0 -0
  4. {x_transformers-2.0.3 → x_transformers-2.0.5}/pyproject.toml +1 -1
  5. x_transformers-2.0.5/train_length_extrapolate.py +137 -0
  6. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/x_transformers.py +5 -6
  7. {x_transformers-2.0.3 → x_transformers-2.0.5}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.0.3 → x_transformers-2.0.5}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.0.3 → x_transformers-2.0.5}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.0.3 → x_transformers-2.0.5}/.gitignore +0 -0
  11. {x_transformers-2.0.3 → x_transformers-2.0.5}/LICENSE +0 -0
  12. {x_transformers-2.0.3 → x_transformers-2.0.5}/README.md +0 -0
  13. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/all-attention.png +0 -0
  14. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/deepnorm.png +0 -0
  17. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/fcm.png +0 -0
  23. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/ffglu.png +0 -0
  24. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/flash-attention.png +0 -0
  25. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/gate_values.png +0 -0
  26. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/gating.png +0 -0
  27. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/macaron-1.png +0 -0
  29. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/macaron-2.png +0 -0
  30. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/normformer.png +0 -0
  32. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/pia.png +0 -0
  33. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/resi_dual.png +0 -0
  35. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/residual_attn.png +0 -0
  36. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/rezero.png +0 -0
  37. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/rotary.png +0 -0
  38. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/sandwich.png +0 -0
  40. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/scalenorm.png +0 -0
  42. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/talking-heads.png +0 -0
  43. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/topk-attention.png +0 -0
  44. {x_transformers-2.0.3 → x_transformers-2.0.5}/images/xval.png +0 -0
  45. {x_transformers-2.0.3 → x_transformers-2.0.5}/tests/test_x_transformers.py +0 -0
  46. {x_transformers-2.0.3 → x_transformers-2.0.5}/train_copy.py +0 -0
  47. {x_transformers-2.0.3 → x_transformers-2.0.5}/train_enwik8.py +0 -0
  48. {x_transformers-2.0.3 → x_transformers-2.0.5}/train_parity.py +0 -0
  49. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/__init__.py +0 -0
  50. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/attend.py +0 -0
  51. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/autoregressive_wrapper.py +0 -0
  52. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/continuous.py +0 -0
  53. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/dpo.py +0 -0
  54. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/multi_input.py +0 -0
  55. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/neo_mlp.py +0 -0
  56. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
  57. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  58. {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.3
3
+ Version: 2.0.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.0.3"
3
+ version = "2.0.5"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,137 @@
1
+ from x_transformers import TransformerWrapper, Decoder
2
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
3
+
4
+ import random
5
+ import tqdm
6
+ import gzip
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from torch.nn import functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+
13
+ # constants
14
+
15
+ NUM_BATCHES = int(1e5)
16
+ BATCH_SIZE = 4
17
+ GRADIENT_ACCUMULATE_EVERY = 4
18
+ LEARNING_RATE = 1e-4
19
+ GENERATE_EVERY = 500
20
+ GENERATE_LENGTH = 256
21
+ SEQ_LEN = 256
22
+
23
+ VALIDATE_EVERY = 100
24
+ VALIDATE_SEQ_LENS = (256, 512, 1024, 2048, 4096)
25
+
26
+ # helpers
27
+
28
+ def cycle(loader):
29
+ while True:
30
+ for data in loader:
31
+ yield data
32
+
33
+ def decode_token(token):
34
+ return str(chr(max(32, token)))
35
+
36
+ def decode_tokens(tokens):
37
+ return ''.join(list(map(decode_token, tokens)))
38
+
39
+ # instantiate GPT-like decoder model
40
+
41
+ model = TransformerWrapper(
42
+ num_tokens = 256,
43
+ max_seq_len = SEQ_LEN,
44
+ use_abs_pos_emb = False,
45
+ attn_layers = Decoder(
46
+ dim = 512,
47
+ depth = 6,
48
+ heads = 8,
49
+ dynamic_pos_bias = True,
50
+ )
51
+ )
52
+
53
+ model = AutoregressiveWrapper(model)
54
+ model.cuda()
55
+
56
+ # prepare enwik8 data
57
+
58
+ with gzip.open('./data/enwik8.gz') as file:
59
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
60
+ train_x, valid_x = np.split(data, [int(90e6)])
61
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
62
+
63
+ class TextSamplerDataset(Dataset):
64
+ def __init__(self, data, seq_len):
65
+ super().__init__()
66
+ self.data = data
67
+ self.seq_len = seq_len
68
+
69
+ def __getitem__(self, index):
70
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
71
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
72
+ return full_seq.cuda()
73
+
74
+ def __len__(self):
75
+ return self.data.size(0) // self.seq_len
76
+
77
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
78
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
79
+
80
+ val_dataset_generate = TextSamplerDataset(data_val, SEQ_LEN)
81
+
82
+ # validation loaders with different sequence lengths
83
+
84
+ val_loaders = dict()
85
+
86
+ for valid_seq_len in VALIDATE_SEQ_LENS:
87
+ val_dataset = TextSamplerDataset(data_val, valid_seq_len)
88
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
89
+
90
+ val_loaders[valid_seq_len] = val_loader
91
+
92
+ # optimizer
93
+
94
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
95
+
96
+ # training
97
+
98
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
99
+ model.train()
100
+
101
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
102
+ loss = model(next(train_loader))
103
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
104
+
105
+ print(f'training loss: {loss.item()}')
106
+
107
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
108
+ optim.step()
109
+ optim.zero_grad()
110
+
111
+ if i % VALIDATE_EVERY == 0:
112
+ print(f'validation losses:\n')
113
+
114
+ model.eval()
115
+ with torch.no_grad():
116
+ for valid_seq_len in VALIDATE_SEQ_LENS:
117
+ val_loader = val_loaders[valid_seq_len]
118
+
119
+ loss = model(next(val_loader))
120
+ print(f'[{valid_seq_len}]:\t {loss.item()}')
121
+
122
+ print('\n')
123
+
124
+ if i % GENERATE_EVERY == 0:
125
+ model.eval()
126
+ inp = random.choice(val_dataset_generate)[:-1]
127
+ prime = decode_tokens(inp)
128
+ print(f'%s \n\n %s', (prime, '*' * 100))
129
+
130
+ sample = model.generate(
131
+ prompts = inp,
132
+ seq_len = GENERATE_LENGTH,
133
+ cache_kv = True
134
+ )
135
+
136
+ output_str = decode_tokens(sample)
137
+ print(f'{output_str}\n\n')
@@ -449,17 +449,16 @@ class DynamicPositionBias(Module):
449
449
  return next(self.parameters()).device
450
450
 
451
451
  def forward(self, i, j):
452
- assert i == j
453
452
  n, device = j, self.device
454
453
 
455
454
  # get the (n x n) matrix of distances
456
- seq_arange = arange(n, device = device)
457
- context_arange = arange(n, device = device)
455
+ seq_arange = arange(j - i, j, device = device)
456
+ context_arange = arange(j, device = device)
458
457
  indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
459
- indices += (n - 1)
458
+ indices += (j - 1)
460
459
 
461
460
  # input to continuous positions MLP
462
- pos = arange(-n + 1, n, device = device).float()
461
+ pos = arange(-j + 1, j, device = device).float()
463
462
  pos = rearrange(pos, '... -> ... 1')
464
463
 
465
464
  if self.log_distance:
@@ -1282,7 +1281,7 @@ class Attention(Module):
1282
1281
  dim_kv_input = dim_latent_kv
1283
1282
 
1284
1283
  if exists(latent_rope_subheads):
1285
- assert not exists(rotate_num_heads)
1284
+ assert not exists(rotate_num_heads), '`rotate_num_heads` cannot be set when multi-latent attention is being used'
1286
1285
  rotate_num_heads = latent_rope_subheads
1287
1286
 
1288
1287
  k_dim = dim_head * (kv_heads - latent_rope_subheads)
File without changes
File without changes