titans-pytorch 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -91,10 +91,14 @@ class NeuralMemory(Module):
91
91
  dim,
92
92
  chunk_size = 1,
93
93
  model: Module | None = None,
94
- store_memory_loss_fn: Callable = default_loss_fn
94
+ store_memory_loss_fn: Callable = default_loss_fn,
95
+ pre_rmsnorm = False
95
96
  ):
96
97
  super().__init__()
97
98
 
99
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
100
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
101
+
98
102
  if not exists(model):
99
103
  model = MLP(dim, depth = 4)
100
104
 
@@ -161,6 +165,8 @@ class NeuralMemory(Module):
161
165
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
162
166
  ):
163
167
 
168
+ seq = self.store_norm(seq)
169
+
164
170
  # curtail sequence by multiple of the chunk size
165
171
  # only a complete chunk of the sequence provides the memory for the next chunk
166
172
 
@@ -244,6 +250,8 @@ class NeuralMemory(Module):
244
250
  chunk_size = self.chunk_size
245
251
  batch, seq_len = seq.shape[:2]
246
252
 
253
+ seq = self.retrieve_norm(seq)
254
+
247
255
  assert seq_len >= chunk_size
248
256
 
249
257
  seq = seq[:, (chunk_size - 1):]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=IuAhvpeEGP_XQfhpWGEwNcmlSUKmtmtxWjjgwdEy0oI,9730
4
+ titans_pytorch-0.0.7.dist-info/METADATA,sha256=NAIbruJrJLaafP5SjcM3a_6qe6yCgiohpC09CzOKsMg,3092
5
+ titans_pytorch-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.7.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=S8J8B9o7Rlnj2hU3FZgpn28GTmis3ZbenLqjB_uny54,9470
4
- titans_pytorch-0.0.6.dist-info/METADATA,sha256=t4HXD6sZT7_pgcwD8TBY6ojYHUHiZ05J6t19wRKtHNc,3092
5
- titans_pytorch-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.6.dist-info/RECORD,,