titans-pytorch 0.0.34__py3-none-any.whl → 0.0.36__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,3 +2,5 @@ from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  )
5
+
6
+ from titans_pytorch.mac_transformer import MemoryAsContextTransformer
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
185
185
  # long term mem tokens
186
186
 
187
187
  self.segment_len = segment_len
188
+
188
189
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
190
+ has_longterm_mems = num_longterm_mem_tokens > 0
189
191
 
190
192
  self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
191
193
 
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
197
199
  self.neural_mem_layers = ModuleList([])
198
200
 
199
201
  layers = tuple(range(1, depth + 1))
200
- neural_memory_layers = set(default(neural_memory_layers, layers))
202
+
203
+ if not exists(neural_memory_layers):
204
+ neural_memory_layers = layers if has_longterm_mems else ()
205
+
206
+ assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
201
207
 
202
208
  for layer in layers:
203
209
 
@@ -205,7 +211,9 @@ class MemoryAsContextTransformer(Module):
205
211
 
206
212
  mem = None
207
213
 
208
- if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
214
+ if layer in neural_memory_layers:
215
+ assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
216
+
209
217
  mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
210
218
  mem = init_hyper_conn(dim = dim, branch = mem)
211
219
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.34
3
+ Version: 0.0.36
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=xXQ9GvtvNArYidV1OOhUhCJ0pIxniElTLnL0_eIZtEE,8821
4
+ titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
+ titans_pytorch-0.0.36.dist-info/METADATA,sha256=7Bum0wO6e6BsB7TShLBZWALyStcs1LLaEv5vvnVlQ9c,3938
7
+ titans_pytorch-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.0.36.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=FGShQHD-dQQdQKKzvNS_jTC_FcikdqO_s3ZKOKfr_9E,8502
4
- titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
5
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
6
- titans_pytorch-0.0.34.dist-info/METADATA,sha256=CNqv_jMqk7yj15IpDn2O3jBdVe4wtrSVkht7mk0wW_E,3938
7
- titans_pytorch-0.0.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.0.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.0.34.dist-info/RECORD,,