titans-pytorch 0.0.32__tar.gz → 0.0.35__tar.gz

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of titans-pytorch might be problematic. Click here for more details.

Files changed (20) hide show
  1. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/PKG-INFO +1 -1
  2. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/pyproject.toml +1 -1
  3. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/__init__.py +2 -0
  4. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/mac_transformer.py +25 -4
  5. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/train.py +0 -3
  6. titans_pytorch-0.0.35/train_mac.py +129 -0
  7. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.github/workflows/python-publish.yml +0 -0
  8. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.github/workflows/test.yaml +0 -0
  9. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.gitignore +0 -0
  10. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/LICENSE +0 -0
  11. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/README.md +0 -0
  12. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/data/README.md +0 -0
  13. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/data/enwik8.gz +0 -0
  14. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/fig1.png +0 -0
  15. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/fig2.png +0 -0
  16. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/requirements.txt +0 -0
  17. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/tests/test_titans.py +0 -0
  18. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/associative_scan.py +0 -0
  19. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/titans.py +0 -0
  20. {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.32
3
+ Version: 0.0.35
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.32"
3
+ version = "0.0.35"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2,3 +2,5 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  )
5
+
6
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
@@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
50
50
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
51
51
 
52
52
  padding = next_seq_len_mult - seq_len
53
- seq = F.pad(seq, (0, 0, 0, padding))
53
+ needs_pad = padding > 0
54
+
55
+ if needs_pad:
56
+ seq = F.pad(seq, (0, 0, 0, padding))
54
57
 
55
58
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
56
59
 
57
60
  def inverse(out):
58
61
  out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
59
- return out[:, :-padding]
62
+
63
+ if needs_pad:
64
+ out = out[:, :-padding]
65
+
66
+ return out
60
67
 
61
68
  return seq, inverse
62
69
 
@@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module):
226
233
 
227
234
  self.to_logits = LinearNoBias(dim, num_tokens)
228
235
 
229
- def forward(self, x):
236
+ def forward(
237
+ self,
238
+ x,
239
+ return_loss = False
240
+ ):
241
+
242
+ if return_loss:
243
+ x, labels = x[:, :-1], x[:, 1:]
230
244
 
231
245
  # math
232
246
 
@@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module):
262
276
 
263
277
  if exists(maybe_neural_mem):
264
278
  batch_streams = x.shape[0]
279
+
265
280
  x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
266
281
 
267
282
  longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
@@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module):
277
292
  x = inverse_segment(x)
278
293
 
279
294
  x = attn(x)
295
+
280
296
  x = ff(x)
281
297
 
282
298
  x = self.reduce_streams(x)
@@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module):
293
309
 
294
310
  x = self.norm(x)
295
311
 
296
- return self.to_logits(x)
312
+ logits = self.to_logits(x)
313
+
314
+ if not return_loss:
315
+ return logits
316
+
317
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
@@ -63,11 +63,8 @@ def decode_tokens(tokens):
63
63
  titans_neural_memory = NeuralMemory(
64
64
  dim = 384,
65
65
  chunk_size = 4,
66
- pre_rmsnorm = True,
67
- post_rmsnorm = True,
68
66
  dim_head = 64,
69
67
  heads = 4,
70
- max_grad_norm = 1.,
71
68
  use_accelerated_scan = True,
72
69
  default_mlp_kwargs = dict(
73
70
  depth = NEURAL_MEMORY_DEPTH
@@ -0,0 +1,129 @@
1
+ import random
2
+ import tqdm
3
+ import gzip
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
13
+
14
+ # constants
15
+
16
+ NUM_BATCHES = int(1e5)
17
+ BATCH_SIZE = 4
18
+ GRADIENT_ACCUMULATE_EVERY = 4
19
+ LEARNING_RATE = 2e-4
20
+ VALIDATE_EVERY = 100
21
+ GENERATE_EVERY = 500
22
+ GENERATE_LENGTH = 512
23
+ SHOULD_GENERATE = False
24
+ SEQ_LEN = 512
25
+
26
+ PROJECT_NAME = 'titans-mac-transformer'
27
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
+ GLOBAL_LAYERS = (2, 4)
29
+ NEURAL_MEMORY_DEPTH = 2
30
+ WINDOW_SIZE = 64
31
+ RUN_NAME = 'mac'
32
+
33
+ # wandb experiment tracker
34
+
35
+ import wandb
36
+ wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
37
+ wandb.run.name = RUN_NAME
38
+ wandb.run.save()
39
+
40
+ # helpers
41
+
42
+ def cycle(loader):
43
+ while True:
44
+ for data in loader:
45
+ yield data
46
+
47
+ def decode_token(token):
48
+ return str(chr(max(32, token)))
49
+
50
+ def decode_tokens(tokens):
51
+ return ''.join(list(map(decode_token, tokens)))
52
+
53
+ # instantiate memory-as-context transformer
54
+
55
+ model = MemoryAsContextTransformer(
56
+ num_tokens = 256,
57
+ dim = 384,
58
+ depth = 8,
59
+ segment_len = WINDOW_SIZE,
60
+ num_persist_mem_tokens = 16,
61
+ num_longterm_mem_tokens = 16,
62
+ neural_memory_layers = (3, 4),
63
+ neural_memory_kwargs = dict(
64
+ default_mlp_kwargs = dict(
65
+ depth = NEURAL_MEMORY_DEPTH
66
+ )
67
+ )
68
+ ).cuda()
69
+
70
+ # prepare enwik8 data
71
+
72
+ with gzip.open('./data/enwik8.gz') as file:
73
+ data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
74
+ data_train, data_val = np.split(data, [int(90e6)])
75
+ data_train, data_val = map(torch.from_numpy, (data_train, data_val))
76
+
77
+ class TextSamplerDataset(Dataset):
78
+ def __init__(self, data, seq_len):
79
+ super().__init__()
80
+ self.data = data
81
+ self.seq_len = seq_len
82
+
83
+ def __getitem__(self, index):
84
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
85
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
86
+ return full_seq.cuda()
87
+
88
+ def __len__(self):
89
+ return self.data.size(0) // self.seq_len
90
+
91
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
92
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
93
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
94
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
95
+
96
+ # optimizer
97
+
98
+ optim = Adam(model.parameters(), lr=LEARNING_RATE)
99
+
100
+ # training
101
+
102
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
103
+ model.train()
104
+
105
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
106
+ loss = model(next(train_loader), return_loss = True)
107
+ loss.backward()
108
+
109
+ print(f'training loss: {loss.item()}')
110
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111
+ optim.step()
112
+ optim.zero_grad()
113
+ wandb.log(dict(loss = loss.item()))
114
+
115
+ if i % VALIDATE_EVERY == 0:
116
+ model.eval()
117
+ with torch.no_grad():
118
+ loss = model(next(val_loader), return_loss = True)
119
+ print(f'validation loss: {loss.item()}')
120
+
121
+ if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
122
+ model.eval()
123
+ inp = random.choice(val_dataset)[:-1]
124
+ prime = decode_tokens(inp)
125
+ print(f'%s \n\n %s', (prime, '*' * 100))
126
+
127
+ sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
128
+ output_str = decode_tokens(sample[0])
129
+ print(output_str)
File without changes