titans-pytorch 0.0.41__tar.gz → 0.0.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.41
3
+ Version: 0.0.43
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -45,8 +45,6 @@ Requires-Dist: tensordict
45
45
  Requires-Dist: torch>=2.2
46
46
  Requires-Dist: x-transformers
47
47
  Provides-Extra: examples
48
- Requires-Dist: local-attention>=1.10.1; extra == 'examples'
49
- Requires-Dist: taylor-series-linear-attention; extra == 'examples'
50
48
  Requires-Dist: tqdm; extra == 'examples'
51
49
  Requires-Dist: wandb; extra == 'examples'
52
50
  Provides-Extra: test
@@ -85,22 +83,36 @@ retrieved = mem(seq)
85
83
  assert seq.shape == retrieved.shape
86
84
  ```
87
85
 
88
- ## Experiments
86
+ A transformer with the `MAC` configuration can be used as
89
87
 
90
- ```bash
91
- $ pip install .[examples]
88
+ ```python
89
+ import torch
90
+ from titans_pytorch import MemoryAsContextTransformer
91
+
92
+ transformer = MemoryAsContextTransformer(
93
+ num_tokens = 256,
94
+ dim = 256,
95
+ depth = 2,
96
+ segment_len = 128, # local attention window size
97
+ num_persist_mem_tokens = 4,
98
+ num_longterm_mem_tokens = 16,
99
+ )
100
+
101
+ token_ids = torch.randint(0, 256, (1, 1023))
102
+
103
+ logits = transformer(token_ids) # (1, 1023, 256)
92
104
  ```
93
105
 
94
- For the SOTA linear attention, you will also need to run
106
+ ## Experiments
95
107
 
96
108
  ```bash
97
- $ pip install -r requirements.txt
109
+ $ pip install .[examples]
98
110
  ```
99
111
 
100
- Then modify `train.py` and run it to query nature
112
+ Then modify `train_mac.py` and run it to query nature
101
113
 
102
114
  ```bash
103
- $ python train.py
115
+ $ python train_mac.py
104
116
  ```
105
117
 
106
118
  ## Citations
@@ -30,22 +30,36 @@ retrieved = mem(seq)
30
30
  assert seq.shape == retrieved.shape
31
31
  ```
32
32
 
33
- ## Experiments
33
+ A transformer with the `MAC` configuration can be used as
34
34
 
35
- ```bash
36
- $ pip install .[examples]
35
+ ```python
36
+ import torch
37
+ from titans_pytorch import MemoryAsContextTransformer
38
+
39
+ transformer = MemoryAsContextTransformer(
40
+ num_tokens = 256,
41
+ dim = 256,
42
+ depth = 2,
43
+ segment_len = 128, # local attention window size
44
+ num_persist_mem_tokens = 4,
45
+ num_longterm_mem_tokens = 16,
46
+ )
47
+
48
+ token_ids = torch.randint(0, 256, (1, 1023))
49
+
50
+ logits = transformer(token_ids) # (1, 1023, 256)
37
51
  ```
38
52
 
39
- For the SOTA linear attention, you will also need to run
53
+ ## Experiments
40
54
 
41
55
  ```bash
42
- $ pip install -r requirements.txt
56
+ $ pip install .[examples]
43
57
  ```
44
58
 
45
- Then modify `train.py` and run it to query nature
59
+ Then modify `train_mac.py` and run it to query nature
46
60
 
47
61
  ```bash
48
- $ python train.py
62
+ $ python train_mac.py
49
63
  ```
50
64
 
51
65
  ## Citations
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.41"
3
+ version = "0.0.43"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -44,8 +44,6 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
44
44
  [project.optional-dependencies]
45
45
 
46
46
  examples = [
47
- "local-attention>=1.10.1",
48
- "taylor-series-linear-attention",
49
47
  "tqdm",
50
48
  "wandb"
51
49
  ]
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import pytest
3
+ from titans_pytorch import NeuralMemory
3
4
 
4
5
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
6
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
@@ -7,9 +8,6 @@ def test_titans(
7
8
  seq_len,
8
9
  max_grad_norm
9
10
  ):
