titans-pytorch 0.0.35__py3-none-any.whl → 0.0.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +10 -2
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.35.dist-info → titans_pytorch-0.0.36.dist-info}/licenses/LICENSE +0 -0
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
|
|
185
185
|
# long term mem tokens
|
186
186
|
|
187
187
|
self.segment_len = segment_len
|
188
|
+
|
188
189
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
190
|
+
has_longterm_mems = num_longterm_mem_tokens > 0
|
189
191
|
|
190
192
|
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
191
193
|
|
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
|
|
197
199
|
self.neural_mem_layers = ModuleList([])
|
198
200
|
|
199
201
|
layers = tuple(range(1, depth + 1))
|
200
|
-
|
202
|
+
|
203
|
+
if not exists(neural_memory_layers):
|
204
|
+
neural_memory_layers = layers if has_longterm_mems else ()
|
205
|
+
|
206
|
+
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
201
207
|
|
202
208
|
for layer in layers:
|
203
209
|
|
@@ -205,7 +211,9 @@ class MemoryAsContextTransformer(Module):
|
|
205
211
|
|
206
212
|
mem = None
|
207
213
|
|
208
|
-
if
|
214
|
+
if layer in neural_memory_layers:
|
215
|
+
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
|
+
|
209
217
|
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
210
218
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
211
219
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=xXQ9GvtvNArYidV1OOhUhCJ0pIxniElTLnL0_eIZtEE,8821
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.36.dist-info/METADATA,sha256=7Bum0wO6e6BsB7TShLBZWALyStcs1LLaEv5vvnVlQ9c,3938
|
7
|
+
titans_pytorch-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|