titans-pytorch 0.0.1__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.1
3
+ Version: 0.0.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -39,6 +39,8 @@ Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: tensordict>=0.6.2
40
40
  Requires-Dist: torch>=2.3
41
41
  Provides-Extra: examples
42
+ Requires-Dist: local-attention>=1.10.0; extra == 'examples'
43
+ Requires-Dist: taylor-series-linear-attention; extra == 'examples'
42
44
  Provides-Extra: test
43
45
  Requires-Dist: pytest; extra == 'test'
44
46
  Description-Content-Type: text/markdown
@@ -49,7 +51,7 @@ Description-Content-Type: text/markdown
49
51
 
50
52
  ## Titans - Pytorch (wip)
51
53
 
52
- Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
54
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
53
55
 
54
56
  ## Install
55
57
 
@@ -63,13 +65,16 @@ $ pip install titans-pytorch
63
65
  import torch
64
66
  from titans_pytorch import NeuralMemory
65
67
 
66
- x = torch.randn(2, 64, 32)
68
+ mem = NeuralMemory(
69
+ dim = 384,
70
+ chunk_size = 64,
71
+ pre_rmsnorm = True
72
+ ).cuda()
67
73
 
68
- mem = NeuralMemory(32)
74
+ seq = torch.randn(2, 1024, 384).cuda()
75
+ retrieved = mem(seq)
69
76
 
70
- out = mem(x)
71
-
72
- assert x.shape == out.shape
77
+ assert seq.shape == retrieved.shape
73
78
  ```
74
79
 
75
80
  ## Citations
@@ -4,7 +4,7 @@
4
4
 
5
5
  ## Titans - Pytorch (wip)
6
6
 
7
- Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
7
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
8
 
9
9
  ## Install
10
10
 
@@ -18,13 +18,16 @@ $ pip install titans-pytorch
18
18
  import torch
19
19
  from titans_pytorch import NeuralMemory
20
20
 
21
- x = torch.randn(2, 64, 32)
21
+ mem = NeuralMemory(
22
+ dim = 384,
23
+ chunk_size = 64,
24
+ pre_rmsnorm = True
25
+ ).cuda()
22
26
 
23
- mem = NeuralMemory(32)
27
+ seq = torch.randn(2, 1024, 384).cuda()
28
+ retrieved = mem(seq)
24
29
 
25
- out = mem(x)
26
-
27
- assert x.shape == out.shape
30
+ assert seq.shape == retrieved.shape
28
31
  ```
29
32
 
30
33
  ## Citations
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.1"
3
+ version = "0.0.8"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -36,7 +36,10 @@ Homepage = "https://pypi.org/project/titans-pytorch/"
36
36
  Repository = "https://github.com/lucidrains/titans-pytorch"
37
37
 
38
38
  [project.optional-dependencies]
39
- examples = []
39
+ examples = [
40
+ "local-attention>=1.10.0",
41
+ "taylor-series-linear-attention"
42
+ ]
40
43
  test = [
41
44
  "pytest"
42
45
  ]
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from functools import partial
3
4
 
4
5
  import torch
@@ -11,12 +12,13 @@ from tensordict import TensorDict
11
12
 
12
13
  from titans_pytorch.associative_scan import (
13
14
  associative_scan,
14
- binary_operator
15
+ binary_operator,
16
+ pad_at_dim
15
17
  )
16
18
 
17
19
  import einx
18
20
  from einops import rearrange, pack, unpack
19
- from einops.layers.torch import Rearrange
21
+ from einops.layers.torch import Rearrange, Reduce
20
22
 
