titans-pytorch 0.2.12__tar.gz → 0.2.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/PKG-INFO +1 -1
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/pyproject.toml +1 -1
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/tests/test_titans.py +2 -3
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/titans_pytorch/mac_transformer.py +4 -10
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/titans_pytorch/neural_memory.py +58 -80
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/train_mac.py +2 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/.gitignore +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/LICENSE +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/README.md +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/data/README.md +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/fig1.png +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/fig2.png +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.12 → titans_pytorch-0.2.15}/titans_pytorch/memory_models.py +0 -0
@@ -148,7 +148,7 @@ def test_mac(
|
|
148
148
|
assert logits.shape == (1, seq_len, 256)
|
149
149
|
|
150
150
|
@pytest.mark.parametrize('sliding', (False, True))
|
151
|
-
@pytest.mark.parametrize('mem_layers', (()))
|
151
|
+
@pytest.mark.parametrize('mem_layers', ((), None))
|
152
152
|
@pytest.mark.parametrize('longterm_mems', (0, 4, 16))
|
153
153
|
@pytest.mark.parametrize('prompt_len', (4, 16))
|
154
154
|
@torch_default_dtype(torch.float64)
|
@@ -190,7 +190,6 @@ def test_neural_mem_inference(
|
|
190
190
|
prompt_len,
|
191
191
|
mem_chunk_size
|
192
192
|
):
|
193
|
-
pytest.skip()
|
194
193
|
|
195
194
|
mem = NeuralMemory(
|
196
195
|
dim = 384,
|
@@ -218,7 +217,7 @@ def test_neural_mem_inference(
|
|
218
217
|
|
219
218
|
for token in seq.unbind(dim = 1):
|
220
219
|
|
221
|
-
one_retrieved, state = mem.
|
220
|
+
one_retrieved, state = mem.forward(
|
222
221
|
token,
|
223
222
|
state = state,
|
224
223
|
)
|
@@ -761,16 +761,10 @@ class MemoryAsContextTransformer(Module):
|
|
761
761
|
|
762
762
|
mem_input, add_residual = mem_hyper_conn(x)
|
763
763
|
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
else:
|
770
|
-
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
771
|
-
mem_input,
|
772
|
-
state = next(neural_mem_caches, None),
|
773
|
-
)
|
764
|
+
retrieved, next_neural_mem_cache = mem.forward(
|
765
|
+
mem_input,
|
766
|
+
state = next(neural_mem_caches, None),
|
767
|
+
)
|
774
768
|
|
775
769
|
if self.gate_attn_output:
|
776
770
|
attn_out_gates = retrieved.sigmoid()
|
@@ -284,7 +284,7 @@ class NeuralMemory(Module):
|
|
284
284
|
adaptive_step_transform: Callable | None = None,
|
285
285
|
default_step_transform_max_lr = 1.,
|
286
286
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
287
|
-
max_mem_layer_modulation =
|
287
|
+
max_mem_layer_modulation = 1., # max of 10.
|
288
288
|
attn_pool_chunks = False,
|
289
289
|
momentum = True,
|
290
290
|
pre_rmsnorm = True,
|
@@ -293,6 +293,9 @@ class NeuralMemory(Module):
|
|
293
293
|
max_grad_norm: float | None = None,
|
294
294
|
use_accelerated_scan = False,
|
295
295
|
activation: Module | None = None,
|
296
|
+
init_adaptive_step_bias = None,
|
297
|
+
init_momentum_bias = None,
|
298
|
+
init_decay_bias = None,
|
296
299
|
default_model_kwargs: dict = dict(
|
297
300
|
depth = 2
|
298
301
|
)
|
@@ -411,12 +414,12 @@ class NeuralMemory(Module):
|
|
411
414
|
# learned adaptive learning rate and momentum
|
412
415
|
|
413
416
|
self.to_momentum = Sequential(
|
414
|
-
|
417
|
+
nn.Linear(dim, heads),
|
415
418
|
Rearrange('b n h -> (b h) n 1')
|
416
419
|
) if momentum else None
|
417
420
|
|
418
421
|
self.to_adaptive_step = Sequential(
|
419
|
-
|
422
|
+
nn.Linear(dim, heads),
|
420
423
|
Rearrange('b n h -> (b h) n')
|
421
424
|
)
|
422
425
|
|
@@ -428,7 +431,7 @@ class NeuralMemory(Module):
|
|
428
431
|
# per layer learning rate modulation
|
429
432
|
|
430
433
|
self.to_layer_modulation = Sequential(
|
431
|
-
|
434
|
+
nn.Linear(dim, heads * self.num_memory_parameter_tensors),
|
432
435
|
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
433
436
|
nn.Sigmoid()
|
434
437
|
) if per_parameter_lr_modulation else None
|
@@ -442,10 +445,27 @@ class NeuralMemory(Module):
|
|
442
445
|
# weight decay factor
|
443
446
|
|
444
447
|
self.to_decay_factor = Sequential(
|
445
|
-
|
448
|
+
nn.Linear(dim, heads),
|
446
449
|
Rearrange('b n h -> (b h) n 1')
|
447
450
|
)
|
448
451
|
|
452
|
+
# inits
|
453
|
+
|
454
|
+
if exists(init_adaptive_step_bias):
|
455
|
+
linear = self.to_adaptive_step[0]
|
456
|
+
nn.init.zeros_(linear.weight)
|
457
|
+
nn.init.constant_(linear.bias, init_adaptive_step_bias)
|
458
|
+
|
459
|
+
if exists(init_momentum_bias):
|
460
|
+
linear = self.to_momentum[0]
|
461
|
+
nn.init.zeros_(linear.weight)
|
462
|
+
nn.init.constant_(linear.bias, init_momentum_bias)
|
463
|
+
|
464
|
+
if exists(init_decay_bias):
|
465
|
+
linear = self.to_decay_factor[0]
|
466
|
+
nn.init.zeros_(linear.weight)
|
467
|
+
nn.init.constant_(linear.bias, init_decay_bias)
|
468
|
+
|
449
469
|
# maybe use accelerated scan
|
450
470
|
|
451
471
|
self.use_accelerated_scan = use_accelerated_scan
|
@@ -635,15 +655,19 @@ class NeuralMemory(Module):
|
|
635
655
|
self,
|
636
656
|
seq,
|
637
657
|
past_weights: dict[str, Tensor],
|
658
|
+
chunk_size = None,
|
659
|
+
need_pad = True
|
638
660
|
):
|
639
|
-
chunk_size = self.retrieve_chunk_size
|
661
|
+
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
640
662
|
batch, seq_len = seq.shape[:2]
|
641
663
|
|
642
664
|
seq = self.retrieve_norm(seq)
|
643
665
|
|
644
|
-
|
666
|
+
need_pad = need_pad or chunk_size > 1
|
667
|
+
|
668
|
+
if need_pad:
|
669
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
645
670
|
|
646
|
-
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
647
671
|
seq_len_plus_one = seq.shape[-2]
|
648
672
|
|
649
673
|
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
@@ -698,76 +722,10 @@ class NeuralMemory(Module):
|
|
698
722
|
|
699
723
|
# restore, pad with empty memory embed
|
700
724
|
|
701
|
-
|
702
|
-
|
703
|
-
return values
|
704
|
-
|
705
|
-
@torch.no_grad()
|
706
|
-
def forward_inference(
|
707
|
-
self,
|
708
|
-
token: Tensor,
|
709
|
-
state: NeuralMemCache | None = None,
|
710
|
-
):
|
711
|
-
# unpack previous state
|
712
|
-
|
713
|
-
if not exists(state):
|
714
|
-
state = (0, None, None, None, None)
|
715
|
-
|
716
|
-
seq_index, weights, cache_store_seq, past_states, updates = state
|
717
|
-
|
718
|
-
curr_seq_len = seq_index + 1
|
719
|
-
batch = token.shape[0]
|
720
|
-
|
721
|
-
if token.ndim == 2:
|
722
|
-
token = rearrange(token, 'b d -> b 1 d')
|
723
|
-
|
724
|
-
assert token.shape[1] == 1
|
725
|
-
|
726
|
-
# increment the sequence cache which is at most the chunk size
|
727
|
-
|
728
|
-
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
729
|
-
|
730
|
-
# early return empty memory, when no memories are stored for steps < first chunk size
|
731
|
-
|
732
|
-
if curr_seq_len < self.chunk_size:
|
733
|
-
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
734
|
-
|
735
|
-
output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
|
725
|
+
if need_pad:
|
726
|
+
values = values[:, 1:]
|
736
727
|
|
737
|
-
|
738
|
-
|
739
|
-
# store if storage sequence cache hits the chunk size
|
740
|
-
|
741
|
-
next_states = past_states
|
742
|
-
store_seq_cache_len = cache_store_seq.shape[-2]
|
743
|
-
|
744
|
-
if not exists(updates):
|
745
|
-
updates = weights.clone().zero_()
|
746
|
-
updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
|
747
|
-
else:
|
748
|
-
updates = updates.apply(lambda t: t[:, -1:])
|
749
|
-
|
750
|
-
if store_seq_cache_len == self.chunk_size:
|
751
|
-
|
752
|
-
next_updates, store_state = self.store_memories(
|
753
|
-
cache_store_seq,
|
754
|
-
weights,
|
755
|
-
past_state = past_states,
|
756
|
-
)
|
757
|
-
|
758
|
-
updates = next_updates
|
759
|
-
cache_store_seq = None
|
760
|
-
next_states = store_state.states
|
761
|
-
|
762
|
-
# retrieve
|
763
|
-
|
764
|
-
retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
|
765
|
-
|
766
|
-
# next state tuple
|
767
|
-
|
768
|
-
next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
|
769
|
-
|
770
|
-
return retrieved, next_store_state
|
728
|
+
return values[:, :seq_len]
|
771
729
|
|
772
730
|
def forward(
|
773
731
|
self,
|
@@ -775,17 +733,25 @@ class NeuralMemory(Module):
|
|
775
733
|
store_seq = None,
|
776
734
|
state: NeuralMemCache | None = None,
|
777
735
|
):
|
736
|
+
if seq.ndim == 2:
|
737
|
+
seq = rearrange(seq, 'b d -> b 1 d')
|
738
|
+
|
739
|
+
is_single_token = seq.shape[1] == 1
|
740
|
+
|
778
741
|
if not exists(state):
|
779
742
|
state = (0, None, None, None, None)
|
780
743
|
|
781
744
|
seq_index, weights, cache_store_seq, past_state, updates = state
|
782
745
|
|
783
|
-
assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
|
784
|
-
|
785
746
|
# store
|
786
747
|
|
787
748
|
store_seq = default(store_seq, seq)
|
788
749
|
|
750
|
+
# take care of cache
|
751
|
+
|
752
|
+
if exists(cache_store_seq):
|
753
|
+
store_seq = safe_cat((cache_store_seq, store_seq))
|
754
|
+
|
789
755
|
# functions
|
790
756
|
|
791
757
|
# compute split sizes of sequence
|
@@ -863,9 +829,21 @@ class NeuralMemory(Module):
|
|
863
829
|
|
864
830
|
# retrieve
|
865
831
|
|
832
|
+
need_pad = True
|
833
|
+
retrieve_chunk_size = None
|
834
|
+
|
835
|
+
if is_single_token:
|
836
|
+
retrieve_chunk_size = 1
|
837
|
+
need_pad = False
|
838
|
+
|
839
|
+
last_update, _ = past_state
|
840
|
+
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
841
|
+
|
866
842
|
retrieved = self.retrieve_memories(
|
867
843
|
seq,
|
868
|
-
updates
|
844
|
+
updates,
|
845
|
+
chunk_size = retrieve_chunk_size,
|
846
|
+
need_pad = need_pad,
|
869
847
|
)
|
870
848
|
|
871
849
|
return retrieved, next_neural_mem_state
|
@@ -34,6 +34,7 @@ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural mem
|
|
34
34
|
NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
35
35
|
NEURAL_MEM_MOMENTUM = True
|
36
36
|
NEURAL_MEM_QK_NORM = True
|
37
|
+
NEURAL_MEM_MAX_LR = 1e-1
|
37
38
|
WINDOW_SIZE = 32
|
38
39
|
NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
|
39
40
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
@@ -98,6 +99,7 @@ model = MemoryAsContextTransformer(
|
|
98
99
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
99
100
|
qk_rmsnorm = NEURAL_MEM_QK_NORM,
|
100
101
|
momentum = NEURAL_MEM_MOMENTUM,
|
102
|
+
default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
|
101
103
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
102
104
|
per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
|
103
105
|
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|