titans-pytorch 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -1
- titans_pytorch/titans.py +11 -8
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.16.dist-info}/METADATA +11 -1
- titans_pytorch-0.0.16.dist-info/RECORD +7 -0
- titans_pytorch-0.0.14.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.16.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.14.dist-info → titans_pytorch-0.0.16.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6
6
|
from torch import nn, Tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module
|
9
|
-
from torch.func import functional_call, vmap,
|
9
|
+
from torch.func import functional_call, vmap, grad
|
10
10
|
|
11
11
|
from tensordict import TensorDict
|
12
12
|
|
@@ -57,7 +57,7 @@ def pack_one_with_inverse(t, pattern):
|
|
57
57
|
|
58
58
|
# classes
|
59
59
|
|
60
|
-
class
|
60
|
+
class MemoryMLP(Module):
|
61
61
|
def __init__(
|
62
62
|
self,
|
63
63
|
dim,
|
@@ -122,7 +122,7 @@ class NeuralMemory(Module):
|
|
122
122
|
# memory mlp
|
123
123
|
|
124
124
|
if not exists(model):
|
125
|
-
model =
|
125
|
+
model = MemoryMLP(dim_head, **default_mlp_kwargs)
|
126
126
|
|
127
127
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
128
|
|
@@ -141,7 +141,7 @@ class NeuralMemory(Module):
|
|
141
141
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
142
|
return loss
|
143
143
|
|
144
|
-
self.
|
144
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
145
145
|
|
146
146
|
# queries for retrieving from the model
|
147
147
|
|
@@ -235,7 +235,7 @@ class NeuralMemory(Module):
|
|
235
235
|
|
236
236
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
237
|
|
238
|
-
grads
|
238
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
239
239
|
|
240
240
|
grads = TensorDict(grads)
|
241
241
|
|
@@ -305,7 +305,7 @@ class NeuralMemory(Module):
|
|
305
305
|
|
306
306
|
next_state = (curr_weights + last_update, next_momentum)
|
307
307
|
|
308
|
-
return updates, next_state
|
308
|
+
return updates, next_state
|
309
309
|
|
310
310
|
def retrieve_memories(
|
311
311
|
self,
|
@@ -382,6 +382,7 @@ class NeuralMemory(Module):
|
|
382
382
|
def forward(
|
383
383
|
self,
|
384
384
|
seq,
|
385
|
+
store_seq = None,
|
385
386
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
387
|
return_next_memories = False
|
387
388
|
):
|
@@ -396,7 +397,9 @@ class NeuralMemory(Module):
|
|
396
397
|
if not exists(past_state):
|
397
398
|
past_state = self.init_weights_and_momentum()
|
398
399
|
|
399
|
-
|
400
|
+
store_seq = default(store_seq, seq)
|
401
|
+
|
402
|
+
updates, next_memories = self.store_memories(store_seq, past_state)
|
400
403
|
|
401
404
|
past_weights, _ = past_state
|
402
405
|
|
@@ -405,4 +408,4 @@ class NeuralMemory(Module):
|
|
405
408
|
if not return_next_memories:
|
406
409
|
return retrieved
|
407
410
|
|
408
|
-
return retrieved, next_memories
|
411
|
+
return retrieved, next_memories
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.16
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -109,3 +109,13 @@ $ python train.py
|
|
109
109
|
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
110
|
}
|
111
111
|
```
|
112
|
+
|
113
|
+
```bibtex
|
114
|
+
@software{Kyrylov_Accelerated_Scan_2024,
|
115
|
+
author = {Kyrylov, Volodymyr},
|
116
|
+
doi = {10.5281/zenodo.10600962},
|
117
|
+
title = {Accelerated Scan},
|
118
|
+
version = {0.1.2},
|
119
|
+
year = {2024}
|
120
|
+
}
|
121
|
+
```
|
@@ -0,0 +1,7 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
|
4
|
+
titans_pytorch-0.0.16.dist-info/METADATA,sha256=mzSgA4okWWSh97fncN2UKJaTVa3PWOHYVaFkQ1Ker0w,3811
|
5
|
+
titans_pytorch-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
titans_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
titans_pytorch-0.0.16.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=O4WO2I6GhxyZobqbgfFx01saFEKUxhA0BBwt70m2yeQ,12306
|
4
|
-
titans_pytorch-0.0.14.dist-info/METADATA,sha256=HKDSJ3sWc54sN1_fOEYU7i5TjiQff49vcZG9G8EU6z4,3598
|
5
|
-
titans_pytorch-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
titans_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
titans_pytorch-0.0.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|