titans-pytorch 0.3.2__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/PKG-INFO +1 -1
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/pyproject.toml +1 -1
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/tests/test_titans.py +25 -25
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/titans_pytorch/memory_models.py +1 -1
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/titans_pytorch/neural_memory.py +26 -20
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/.gitignore +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/LICENSE +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/README.md +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/data/README.md +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/fig1.png +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/fig2.png +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.3.2 → titans_pytorch-0.3.3}/train_mac.py +0 -0
@@ -42,7 +42,7 @@ def test_titans(
|
|
42
42
|
per_parameter_lr_modulation
|
43
43
|
):
|
44
44
|
mem = NeuralMemory(
|
45
|
-
dim =
|
45
|
+
dim = 16,
|
46
46
|
chunk_size = chunk_size,
|
47
47
|
activation = nn.SiLU() if silu else None,
|
48
48
|
attn_pool_chunks = attn_pool_chunks,
|
@@ -52,7 +52,7 @@ def test_titans(
|
|
52
52
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
53
53
|
)
|
54
54
|
|
55
|
-
seq = torch.randn(2, seq_len,
|
55
|
+
seq = torch.randn(2, seq_len, 16)
|
56
56
|
retrieved, _ = mem(seq)
|
57
57
|
|
58
58
|
assert seq.shape == retrieved.shape
|
@@ -61,14 +61,14 @@ def test_titans_attn_memory():
|
|
61
61
|
from titans_pytorch.memory_models import MemoryAttention
|
62
62
|
|
63
63
|
mem = NeuralMemory(
|
64
|
-
dim =
|
64
|
+
dim = 16,
|
65
65
|
chunk_size = 64,
|
66
66
|
model = MemoryAttention(
|
67
|
-
dim =
|
67
|
+
dim = 16
|
68
68
|
)
|
69
69
|
)
|
70
70
|
|
71
|
-
seq = torch.randn(2, 1024,
|
71
|
+
seq = torch.randn(2, 1024, 16)
|
72
72
|
retrieved, _ = mem(seq)
|
73
73
|
|
74
74
|
assert seq.shape == retrieved.shape
|
@@ -78,14 +78,14 @@ def test_neural_mem_chaining_chunks(
|
|
78
78
|
gated_transition
|
79
79
|
):
|
80
80
|
mem = NeuralMemory(
|
81
|
-
dim =
|
82
|
-
dim_head =
|
81
|
+
dim = 16,
|
82
|
+
dim_head = 16,
|
83
83
|
heads = 2,
|
84
84
|
chunk_size = 16,
|
85
85
|
gated_transition = gated_transition
|
86
86
|
)
|
87
87
|
|
88
|
-
seq = torch.randn(2, 48,
|
88
|
+
seq = torch.randn(2, 48, 16)
|
89
89
|
|
90
90
|
parallel_retrieved, state = mem(seq)
|
91
91
|
|
@@ -99,21 +99,21 @@ def test_neural_mem_chaining_chunks(
|
|
99
99
|
|
100
100
|
def test_neural_mem_chaining_with_weight_residual():
|
101
101
|
mem = NeuralMemory(
|
102
|
-
dim =
|
103
|
-
dim_head =
|
102
|
+
dim = 16,
|
103
|
+
dim_head = 16,
|
104
104
|
heads = 2,
|
105
105
|
chunk_size = 64
|
106
106
|
)
|
107
107
|
|
108
108
|
mem2 = NeuralMemory(
|
109
|
-
dim =
|
110
|
-
dim_head =
|
109
|
+
dim = 16,
|
110
|
+
dim_head = 16,
|
111
111
|
heads = 2,
|
112
112
|
chunk_size = 64,
|
113
113
|
accept_weight_residual = True
|
114
114
|
)
|
115
115
|
|
116
|
-
seq = torch.randn(2, 256,
|
116
|
+
seq = torch.randn(2, 256, 16)
|
117
117
|
|
118
118
|
seq, state = mem(seq)
|
119
119
|
|
@@ -124,18 +124,18 @@ def test_neural_mem_chaining_with_weight_residual():
|
|
124
124
|
first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
|
125
125
|
second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
|
126
126
|
|
127
|
-
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-
|
127
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
|
128
128
|
|
129
129
|
def test_neural_mem_chaining_with_batch_size():
|
130
130
|
mem = NeuralMemory(
|
131
|
-
dim =
|
132
|
-
dim_head =
|
131
|
+
dim = 16,
|
132
|
+
dim_head = 16,
|
133
133
|
heads = 2,
|
134
134
|
chunk_size = 16,
|
135
135
|
batch_size = 64
|
136
136
|
)
|
137
137
|
|
138
|
-
seq = torch.randn(2, 112,
|
138
|
+
seq = torch.randn(2, 112, 16)
|
139
139
|
|
140
140
|
parallel_retrieved, state = mem(seq)
|
141
141
|
|
@@ -169,7 +169,7 @@ def test_mac(
|
|
169
169
|
):
|
170
170
|
transformer = MemoryAsContextTransformer(
|
171
171
|
num_tokens = 256,
|
172
|
-
dim =
|
172
|
+
dim = 16,
|
173
173
|
depth = 2,
|
174
174
|
num_persist_mem_tokens = num_persist_mem_tokens,
|
175
175
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
@@ -201,7 +201,7 @@ def test_mac_sampling(
|
|
201
201
|
):
|
202
202
|
transformer = MemoryAsContextTransformer(
|
203
203
|
num_tokens = 256,
|
204
|
-
dim =
|
204
|
+
dim = 16,
|
205
205
|
depth = 4,
|
206
206
|
segment_len = 32,
|
207
207
|
num_persist_mem_tokens = 4,
|
@@ -235,12 +235,12 @@ def test_neural_mem_inference(
|
|
235
235
|
):
|
236
236
|
|
237
237
|
mem = NeuralMemory(
|
238
|
-
dim =
|
238
|
+
dim = 16,
|
239
239
|
chunk_size = mem_chunk_size,
|
240
240
|
gated_transition = gated_transition
|
241
241
|
)
|
242
242
|
|
243
|
-
seq = torch.randn(2, seq_len,
|
243
|
+
seq = torch.randn(2, seq_len, 16)
|
244
244
|
parallel_retrieved, _ = mem(seq)
|
245
245
|
|
246
246
|
assert seq.shape == parallel_retrieved.shape
|
@@ -282,7 +282,7 @@ def test_flex(
|
|
282
282
|
pytest.skip()
|
283
283
|
|
284
284
|
attn = SegmentedAttention(
|
285
|
-
dim =
|
285
|
+
dim = 16,
|
286
286
|
segment_len = 32,
|
287
287
|
num_persist_mem_tokens = 1,
|
288
288
|
num_longterm_mem_tokens = 1,
|
@@ -290,7 +290,7 @@ def test_flex(
|
|
290
290
|
sliding = sliding
|
291
291
|
).cuda()
|
292
292
|
|
293
|
-
seq = torch.randn(1, seq_len,
|
293
|
+
seq = torch.randn(1, seq_len, 16).cuda()
|
294
294
|
|
295
295
|
out_flex, _ = attn(seq)
|
296
296
|
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
@@ -307,8 +307,8 @@ def test_assoc_scan():
|
|
307
307
|
seq_len = 128
|
308
308
|
mid_point = seq_len // 2
|
309
309
|
|
310
|
-
gates = torch.randn(2, seq_len,
|
311
|
-
inputs = torch.randn(2, seq_len,
|
310
|
+
gates = torch.randn(2, seq_len, 16).sigmoid()
|
311
|
+
inputs = torch.randn(2, seq_len, 16)
|
312
312
|
|
313
313
|
output = scan(gates, inputs)
|
314
314
|
|
@@ -690,16 +690,27 @@ class NeuralMemory(Module):
|
|
690
690
|
def retrieve_memories(
|
691
691
|
self,
|
692
692
|
seq,
|
693
|
-
|
694
|
-
chunk_size = None,
|
695
|
-
need_pad = True
|
693
|
+
weights: dict[str, Tensor],
|
696
694
|
):
|
697
|
-
chunk_size =
|
695
|
+
chunk_size = self.retrieve_chunk_size
|
696
|
+
|
697
|
+
weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
|
698
|
+
|
698
699
|
batch, seq_len = seq.shape[:2]
|
699
700
|
|
700
|
-
|
701
|
+
# auto infer single token decoding, if there are only 1 set of weights and 1 token
|
702
|
+
|
703
|
+
is_one_token = seq_len == 1
|
704
|
+
is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
|
705
|
+
|
706
|
+
is_single_token_decode = is_one_token and is_one_weight
|
707
|
+
|
708
|
+
if is_single_token_decode:
|
709
|
+
chunk_size = 1
|
710
|
+
|
711
|
+
# padding related, for chunked processing
|
701
712
|
|
702
|
-
need_pad =
|
713
|
+
need_pad = chunk_size > 1 or not is_one_weight
|
703
714
|
|
704
715
|
if need_pad:
|
705
716
|
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
@@ -714,7 +725,11 @@ class NeuralMemory(Module):
|
|
714
725
|
# the parameters of the memory model stores the memories of the key / values
|
715
726
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
716
727
|
|
717
|
-
|
728
|
+
weights = TensorDict(weights)
|
729
|
+
|
730
|
+
# pre norm
|
731
|
+
|
732
|
+
seq = self.retrieve_norm(seq)
|
718
733
|
|
719
734
|
# sequence Float['b n d'] to queries
|
720
735
|
|
@@ -730,14 +745,14 @@ class NeuralMemory(Module):
|
|
730
745
|
|
731
746
|
# fetch values from memory model
|
732
747
|
|
733
|
-
if
|
734
|
-
|
748
|
+
if weights_have_expanded_shape:
|
749
|
+
weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
|
735
750
|
|
736
751
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
737
752
|
|
738
753
|
# forward functional call
|
739
754
|
|
740
|
-
values = functional_call(self.memory_model, dict(
|
755
|
+
values = functional_call(self.memory_model, dict(weights), queries)
|
741
756
|
|
742
757
|
# reconstitute batch dimension
|
743
758
|
|
@@ -885,22 +900,13 @@ class NeuralMemory(Module):
|
|
885
900
|
|
886
901
|
# retrieve
|
887
902
|
|
888
|
-
need_pad = True
|
889
|
-
retrieve_chunk_size = None
|
890
|
-
|
891
903
|
if is_single_token:
|
892
|
-
retrieve_chunk_size = 1
|
893
|
-
need_pad = False
|
894
|
-
|
895
904
|
last_update, _ = next_neural_mem_state.states
|
896
|
-
|
897
905
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
898
906
|
|
899
907
|
retrieved = self.retrieve_memories(
|
900
908
|
seq,
|
901
|
-
updates
|
902
|
-
chunk_size = retrieve_chunk_size,
|
903
|
-
need_pad = need_pad,
|
909
|
+
updates
|
904
910
|
)
|
905
911
|
|
906
912
|
return retrieved, next_neural_mem_state
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|