titans-pytorch 0.0.14__tar.gz → 0.0.15__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch-0.0.15/.github/workflows/test.yaml +24 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/PKG-INFO +11 -1
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/README.md +10 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/pyproject.toml +1 -1
- titans_pytorch-0.0.15/tests/test_titans.py +15 -0
- titans_pytorch-0.0.15/titans_pytorch/__init__.py +5 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/titans_pytorch/titans.py +40 -8
- titans_pytorch-0.0.14/titans_pytorch/__init__.py +0 -3
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/.gitignore +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/LICENSE +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/data/README.md +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/fig1.png +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/fig2.png +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/requirements.txt +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.14 → titans_pytorch-0.0.15}/train.py +0 -0
@@ -0,0 +1,24 @@
|
|
1
|
+
name: Tests the examples in README
|
2
|
+
on: [push, pull_request]
|
3
|
+
|
4
|
+
env:
|
5
|
+
TYPECHECK: True
|
6
|
+
|
7
|
+
jobs:
|
8
|
+
test:
|
9
|
+
runs-on: ubuntu-latest
|
10
|
+
steps:
|
11
|
+
- uses: actions/checkout@v4
|
12
|
+
- name: Install Python
|
13
|
+
uses: actions/setup-python@v5
|
14
|
+
with:
|
15
|
+
python-version: "3.11"
|
16
|
+
- name: Install dependencies
|
17
|
+
run: |
|
18
|
+
python -m pip install uv
|
19
|
+
python -m uv pip install --upgrade pip
|
20
|
+
python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
|
21
|
+
python -m uv pip install -e .[test]
|
22
|
+
- name: Test with pytest
|
23
|
+
run: |
|
24
|
+
python -m pytest tests/
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.15
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -109,3 +109,13 @@ $ python train.py
|
|
109
109
|
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
110
|
}
|
111
111
|
```
|
112
|
+
|
113
|
+
```bibtex
|
114
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
115
|
+
author = {Kyrylov, Volodymyr},
|
116
|
+
doi = {10.5281/zenodo.10600962},
|
117
|
+
title = {Accelerated Scan},
|
118
|
+
version = {0.1.2},
|
119
|
+
year = {2024}
|
120
|
+
}
|
121
|
+
```
|
@@ -58,3 +58,13 @@ $ python train.py
|
|
58
58
|
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
59
59
|
}
|
60
60
|
```
|
61
|
+
|
62
|
+
```bibtex
|
63
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
64
|
+
author = {Kyrylov, Volodymyr},
|
65
|
+
doi = {10.5281/zenodo.10600962},
|
66
|
+
title = {Accelerated Scan},
|
67
|
+
version = {0.1.2},
|
68
|
+
year = {2024}
|
69
|
+
}
|
70
|
+
```
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from titans_pytorch import NeuralMemory
|
5
|
+
|
6
|
+
def test_titans():
|
7
|
+
mem = NeuralMemory(
|
8
|
+
dim = 384,
|
9
|
+
chunk_size = 64,
|
10
|
+
)
|
11
|
+
|
12
|
+
seq = torch.randn(2, 1024, 384)
|
13
|
+
retrieved = mem(seq)
|
14
|
+
|
15
|
+
assert seq.shape == retrieved.shape
|
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
from torch import nn, Tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module
|
9
|
-
from torch.func import functional_call, vmap,
|
9
|
+
from torch.func import functional_call, vmap, grad
|
10
10
|
|
11
11
|
from tensordict import TensorDict
|
12
12
|
|
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
|
|
57
57
|
|
58
58
|
# classes
|
59
59
|
|
60
|
-
class
|
60
|
+
class MemoryMLP(Module):
|
61
61
|
def __init__(
|
62
62
|
self,
|
63
63
|
dim,
|
@@ -80,6 +80,35 @@ class MLP(Module):
|
|
80
80
|
|
81
81
|
return x
|
82
82
|
|
83
|
+
# improvised attention as memory module
|
84
|
+
# todo - expand if see signal in experiments
|
85
|
+
|
86
|
+
class MemoryAttention(Module):
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
dim
|
90
|
+
):
|
91
|
+
super().__init__()
|
92
|
+
self.weights = nn.ParameterList([
|
93
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
94
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
95
|
+
nn.Parameter(torch.randn(dim, dim)), # values
|
96
|
+
])
|
97
|
+
|
98
|
+
def forward(self, x):
|
99
|
+
wq, wk, wv = self.weights
|
100
|
+
|
101
|
+
q = x @ wq
|
102
|
+
k = x @ wk
|
103
|
+
v = x @ wv
|
104
|
+
|
105
|
+
sim = q @ k.transpose(-1, -2)
|
106
|
+
|
107
|
+
attn = sim.softmax(dim = -1)
|
108
|
+
|
109
|
+
out = attn @ v
|
110
|
+
return out
|
111
|
+
|
83
112
|
# main neural memory
|
84
113
|
|
85
114
|
def default_loss_fn(pred, target):
|
@@ -122,7 +151,7 @@ class NeuralMemory(Module):
|
|
122
151
|
# memory mlp
|
123
152
|
|
124
153
|
if not exists(model):
|
125
|
-
model =
|
154
|
+
model = MemoryMLP(dim_head, **default_mlp_kwargs)
|
126
155
|
|
127
156
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
157
|
|
@@ -141,7 +170,7 @@ class NeuralMemory(Module):
|
|
141
170
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
171
|
return loss
|
143
172
|
|
144
|
-
self.
|
173
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
145
174
|
|
146
175
|
# queries for retrieving from the model
|
147
176
|
|
@@ -235,7 +264,7 @@ class NeuralMemory(Module):
|
|
235
264
|
|
236
265
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
266
|
|
238
|
-
grads
|
267
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
239
268
|
|
240
269
|
grads = TensorDict(grads)
|
241
270
|
|
@@ -305,7 +334,7 @@ class NeuralMemory(Module):
|
|
305
334
|
|
306
335
|
next_state = (curr_weights + last_update, next_momentum)
|
307
336
|
|
308
|
-
return updates, next_state
|
337
|
+
return updates, next_state
|
309
338
|
|
310
339
|
def retrieve_memories(
|
311
340
|
self,
|
@@ -382,6 +411,7 @@ class NeuralMemory(Module):
|
|
382
411
|
def forward(
|
383
412
|
self,
|
384
413
|
seq,
|
414
|
+
store_seq = None,
|
385
415
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
416
|
return_next_memories = False
|
387
417
|
):
|
@@ -396,7 +426,9 @@ class NeuralMemory(Module):
|
|
396
426
|
if not exists(past_state):
|
397
427
|
past_state = self.init_weights_and_momentum()
|
398
428
|
|
399
|
-
|
429
|
+
store_seq = default(store_seq, seq)
|
430
|
+
|
431
|
+
updates, next_memories = self.store_memories(store_seq, past_state)
|
400
432
|
|
401
433
|
past_weights, _ = past_state
|
402
434
|
|
@@ -405,4 +437,4 @@ class NeuralMemory(Module):
|
|
405
437
|
if not return_next_memories:
|
406
438
|
return retrieved
|
407
439
|
|
408
|
-
return retrieved, next_memories
|
440
|
+
return retrieved, next_memories
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|