titans-pytorch 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/mac_transformer.py +10 -2
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/licenses/LICENSE +0 -0
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
|
|
185
185
|
# long term mem tokens
|
186
186
|
|
187
187
|
self.segment_len = segment_len
|
188
|
+
|
188
189
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
190
|
+
has_longterm_mems = num_longterm_mem_tokens > 0
|
189
191
|
|
190
192
|
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
191
193
|
|
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
|
|
197
199
|
self.neural_mem_layers = ModuleList([])
|
198
200
|
|
199
201
|
layers = tuple(range(1, depth + 1))
|
200
|
-
|
202
|
+
|
203
|
+
if not exists(neural_memory_layers):
|
204
|
+
neural_memory_layers = layers if has_longterm_mems else ()
|
205
|
+
|
206
|
+
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
201
207
|
|
202
208
|
for layer in layers:
|
203
209
|
|
@@ -205,7 +211,9 @@ class MemoryAsContextTransformer(Module):
|
|
205
211
|
|
206
212
|
mem = None
|
207
213
|
|
208
|
-
if
|
214
|
+
if layer in neural_memory_layers:
|
215
|
+
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
|
+
|
209
217
|
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
210
218
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
211
219
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=xXQ9GvtvNArYidV1OOhUhCJ0pIxniElTLnL0_eIZtEE,8821
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.36.dist-info/METADATA,sha256=7Bum0wO6e6BsB7TShLBZWALyStcs1LLaEv5vvnVlQ9c,3938
|
7
|
+
titans_pytorch-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|