titans-pytorch 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/__init__.py +2 -1
- titans_pytorch/titans.py +11 -8
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.16.dist-info}/METADATA +11 -1
- titans_pytorch-0.0.16.dist-info/RECORD +7 -0
- titans_pytorch-0.0.14.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.16.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.16.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
from torch import nn, Tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module
|
9
|
-
from torch.func import functional_call, vmap,
|
9
|
+
from torch.func import functional_call, vmap, grad
|
10
10
|
|
11
11
|
from tensordict import TensorDict
|
12
12
|
|
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
|
|
57
57
|
|
58
58
|
# classes
|
59
59
|
|
60
|
-
class
|
60
|
+
class MemoryMLP(Module):
|
61
61
|
def __init__(
|
62
62
|
self,
|
63
63
|
dim,
|
@@ -122,7 +122,7 @@ class NeuralMemory(Module):
|
|
122
122
|
# memory mlp
|
123
123
|
|
124
124
|
if not exists(model):
|
125
|
-
model =
|
125
|
+
model = MemoryMLP(dim_head, **default_mlp_kwargs)
|
126
126
|
|
127
127
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
128
|
|
@@ -141,7 +141,7 @@ class NeuralMemory(Module):
|
|
141
141
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
142
|
return loss
|
143
143
|
|
144
|
-
self.
|
144
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
145
145
|
|
146
146
|
# queries for retrieving from the model
|
147
147
|
|
@@ -235,7 +235,7 @@ class NeuralMemory(Module):
|
|
235
235
|
|
236
236
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
237
|
|
238
|
-
grads
|
238
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
239
239
|
|
240
240
|
grads = TensorDict(grads)
|
241
241
|
|
@@ -305,7 +305,7 @@ class NeuralMemory(Module):
|
|
305
305
|
|
306
306
|
next_state = (curr_weights + last_update, next_momentum)
|
307
307
|
|
308
|
-
return updates, next_state
|
308
|
+
return updates, next_state
|
309
309
|
|
310
310
|
def retrieve_memories(
|
311
311
|
self,
|
@@ -382,6 +382,7 @@ class NeuralMemory(Module):
|
|
382
382
|
def forward(
|
383
383
|
self,
|
384
384
|
seq,
|
385
|
+
store_seq = None,
|
385
386
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
387
|
return_next_memories = False
|
387
388
|
):
|
@@ -396,7 +397,9 @@ class NeuralMemory(Module):
|
|
396
397
|
if not exists(past_state):
|
397
398
|
past_state = self.init_weights_and_momentum()
|
398
399
|
|
399
|
-
|
400
|
+
store_seq = default(store_seq, seq)
|
401
|
+
|
402
|
+
updates, next_memories = self.store_memories(store_seq, past_state)
|
400
403
|
|
401
404
|
past_weights, _ = past_state
|
402
405
|
|
@@ -405,4 +408,4 @@ class NeuralMemory(Module):
|
|
405
408
|
if not return_next_memories:
|
406
409
|
return retrieved
|
407
410
|
|
408
|
-
return retrieved, next_memories
|
411
|
+
return retrieved, next_memories
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.16
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -109,3 +109,13 @@ $ python train.py
|
|
109
109
|
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
110
|
}
|
111
111
|
```
|
112
|
+
|
113
|
+
```bibtex
|
114
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
115
|
+
author = {Kyrylov, Volodymyr},
|
116
|
+
doi = {10.5281/zenodo.10600962},
|
117
|
+
title = {Accelerated Scan},
|
118
|
+
version = {0.1.2},
|
119
|
+
year = {2024}
|
120
|
+
}
|
121
|
+
```
|
@@ -0,0 +1,7 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
|
4
|
+
titans_pytorch-0.0.16.dist-info/METADATA,sha256=mzSgA4okWWSh97fncN2UKJaTVa3PWOHYVaFkQ1Ker0w,3811
|
5
|
+
titans_pytorch-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
titans_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
titans_pytorch-0.0.16.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
|
4
|
-
titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
|
5
|
-
titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
titans_pytorch-0.0.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|