titans-pytorch 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +17 -2
- {titans_pytorch-0.0.6.dist-info → titans_pytorch-0.0.8.dist-info}/METADATA +10 -9
- titans_pytorch-0.0.8.dist-info/RECORD +7 -0
- titans_pytorch-0.0.6.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.6.dist-info → titans_pytorch-0.0.8.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.6.dist-info → titans_pytorch-0.0.8.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -55,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
|
|
|
55
55
|
|
|
56
56
|
return packed, inverse
|
|
57
57
|
|
|
58
|
+
def softclamp_max(t, max_value):
|
|
59
|
+
range_value = max_value / 2
|
|
60
|
+
return ((t / range_value).tanh() * range_value) + range_value
|
|
61
|
+
|
|
58
62
|
# classes
|
|
59
63
|
|
|
60
64
|
class MLP(Module):
|
|
@@ -91,10 +95,15 @@ class NeuralMemory(Module):
|
|
|
91
95
|
dim,
|
|
92
96
|
chunk_size = 1,
|
|
93
97
|
model: Module | None = None,
|
|
94
|
-
store_memory_loss_fn: Callable = default_loss_fn
|
|
98
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
|
99
|
+
pre_rmsnorm = False,
|
|
100
|
+
max_adaptive_step_size = 1e-5
|
|
95
101
|
):
|
|
96
102
|
super().__init__()
|
|
97
103
|
|
|
104
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
105
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
106
|
+
|
|
98
107
|
if not exists(model):
|
|
99
108
|
model = MLP(dim, depth = 4)
|
|
100
109
|
|
|
@@ -140,6 +149,8 @@ class NeuralMemory(Module):
|
|
|
140
149
|
Rearrange('... 1 -> ...')
|
|
141
150
|
)
|
|
142
151
|
|
|
152
|
+
self.max_adaptive_step_size = max_adaptive_step_size
|
|
153
|
+
|
|
143
154
|
# weight decay factor
|
|
144
155
|
|
|
145
156
|
self.to_decay_factor = nn.Sequential(
|
|
@@ -161,6 +172,8 @@ class NeuralMemory(Module):
|
|
|
161
172
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
162
173
|
):
|
|
163
174
|
|
|
175
|
+
seq = self.store_norm(seq)
|
|
176
|
+
|
|
164
177
|
# curtail sequence by multiple of the chunk size
|
|
165
178
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
166
179
|
|
|
@@ -182,7 +195,7 @@ class NeuralMemory(Module):
|
|
|
182
195
|
|
|
183
196
|
batch = seq.shape[0]
|
|
184
197
|
|
|
185
|
-
adaptive_lr = self.to_adaptive_step(seq).
|
|
198
|
+
adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
|
|
186
199
|
|
|
187
200
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
188
201
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
@@ -244,6 +257,8 @@ class NeuralMemory(Module):
|
|
|
244
257
|
chunk_size = self.chunk_size
|
|
245
258
|
batch, seq_len = seq.shape[:2]
|
|
246
259
|
|
|
260
|
+
seq = self.retrieve_norm(seq)
|
|
261
|
+
|
|
247
262
|
assert seq_len >= chunk_size
|
|
248
263
|
|
|
249
264
|
seq = seq[:, (chunk_size - 1):]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -39,7 +39,8 @@ Requires-Dist: einx>=0.3.0
|
|
|
39
39
|
Requires-Dist: tensordict>=0.6.2
|
|
40
40
|
Requires-Dist: torch>=2.3
|
|
41
41
|
Provides-Extra: examples
|
|
42
|
-
Requires-Dist: local-attention>=1.
|
|
42
|
+
Requires-Dist: local-attention>=1.10.0; extra == 'examples'
|
|
43
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
|
43
44
|
Provides-Extra: test
|
|
44
45
|
Requires-Dist: pytest; extra == 'test'
|
|
45
46
|
Description-Content-Type: text/markdown
|
|
@@ -64,16 +65,16 @@ $ pip install titans-pytorch
|
|
|
64
65
|
import torch
|
|
65
66
|
from titans_pytorch import NeuralMemory
|
|
66
67
|
|
|
67
|
-
x = torch.randn(2, 64, 32)
|
|
68
|
-
|
|
69
68
|
mem = NeuralMemory(
|
|
70
|
-
dim =
|
|
71
|
-
chunk_size =
|
|
72
|
-
|
|
69
|
+
dim = 384,
|
|
70
|
+
chunk_size = 64,
|
|
71
|
+
pre_rmsnorm = True
|
|
72
|
+
).cuda()
|
|
73
73
|
|
|
74
|
-
|
|
74
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
|
75
|
+
retrieved = mem(seq)
|
|
75
76
|
|
|
76
|
-
assert
|
|
77
|
+
assert seq.shape == retrieved.shape
|
|
77
78
|
```
|
|
78
79
|
|
|
79
80
|
## Citations
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/titans.py,sha256=I8o_wauKQre87wl645Gy1nchPn5ckfZq4Q6pf5IpToU,9988
|
|
4
|
+
titans_pytorch-0.0.8.dist-info/METADATA,sha256=70ENsVn58zYyJM0UKanESujf0xSQWTcWXBh5HY1frKk,3219
|
|
5
|
+
titans_pytorch-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
titans_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
titans_pytorch-0.0.8.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/titans.py,sha256=S8J8B9o7Rlnj2hU3FZgpn28GTmis3ZbenLqjB_uny54,9470
|
|
4
|
-
titans_pytorch-0.0.6.dist-info/METADATA,sha256=t4HXD6sZT7_pgcwD8TBY6ojYHUHiZ05J6t19wRKtHNc,3092
|
|
5
|
-
titans_pytorch-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
titans_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
titans_pytorch-0.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|