titans-pytorch 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
488
488
  neural_memory_model: Module | None = None,
489
489
  neural_memory_kwargs: dict = dict(),
490
490
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 0.,
491
+ aux_kv_recon_loss_weight = 1.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
494
  weight_tie_memory_model = False,
@@ -1,5 +1,5 @@
1
1
  import torch
2
- from torch import nn
2
+ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
@@ -304,13 +304,26 @@ class NeuralMemory(Module):
304
304
  nn.Sigmoid()
305
305
  ) if heads > 1 else None
306
306
 
307
- # memory mlp
307
+ # memory model
308
308
 
309
309
  if not exists(model):
310
310
  model = MemoryMLP(dim_head, **default_model_kwargs)
311
311
 
312
+ # validate memory model
313
+
312
314
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
313
315
 
316
+ test_shape = (3, 2, dim_head)
317
+
318
+ with torch.no_grad():
319
+ try:
320
+ test_input = torch.randn(test_shape)
321
+ mem_model_output = model(test_input)
322
+ except:
323
+ raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
324
+
325
+ assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
326
+
314
327
  # the memory is the weights of the model
315
328
 
316
329
  self.memory_model = model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.4
3
+ Version: 0.2.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
+ titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
5
+ titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
6
+ titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
7
+ titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.5.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=g-Rx8zwTUbMv-XBYWPe9abFVVSUFLxOn_yVQ-wWvG5M,26039
4
- titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
- titans_pytorch/neural_memory.py,sha256=3ykFukUDp3dW1QwDmS3jZ2wFysiZE2ippcOoMFall34,24143
6
- titans_pytorch-0.2.4.dist-info/METADATA,sha256=2yY3d58zPQ1uyvnTX4Dml7a2dd2jRu3TR5NhBpPNmdY,6819
7
- titans_pytorch-0.2.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.4.dist-info/RECORD,,