titans-pytorch 0.0.42__tar.gz → 0.0.44__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.42
3
+ Version: 0.0.44
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -83,6 +83,26 @@ retrieved = mem(seq)
83
83
  assert seq.shape == retrieved.shape
84
84
  ```
85
85
 
86
+ A transformer with the `MAC` configuration can be used as
87
+
88
+ ```python
89
+ import torch
90
+ from titans_pytorch import MemoryAsContextTransformer
91
+
92
+ transformer = MemoryAsContextTransformer(
93
+ num_tokens = 256,
94
+ dim = 256,
95
+ depth = 2,
96
+ segment_len = 128, # local attention window size
97
+ num_persist_mem_tokens = 4,
98
+ num_longterm_mem_tokens = 16,
99
+ )
100
+
101
+ token_ids = torch.randint(0, 256, (1, 1023))
102
+
103
+ logits = transformer(token_ids) # (1, 1023, 256)
104
+ ```
105
+
86
106
  ## Experiments
87
107
 
88
108
  ```bash
@@ -30,6 +30,26 @@ retrieved = mem(seq)
30
30
  assert seq.shape == retrieved.shape
31
31
  ```
32
32
 
33
+ A transformer with the `MAC` configuration can be used as
34
+
35
+ ```python
36
+ import torch
37
+ from titans_pytorch import MemoryAsContextTransformer
38
+
39
+ transformer = MemoryAsContextTransformer(
40
+ num_tokens = 256,
41
+ dim = 256,
42
+ depth = 2,
43
+ segment_len = 128, # local attention window size
44
+ num_persist_mem_tokens = 4,
45
+ num_longterm_mem_tokens = 16,
46
+ )
47
+
48
+ token_ids = torch.randint(0, 256, (1, 1023))
49
+
50
+ logits = transformer(token_ids) # (1, 1023, 256)
51
+ ```
52
+
33
53
  ## Experiments
34
54
 
35
55
  ```bash
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.42"
3
+ version = "0.0.44"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  import pytest
3
+ from titans_pytorch import NeuralMemory
3
4
 
4
5
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
5
6
  @pytest.mark.parametrize('max_grad_norm', (None, 2.))
@@ -7,9 +8,6 @@ def test_titans(
7
8
  seq_len,
8
9
  max_grad_norm
9
10
  ):
10
-
11
- from titans_pytorch import NeuralMemory
12
-
13
11
  mem = NeuralMemory(
14
12
  dim = 384,
15
13
  chunk_size = 64,
@@ -22,11 +20,14 @@ def test_titans(
22
20
  assert seq.shape == retrieved.shape
23
21
 
24
22
  def test_titans_attn_memory():
25
- from titans_pytorch.titans_attn_memory import NeuralMemory
23
+ from titans_pytorch.titans import MemoryAttention
26
24
 
27
25
  mem = NeuralMemory(
28
26
  dim = 384,
29
27
  chunk_size = 64,
28
+ model = MemoryAttention(
29
+ dim = 384
30
+ )
30
31
  )
31
32
 
32
33
  seq = torch.randn(2, 1024, 384)
@@ -0,0 +1,8 @@
1
+ from titans_pytorch.titans import (
2
+ NeuralMemory,
3
+ MemoryMLP,
4
+ )
5
+
6
+ from titans_pytorch.mac_transformer import (
7
+ MemoryAsContextTransformer
8
+ )
@@ -224,7 +224,7 @@ class MemoryAsContextTransformer(Module):
224
224
  self.layers = ModuleList([])
225
225
 
226
226
  self.neural_mem_layers = ModuleList([])
227
- neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
227
+ self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
228
228
 
229
229
  layers = tuple(range(1, depth + 1))
230
230
 
@@ -245,7 +245,7 @@ class MemoryAsContextTransformer(Module):
245
245
 
246
246
  mem = NeuralMemory(
247
247
  dim = dim,
248
- chunk_size = neural_memory_segment_len,
248
+ chunk_size = self.neural_memory_segment_len,
249
249
  **neural_memory_kwargs
250
250
  )
251
251
 
@@ -287,10 +287,7 @@ class MemoryAsContextTransformer(Module):
287
287
 
288
288
  # math
289
289
 
290
- batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
291
-
292
- windows = ceil(seq_len / segment_len)
293
- total_segment_len = segment_len + num_longterm_mem_tokens
290
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
294
291
 
295
292
  # token embedding
296
293
 
@@ -305,11 +302,16 @@ class MemoryAsContextTransformer(Module):
305
302
 
306
303
  x = inverse_segment(x)
307
304
 
305
+ seq_len_with_mem = x.shape[-2]
306
+
308
307
  # apply axial positional embedding
309
308
  # so intra and inter segment can be more easily discerned by the network
310
309
 
311
- pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
312
- x = x + pos_emb[:x.shape[-2]]
310
+ neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len)
311
+
312
+ pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
313
+
314
+ x = x + pos_emb[:seq_len_with_mem]
313
315
 
314
316
  # value residual
315
317
 
@@ -334,7 +336,7 @@ class MemoryAsContextTransformer(Module):
334
336
 
335
337
  # excise out the memories
336
338
 
337
- x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
339
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
338
340
 
339
341
  x, _ = unpack(x, mem_ps, 'b * d')
340
342
 
@@ -425,11 +425,7 @@ class NeuralMemory(Module):
425
425
  next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
426
426
 
427
427
  padding = next_seq_len - curtailed_seq_len
428
-
429
- needs_pad = padding > 0
430
-
431
- if needs_pad:
432
- seq = pad_at_dim(seq, (0, padding), dim = 1)
428
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
433
429
 
434
430
  # the parameters of the memory model stores the memories of the key / values
435
431
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -481,10 +477,7 @@ class NeuralMemory(Module):
481
477
  empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
482
478
  values = torch.cat((empty_memory_embeds, values), dim = -2)
483
479
 
484
- if needs_pad:
485
- values = values[:, :-padding]
486
-
487
- return values
480
+ return values[:, :seq_len]
488
481
 
489
482
  def forward(
490
483
  self,
@@ -62,6 +62,7 @@ model = MemoryAsContextTransformer(
62
62
  num_persist_mem_tokens = NUM_PERSIST_MEM,
63
63
  num_longterm_mem_tokens = NUM_LONGTERM_MEM,
64
64
  neural_memory_layers = NEURAL_MEM_LAYERS,
65
+ neural_memory_segment_len = WINDOW_SIZE // 2,
65
66
  neural_memory_kwargs = dict(
66
67
  dim_head = 64,
67
68
  heads = 4,
@@ -1,6 +0,0 @@
1
- from titans_pytorch.titans import (
2
- NeuralMemory,
3
- MemoryMLP,
4
- )
5
-
6
- from titans_pytorch.mac_transformer import MemoryAsContextTransformer
File without changes