titans-pytorch 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +13 -3
- {titans_pytorch-0.0.7.dist-info → titans_pytorch-0.0.9.dist-info}/METADATA +22 -9
- titans_pytorch-0.0.9.dist-info/RECORD +7 -0
- titans_pytorch-0.0.7.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.7.dist-info → titans_pytorch-0.0.9.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.7.dist-info → titans_pytorch-0.0.9.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -55,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
|
|
|
55
55
|
|
|
56
56
|
return packed, inverse
|
|
57
57
|
|
|
58
|
+
def softclamp_max(t, max_value):
|
|
59
|
+
range_value = max_value / 2
|
|
60
|
+
return ((t / range_value).tanh() * range_value) + range_value
|
|
61
|
+
|
|
58
62
|
# classes
|
|
59
63
|
|
|
60
64
|
class MLP(Module):
|
|
@@ -92,7 +96,8 @@ class NeuralMemory(Module):
|
|
|
92
96
|
chunk_size = 1,
|
|
93
97
|
model: Module | None = None,
|
|
94
98
|
store_memory_loss_fn: Callable = default_loss_fn,
|
|
95
|
-
pre_rmsnorm = False
|
|
99
|
+
pre_rmsnorm = False,
|
|
100
|
+
max_adaptive_step_size = 1e-5
|
|
96
101
|
):
|
|
97
102
|
super().__init__()
|
|
98
103
|
|
|
@@ -144,6 +149,8 @@ class NeuralMemory(Module):
|
|
|
144
149
|
Rearrange('... 1 -> ...')
|
|
145
150
|
)
|
|
146
151
|
|
|
152
|
+
self.max_adaptive_step_size = max_adaptive_step_size
|
|
153
|
+
|
|
147
154
|
# weight decay factor
|
|
148
155
|
|
|
149
156
|
self.to_decay_factor = nn.Sequential(
|
|
@@ -188,7 +195,7 @@ class NeuralMemory(Module):
|
|
|
188
195
|
|
|
189
196
|
batch = seq.shape[0]
|
|
190
197
|
|
|
191
|
-
adaptive_lr = self.to_adaptive_step(seq).
|
|
198
|
+
adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
|
|
192
199
|
|
|
193
200
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
194
201
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
@@ -304,7 +311,10 @@ class NeuralMemory(Module):
|
|
|
304
311
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
305
312
|
return_next_memories = False
|
|
306
313
|
):
|
|
307
|
-
batch = seq.shape[
|
|
314
|
+
batch, seq_len = seq.shape[:2]
|
|
315
|
+
|
|
316
|
+
if seq_len < self.chunk_size:
|
|
317
|
+
return torch.zeros_like(seq)
|
|
308
318
|
|
|
309
319
|
if exists(past_state):
|
|
310
320
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -39,7 +39,8 @@ Requires-Dist: einx>=0.3.0
|
|
|
39
39
|
Requires-Dist: tensordict>=0.6.2
|
|
40
40
|
Requires-Dist: torch>=2.3
|
|
41
41
|
Provides-Extra: examples
|
|
42
|
-
Requires-Dist: local-attention>=1.
|
|
42
|
+
Requires-Dist: local-attention>=1.10.0; extra == 'examples'
|
|
43
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
|
43
44
|
Provides-Extra: test
|
|
44
45
|
Requires-Dist: pytest; extra == 'test'
|
|
45
46
|
Description-Content-Type: text/markdown
|
|
@@ -64,16 +65,28 @@ $ pip install titans-pytorch
|
|
|
64
65
|
import torch
|
|
65
66
|
from titans_pytorch import NeuralMemory
|
|
66
67
|
|
|
67
|
-
x = torch.randn(2, 64, 32)
|
|
68
|
-
|
|
69
68
|
mem = NeuralMemory(
|
|
70
|
-
dim =
|
|
71
|
-
chunk_size =
|
|
72
|
-
|
|
69
|
+
dim = 384,
|
|
70
|
+
chunk_size = 64,
|
|
71
|
+
pre_rmsnorm = True
|
|
72
|
+
).cuda()
|
|
73
|
+
|
|
74
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
|
75
|
+
retrieved = mem(seq)
|
|
73
76
|
|
|
74
|
-
|
|
77
|
+
assert seq.shape == retrieved.shape
|
|
78
|
+
```
|
|
75
79
|
|
|
76
|
-
|
|
80
|
+
## Experiments
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
$ pip install .[examples]
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Then
|
|
87
|
+
|
|
88
|
+
```bash
|
|
89
|
+
$ python train.py
|
|
77
90
|
```
|
|
78
91
|
|
|
79
92
|
## Citations
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/titans.py,sha256=0wnvhjMIe0Asjf7RVlGm1Ax8xmnDSgEi0rXoIqyoLo0,10078
|
|
4
|
+
titans_pytorch-0.0.9.dist-info/METADATA,sha256=IY6bB54p6mkO8R6j3nJ-bXBlMiukmJt85fa7Lp-HRWw,3311
|
|
5
|
+
titans_pytorch-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
titans_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
titans_pytorch-0.0.9.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/titans.py,sha256=IuAhvpeEGP_XQfhpWGEwNcmlSUKmtmtxWjjgwdEy0oI,9730
|
|
4
|
-
titans_pytorch-0.0.7.dist-info/METADATA,sha256=NAIbruJrJLaafP5SjcM3a_6qe6yCgiohpC09CzOKsMg,3092
|
|
5
|
-
titans_pytorch-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
titans_pytorch-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
titans_pytorch-0.0.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|