titans-pytorch 0.0.42__tar.gz → 0.0.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/PKG-INFO +21 -1
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/README.md +20 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/pyproject.toml +1 -1
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/tests/test_titans.py +5 -4
- titans_pytorch-0.0.43/titans_pytorch/__init__.py +8 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/titans_pytorch/titans.py +2 -9
- titans_pytorch-0.0.42/titans_pytorch/__init__.py +0 -6
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/.gitignore +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/LICENSE +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/data/README.md +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/fig1.png +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/fig2.png +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.0.42 → titans_pytorch-0.0.43}/train_mac.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.43
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -83,6 +83,26 @@ retrieved = mem(seq)
|
|
|
83
83
|
assert seq.shape == retrieved.shape
|
|
84
84
|
```
|
|
85
85
|
|
|
86
|
+
A transformer with the `MAC` configuration can be used as
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
import torch
|
|
90
|
+
from titans_pytorch import MemoryAsContextTransformer
|
|
91
|
+
|
|
92
|
+
transformer = MemoryAsContextTransformer(
|
|
93
|
+
num_tokens = 256,
|
|
94
|
+
dim = 256,
|
|
95
|
+
depth = 2,
|
|
96
|
+
segment_len = 128, # local attention window size
|
|
97
|
+
num_persist_mem_tokens = 4,
|
|
98
|
+
num_longterm_mem_tokens = 16,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
token_ids = torch.randint(0, 256, (1, 1023))
|
|
102
|
+
|
|
103
|
+
logits = transformer(token_ids) # (1, 1023, 256)
|
|
104
|
+
```
|
|
105
|
+
|
|
86
106
|
## Experiments
|
|
87
107
|
|
|
88
108
|
```bash
|
|
@@ -30,6 +30,26 @@ retrieved = mem(seq)
|
|
|
30
30
|
assert seq.shape == retrieved.shape
|
|
31
31
|
```
|
|
32
32
|
|
|
33
|
+
A transformer with the `MAC` configuration can be used as
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
import torch
|
|
37
|
+
from titans_pytorch import MemoryAsContextTransformer
|
|
38
|
+
|
|
39
|
+
transformer = MemoryAsContextTransformer(
|
|
40
|
+
num_tokens = 256,
|
|
41
|
+
dim = 256,
|
|
42
|
+
depth = 2,
|
|
43
|
+
segment_len = 128, # local attention window size
|
|
44
|
+
num_persist_mem_tokens = 4,
|
|
45
|
+
num_longterm_mem_tokens = 16,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
token_ids = torch.randint(0, 256, (1, 1023))
|
|
49
|
+
|
|
50
|
+
logits = transformer(token_ids) # (1, 1023, 256)
|
|
51
|
+
```
|
|
52
|
+
|
|
33
53
|
## Experiments
|
|
34
54
|
|
|
35
55
|
```bash
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import pytest
|
|
3
|
+
from titans_pytorch import NeuralMemory
|
|
3
4
|
|
|
4
5
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
5
6
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
@@ -7,9 +8,6 @@ def test_titans(
|
|
|
7
8
|
seq_len,
|
|
8
9
|
max_grad_norm
|
|
9
10
|
):
|
|
10
|
-
|
|
11
|
-
from titans_pytorch import NeuralMemory
|
|
12
|
-
|
|
13
11
|
mem = NeuralMemory(
|
|
14
12
|
dim = 384,
|
|
15
13
|
chunk_size = 64,
|
|
@@ -22,11 +20,14 @@ def test_titans(
|
|
|
22
20
|
assert seq.shape == retrieved.shape
|
|
23
21
|
|
|
24
22
|
def test_titans_attn_memory():
|
|
25
|
-
from titans_pytorch.
|
|
23
|
+
from titans_pytorch.titans import MemoryAttention
|
|
26
24
|
|
|
27
25
|
mem = NeuralMemory(
|
|
28
26
|
dim = 384,
|
|
29
27
|
chunk_size = 64,
|
|
28
|
+
model = MemoryAttention(
|
|
29
|
+
dim = 384
|
|
30
|
+
)
|
|
30
31
|
)
|
|
31
32
|
|
|
32
33
|
seq = torch.randn(2, 1024, 384)
|
|
@@ -425,11 +425,7 @@ class NeuralMemory(Module):
|
|
|
425
425
|
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
426
426
|
|
|
427
427
|
padding = next_seq_len - curtailed_seq_len
|
|
428
|
-
|
|
429
|
-
needs_pad = padding > 0
|
|
430
|
-
|
|
431
|
-
if needs_pad:
|
|
432
|
-
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
428
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
433
429
|
|
|
434
430
|
# the parameters of the memory model stores the memories of the key / values
|
|
435
431
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -481,10 +477,7 @@ class NeuralMemory(Module):
|
|
|
481
477
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
|
482
478
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
|
483
479
|
|
|
484
|
-
|
|
485
|
-
values = values[:, :-padding]
|
|
486
|
-
|
|
487
|
-
return values
|
|
480
|
+
return values[:, :seq_len]
|
|
488
481
|
|
|
489
482
|
def forward(
|
|
490
483
|
self,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|