titans-pytorch 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
185
185
  # long term mem tokens
186
186
 
187
187
  self.segment_len = segment_len
188
+
188
189
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
190
+ has_longterm_mems = num_longterm_mem_tokens > 0
189
191
 
190
192
  self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
191
193
 
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
197
199
  self.neural_mem_layers = ModuleList([])
198
200
 
199
201
  layers = tuple(range(1, depth + 1))
200
- neural_memory_layers = set(default(neural_memory_layers, layers))
202
+
203
+ if not exists(neural_memory_layers):
204
+ neural_memory_layers = layers if has_longterm_mems else ()
205
+
206
+ assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
201
207
 
202
208
  for layer in layers:
203
209
 
@@ -205,7 +211,9 @@ class MemoryAsContextTransformer(Module):
205
211
 
206
212
  mem = None
207
213
 
208
- if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
214
+ if layer in neural_memory_layers:
215
+ assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
+
209
217
  mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
210
218
  mem = init_hyper_conn(dim = dim, branch = mem)
211
219
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.35
3
+ Version: 0.0.36
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
3
+ titans_pytorch/mac_transformer.py,sha256=xXQ9GvtvNArYidV1OOhUhCJ0pIxniElTLnL0_eIZtEE,8821
4
4
  titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
5
  titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.35.dist-info/METADATA,sha256=jrhx-Bp1LqlOAV3jl4M70WTqq29ciz5lYWJvo2aoPE4,3938
7
- titans_pytorch-0.0.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.35.dist-info/RECORD,,
6
+ titans_pytorch-0.0.36.dist-info/METADATA,sha256=7Bum0wO6e6BsB7TShLBZWALyStcs1LLaEv5vvnVlQ9c,3938
7
+ titans_pytorch-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.36.dist-info/RECORD,,