titans-pytorch 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +1 -1
- titans_pytorch/memory_models.py +1 -1
- titans_pytorch/neural_memory.py +14 -1
- {titans_pytorch-0.2.4.dist-info → titans_pytorch-0.2.5.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.5.dist-info/RECORD +9 -0
- titans_pytorch-0.2.4.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.4.dist-info → titans_pytorch-0.2.5.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.4.dist-info → titans_pytorch-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
|
|
488
488
|
neural_memory_model: Module | None = None,
|
489
489
|
neural_memory_kwargs: dict = dict(),
|
490
490
|
neural_memory_layers: tuple[int, ...] | None = None,
|
491
|
-
aux_kv_recon_loss_weight =
|
491
|
+
aux_kv_recon_loss_weight = 1.,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
494
|
weight_tie_memory_model = False,
|
titans_pytorch/memory_models.py
CHANGED
titans_pytorch/neural_memory.py
CHANGED
@@ -304,13 +304,26 @@ class NeuralMemory(Module):
|
|
304
304
|
nn.Sigmoid()
|
305
305
|
) if heads > 1 else None
|
306
306
|
|
307
|
-
# memory
|
307
|
+
# memory model
|
308
308
|
|
309
309
|
if not exists(model):
|
310
310
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
311
311
|
|
312
|
+
# validate memory model
|
313
|
+
|
312
314
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
313
315
|
|
316
|
+
test_shape = (3, 2, dim_head)
|
317
|
+
|
318
|
+
with torch.no_grad():
|
319
|
+
try:
|
320
|
+
test_input = torch.randn(test_shape)
|
321
|
+
mem_model_output = model(test_input)
|
322
|
+
except:
|
323
|
+
raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
|
324
|
+
|
325
|
+
assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
|
326
|
+
|
314
327
|
# the memory is the weights of the model
|
315
328
|
|
316
329
|
self.memory_model = model
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
+
titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
|
5
|
+
titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
|
6
|
+
titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
|
7
|
+
titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.5.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=g-Rx8zwTUbMv-XBYWPe9abFVVSUFLxOn_yVQ-wWvG5M,26039
|
4
|
-
titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
|
5
|
-
titans_pytorch/neural_memory.py,sha256=3ykFukUDp3dW1QwDmS3jZ2wFysiZE2ippcOoMFall34,24143
|
6
|
-
titans_pytorch-0.2.4.dist-info/METADATA,sha256=2yY3d58zPQ1uyvnTX4Dml7a2dd2jRu3TR5NhBpPNmdY,6819
|
7
|
-
titans_pytorch-0.2.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|