titans-pytorch 0.2.14__tar.gz → 0.2.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/PKG-INFO +1 -1
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/pyproject.toml +1 -1
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/tests/test_titans.py +2 -3
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/titans_pytorch/mac_transformer.py +4 -10
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/titans_pytorch/neural_memory.py +33 -75
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/train_mac.py +1 -1
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/.gitignore +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/LICENSE +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/README.md +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/data/README.md +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/fig1.png +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/fig2.png +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.14 → titans_pytorch-0.2.15}/titans_pytorch/memory_models.py +0 -0
@@ -148,7 +148,7 @@ def test_mac(
|
|
148
148
|
assert logits.shape == (1, seq_len, 256)
|
149
149
|
|
150
150
|
@pytest.mark.parametrize('sliding', (False, True))
|
151
|
-
@pytest.mark.parametrize('mem_layers', (()))
|
151
|
+
@pytest.mark.parametrize('mem_layers', ((), None))
|
152
152
|
@pytest.mark.parametrize('longterm_mems', (0, 4, 16))
|
153
153
|
@pytest.mark.parametrize('prompt_len', (4, 16))
|
154
154
|
@torch_default_dtype(torch.float64)
|
@@ -190,7 +190,6 @@ def test_neural_mem_inference(
|
|
190
190
|
prompt_len,
|
191
191
|
mem_chunk_size
|
192
192
|
):
|
193
|
-
pytest.skip()
|
194
193
|
|
195
194
|
mem = NeuralMemory(
|
196
195
|
dim = 384,
|
@@ -218,7 +217,7 @@ def test_neural_mem_inference(
|
|
218
217
|
|
219
218
|
for token in seq.unbind(dim = 1):
|
220
219
|
|
221
|
-
one_retrieved, state = mem.
|
220
|
+
one_retrieved, state = mem.forward(
|
222
221
|
token,
|
223
222
|
state = state,
|
224
223
|
)
|
@@ -761,16 +761,10 @@ class MemoryAsContextTransformer(Module):
|
|
761
761
|
|
762
762
|
mem_input, add_residual = mem_hyper_conn(x)
|
763
763
|
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
else:
|
770
|
-
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
771
|
-
mem_input,
|
772
|
-
state = next(neural_mem_caches, None),
|
773
|
-
)
|
764
|
+
retrieved, next_neural_mem_cache = mem.forward(
|
765
|
+
mem_input,
|
766
|
+
state = next(neural_mem_caches, None),
|
767
|
+
)
|
774
768
|
|
775
769
|
if self.gate_attn_output:
|
776
770
|
attn_out_gates = retrieved.sigmoid()
|
@@ -655,15 +655,19 @@ class NeuralMemory(Module):
|
|
655
655
|
self,
|
656
656
|
seq,
|
657
657
|
past_weights: dict[str, Tensor],
|
658
|
+
chunk_size = None,
|
659
|
+
need_pad = True
|
658
660
|
):
|
659
|
-
chunk_size = self.retrieve_chunk_size
|
661
|
+
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
660
662
|
batch, seq_len = seq.shape[:2]
|
661
663
|
|
662
664
|
seq = self.retrieve_norm(seq)
|
663
665
|
|
664
|
-
|
666
|
+
need_pad = need_pad or chunk_size > 1
|
667
|
+
|
668
|
+
if need_pad:
|
669
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
665
670
|
|
666
|
-
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
667
671
|
seq_len_plus_one = seq.shape[-2]
|
668
672
|
|
669
673
|
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
@@ -718,76 +722,10 @@ class NeuralMemory(Module):
|
|
718
722
|
|
719
723
|
# restore, pad with empty memory embed
|
720
724
|
|
721
|
-
|
722
|
-
|
723
|
-
return values
|
724
|
-
|
725
|
-
@torch.no_grad()
|
726
|
-
def forward_inference(
|
727
|
-
self,
|
728
|
-
token: Tensor,
|
729
|
-
state: NeuralMemCache | None = None,
|
730
|
-
):
|
731
|
-
# unpack previous state
|
732
|
-
|
733
|
-
if not exists(state):
|
734
|
-
state = (0, None, None, None, None)
|
735
|
-
|
736
|
-
seq_index, weights, cache_store_seq, past_states, updates = state
|
737
|
-
|
738
|
-
curr_seq_len = seq_index + 1
|
739
|
-
batch = token.shape[0]
|
725
|
+
if need_pad:
|
726
|
+
values = values[:, 1:]
|
740
727
|
|
741
|
-
|
742
|
-
token = rearrange(token, 'b d -> b 1 d')
|
743
|
-
|
744
|
-
assert token.shape[1] == 1
|
745
|
-
|
746
|
-
# increment the sequence cache which is at most the chunk size
|
747
|
-
|
748
|
-
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
749
|
-
|
750
|
-
# early return empty memory, when no memories are stored for steps < first chunk size
|
751
|
-
|
752
|
-
if curr_seq_len < self.chunk_size:
|
753
|
-
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
754
|
-
|
755
|
-
output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
|
756
|
-
|
757
|
-
return output
|
758
|
-
|
759
|
-
# store if storage sequence cache hits the chunk size
|
760
|
-
|
761
|
-
next_states = past_states
|
762
|
-
store_seq_cache_len = cache_store_seq.shape[-2]
|
763
|
-
|
764
|
-
if not exists(updates):
|
765
|
-
updates = weights.clone().zero_()
|
766
|
-
updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
|
767
|
-
else:
|
768
|
-
updates = updates.apply(lambda t: t[:, -1:])
|
769
|
-
|
770
|
-
if store_seq_cache_len == self.chunk_size:
|
771
|
-
|
772
|
-
next_updates, store_state = self.store_memories(
|
773
|
-
cache_store_seq,
|
774
|
-
weights,
|
775
|
-
past_state = past_states,
|
776
|
-
)
|
777
|
-
|
778
|
-
updates = next_updates
|
779
|
-
cache_store_seq = None
|
780
|
-
next_states = store_state.states
|
781
|
-
|
782
|
-
# retrieve
|
783
|
-
|
784
|
-
retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
|
785
|
-
|
786
|
-
# next state tuple
|
787
|
-
|
788
|
-
next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
|
789
|
-
|
790
|
-
return retrieved, next_store_state
|
728
|
+
return values[:, :seq_len]
|
791
729
|
|
792
730
|
def forward(
|
793
731
|
self,
|
@@ -795,17 +733,25 @@ class NeuralMemory(Module):
|
|
795
733
|
store_seq = None,
|
796
734
|
state: NeuralMemCache | None = None,
|
797
735
|
):
|
736
|
+
if seq.ndim == 2:
|
737
|
+
seq = rearrange(seq, 'b d -> b 1 d')
|
738
|
+
|
739
|
+
is_single_token = seq.shape[1] == 1
|
740
|
+
|
798
741
|
if not exists(state):
|
799
742
|
state = (0, None, None, None, None)
|
800
743
|
|
801
744
|
seq_index, weights, cache_store_seq, past_state, updates = state
|
802
745
|
|
803
|
-
assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
|
804
|
-
|
805
746
|
# store
|
806
747
|
|
807
748
|
store_seq = default(store_seq, seq)
|
808
749
|
|
750
|
+
# take care of cache
|
751
|
+
|
752
|
+
if exists(cache_store_seq):
|
753
|
+
store_seq = safe_cat((cache_store_seq, store_seq))
|
754
|
+
|
809
755
|
# functions
|
810
756
|
|
811
757
|
# compute split sizes of sequence
|
@@ -883,9 +829,21 @@ class NeuralMemory(Module):
|
|
883
829
|
|
884
830
|
# retrieve
|
885
831
|
|
832
|
+
need_pad = True
|
833
|
+
retrieve_chunk_size = None
|
834
|
+
|
835
|
+
if is_single_token:
|
836
|
+
retrieve_chunk_size = 1
|
837
|
+
need_pad = False
|
838
|
+
|
839
|
+
last_update, _ = past_state
|
840
|
+
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
841
|
+
|
886
842
|
retrieved = self.retrieve_memories(
|
887
843
|
seq,
|
888
|
-
updates
|
844
|
+
updates,
|
845
|
+
chunk_size = retrieve_chunk_size,
|
846
|
+
need_pad = need_pad,
|
889
847
|
)
|
890
848
|
|
891
849
|
return retrieved, next_neural_mem_state
|
@@ -87,7 +87,6 @@ model = MemoryAsContextTransformer(
|
|
87
87
|
neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
|
88
88
|
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
89
89
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
90
|
-
default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
|
91
90
|
use_flex_attn = USE_FLEX_ATTN,
|
92
91
|
sliding_window_attn = SLIDING_WINDOWS,
|
93
92
|
neural_memory_model = MemoryMLP(
|
@@ -100,6 +99,7 @@ model = MemoryAsContextTransformer(
|
|
100
99
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
101
100
|
qk_rmsnorm = NEURAL_MEM_QK_NORM,
|
102
101
|
momentum = NEURAL_MEM_MOMENTUM,
|
102
|
+
default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
|
103
103
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
104
104
|
per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
|
105
105
|
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|