titans-pytorch 0.2.14__tar.gz → 0.2.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/PKG-INFO +1 -1
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/pyproject.toml +1 -1
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/tests/test_titans.py +2 -3
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/titans_pytorch/mac_transformer.py +14 -9
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/titans_pytorch/neural_memory.py +43 -76
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/train_mac.py +4 -2
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/.gitignore +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/LICENSE +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/README.md +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/data/README.md +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/fig1.png +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/fig2.png +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.16}/titans_pytorch/memory_models.py +0 -0
@@ -148,7 +148,7 @@ def test_mac(
|
|
148
148
|
assert logits.shape == (1, seq_len, 256)
|
149
149
|
|
150
150
|
@pytest.mark.parametrize('sliding', (False, True))
|
151
|
-
@pytest.mark.parametrize('mem_layers', (()))
|
151
|
+
@pytest.mark.parametrize('mem_layers', ((), None))
|
152
152
|
@pytest.mark.parametrize('longterm_mems', (0, 4, 16))
|
153
153
|
@pytest.mark.parametrize('prompt_len', (4, 16))
|
154
154
|
@torch_default_dtype(torch.float64)
|
@@ -190,7 +190,6 @@ def test_neural_mem_inference(
|
|
190
190
|
prompt_len,
|
191
191
|
mem_chunk_size
|
192
192
|
):
|
193
|
-
pytest.skip()
|
194
193
|
|
195
194
|
mem = NeuralMemory(
|
196
195
|
dim = 384,
|
@@ -218,7 +217,7 @@ def test_neural_mem_inference(
|
|
218
217
|
|
219
218
|
for token in seq.unbind(dim = 1):
|
220
219
|
|
221
|
-
one_retrieved, state = mem.
|
220
|
+
one_retrieved, state = mem.forward(
|
222
221
|
token,
|
223
222
|
state = state,
|
224
223
|
)
|
@@ -491,6 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
+
neural_mem_weight_residual = False
|
494
495
|
):
|
495
496
|
super().__init__()
|
496
497
|
|
@@ -524,6 +525,8 @@ class MemoryAsContextTransformer(Module):
|
|
524
525
|
|
525
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
526
527
|
|
528
|
+
self.neural_mem_weight_residual = neural_mem_weight_residual
|
529
|
+
|
527
530
|
# mem, attn, and feedforward layers
|
528
531
|
|
529
532
|
for layer in layers:
|
@@ -739,6 +742,10 @@ class MemoryAsContextTransformer(Module):
|
|
739
742
|
|
740
743
|
value_residual = None
|
741
744
|
|
745
|
+
# neural mem weight residual
|
746
|
+
|
747
|
+
mem_weight_residual = None
|
748
|
+
|
742
749
|
# when inferencing, only do one token at a time
|
743
750
|
|
744
751
|
if is_inferencing:
|
@@ -761,16 +768,14 @@ class MemoryAsContextTransformer(Module):
|
|
761
768
|
|
762
769
|
mem_input, add_residual = mem_hyper_conn(x)
|
763
770
|
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
771
|
+
retrieved, next_neural_mem_cache = mem.forward(
|
772
|
+
mem_input,
|
773
|
+
state = next(neural_mem_caches, None),
|
774
|
+
prev_weights = mem_weight_residual
|
775
|
+
)
|
768
776
|
|
769
|
-
|
770
|
-
|
771
|
-
mem_input,
|
772
|
-
state = next(neural_mem_caches, None),
|
773
|
-
)
|
777
|
+
if self.neural_mem_weight_residual:
|
778
|
+
mem_weight_residual = next_neural_mem_cache.updates
|
774
779
|
|
775
780
|
if self.gate_attn_output:
|
776
781
|
attn_out_gates = retrieved.sigmoid()
|
@@ -494,7 +494,8 @@ class NeuralMemory(Module):
|
|
494
494
|
seq,
|
495
495
|
weights: dict[str, Tensor] | None = None,
|
496
496
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
497
|
-
seq_index = 0
|
497
|
+
seq_index = 0,
|
498
|
+
prev_weights = None
|
498
499
|
):
|
499
500
|
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
|
500
501
|
|
@@ -560,6 +561,12 @@ class NeuralMemory(Module):
|
|
560
561
|
|
561
562
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
562
563
|
|
564
|
+
# maybe add previous layer weight
|
565
|
+
|
566
|
+
if exists(prev_weights):
|
567
|
+
prev_weights = prev_weights.apply(lambda t: t[:, -1:])
|
568
|
+
weights_for_surprise = weights_for_surprise + prev_weights
|
569
|
+
|
563
570
|
# flatten batch and time if surprise depends on previous layer memory model
|
564
571
|
|
565
572
|
weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
|
@@ -655,15 +662,19 @@ class NeuralMemory(Module):
|
|
655
662
|
self,
|
656
663
|
seq,
|
657
664
|
past_weights: dict[str, Tensor],
|
665
|
+
chunk_size = None,
|
666
|
+
need_pad = True
|
658
667
|
):
|
659
|
-
chunk_size = self.retrieve_chunk_size
|
668
|
+
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
660
669
|
batch, seq_len = seq.shape[:2]
|
661
670
|
|
662
671
|
seq = self.retrieve_norm(seq)
|
663
672
|
|
664
|
-
|
673
|
+
need_pad = need_pad or chunk_size > 1
|
674
|
+
|
675
|
+
if need_pad:
|
676
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
665
677
|
|
666
|
-
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
667
678
|
seq_len_plus_one = seq.shape[-2]
|
668
679
|
|
669
680
|
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
@@ -718,94 +729,37 @@ class NeuralMemory(Module):
|
|
718
729
|
|
719
730
|
# restore, pad with empty memory embed
|
720
731
|
|
721
|
-
|
722
|
-
|
723
|
-
return values
|
724
|
-
|
725
|
-
@torch.no_grad()
|
726
|
-
def forward_inference(
|
727
|
-
self,
|
728
|
-
token: Tensor,
|
729
|
-
state: NeuralMemCache | None = None,
|
730
|
-
):
|
731
|
-
# unpack previous state
|
732
|
-
|
733
|
-
if not exists(state):
|
734
|
-
state = (0, None, None, None, None)
|
735
|
-
|
736
|
-
seq_index, weights, cache_store_seq, past_states, updates = state
|
737
|
-
|
738
|
-
curr_seq_len = seq_index + 1
|
739
|
-
batch = token.shape[0]
|
740
|
-
|
741
|
-
if token.ndim == 2:
|
742
|
-
token = rearrange(token, 'b d -> b 1 d')
|
743
|
-
|
744
|
-
assert token.shape[1] == 1
|
745
|
-
|
746
|
-
# increment the sequence cache which is at most the chunk size
|
747
|
-
|
748
|
-
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
749
|
-
|
750
|
-
# early return empty memory, when no memories are stored for steps < first chunk size
|
751
|
-
|
752
|
-
if curr_seq_len < self.chunk_size:
|
753
|
-
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
754
|
-
|
755
|
-
output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
|
756
|
-
|
757
|
-
return output
|
758
|
-
|
759
|
-
# store if storage sequence cache hits the chunk size
|
732
|
+
if need_pad:
|
733
|
+
values = values[:, 1:]
|
760
734
|
|
761
|
-
|
762
|
-
store_seq_cache_len = cache_store_seq.shape[-2]
|
763
|
-
|
764
|
-
if not exists(updates):
|
765
|
-
updates = weights.clone().zero_()
|
766
|
-
updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
|
767
|
-
else:
|
768
|
-
updates = updates.apply(lambda t: t[:, -1:])
|
769
|
-
|
770
|
-
if store_seq_cache_len == self.chunk_size:
|
771
|
-
|
772
|
-
next_updates, store_state = self.store_memories(
|
773
|
-
cache_store_seq,
|
774
|
-
weights,
|
775
|
-
past_state = past_states,
|
776
|
-
)
|
777
|
-
|
778
|
-
updates = next_updates
|
779
|
-
cache_store_seq = None
|
780
|
-
next_states = store_state.states
|
781
|
-
|
782
|
-
# retrieve
|
783
|
-
|
784
|
-
retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
|
785
|
-
|
786
|
-
# next state tuple
|
787
|
-
|
788
|
-
next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
|
789
|
-
|
790
|
-
return retrieved, next_store_state
|
735
|
+
return values[:, :seq_len]
|
791
736
|
|
792
737
|
def forward(
|
793
738
|
self,
|
794
739
|
seq,
|
795
740
|
store_seq = None,
|
796
741
|
state: NeuralMemCache | None = None,
|
742
|
+
prev_weights = None
|
797
743
|
):
|
744
|
+
if seq.ndim == 2:
|
745
|
+
seq = rearrange(seq, 'b d -> b 1 d')
|
746
|
+
|
747
|
+
is_single_token = seq.shape[1] == 1
|
748
|
+
|
798
749
|
if not exists(state):
|
799
750
|
state = (0, None, None, None, None)
|
800
751
|
|
801
752
|
seq_index, weights, cache_store_seq, past_state, updates = state
|
802
753
|
|
803
|
-
assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
|
804
|
-
|
805
754
|
# store
|
806
755
|
|
807
756
|
store_seq = default(store_seq, seq)
|
808
757
|
|
758
|
+
# take care of cache
|
759
|
+
|
760
|
+
if exists(cache_store_seq):
|
761
|
+
store_seq = safe_cat((cache_store_seq, store_seq))
|
762
|
+
|
809
763
|
# functions
|
810
764
|
|
811
765
|
# compute split sizes of sequence
|
@@ -861,6 +815,7 @@ class NeuralMemory(Module):
|
|
861
815
|
weights,
|
862
816
|
seq_index = seq_index,
|
863
817
|
past_state = past_state,
|
818
|
+
prev_weights = prev_weights
|
864
819
|
)
|
865
820
|
|
866
821
|
seq_index = next_neural_mem_state.seq_index
|
@@ -883,9 +838,21 @@ class NeuralMemory(Module):
|
|
883
838
|
|
884
839
|
# retrieve
|
885
840
|
|
841
|
+
need_pad = True
|
842
|
+
retrieve_chunk_size = None
|
843
|
+
|
844
|
+
if is_single_token:
|
845
|
+
retrieve_chunk_size = 1
|
846
|
+
need_pad = False
|
847
|
+
|
848
|
+
last_update, _ = past_state
|
849
|
+
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
850
|
+
|
886
851
|
retrieved = self.retrieve_memories(
|
887
852
|
seq,
|
888
|
-
updates
|
853
|
+
updates,
|
854
|
+
chunk_size = retrieve_chunk_size,
|
855
|
+
need_pad = need_pad,
|
889
856
|
)
|
890
857
|
|
891
858
|
return retrieved, next_neural_mem_state
|
@@ -36,8 +36,9 @@ NEURAL_MEM_MOMENTUM = True
|
|
36
36
|
NEURAL_MEM_QK_NORM = True
|
37
37
|
NEURAL_MEM_MAX_LR = 1e-1
|
38
38
|
WINDOW_SIZE = 32
|
39
|
-
NEURAL_MEM_SEGMENT_LEN =
|
39
|
+
NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
|
40
40
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
41
|
+
NEURAL_MEM_WEIGHT_RESIDUAL = True
|
41
42
|
SLIDING_WINDOWS = True
|
42
43
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
43
44
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
@@ -87,7 +88,7 @@ model = MemoryAsContextTransformer(
|
|
87
88
|
neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
|
88
89
|
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
89
90
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
90
|
-
|
91
|
+
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
|
91
92
|
use_flex_attn = USE_FLEX_ATTN,
|
92
93
|
sliding_window_attn = SLIDING_WINDOWS,
|
93
94
|
neural_memory_model = MemoryMLP(
|
@@ -100,6 +101,7 @@ model = MemoryAsContextTransformer(
|
|
100
101
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
101
102
|
qk_rmsnorm = NEURAL_MEM_QK_NORM,
|
102
103
|
momentum = NEURAL_MEM_MOMENTUM,
|
104
|
+
default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
|
103
105
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
104
106
|
per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
|
105
107
|
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|