titans-pytorch 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +3 -1
- titans_pytorch/titans.py +40 -8
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.15.dist-info}/METADATA +11 -1
- titans_pytorch-0.0.15.dist-info/RECORD +7 -0
- titans_pytorch-0.0.14.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.15.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.15.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
from torch import nn, Tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module
|
9
|
-
from torch.func import functional_call, vmap,
|
9
|
+
from torch.func import functional_call, vmap, grad
|
10
10
|
|
11
11
|
from tensordict import TensorDict
|
12
12
|
|
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
|
|
57
57
|
|
58
58
|
# classes
|
59
59
|
|
60
|
-
class
|
60
|
+
class MemoryMLP(Module):
|
61
61
|
def __init__(
|
62
62
|
self,
|
63
63
|
dim,
|
@@ -80,6 +80,35 @@ class MLP(Module):
|
|
80
80
|
|
81
81
|
return x
|
82
82
|
|
83
|
+
# improvised attention as memory module
|
84
|
+
# todo - expand if see signal in experiments
|
85
|
+
|
86
|
+
class MemoryAttention(Module):
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
dim
|
90
|
+
):
|
91
|
+
super().__init__()
|
92
|
+
self.weights = nn.ParameterList([
|
93
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
94
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
95
|
+
nn.Parameter(torch.randn(dim, dim)), # values
|
96
|
+
])
|
97
|
+
|
98
|
+
def forward(self, x):
|
99
|
+
wq, wk, wv = self.weights
|
100
|
+
|
101
|
+
q = x @ wq
|
102
|
+
k = x @ wk
|
103
|
+
v = x @ wv
|
104
|
+
|
105
|
+
sim = q @ k.transpose(-1, -2)
|
106
|
+
|
107
|
+
attn = sim.softmax(dim = -1)
|
108
|
+
|
109
|
+
out = attn @ v
|
110
|
+
return out
|
111
|
+
|
83
112
|
# main neural memory
|
84
113
|
|
85
114
|
def default_loss_fn(pred, target):
|
@@ -122,7 +151,7 @@ class NeuralMemory(Module):
|
|
122
151
|
# memory mlp
|
123
152
|
|
124
153
|
if not exists(model):
|
125
|
-
model =
|
154
|
+
model = MemoryMLP(dim_head, **default_mlp_kwargs)
|
126
155
|
|
127
156
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
157
|
|
@@ -141,7 +170,7 @@ class NeuralMemory(Module):
|
|
141
170
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
171
|
return loss
|
143
172
|
|
144
|
-
self.
|
173
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
145
174
|
|
146
175
|
# queries for retrieving from the model
|
147
176
|
|
@@ -235,7 +264,7 @@ class NeuralMemory(Module):
|
|
235
264
|
|
236
265
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
266
|
|
238
|
-
grads
|
267
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
239
268
|
|
240
269
|
grads = TensorDict(grads)
|
241
270
|
|
@@ -305,7 +334,7 @@ class NeuralMemory(Module):
|
|
305
334
|
|
306
335
|
next_state = (curr_weights + last_update, next_momentum)
|
307
336
|
|
308
|
-
return updates, next_state
|
337
|
+
return updates, next_state
|
309
338
|
|
310
339
|
def retrieve_memories(
|
311
340
|
self,
|
@@ -382,6 +411,7 @@ class NeuralMemory(Module):
|
|
382
411
|
def forward(
|
383
412
|
self,
|
384
413
|
seq,
|
414
|
+
store_seq = None,
|
385
415
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
416
|
return_next_memories = False
|
387
417
|
):
|
@@ -396,7 +426,9 @@ class NeuralMemory(Module):
|
|
396
426
|
if not exists(past_state):
|
397
427
|
past_state = self.init_weights_and_momentum()
|
398
428
|
|
399
|
-
|
429
|
+
store_seq = default(store_seq, seq)
|
430
|
+
|
431
|
+
updates, next_memories = self.store_memories(store_seq, past_state)
|
400
432
|
|
401
433
|
past_weights, _ = past_state
|
402
434
|
|
@@ -405,4 +437,4 @@ class NeuralMemory(Module):
|
|
405
437
|
if not return_next_memories:
|
406
438
|
return retrieved
|
407
439
|
|
408
|
-
return retrieved, next_memories
|
440
|
+
return retrieved, next_memories
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.15
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -109,3 +109,13 @@ $ python train.py
|
|
109
109
|
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
110
|
}
|
111
111
|
```
|
112
|
+
|
113
|
+
```bibtex
|
114
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
115
|
+
author = {Kyrylov, Volodymyr},
|
116
|
+
doi = {10.5281/zenodo.10600962},
|
117
|
+
title = {Accelerated Scan},
|
118
|
+
version = {0.1.2},
|
119
|
+
year = {2024}
|
120
|
+
}
|
121
|
+
```
|
@@ -0,0 +1,7 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=bu8p8kUA24EVrTz-ojixHTwV-6KTY9Y0cNJkaMW4Whw,91
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=ArFYKgI0p7N3mmv8b4ncxkl3gkKAXWrFrnl2quh2RqE,12930
|
4
|
+
titans_pytorch-0.0.15.dist-info/METADATA,sha256=3IC7BT7J3BYx23wUOYeuGgtz769dsIczBnWbm6oi0Tw,3811
|
5
|
+
titans_pytorch-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
titans_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
titans_pytorch-0.0.15.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
|
4
|
-
titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
|
5
|
-
titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
titans_pytorch-0.0.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|