titans-pytorch 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +1 -1
- {titans_pytorch-0.2.20.dist-info → titans_pytorch-0.2.21.dist-info}/METADATA +1 -1
- {titans_pytorch-0.2.20.dist-info → titans_pytorch-0.2.21.dist-info}/RECORD +5 -5
- {titans_pytorch-0.2.20.dist-info → titans_pytorch-0.2.21.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.20.dist-info → titans_pytorch-0.2.21.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -587,7 +587,7 @@ class NeuralMemory(Module):
|
|
587
587
|
|
588
588
|
prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
|
589
589
|
|
590
|
-
if exists(self.to_learned_weight_residual_mix):
|
590
|
+
if exists(self.to_learned_weight_residual_mix) and num_chunks > 0:
|
591
591
|
mix = self.to_learned_weight_residual_mix(chunked_seq)
|
592
592
|
mix = rearrange(mix, 'b h n -> (b h) n')
|
593
593
|
prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,27
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
4
|
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.2.
|
7
|
-
titans_pytorch-0.2.
|
8
|
-
titans_pytorch-0.2.
|
9
|
-
titans_pytorch-0.2.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=mpVODrfNhWComrzfkg1d6OgNgcYXZH-HU6Uykw1foI8,26728
|
6
|
+
titans_pytorch-0.2.21.dist-info/METADATA,sha256=QRtuMbSc-WzVNYdY5pxwBKr3aFAqrhgvXW40y-2JZSU,6816
|
7
|
+
titans_pytorch-0.2.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|