10
-
11
- from titans_pytorch import NeuralMemory
12
-
13
11
  mem = NeuralMemory(
14
12
  dim = 384,
15
13
  chunk_size = 64,
@@ -22,11 +20,14 @@ def test_titans(
22
20
  assert seq.shape == retrieved.shape
23
21
 
24
22
  def test_titans_attn_memory():
25
- from titans_pytorch.titans_attn_memory import NeuralMemory
23
+ from titans_pytorch.titans import MemoryAttention
26
24
 
27
25
  mem = NeuralMemory(
28
26
  dim = 384,
29
27
  chunk_size = 64,
28
+ model = MemoryAttention(
29
+ dim = 384
30
+ )
30
31
  )
31
32
 
32
33
  seq = torch.randn(2, 1024, 384)
@@ -0,0 +1,8 @@
1
+ from titans_pytorch.titans import (
2
+ NeuralMemory,
3
+ MemoryMLP,
4
+ )
5
+
6
+ from titans_pytorch.mac_transformer import (
7
+ MemoryAsContextTransformer
8
+ )
@@ -324,7 +324,6 @@ class MemoryAsContextTransformer(Module):
324
324
  if exists(maybe_neural_mem):
325
325
  x = maybe_neural_mem(x)
326
326
 
327
-
328
327
  x, values = attn(x, value_residual = value_residual)
329
328
 
330
329
  value_residual = default(value_residual, values)
@@ -27,7 +27,7 @@ n - sequence
27
27
  d - feature dimension
28
28
  c - intra-chunk
29
29
  """
30
- 7
30
+
31
31
  LinearNoBias = partial(Linear, bias = False)
32
32
 
33
33
  # functions
@@ -107,6 +107,44 @@ class MemoryMLP(Module):
107
107
 
108
108
  return x
109
109
 
110
+ # improvised attention as memory module
111
+
112
+ class MemoryAttention(Module):
113
+ def __init__(
114
+ self,
115
+ dim
116
+ ):
117
+ super().__init__()
118
+ self.weights = nn.ParameterList([
119
+ nn.Parameter(torch.randn(dim, dim)), # queries
120
+ nn.Parameter(torch.randn(dim, dim)), # keys
121
+ nn.Parameter(torch.randn(dim, dim)), # values
122
+ nn.Parameter(torch.randn(dim, dim * 2)), # ff w1
123
+ nn.Parameter(torch.randn(dim * 2, dim)), # ff w2
124
+ ])
125
+
126
+ def forward(self, x):
127
+
128
+ assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
129
+
130
+ wq, wk, wv, ffw1, ffw2 = self.weights
131
+
132
+ q = F.normalize(x @ wq, dim = -1)
133
+ k = F.normalize(x @ wk, dim = -1)
134
+ v = x @ wv
135
+
136
+ attn_out = F.scaled_dot_product_attention(
137
+ q, k, v,
138
+ is_causal = True
139
+ )
140
+
141
+ x = x + attn_out
142
+
143
+ h = F.silu(x @ ffw1)
144
+ out = h @ ffw2
145
+
146
+ return out
147
+
110
148
  # main neural memory
111
149
 
112
150
  def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
@@ -129,7 +167,7 @@ class NeuralMemory(Module):
129
167
  post_rmsnorm = True,
130
168
  max_grad_norm: float | None = None,
131
169
  use_accelerated_scan = False,
132
- default_mlp_kwargs: dict = dict(
170
+ default_model_kwargs: dict = dict(
133
171
  depth = 2
134
172
  )
135
173
  ):
@@ -162,7 +200,7 @@ class NeuralMemory(Module):
162
200
  # memory mlp
163
201
 
164
202
  if not exists(model):
165
- model = MemoryMLP(dim_head, **default_mlp_kwargs)
203
+ model = MemoryMLP(dim_head, **default_model_kwargs)
166
204
 
167
205
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
168
206
 
@@ -387,11 +425,7 @@ class NeuralMemory(Module):
387
425
  next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
388
426
 
389
427
  padding = next_seq_len - curtailed_seq_len
390
-
391
- needs_pad = padding > 0
392
-
393
- if needs_pad:
394
- seq = pad_at_dim(seq, (0, padding), dim = 1)
428
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
395
429
 
396
430
  # the parameters of the memory model stores the memories of the key / values
397
431
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -443,10 +477,7 @@ class NeuralMemory(Module):
443
477
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
444
478
  values = torch.cat((empty_memory_embeds, values), dim = -2)
445
479
 
446
- if needs_pad:
447
- values = values[:, :-padding]
448
-
449
- return values
480
+ return values[:, :seq_len]
450
481
 
451
482
  def forward(
452
483
  self,
@@ -28,9 +28,9 @@ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
28
28
  NEURAL_MEMORY_DEPTH = 2
29
29
  NUM_PERSIST_MEM = 4
30
30
  NUM_LONGTERM_MEM = 4
31
- NEURAL_MEM_LAYERS = (4,)
31
+ NEURAL_MEM_LAYERS = (2, 4)
32
32
  WINDOW_SIZE = 32
33
- RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
33
+ RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
34
34
 
35
35
  # wandb experiment tracker
36
36
 
@@ -65,7 +65,7 @@ model = MemoryAsContextTransformer(
65
65
  neural_memory_kwargs = dict(
66
66
  dim_head = 64,
67
67
  heads = 4,
68
- default_mlp_kwargs = dict(
68
+ default_model_kwargs = dict(
69
69
  depth = NEURAL_MEMORY_DEPTH,
70
70
  )
71
71
  )
@@ -1 +0,0 @@
1
- pytorch-fast-transformers>=0.4.0
@@ -1,6 +0,0 @@
1
- from titans_pytorch.titans import (
2
- NeuralMemory,
3
- MemoryMLP,
4
- )
5
-
6
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
@@ -1,419 +0,0 @@
1
- from __future__ import annotations
2
- import math
3
- from functools import partial
4
-
5
- import torch
6
- from torch import nn, Tensor
7
- import torch.nn.functional as F
8
- from torch.nn import Linear, Module
9
- from torch.func import functional_call, vmap, grad
10
-
11
- from tensordict import TensorDict
12
-
13
- from titans_pytorch.associative_scan import (
14
- associative_scan,
15
- binary_operator,
16
- pad_at_dim
17
- )
18
-
19
- import einx
20
- from einops import rearrange, pack, unpack
21
- from einops.layers.torch import Rearrange, Reduce
22
-
23
- """
24
- ein notation:
25
- b - batch
26
- n - sequence
27
- d - feature dimension
28
- c - intra-chunk
29
- """
30
-
31
- # constants
32
-
33
- LinearNoBias = partial(Linear, bias = False)
34
-
35
- # functions
36
-
37
- def exists(v):
38
- return v is not None
39
-
40
- def default(v, d):
41
- return v if exists(v) else d
42
-
43
- def round_down_multiple(seq, mult):
44
- return seq // mult * mult
45
-
46
- def round_up_multiple(seq, mult):
47
- return math.ceil(seq / mult) * mult
48
-
49
- def pack_one_with_inverse(t, pattern):
50
- packed, packed_shape = pack([t], pattern)
51
-
52
- def inverse(out, inv_pattern = None):
53
- inv_pattern = default(inv_pattern, pattern)
54
- return unpack(out, packed_shape, inv_pattern)[0]
55
-
56
- return packed, inverse
57
-
58
- # classes
59
-
60
- # improvised attention as memory module
61
- # todo - expand if see signal in experiments (update: not seeing it)
62
-
63
- class MemoryAttention(Module):
64
- def __init__(
65
- self,
66
- dim
67
- ):
68
- super().__init__()
69
- self.weights = nn.ParameterList([
70
- nn.Parameter(torch.randn(dim, dim)), # queries
71
- nn.Parameter(torch.randn(dim, dim)), # keys
72
- nn.Parameter(torch.randn(dim, dim)), # values weight 1
73
- nn.Parameter(torch.randn(dim, dim)), # values weight 2
74
- ])
75
-
76
- def forward(self, x):
77
-
78
- assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
79
-
80
- wq, wk, wv1, wv2 = self.weights
81
-
82
- q = x @ wq
83
- k = x @ wk
84
- v = x @ wv1
85
-
86
- hidden = F.scaled_dot_product_attention(
87
- q, k, v,
88
- is_causal = True
89
- )
90
-
91
- return F.silu(hidden) @ wv2
92
-
93
- # main neural memory
94
-
95
- def default_loss_fn(pred, target):
96
- return (pred - target).pow(2).mean(dim = -1).sum()
97
-
98
- class NeuralMemory(Module):
99
- def __init__(
100
- self,
101
- dim,
102
- chunk_size = 1,
103
- dim_head = None,
104
- heads = 1,
105
- model: MemoryAttention | None = None,
106
- store_memory_loss_fn: Callable = default_loss_fn,
107
- pre_rmsnorm = True,
108
- post_rmsnorm = True,
109
- use_accelerated_scan = False,
110
- default_model_kwargs: dict = dict()
111
- ):
112
- super().__init__()
113
-
114
- # norms
115
-
116
- self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
117
- self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
118
-
119
- self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
120
-
121
- # maybe multi-headed
122
-
123
- dim_head = default(dim_head, dim)
124
- dim_inner = dim_head * heads
125
-
126
- self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
127
- self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
128
- self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
129
-
130
- # memory mlp
131
-
132
- if not exists(model):
133
- model = MemoryAttention(dim_head, **default_model_kwargs)
134
-
135
- assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
136
-
137
- # the memory is the weights of the model
138
-
139
- self.memory_model = model
140
-
141
- # the chunk size within the paper where adaptive step, momentum, weight decay are shared
142
-
143
- self.chunk_size = chunk_size
144
-
145
- # prepare function for per sample gradients from model above, using torch.func
146
-
147
- def forward_and_loss(params, inputs, target):
148
- pred = functional_call(self.memory_model, params, inputs)
149
- loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
150
- return loss
151
-
152
- self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
153
-
154
- # queries for retrieving from the model
155
-
156
- self.to_queries = LinearNoBias(dim, dim_inner)
157
-
158
- # keys and values for storing to the model
159
-
160
- self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
161
- self.store_memory_loss_fn = store_memory_loss_fn
162
-
163
- # learned adaptive learning rate and momentum
164
- # todo - explore mlp layerwise learned lr / momentum
165
-
166
- self.to_momentum = nn.Sequential(
167
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
168
- LinearNoBias(dim, heads),
169
- Rearrange('b n h -> (b h) n 1')
170
- )
171
-
172
- self.to_adaptive_step = nn.Sequential(
173
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
174
- LinearNoBias(dim, heads),
175
- Rearrange('b n h -> (b h) n')
176
- )
177
-
178
- # weight decay factor
179
-
180
- self.to_decay_factor = nn.Sequential(
181
- Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
182
- LinearNoBias(dim, heads),
183
- Rearrange('b n h -> (b h) n 1')
184
- )
185
-
186
- # maybe use accelerated scan
187
-
188
- self.use_accelerated_scan = use_accelerated_scan
189
-
190
- def init_weights_and_momentum(self):
191
- params = TensorDict(dict(self.memory_model.named_parameters()))
192
-
193
- init_weights = params.clone().zero_()
194
- init_momentum = params.clone().zero_()
195
-
196
- return init_weights, init_momentum
197
-
198
- def store_memories(
199
- self,
200
- seq,
201
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
202
- ):
203
-
204
- seq = self.store_norm(seq)
205
-
206
- # curtail sequence by multiple of the chunk size
207
- # only a complete chunk of the sequence provides the memory for the next chunk
208
-
209
- seq_len, chunk_size = seq.shape[-2], self.chunk_size
210
- round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
211
-
212
- seq = seq[:, :round_down_seq_len]
213
-
214
- # curr weights + past weights, in the case that the initial weights are learned
215
-
216
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
217
-
218
- past_state = tuple(TensorDict(d) for d in past_state)
219
- past_weights, past_momentum = past_state
220
-
221
- curr_weights = curr_weights + past_weights
222
-
223
- # pack batch and sequence dimension
224
-
225
- adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
226
-
227
- adaptive_momentum = self.to_momentum(seq).sigmoid()
228
- decay_factor = self.to_decay_factor(seq).sigmoid()
229
-
230
- # keys and values
231
-
232
- keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
233
-
234
- # maybe multi head
235
-
236
- keys, values = map(self.split_heads, (keys, values))
237
-
238
- batch = keys.shape[0]
239
-
240
- # take care of chunking
241
-
242
- keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
243
-
244
- # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
245
-
246
- grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
247
-
248
- grads = TensorDict(grads)
249
-
250
- # restore batch and sequence dimension
251
-
252
- grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
253
-
254
- # multiply gradients with learned adaptive step size
255
-
256
- surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
257
-
258
- # determine scan function
259
-
260
- def default_associative_scan(gates, inputs):
261
- _, outputs = associative_scan(binary_operator, (gates, inputs))
262
- return outputs
263
-
264
- if self.use_accelerated_scan:
265
- from accelerated_scan.triton import scan as triton_scan
266
- from accelerated_scan.warp import scan as warp_scan
267
-
268
- scan = triton_scan if seq.is_cuda else warp_scan
269
-
270
- def accelerate_scan_fn(gates, inputs):
271
- gates = gates.expand_as(inputs)
272
- gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
273
-
274
- seq_len = gates.shape[-1]
275
- next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
276
-
277
- gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
278
- inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
279
-
280
- outputs = scan(gates, inputs)
281
-
282
- outputs = outputs[..., :seq_len]
283
- outputs = rearrange(outputs, 'b d n -> b n d')
284
- return outputs
285
-
286
- scan_fn = accelerate_scan_fn
287
- else:
288
- scan_fn = default_associative_scan
289
-
290
- # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
291
-
292
- next_momentum = TensorDict()
293
- updates = TensorDict()
294
-
295
- for param_name, surprise in surprises.items():
296
-
297
- surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
298
-
299
- # derive momentum with associative scan - eq (10)
300
-
301
- momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
302
-
303
- # use associative scan again for learned forgetting (weight decay) - eq (13)
304
-
305
- update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
306
-
307
- updates[param_name] = inverse_pack(update)
308
- next_momentum[param_name] = inverse_pack(momentum)
309
-
310
- # compute the next weight per batch
311
-
312
- last_update = updates.apply(lambda t: t[:, -1])
313
-
314
- next_state = (curr_weights + last_update, next_momentum)
315
-
316
- return updates, next_state
317
-
318
- def retrieve_memories(
319
- self,
320
- seq,
321
- past_weights: dict[str, Tensor] | None = None,
322
- ):
323
- chunk_size = self.chunk_size
324
- seq_len = seq.shape[1]
325
-
326
- seq = self.retrieve_norm(seq)
327
-
328
- assert seq_len > chunk_size
329
-
330
- seq = seq[:, chunk_size:]
331
- curtailed_seq_len = seq.shape[-2]
332
-
333
- next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
334
-
335
- padding = next_seq_len - curtailed_seq_len
336
-
337
- seq = pad_at_dim(seq, (0, padding), dim = 1)
338
-
339
- # the parameters of the memory model stores the memories of the key / values
340
- # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
341
-
342
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
343
-
344
- if exists(past_weights):
345
- past_weights = TensorDict(past_weights)
346
- assert past_weights.keys() == curr_weights.keys()
347
-
348
- curr_weights = curr_weights + past_weights
349
-
350
- # sequence Float['b n d'] to queries
351
-
352
- queries = self.to_queries(seq)
353
-
354
- # maybe multihead
355
-
356
- queries = self.split_heads(queries)
357
-
358
- batch = queries.shape[0]
359
-
360
- # fetch values from memory model
361
-
362
- curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
363
- queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
364
-
365
- # forward functional call
366
-
367
- values = functional_call(self.memory_model, dict(curr_weights), queries)
368
-
369
- # reconstitute batch dimension
370
-
371
- values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
372
-
373
- # maybe merge heads and combine
374
-
375
- values = self.merge_heads(values)
376
-
377
- values = self.combine_heads(values)
378
-
379
- # post norm, somehow could not stabilize this without it, not in paper
380
-
381
- values = self.post_rmsnorm(values)
382
-
383
- # restore
384
-
385
- values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
386
- values = values[:, :-padding]
387
-
388
- return values
389
-
390
- def forward(
391
- self,
392
- seq,
393
- store_seq = None,
394
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
395
- return_next_memories = False
396
- ):
397
- batch, seq_len = seq.shape[:2]
398
-
399
- if seq_len <= self.chunk_size:
400
- return torch.zeros_like(seq)
401
-
402
- if exists(past_state):
403
- past_state = tuple(TensorDict(d) for d in past_state)
404
-
405
- if not exists(past_state):
406
- past_state = self.init_weights_and_momentum()
407
-
408
- store_seq = default(store_seq, seq)
409
-
410
- updates, next_memories = self.store_memories(store_seq, past_state)
411
-
412
- past_weights, _ = past_state
413
-
414
- retrieved = self.retrieve_memories(seq, past_weights + updates)
415
-
416
- if not return_next_memories:
417
- return retrieved
418
-
419
- return retrieved, next_memories
@@ -1,152 +0,0 @@
1
- import random
2
- import tqdm
3
- import gzip
4
- import numpy as np
5
-
6
- import torch
7
- from torch import nn
8
- from torch.optim import Adam
9
- from torch.nn import functional as F
10
- from torch.utils.data import DataLoader, Dataset
11
-
12
- from local_attention import LocalTransformer
13
-
14
- from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
-
16
- from titans_pytorch.titans import (
17
- NeuralMemory,
18
- MemoryMLP
19
- )
20
-
21
- # constants
22
-
23
- NUM_BATCHES = int(1e5)
24
- BATCH_SIZE = 4
25
- GRADIENT_ACCUMULATE_EVERY = 4
26
- LEARNING_RATE = 2e-4
27
- VALIDATE_EVERY = 100
28
- GENERATE_EVERY = 500
29
- GENERATE_LENGTH = 512
30
- SHOULD_GENERATE = True
31
- SEQ_LEN = 512
32
-
33
- PROJECT_NAME = 'titans-neural-memory'
34
- WANDB_ONLINE = False # turn this on to pipe experiment to cloud
35
- GLOBAL_LAYERS = (2, 4)
36
- USE_TITANS_MEMORY = True
37
- NEURAL_MEMORY_DEPTH = 2
38
- WINDOW_SIZE = 64
39
- RUN_NAME = 'neural memory'
40
-
41
- # wandb experiment tracker
42
-
43
- import wandb
44
- wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
45
- wandb.run.name = RUN_NAME
46
- wandb.run.save()
47
-
48
- # helpers
49
-
50
- def cycle(loader):
51
- while True:
52
- for data in loader:
53
- yield data
54
-
55
- def decode_token(token):
56
- return str(chr(max(32, token)))
57
-
58
- def decode_tokens(tokens):
59
- return ''.join(list(map(decode_token, tokens)))
60
-
61
- # instantiate GPT-like decoder model
62
-
63
- titans_neural_memory = NeuralMemory(
64
- dim = 384,
65
- chunk_size = 4,
66
- dim_head = 64,
67
- heads = 4,
68
- use_accelerated_scan = True,
69
- default_mlp_kwargs = dict(
70
- depth = NEURAL_MEMORY_DEPTH
71
- )
72
- )
73
-
74
- linear_attn = TaylorSeriesLinearAttn(
75
- dim = 384,
76
- dim_head = 16,
77
- heads = 16,
78
- causal = True,
79
- prenorm = True
80
- )
81
-
82
- model = LocalTransformer(
83
- num_tokens = 256,
84
- dim = 384,
85
- depth = 8,
86
- causal = True,
87
- local_attn_window_size = WINDOW_SIZE,
88
- max_seq_len = SEQ_LEN,
89
- global_attn_layer = linear_attn if not USE_TITANS_MEMORY else titans_neural_memory,
90
- layers_insert_global_attn = GLOBAL_LAYERS
91
- ).cuda()
92
-
93
- # prepare enwik8 data
94
-
95
- with gzip.open('./data/enwik8.gz') as file:
96
- data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
97
- data_train, data_val = np.split(data, [int(90e6)])
98
- data_train, data_val = map(torch.from_numpy, (data_train, data_val))
99
-
100
- class TextSamplerDataset(Dataset):
101
- def __init__(self, data, seq_len):
102
- super().__init__()
103
- self.data = data
104
- self.seq_len = seq_len
105
-
106
- def __getitem__(self, index):
107
- rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
108
- full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
109
- return full_seq.cuda()
110
-
111
- def __len__(self):
112
- return self.data.size(0) // self.seq_len
113
-
114
- train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
115
- val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
116
- train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
117
- val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
118
-
119
- # optimizer
120
-
121
- optim = Adam(model.parameters(), lr=LEARNING_RATE)
122
-
123
- # training
124
-
125
- for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
126
- model.train()
127
-
128
- for __ in range(GRADIENT_ACCUMULATE_EVERY):
129
- loss = model(next(train_loader), return_loss = True)
130
- loss.backward()
131
-
132
- print(f'training loss: {loss.item()}')
133
- torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
134
- optim.step()
135
- optim.zero_grad()
136
- wandb.log(dict(loss = loss.item()))
137
-
138
- if i % VALIDATE_EVERY == 0:
139
- model.eval()
140
- with torch.no_grad():
141
- loss = model(next(val_loader), return_loss = True)
142
- print(f'validation loss: {loss.item()}')
143
-
144
- if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
145
- model.eval()
146
- inp = random.choice(val_dataset)[:-1]
147
- prime = decode_tokens(inp)
148
- print(f'%s \n\n %s', (prime, '*' * 100))
149
-
150
- sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
151
- output_str = decode_tokens(sample[0])
152
- print(output_str)
File without changes