titans-pytorch 0.1.33__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +26 -4
- titans_pytorch/neural_memory.py +7 -1
- {titans_pytorch-0.1.33.dist-info → titans_pytorch-0.1.34.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.34.dist-info/RECORD +8 -0
- titans_pytorch-0.1.33.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.33.dist-info → titans_pytorch-0.1.34.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.33.dist-info → titans_pytorch-0.1.34.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
from math import ceil
|
5
|
+
from copy import deepcopy
|
5
6
|
from functools import partial
|
6
7
|
from collections import namedtuple
|
7
8
|
|
@@ -485,11 +486,13 @@ class MemoryAsContextTransformer(Module):
|
|
485
486
|
heads = 8,
|
486
487
|
ff_mult = 4,
|
487
488
|
num_residual_streams = 4,
|
489
|
+
neural_memory_model: Module | None = None,
|
488
490
|
neural_memory_kwargs: dict = dict(),
|
489
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
490
492
|
aux_kv_recon_loss_weight = 0.,
|
491
493
|
use_flex_attn = False,
|
492
|
-
sliding_window_attn = False
|
494
|
+
sliding_window_attn = False,
|
495
|
+
weight_tie_memory_model = False
|
493
496
|
):
|
494
497
|
super().__init__()
|
495
498
|
|
@@ -523,6 +526,15 @@ class MemoryAsContextTransformer(Module):
|
|
523
526
|
|
524
527
|
neural_memory_layers = default(neural_memory_layers, layers)
|
525
528
|
|
529
|
+
# weight tying neural memory model
|
530
|
+
|
531
|
+
maybe_copy = deepcopy if not weight_tie_memory_model else identity
|
532
|
+
|
533
|
+
if weight_tie_memory_model:
|
534
|
+
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
535
|
+
|
536
|
+
self.weight_tie_memory_model = weight_tie_memory_model
|
537
|
+
|
526
538
|
# mem, attn, and feedforward layers
|
527
539
|
|
528
540
|
for layer in layers:
|
@@ -551,6 +563,7 @@ class MemoryAsContextTransformer(Module):
|
|
551
563
|
mem = NeuralMemory(
|
552
564
|
dim = dim,
|
553
565
|
chunk_size = self.neural_memory_segment_len,
|
566
|
+
model = maybe_copy(neural_memory_model),
|
554
567
|
**neural_memory_kwargs
|
555
568
|
)
|
556
569
|
|
@@ -683,7 +696,7 @@ class MemoryAsContextTransformer(Module):
|
|
683
696
|
|
684
697
|
# math
|
685
698
|
|
686
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
699
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
|
687
700
|
|
688
701
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
689
702
|
|
@@ -736,6 +749,10 @@ class MemoryAsContextTransformer(Module):
|
|
736
749
|
next_kv_caches = []
|
737
750
|
next_neural_mem_caches = []
|
738
751
|
|
752
|
+
# weight tied neural memory
|
753
|
+
|
754
|
+
neural_memory_updates = None
|
755
|
+
|
739
756
|
# value residual
|
740
757
|
|
741
758
|
value_residual = None
|
@@ -769,7 +786,8 @@ class MemoryAsContextTransformer(Module):
|
|
769
786
|
if not is_inferencing:
|
770
787
|
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
771
788
|
mem_input,
|
772
|
-
return_aux_kv_loss = True
|
789
|
+
return_aux_kv_loss = True,
|
790
|
+
prev_layer_updates = neural_memory_updates
|
773
791
|
)
|
774
792
|
|
775
793
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
@@ -777,9 +795,13 @@ class MemoryAsContextTransformer(Module):
|
|
777
795
|
else:
|
778
796
|
retrieved, next_neural_mem_cache = mem.forward_inference(
|
779
797
|
mem_input,
|
780
|
-
state = next(neural_mem_caches, None)
|
798
|
+
state = next(neural_mem_caches, None),
|
799
|
+
prev_layer_updates = neural_memory_updates
|
781
800
|
)
|
782
801
|
|
802
|
+
if weight_tie_memory_model:
|
803
|
+
neural_memory_updates = next_neural_mem_cache.updates
|
804
|
+
|
783
805
|
if self.gate_attn_output:
|
784
806
|
attn_out_gates = retrieved.sigmoid()
|
785
807
|
else:
|
titans_pytorch/neural_memory.py
CHANGED
@@ -822,6 +822,7 @@ class NeuralMemory(Module):
|
|
822
822
|
self,
|
823
823
|
token: Tensor,
|
824
824
|
state = None,
|
825
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
825
826
|
):
|
826
827
|
|
827
828
|
# unpack previous state
|
@@ -863,12 +864,17 @@ class NeuralMemory(Module):
|
|
863
864
|
else:
|
864
865
|
updates = updates.apply(lambda t: t[:, -1:])
|
865
866
|
|
867
|
+
if exists(prev_layer_updates):
|
868
|
+
prev_layer_updates = TensorDict(prev_layer_updates)
|
869
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
870
|
+
|
866
871
|
if store_seq_cache_len == self.chunk_size:
|
867
872
|
|
868
873
|
next_updates, next_states, _ = self.store_memories(
|
869
874
|
cache_store_seq,
|
870
875
|
weights,
|
871
|
-
past_state = past_states
|
876
|
+
past_state = past_states,
|
877
|
+
prev_layer_updates = prev_layer_updates,
|
872
878
|
)
|
873
879
|
|
874
880
|
updates = next_updates
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=JvA4mhQaW9LD4j6boRUfLfjyzDCtjqybIr4Ajeio8n8,25708
|
4
|
+
titans_pytorch/neural_memory.py,sha256=nNAxhkubuHCGs3bty_eA_yBhWqepPZJgKKvkWXO6IK4,28653
|
5
|
+
titans_pytorch-0.1.34.dist-info/METADATA,sha256=pVgjCX_YTT9_5WPcFfXpoaBvzrg1-esvwS0kPpeJAYU,6826
|
6
|
+
titans_pytorch-0.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.34.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
|
4
|
-
titans_pytorch/neural_memory.py,sha256=9dXpSaQYomc-ur-nEwej1nG9M5NqS0c3LBBP9jUIMPU,28352
|
5
|
-
titans_pytorch-0.1.33.dist-info/METADATA,sha256=A9BBoe0Sas2kxUcUi7w_RFl8-SIF1TLzPIRGuZlauFM,6826
|
6
|
-
titans_pytorch-0.1.33.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.33.dist-info/RECORD,,
|
File without changes
|
File without changes
|