21
23
  """
22
24
  ein notation:
@@ -41,6 +43,9 @@ def default(v, d):
41
43
  def round_down_multiple(seq, mult):
42
44
  return seq // mult * mult
43
45
 
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
44
49
  def pack_one_with_inverse(t, pattern):
45
50
  packed, packed_shape = pack([t], pattern)
46
51
 
@@ -50,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
50
55
 
51
56
  return packed, inverse
52
57
 
58
+ def softclamp_max(t, max_value):
59
+ range_value = max_value / 2
60
+ return ((t / range_value).tanh() * range_value) + range_value
61
+
53
62
  # classes
54
63
 
55
64
  class MLP(Module):
@@ -84,11 +93,17 @@ class NeuralMemory(Module):
84
93
  def __init__(
85
94
  self,
86
95
  dim,
96
+ chunk_size = 1,
87
97
  model: Module | None = None,
88
- store_memory_loss_fn: Callable = default_loss_fn
98
+ store_memory_loss_fn: Callable = default_loss_fn,
99
+ pre_rmsnorm = False,
100
+ max_adaptive_step_size = 1e-5
89
101
  ):
90
102
  super().__init__()
91
103
 
104
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
105
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
106
+
92
107
  if not exists(model):
93
108
  model = MLP(dim, depth = 4)
94
109
 
@@ -98,11 +113,15 @@ class NeuralMemory(Module):
98
113
 
99
114
  self.memory_model = model
100
115
 
116
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
117
+
118
+ self.chunk_size = chunk_size
119
+
101
120
  # prepare function for per sample gradients from model above, using torch.func
102
121
 
103
122
  def forward_and_loss(params, inputs, target):
104
123
  pred = functional_call(self.memory_model, params, inputs)
105
- loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) == v|²
124
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
106
125
  return loss
107
126
 
108
127
  self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
@@ -119,9 +138,25 @@ class NeuralMemory(Module):
119
138
  # learned adaptive learning rate and momentum
120
139
  # todo - explore mlp layerwise learned lr / momentum
121
140
 
122
- self.to_momentum = LinearNoBias(dim, 1)
123
- self.to_adaptive_step = nn.Sequential(LinearNoBias(dim, 1), Rearrange('... 1 -> ...'))
124
- self.to_decay_factor = nn.Sequential(LinearNoBias(dim, 1), nn.Sigmoid()) # weight decay factor
141
+ self.to_momentum = nn.Sequential(
142
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
143
+ LinearNoBias(dim, 1)
144
+ )
145
+
146
+ self.to_adaptive_step = nn.Sequential(
147
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
148
+ LinearNoBias(dim, 1),
149
+ Rearrange('... 1 -> ...')
150
+ )
151
+
152
+ self.max_adaptive_step_size = max_adaptive_step_size
153
+
154
+ # weight decay factor
155
+
156
+ self.to_decay_factor = nn.Sequential(
157
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
158
+ LinearNoBias(dim, 1)
159
+ )
125
160
 
126
161
  def init_weights_and_momentum(self):
127
162
  params = TensorDict(dict(self.memory_model.named_parameters()))
@@ -137,6 +172,18 @@ class NeuralMemory(Module):
137
172
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
138
173
  ):
139
174
 
175
+ seq = self.store_norm(seq)
176
+
177
+ # curtail sequence by multiple of the chunk size
178
+ # only a complete chunk of the sequence provides the memory for the next chunk
179
+
180
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
181
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
182
+
183
+ seq = seq[:, :round_down_seq_len]
184
+
185
+ # curr weights + past weights, in the case that the initial weights are learned
186
+
140
187
  curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
141
188
 
142
189
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -148,16 +195,19 @@ class NeuralMemory(Module):
148
195
 
149
196
  batch = seq.shape[0]
150
197
 
151
- adaptive_lr = self.to_adaptive_step(seq)
152
- adaptive_momentum = self.to_momentum(seq)
198
+ adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
153
199
 
154
- decay_factor = self.to_decay_factor(seq)
200
+ adaptive_momentum = self.to_momentum(seq).sigmoid()
201
+ decay_factor = self.to_decay_factor(seq).sigmoid()
155
202
 
156
203
  # keys and values
157
204
 
158
- seq = rearrange(seq, 'b n d -> (b n) d')
159
205
  keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
160
206
 
207
+ # take care of chunking
208
+
209
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
210
+
161
211
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
162
212
 
163
213
  grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
@@ -172,31 +222,24 @@ class NeuralMemory(Module):
172
222
 
173
223
  surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
174
224
 
175
- # derive momentum with associative scan - eq (10)
225
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
176
226
 
177
227
  next_momentum = TensorDict()
228
+ updates = TensorDict()
178
229
 
179
230
  for param_name, surprise in surprises.items():
180
231
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
181
232
 
182
- _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
183
-
184
- momentum = inverse_pack(momentum)
185
-
186
- next_momentum[param_name] = momentum
187
-
188
- # use associative scan again for learned forgetting (weight decay) - eq (13)
233
+ # derive momentum with associative scan - eq (10)
189
234
 
190
- updates = TensorDict()
235
+ _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
191
236
 
192
- for param_name, momentum in next_momentum.items():
193
- momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
237
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
194
238
 
195
239
  _, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
196
240
 
197
- update = inverse_pack(update)
198
-
199
- updates[param_name] = update
241
+ updates[param_name] = inverse_pack(update)
242
+ next_momentum[param_name] = inverse_pack(momentum)
200
243
 
201
244
  # compute the next weight per batch
202
245
 
@@ -204,14 +247,28 @@ class NeuralMemory(Module):
204
247
 
205
248
  next_state = (curr_weights + last_update, next_momentum)
206
249
 
207
- return updates, next_state, aux_store_loss.mean()
250
+ return updates, next_state, aux_store_loss.mean() / chunk_size
208
251
 
209
252
  def retrieve_memories(
210
253
  self,
211
254
  seq,
212
255
  past_weights: dict[str, Tensor] | None = None,
213
256
  ):
214
- batch = seq.shape[0]
257
+ chunk_size = self.chunk_size
258
+ batch, seq_len = seq.shape[:2]
259
+
260
+ seq = self.retrieve_norm(seq)
261
+
262
+ assert seq_len >= chunk_size
263
+
264
+ seq = seq[:, (chunk_size - 1):]
265
+ curtailed_seq_len = seq.shape[-2]
266
+
267
+ next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
268
+
269
+ padding = next_seq_len - curtailed_seq_len
270
+
271
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
215
272
 
216
273
  # the parameters of the memory model stores the memories of the key / values
217
274
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -231,7 +288,7 @@ class NeuralMemory(Module):
231
288
  # fetch values from memory model
232
289
 
233
290
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
234
- queries = rearrange(queries, 'b n d -> (b n) 1 d')
291
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
235
292
 
236
293
  # forward functional call
237
294
 
@@ -239,7 +296,12 @@ class NeuralMemory(Module):
239
296
 
240
297
  # reconstitute batch dimension
241
298
 
242
- values = rearrange(values, '(b n) 1 d -> b n d', b = batch)
299
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
300
+
301
+ # restore
302
+
303
+ values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
304
+ values = values[:, :-padding]
243
305
 
244
306
  return values
245
307
 
@@ -0,0 +1,132 @@
1
+ import random
2
+ import tqdm
3
+ import gzip
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from local_attention import LocalTransformer
13
+
14
+ from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
+
16
+ from titans_pytorch.titans import NeuralMemory
17
+
18
+ # constants
19
+
20
+ NUM_BATCHES = int(1e5)
21
+ BATCH_SIZE = 4
22
+ GRADIENT_ACCUMULATE_EVERY = 4
23
+ LEARNING_RATE = 2e-4
24
+ VALIDATE_EVERY = 100
25
+ GENERATE_EVERY = 500
26
+ GENERATE_LENGTH = 512
27
+ SHOULD_GENERATE = False
28
+ SEQ_LEN = 512
29
+
30
+ # helpers
31
+
32
+ def cycle(loader):
33
+ while True:
34
+ for data in loader:
35
+ yield data
36
+
37
+ def decode_token(token):
38
+ return str(chr(max(32, token)))
39
+
40
+ def decode_tokens(tokens):
41
+ return ''.join(list(map(decode_token, tokens)))
42
+
43
+ # instantiate GPT-like decoder model
44
+
45
+ titans_neural_memory = NeuralMemory(
46
+ dim = 384,
47
+ chunk_size = 64,
48
+ pre_rmsnorm = True
49
+ )
50
+
51
+ titans_neural_memory = nn.Sequential(
52
+ titans_neural_memory,
53
+ nn.RMSNorm(384)
54
+ )
55
+
56
+ linear_attn = TaylorSeriesLinearAttn(
57
+ dim = 384,
58
+ dim_head = 16,
59
+ heads = 16,
60
+ causal = True
61
+ )
62
+
63
+ model = LocalTransformer(
64
+ num_tokens = 256,
65
+ dim = 384,
66
+ depth = 8,
67
+ causal = True,
68
+ local_attn_window_size = 64,
69
+ max_seq_len = SEQ_LEN,
70
+ global_attn_layer = titans_neural_memory,
71
+ layers_insert_global_attn = (4,)
72
+ ).cuda()
73
+
74
+ # prepare enwik8 data
75
+
76
+ with gzip.open('./data/enwik8.gz') as file:
77
+ data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
78
+ data_train, data_val = np.split(data, [int(90e6)])
79
+ data_train, data_val = map(torch.from_numpy, (data_train, data_val))
80
+
81
+ class TextSamplerDataset(Dataset):
82
+ def __init__(self, data, seq_len):
83
+ super().__init__()
84
+ self.data = data
85
+ self.seq_len = seq_len
86
+
87
+ def __getitem__(self, index):
88
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
89
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
90
+ return full_seq.cuda()
91
+
92
+ def __len__(self):
93
+ return self.data.size(0) // self.seq_len
94
+
95
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
96
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
97
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
98
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
99
+
100
+ # optimizer
101
+
102
+ optim = Adam(model.parameters(), lr=LEARNING_RATE)
103
+
104
+ # training
105
+
106
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
107
+ model.train()
108
+
109
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
110
+ loss = model(next(train_loader), return_loss = True)
111
+ loss.backward()
112
+
113
+ print(f'training loss: {loss.item()}')
114
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
115
+ optim.step()
116
+ optim.zero_grad()
117
+
118
+ if i % VALIDATE_EVERY == 0:
119
+ model.eval()
120
+ with torch.no_grad():
121
+ loss = model(next(val_loader), return_loss = True)
122
+ print(f'validation loss: {loss.item()}')
123
+
124
+ if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
125
+ model.eval()
126
+ inp = random.choice(val_dataset)[:-1]
127
+ prime = decode_tokens(inp)
128
+ print(f'%s \n\n %s', (prime, '*' * 100))
129
+
130
+ sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
131
+ output_str = decode_tokens(sample[0])
132
+ print(output_str)
File without changes
File without changes
File without changes