titans-pytorch 0.1.29__py3-none-any.whl → 0.1.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +3 -2
- {titans_pytorch-0.1.29.dist-info → titans_pytorch-0.1.30.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.30.dist-info/RECORD +8 -0
- titans_pytorch-0.1.29.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.29.dist-info → titans_pytorch-0.1.30.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.29.dist-info → titans_pytorch-0.1.30.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -409,6 +409,7 @@ class NeuralMemory(Module):
|
|
409
409
|
):
|
410
410
|
super().__init__()
|
411
411
|
dim_head = default(dim_head, dim)
|
412
|
+
assert not (heads == 1 and dim_head != dim)
|
412
413
|
|
413
414
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
414
415
|
|
@@ -566,7 +567,7 @@ class NeuralMemory(Module):
|
|
566
567
|
):
|
567
568
|
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
568
569
|
|
569
|
-
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
570
|
+
seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
|
570
571
|
|
571
572
|
# handle edge case
|
572
573
|
|
@@ -645,7 +646,7 @@ class NeuralMemory(Module):
|
|
645
646
|
|
646
647
|
# restore batch and sequence dimension
|
647
648
|
|
648
|
-
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
649
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch * heads))
|
649
650
|
|
650
651
|
# maybe per layer modulation
|
651
652
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
|
4
|
+
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
+
titans_pytorch-0.1.30.dist-info/METADATA,sha256=o5flkZ0hNhZE06bSKVEFpbrkhuWB9putcaL_MZ0sJHA,6815
|
6
|
+
titans_pytorch-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.30.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
|
4
|
-
titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
|
5
|
-
titans_pytorch-0.1.29.dist-info/METADATA,sha256=9Na2UlBJ4mECXXY5GIyuokgN0oxs38rps24TIM6CNFY,6815
|
6
|
-
titans_pytorch-0.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.29.dist-info/RECORD,,
|
File without changes
|
File without changes
|