titans-pytorch 0.1.28__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -409,6 +409,7 @@ class NeuralMemory(Module):
409
409
  ):
410
410
  super().__init__()
411
411
  dim_head = default(dim_head, dim)
412
+ assert not (heads == 1 and dim_head != dim)
412
413
 
413
414
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
414
415
 
@@ -566,7 +567,7 @@ class NeuralMemory(Module):
566
567
  ):
567
568
  assert xnor(exists(value_residual), exists(self.learned_value_residual))
568
569
 
569
- seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
570
+ seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
570
571
 
571
572
  # handle edge case
572
573
 
@@ -645,7 +646,7 @@ class NeuralMemory(Module):
645
646
 
646
647
  # restore batch and sequence dimension
647
648
 
648
- grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
649
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch * heads))
649
650
 
650
651
  # maybe per layer modulation
651
652
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.28
3
+ Version: 0.1.30
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
4
+ titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
5
+ titans_pytorch-0.1.30.dist-info/METADATA,sha256=o5flkZ0hNhZE06bSKVEFpbrkhuWB9putcaL_MZ0sJHA,6815
6
+ titans_pytorch-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.30.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
4
- titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
5
- titans_pytorch-0.1.28.dist-info/METADATA,sha256=8AJX9oaut11GeFcyBmVsmbnY7oWhsal13yv75DtPeno,6815
6
- titans_pytorch-0.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.28.dist-info/RECORD,,