titans-pytorch 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -91,10 +91,14 @@ class NeuralMemory(Module):
91
91
  dim,
92
92
  chunk_size = 1,
93
93
  model: Module | None = None,
94
- store_memory_loss_fn: Callable = default_loss_fn
94
+ store_memory_loss_fn: Callable = default_loss_fn,
95
+ pre_rmsnorm = False
95
96
  ):
96
97
  super().__init__()
97
98
 
99
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
100
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
101
+
98
102
  if not exists(model):
99
103
  model = MLP(dim, depth = 4)
100
104
 
@@ -161,10 +165,12 @@ class NeuralMemory(Module):
161
165
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
162
166
  ):
163
167
 
168
+ seq = self.store_norm(seq)
169
+
164
170
  # curtail sequence by multiple of the chunk size
165
171
  # only a complete chunk of the sequence provides the memory for the next chunk
166
172
 
167
- seq_len = seq.shape[-2]
173
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
168
174
  round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
169
175
 
170
176
  seq = seq[:, :round_down_seq_len]
@@ -234,7 +240,7 @@ class NeuralMemory(Module):
234
240
 
235
241
  next_state = (curr_weights + last_update, next_momentum)
236
242
 
237
- return updates, next_state, aux_store_loss.mean()
243
+ return updates, next_state, aux_store_loss.mean() / chunk_size
238
244
 
239
245
  def retrieve_memories(
240
246
  self,
@@ -244,6 +250,8 @@ class NeuralMemory(Module):
244
250
  chunk_size = self.chunk_size
245
251
  batch, seq_len = seq.shape[:2]
246
252
 
253
+ seq = self.retrieve_norm(seq)
254
+
247
255
  assert seq_len >= chunk_size
248
256
 
249
257
  seq = seq[:, (chunk_size - 1):]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -39,6 +39,7 @@ Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: tensordict>=0.6.2
40
40
  Requires-Dist: torch>=2.3
41
41
  Provides-Extra: examples
42
+ Requires-Dist: local-attention>=1.9.15; extra == 'examples'
42
43
  Provides-Extra: test
43
44
  Requires-Dist: pytest; extra == 'test'
44
45
  Description-Content-Type: text/markdown
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=IuAhvpeEGP_XQfhpWGEwNcmlSUKmtmtxWjjgwdEy0oI,9730
4
+ titans_pytorch-0.0.7.dist-info/METADATA,sha256=NAIbruJrJLaafP5SjcM3a_6qe6yCgiohpC09CzOKsMg,3092
5
+ titans_pytorch-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.7.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=3Mewuysj0g7iAlfjdqMlJhn9-pKJuOerB1frQmQYXuc,9428
4
- titans_pytorch-0.0.5.dist-info/METADATA,sha256=f1DgCKZz9nqNfZOrqbOpyn-yEx2v5M5zgGIW0Zeu84I,3032
5
- titans_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.5.dist-info/RECORD,,