titans-pytorch 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +11 -3
- {titans_pytorch-0.0.5.dist-info → titans_pytorch-0.0.7.dist-info}/METADATA +2 -1
- titans_pytorch-0.0.7.dist-info/RECORD +7 -0
- titans_pytorch-0.0.5.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.5.dist-info → titans_pytorch-0.0.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.5.dist-info → titans_pytorch-0.0.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -91,10 +91,14 @@ class NeuralMemory(Module):
|
|
|
91
91
|
dim,
|
|
92
92
|
chunk_size = 1,
|
|
93
93
|
model: Module | None = None,
|
|
94
|
-
store_memory_loss_fn: Callable = default_loss_fn
|
|
94
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
|
95
|
+
pre_rmsnorm = False
|
|
95
96
|
):
|
|
96
97
|
super().__init__()
|
|
97
98
|
|
|
99
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
100
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
101
|
+
|
|
98
102
|
if not exists(model):
|
|
99
103
|
model = MLP(dim, depth = 4)
|
|
100
104
|
|
|
@@ -161,10 +165,12 @@ class NeuralMemory(Module):
|
|
|
161
165
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
162
166
|
):
|
|
163
167
|
|
|
168
|
+
seq = self.store_norm(seq)
|
|
169
|
+
|
|
164
170
|
# curtail sequence by multiple of the chunk size
|
|
165
171
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
166
172
|
|
|
167
|
-
seq_len = seq.shape[-2]
|
|
173
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
|
168
174
|
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
169
175
|
|
|
170
176
|
seq = seq[:, :round_down_seq_len]
|
|
@@ -234,7 +240,7 @@ class NeuralMemory(Module):
|
|
|
234
240
|
|
|
235
241
|
next_state = (curr_weights + last_update, next_momentum)
|
|
236
242
|
|
|
237
|
-
return updates, next_state, aux_store_loss.mean()
|
|
243
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
|
238
244
|
|
|
239
245
|
def retrieve_memories(
|
|
240
246
|
self,
|
|
@@ -244,6 +250,8 @@ class NeuralMemory(Module):
|
|
|
244
250
|
chunk_size = self.chunk_size
|
|
245
251
|
batch, seq_len = seq.shape[:2]
|
|
246
252
|
|
|
253
|
+
seq = self.retrieve_norm(seq)
|
|
254
|
+
|
|
247
255
|
assert seq_len >= chunk_size
|
|
248
256
|
|
|
249
257
|
seq = seq[:, (chunk_size - 1):]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -39,6 +39,7 @@ Requires-Dist: einx>=0.3.0
|
|
|
39
39
|
Requires-Dist: tensordict>=0.6.2
|
|
40
40
|
Requires-Dist: torch>=2.3
|
|
41
41
|
Provides-Extra: examples
|
|
42
|
+
Requires-Dist: local-attention>=1.9.15; extra == 'examples'
|
|
42
43
|
Provides-Extra: test
|
|
43
44
|
Requires-Dist: pytest; extra == 'test'
|
|
44
45
|
Description-Content-Type: text/markdown
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/titans.py,sha256=IuAhvpeEGP_XQfhpWGEwNcmlSUKmtmtxWjjgwdEy0oI,9730
|
|
4
|
+
titans_pytorch-0.0.7.dist-info/METADATA,sha256=NAIbruJrJLaafP5SjcM3a_6qe6yCgiohpC09CzOKsMg,3092
|
|
5
|
+
titans_pytorch-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
titans_pytorch-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
titans_pytorch-0.0.7.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/titans.py,sha256=3Mewuysj0g7iAlfjdqMlJhn9-pKJuOerB1frQmQYXuc,9428
|
|
4
|
-
titans_pytorch-0.0.5.dist-info/METADATA,sha256=f1DgCKZz9nqNfZOrqbOpyn-yEx2v5M5zgGIW0Zeu84I,3032
|
|
5
|
-
titans_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
titans_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
titans_pytorch-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|