titans-pytorch 0.0.6__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.6"
3
+ version = "0.0.7"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -91,10 +91,14 @@ class NeuralMemory(Module):
91
91
  dim,
92
92
  chunk_size = 1,
93
93
  model: Module | None = None,
94
- store_memory_loss_fn: Callable = default_loss_fn
94
+ store_memory_loss_fn: Callable = default_loss_fn,
95
+ pre_rmsnorm = False
95
96
  ):
96
97
  super().__init__()
97
98
 
99
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
100
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
101
+
98
102
  if not exists(model):
99
103
  model = MLP(dim, depth = 4)
100
104
 
@@ -161,6 +165,8 @@ class NeuralMemory(Module):
161
165
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
162
166
  ):
163
167
 
168
+ seq = self.store_norm(seq)
169
+
164
170
  # curtail sequence by multiple of the chunk size
165
171
  # only a complete chunk of the sequence provides the memory for the next chunk
166
172
 
@@ -244,6 +250,8 @@ class NeuralMemory(Module):
244
250
  chunk_size = self.chunk_size
245
251
  batch, seq_len = seq.shape[:2]
246
252
 
253
+ seq = self.retrieve_norm(seq)
254
+
247
255
  assert seq_len >= chunk_size
248
256
 
249
257
  seq = seq[:, (chunk_size - 1):]
@@ -0,0 +1,108 @@
1
+ import random
2
+ import tqdm
3
+ import gzip
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch.optim import Adam
8
+ from torch.nn import functional as F
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ from local_attention import LocalTransformer
12
+
13
+ from titans_pytorch.titans import NeuralMemory
14
+
15
+ # constants
16
+
17
+ NUM_BATCHES = int(1e5)
18
+ BATCH_SIZE = 4
19
+ GRADIENT_ACCUMULATE_EVERY = 4
20
+ LEARNING_RATE = 2e-4
21
+ VALIDATE_EVERY = 100
22
+ GENERATE_EVERY = 500
23
+ GENERATE_LENGTH = 512
24
+ SEQ_LEN = 512
25
+
26
+ # helpers
27
+
28
+ def cycle(loader):
29
+ while True:
30
+ for data in loader:
31
+ yield data
32
+
33
+ def decode_token(token):
34
+ return str(chr(max(32, token)))
35
+
36
+ def decode_tokens(tokens):
37
+ return ''.join(list(map(decode_token, tokens)))
38
+
39
+ # instantiate GPT-like decoder model
40
+
41
+ model = LocalTransformer(
42
+ num_tokens = 256,
43
+ dim = 512,
44
+ depth = 8,
45
+ causal = True,
46
+ local_attn_window_size = 64,
47
+ max_seq_len = SEQ_LEN
48
+ ).cuda()
49
+
50
+ # prepare enwik8 data
51
+
52
+ with gzip.open('./data/enwik8.gz') as file:
53
+ data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
54
+ data_train, data_val = np.split(data, [int(90e6)])
55
+ data_train, data_val = map(torch.from_numpy, (data_train, data_val))
56
+
57
+ class TextSamplerDataset(Dataset):
58
+ def __init__(self, data, seq_len):
59
+ super().__init__()
60
+ self.data = data
61
+ self.seq_len = seq_len
62
+
63
+ def __getitem__(self, index):
64
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
65
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
66
+ return full_seq.cuda()
67
+
68
+ def __len__(self):
69
+ return self.data.size(0) // self.seq_len
70
+
71
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
72
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
73
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
74
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
75
+
76
+ # optimizer
77
+
78
+ optim = Adam(model.parameters(), lr=LEARNING_RATE)
79
+
80
+ # training
81
+
82
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
83
+ model.train()
84
+
85
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
86
+ loss = model(next(train_loader), return_loss = True)
87
+ loss.backward()
88
+
89
+ print(f'training loss: {loss.item()}')
90
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
91
+ optim.step()
92
+ optim.zero_grad()
93
+
94
+ if i % VALIDATE_EVERY == 0:
95
+ model.eval()
96
+ with torch.no_grad():
97
+ loss = model(next(val_loader), return_loss = True)
98
+ print(f'validation loss: {loss.item()}')
99
+
100
+ if i % GENERATE_EVERY == 0:
101
+ model.eval()
102
+ inp = random.choice(val_dataset)[:-1]
103
+ prime = decode_tokens(inp)
104
+ print(f'%s \n\n %s', (prime, '*' * 100))
105
+
106
+ sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
107
+ output_str = decode_tokens(sample[0])
108
+ print(output_str)
File without changes
File without changes
File without changes
File without changes