titans-pytorch 0.2.10__tar.gz → 0.2.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/PKG-INFO +1 -1
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/pyproject.toml +1 -1
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/tests/test_titans.py +39 -16
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/titans_pytorch/mac_transformer.py +2 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/titans_pytorch/neural_memory.py +110 -27
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/train_mac.py +3 -1
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/.gitignore +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/LICENSE +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/README.md +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/data/README.md +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/fig1.png +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/fig2.png +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.10 → titans_pytorch-0.2.11}/titans_pytorch/memory_models.py +0 -0
@@ -73,41 +73,62 @@ def test_titans_attn_memory():
|
|
73
73
|
|
74
74
|
assert seq.shape == retrieved.shape
|
75
75
|
|
76
|
-
def
|
77
|
-
mem
|
76
|
+
def test_neural_mem_chaining_chunks():
|
77
|
+
mem = NeuralMemory(
|
78
78
|
dim = 384,
|
79
|
-
|
79
|
+
dim_head = 64,
|
80
|
+
heads = 2,
|
81
|
+
chunk_size = 16
|
80
82
|
)
|
81
83
|
|
82
|
-
|
83
|
-
store_seq = torch.randn(2, 64 * 32, 384)
|
84
|
+
seq = torch.randn(2, 48, 384)
|
84
85
|
|
85
|
-
|
86
|
+
parallel_retrieved, state = mem(seq)
|
86
87
|
|
87
|
-
|
88
|
+
seq_first, seq_second, seq_third = seq.split(16, dim = 1)
|
88
89
|
|
89
|
-
|
90
|
-
mem =
|
90
|
+
first_retrieved, state = mem(seq_first)
|
91
|
+
second_retrieved, state = mem(seq_second, state = state)
|
92
|
+
third_retrieved, state = mem(seq_third, state = state)
|
93
|
+
|
94
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
|
95
|
+
|
96
|
+
def test_neural_mem_chaining_with_batch_size():
|
97
|
+
mem = NeuralMemory(
|
91
98
|
dim = 384,
|
92
|
-
|
99
|
+
dim_head = 64,
|
100
|
+
heads = 2,
|
101
|
+
chunk_size = 16,
|
102
|
+
batch_size = 64
|
93
103
|
)
|
94
104
|
|
95
|
-
seq = torch.randn(2,
|
96
|
-
store_seq = torch.randn(2, 128 * 8, 384)
|
105
|
+
seq = torch.randn(2, 112, 384)
|
97
106
|
|
98
|
-
|
107
|
+
parallel_retrieved, state = mem(seq)
|
99
108
|
|
100
|
-
|
109
|
+
seq_first, seq_second, seq_third = seq[:, :16], seq[:, 16:64], seq[:, 64:]
|
110
|
+
|
111
|
+
first_retrieved, state = mem(seq_first)
|
112
|
+
second_retrieved, state = mem(seq_second, state = state)
|
113
|
+
third_retrieved, state = mem(seq_third, state = state)
|
114
|
+
|
115
|
+
parallel_part_retrieved = torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1)
|
116
|
+
|
117
|
+
assert torch.allclose(parallel_retrieved, parallel_part_retrieved, atol = 1e-5)
|
101
118
|
|
102
119
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
103
120
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
104
121
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
105
122
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
123
|
+
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
124
|
+
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
106
125
|
def test_mac(
|
107
126
|
seq_len,
|
108
127
|
num_persist_mem_tokens,
|
109
128
|
num_longterm_mem_tokens,
|
110
|
-
neural_mem_gate_attn_output
|
129
|
+
neural_mem_gate_attn_output,
|
130
|
+
neural_mem_segment_len,
|
131
|
+
neural_mem_batch_size
|
111
132
|
):
|
112
133
|
transformer = MemoryAsContextTransformer(
|
113
134
|
num_tokens = 256,
|
@@ -116,7 +137,9 @@ def test_mac(
|
|
116
137
|
num_persist_mem_tokens = num_persist_mem_tokens,
|
117
138
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
118
139
|
segment_len = 128,
|
119
|
-
neural_mem_gate_attn_output = neural_mem_gate_attn_output
|
140
|
+
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
141
|
+
neural_memory_segment_len = neural_mem_segment_len,
|
142
|
+
neural_memory_batch_size = neural_mem_batch_size,
|
120
143
|
)
|
121
144
|
|
122
145
|
x = torch.randint(0, 256, (1, seq_len))
|
@@ -481,6 +481,7 @@ class MemoryAsContextTransformer(Module):
|
|
481
481
|
neural_memory_add_value_residual = False,
|
482
482
|
num_longterm_mem_tokens = 0,
|
483
483
|
num_persist_mem_tokens = 0,
|
484
|
+
neural_memory_batch_size = None,
|
484
485
|
dim_head = 64,
|
485
486
|
heads = 8,
|
486
487
|
ff_mult = 4,
|
@@ -551,6 +552,7 @@ class MemoryAsContextTransformer(Module):
|
|
551
552
|
mem = NeuralMemory(
|
552
553
|
dim = dim,
|
553
554
|
chunk_size = self.neural_memory_segment_len,
|
555
|
+
batch_size = neural_memory_batch_size,
|
554
556
|
model = deepcopy(neural_memory_model),
|
555
557
|
**neural_memory_kwargs
|
556
558
|
)
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
6
6
|
from collections import namedtuple
|
7
7
|
|
8
8
|
import torch
|
9
|
-
from torch import nn, cat, Tensor
|
9
|
+
from torch import nn, cat, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
11
|
from torch.nn import Linear, Module, Parameter, ParameterList
|
12
12
|
from torch.func import functional_call, vmap, grad
|
@@ -39,7 +39,7 @@ w - num memory network weight parameters
|
|
39
39
|
LinearNoBias = partial(Linear, bias = False)
|
40
40
|
|
41
41
|
NeuralMemCache = namedtuple('NeuralMemCache', [
|
42
|
-
'
|
42
|
+
'seq_index',
|
43
43
|
'weights',
|
44
44
|
'cache_store_segment',
|
45
45
|
'states',
|
@@ -63,6 +63,9 @@ def identity(t):
|
|
63
63
|
def xnor(x, y):
|
64
64
|
return not (x ^ y)
|
65
65
|
|
66
|
+
def divisible_by(num, den):
|
67
|
+
return (num % den) == 0
|
68
|
+
|
66
69
|
def safe_cat(inputs, dim = -2):
|
67
70
|
inputs = tuple(filter(exists, inputs))
|
68
71
|
|
@@ -73,6 +76,9 @@ def safe_cat(inputs, dim = -2):
|
|
73
76
|
|
74
77
|
return cat(inputs, dim = dim)
|
75
78
|
|
79
|
+
def is_empty_tensor(t):
|
80
|
+
return t.numel() == 0
|
81
|
+
|
76
82
|
def dict_get_shape(td):
|
77
83
|
return {k: v.shape for k, v in td.items()}
|
78
84
|
|
@@ -118,7 +124,7 @@ def softclamp_max(t, max_value):
|
|
118
124
|
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
119
125
|
|
120
126
|
def softclamp_grad_norm(t, max_value):
|
121
|
-
if t
|
127
|
+
if is_empty_tensor(t):
|
122
128
|
return t
|
123
129
|
|
124
130
|
t, inverse = pack_one_with_inverse(t, 'bn *')
|
@@ -270,6 +276,7 @@ class NeuralMemory(Module):
|
|
270
276
|
self,
|
271
277
|
dim,
|
272
278
|
chunk_size: int | tuple[int, int] = 1,
|
279
|
+
batch_size = None,
|
273
280
|
dim_head = None,
|
274
281
|
heads = 1,
|
275
282
|
model: Module | None = None,
|
@@ -296,6 +303,13 @@ class NeuralMemory(Module):
|
|
296
303
|
|
297
304
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
298
305
|
|
306
|
+
# batch size
|
307
|
+
|
308
|
+
if exists(batch_size):
|
309
|
+
assert divisible_by(batch_size, self.store_chunk_size)
|
310
|
+
|
311
|
+
self.batch_size = batch_size
|
312
|
+
|
299
313
|
# associative scan
|
300
314
|
|
301
315
|
self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
|
@@ -460,9 +474,9 @@ class NeuralMemory(Module):
|
|
460
474
|
seq,
|
461
475
|
weights: dict[str, Tensor] | None = None,
|
462
476
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
463
|
-
|
477
|
+
seq_index = 0
|
464
478
|
):
|
465
|
-
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads,
|
479
|
+
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
|
466
480
|
|
467
481
|
# curtail sequence by multiple of the chunk size
|
468
482
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
@@ -472,6 +486,8 @@ class NeuralMemory(Module):
|
|
472
486
|
|
473
487
|
seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
|
474
488
|
|
489
|
+
next_seq_len_index = seq_index + round_down_seq_len
|
490
|
+
|
475
491
|
# init weights if needed
|
476
492
|
# weights of the memory network
|
477
493
|
|
@@ -568,7 +584,7 @@ class NeuralMemory(Module):
|
|
568
584
|
|
569
585
|
if num_chunks == 0:
|
570
586
|
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
|
571
|
-
next_store_state = NeuralMemCache(
|
587
|
+
next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
|
572
588
|
|
573
589
|
output = (updates, next_store_state)
|
574
590
|
|
@@ -607,7 +623,7 @@ class NeuralMemory(Module):
|
|
607
623
|
|
608
624
|
next_state = (next_last_update, next_last_momentum)
|
609
625
|
|
610
|
-
next_store_state = NeuralMemCache(
|
626
|
+
next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
|
611
627
|
|
612
628
|
# returns
|
613
629
|
|
@@ -619,9 +635,8 @@ class NeuralMemory(Module):
|
|
619
635
|
self,
|
620
636
|
seq,
|
621
637
|
past_weights: dict[str, Tensor],
|
622
|
-
chunk_size = None,
|
623
638
|
):
|
624
|
-
chunk_size =
|
639
|
+
chunk_size = self.retrieve_chunk_size
|
625
640
|
batch, seq_len = seq.shape[:2]
|
626
641
|
|
627
642
|
seq = self.retrieve_norm(seq)
|
@@ -691,9 +706,8 @@ class NeuralMemory(Module):
|
|
691
706
|
def forward_inference(
|
692
707
|
self,
|
693
708
|
token: Tensor,
|
694
|
-
state = None,
|
709
|
+
state: NeuralMemCache | None = None,
|
695
710
|
):
|
696
|
-
|
697
711
|
# unpack previous state
|
698
712
|
|
699
713
|
if not exists(state):
|
@@ -707,6 +721,8 @@ class NeuralMemory(Module):
|
|
707
721
|
if token.ndim == 2:
|
708
722
|
token = rearrange(token, 'b d -> b 1 d')
|
709
723
|
|
724
|
+
assert token.shape[1] == 1
|
725
|
+
|
710
726
|
# increment the sequence cache which is at most the chunk size
|
711
727
|
|
712
728
|
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
@@ -757,32 +773,99 @@ class NeuralMemory(Module):
|
|
757
773
|
self,
|
758
774
|
seq,
|
759
775
|
store_seq = None,
|
760
|
-
|
761
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
762
|
-
chunk_size = None,
|
763
|
-
store_chunk_size = None,
|
764
|
-
return_next_state = False,
|
776
|
+
state: NeuralMemCache | None = None,
|
765
777
|
):
|
766
|
-
|
778
|
+
if not exists(state):
|
779
|
+
state = (0, None, None, None, None)
|
780
|
+
|
781
|
+
seq_index, weights, cache_store_seq, past_state, updates = state
|
782
|
+
|
783
|
+
assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
|
767
784
|
|
768
785
|
# store
|
769
786
|
|
770
787
|
store_seq = default(store_seq, seq)
|
771
788
|
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
789
|
+
# functions
|
790
|
+
|
791
|
+
# compute split sizes of sequence
|
792
|
+
# for now manually update weights to last update at the correct boundaries
|
793
|
+
|
794
|
+
store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
|
795
|
+
|
796
|
+
need_update_weights = exists(batch_size)
|
797
|
+
|
798
|
+
# determine split sizes and when to update
|
799
|
+
|
800
|
+
if need_update_weights:
|
801
|
+
update_after_final_store = divisible_by(seq_index + store_seq_len, batch_size)
|
802
|
+
|
803
|
+
seq_range = torch.arange(store_seq_len) + seq_index + 1
|
804
|
+
batch_boundary = divisible_by(seq_range, batch_size)
|
805
|
+
|
806
|
+
indices = seq_range[batch_boundary] - seq_index
|
807
|
+
|
808
|
+
indices = F.pad(indices, (1, 0), value = 0)
|
809
|
+
|
810
|
+
if indices[-1] != store_seq_len:
|
811
|
+
indices = F.pad(indices, (0, 1), value = store_seq_len)
|
812
|
+
|
813
|
+
split_sizes = (indices[1:] - indices[:-1]).tolist()
|
814
|
+
|
815
|
+
assert sum(split_sizes) == store_seq_len
|
816
|
+
else:
|
817
|
+
split_sizes = (store_seq_len,)
|
818
|
+
update_after_final_store = False
|
819
|
+
|
820
|
+
# accumulate updates
|
821
|
+
|
822
|
+
updates = None
|
823
|
+
|
824
|
+
def accum_updates(past_updates, future_updates):
|
825
|
+
if not exists(past_updates):
|
826
|
+
return future_updates
|
827
|
+
|
828
|
+
return TensorDict({param_name: cat((past_update[:, :-1], future_update), dim = 1) for (param_name, past_update), (_, future_update) in zip(past_updates.items(), future_updates.items())})
|
829
|
+
|
830
|
+
# loop through chunks of store sequences
|
831
|
+
|
832
|
+
store_seqs = store_seq.split(split_sizes, dim = -2)
|
833
|
+
|
834
|
+
for ind, store_seq_chunk in enumerate(store_seqs):
|
835
|
+
is_last = ind == (len(store_seqs) - 1)
|
836
|
+
|
837
|
+
# store
|
838
|
+
|
839
|
+
next_updates, next_neural_mem_state = self.store_memories(
|
840
|
+
store_seq_chunk,
|
841
|
+
weights,
|
842
|
+
seq_index = seq_index,
|
843
|
+
past_state = past_state,
|
844
|
+
)
|
845
|
+
|
846
|
+
seq_index = next_neural_mem_state.seq_index
|
847
|
+
past_state = next_neural_mem_state.states
|
848
|
+
|
849
|
+
updates = accum_updates(updates, next_updates)
|
850
|
+
|
851
|
+
if is_last and not update_after_final_store:
|
852
|
+
continue
|
853
|
+
|
854
|
+
# update weights once batch size is fulfilled
|
855
|
+
|
856
|
+
last_update, _ = past_state
|
857
|
+
|
858
|
+
weights = last_update
|
859
|
+
|
860
|
+
next_neural_mem_state = list(next_neural_mem_state)
|
861
|
+
next_neural_mem_state[1] = last_update
|
862
|
+
next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
|
777
863
|
|
778
864
|
# retrieve
|
779
865
|
|
780
866
|
retrieved = self.retrieve_memories(
|
781
867
|
seq,
|
782
|
-
updates
|
783
|
-
chunk_size = chunk_size,
|
868
|
+
updates
|
784
869
|
)
|
785
870
|
|
786
|
-
|
787
|
-
|
788
|
-
return output
|
871
|
+
return retrieved, next_neural_mem_state
|
@@ -35,7 +35,8 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
|
35
35
|
NEURAL_MEM_MOMENTUM = True
|
36
36
|
NEURAL_MEM_QK_NORM = True
|
37
37
|
WINDOW_SIZE = 32
|
38
|
-
NEURAL_MEM_SEGMENT_LEN =
|
38
|
+
NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
|
39
|
+
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
39
40
|
SLIDING_WINDOWS = True
|
40
41
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
41
42
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
@@ -83,6 +84,7 @@ model = MemoryAsContextTransformer(
|
|
83
84
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
84
85
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
85
86
|
neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
|
87
|
+
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
86
88
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
87
89
|
use_flex_attn = USE_FLEX_ATTN,
|
88
90
|
sliding_window_attn = SLIDING_WINDOWS,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|