titans-pytorch 0.3.1__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/PKG-INFO +1 -1
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/pyproject.toml +1 -1
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/tests/test_titans.py +25 -25
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/titans_pytorch/memory_models.py +6 -2
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/titans_pytorch/neural_memory.py +28 -21
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/.gitignore +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/LICENSE +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/README.md +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/data/README.md +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/fig1.png +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/fig2.png +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.3.1 → titans_pytorch-0.3.3}/train_mac.py +0 -0
@@ -42,7 +42,7 @@ def test_titans(
|
|
42
42
|
per_parameter_lr_modulation
|
43
43
|
):
|
44
44
|
mem = NeuralMemory(
|
45
|
-
dim =
|
45
|
+
dim = 16,
|
46
46
|
chunk_size = chunk_size,
|
47
47
|
activation = nn.SiLU() if silu else None,
|
48
48
|
attn_pool_chunks = attn_pool_chunks,
|
@@ -52,7 +52,7 @@ def test_titans(
|
|
52
52
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
53
53
|
)
|
54
54
|
|
55
|
-
seq = torch.randn(2, seq_len,
|
55
|
+
seq = torch.randn(2, seq_len, 16)
|
56
56
|
retrieved, _ = mem(seq)
|
57
57
|
|
58
58
|
assert seq.shape == retrieved.shape
|
@@ -61,14 +61,14 @@ def test_titans_attn_memory():
|
|
61
61
|
from titans_pytorch.memory_models import MemoryAttention
|
62
62
|
|
63
63
|
mem = NeuralMemory(
|
64
|
-
dim =
|
64
|
+
dim = 16,
|
65
65
|
chunk_size = 64,
|
66
66
|
model = MemoryAttention(
|
67
|
-
dim =
|
67
|
+
dim = 16
|
68
68
|
)
|
69
69
|
)
|
70
70
|
|
71
|
-
seq = torch.randn(2, 1024,
|
71
|
+
seq = torch.randn(2, 1024, 16)
|
72
72
|
retrieved, _ = mem(seq)
|
73
73
|
|
74
74
|
assert seq.shape == retrieved.shape
|
@@ -78,14 +78,14 @@ def test_neural_mem_chaining_chunks(
|
|
78
78
|
gated_transition
|
79
79
|
):
|
80
80
|
mem = NeuralMemory(
|
81
|
-
dim =
|
82
|
-
dim_head =
|
81
|
+
dim = 16,
|
82
|
+
dim_head = 16,
|
83
83
|
heads = 2,
|
84
84
|
chunk_size = 16,
|
85
85
|
gated_transition = gated_transition
|
86
86
|
)
|
87
87
|
|
88
|
-
seq = torch.randn(2, 48,
|
88
|
+
seq = torch.randn(2, 48, 16)
|
89
89
|
|
90
90
|
parallel_retrieved, state = mem(seq)
|
91
91
|
|
@@ -99,21 +99,21 @@ def test_neural_mem_chaining_chunks(
|
|
99
99
|
|
100
100
|
def test_neural_mem_chaining_with_weight_residual():
|
101
101
|
mem = NeuralMemory(
|
102
|
-
dim =
|
103
|
-
dim_head =
|
102
|
+
dim = 16,
|
103
|
+
dim_head = 16,
|
104
104
|
heads = 2,
|
105
105
|
chunk_size = 64
|
106
106
|
)
|
107
107
|
|
108
108
|
mem2 = NeuralMemory(
|
109
|
-
dim =
|
110
|
-
dim_head =
|
109
|
+
dim = 16,
|
110
|
+
dim_head = 16,
|
111
111
|
heads = 2,
|
112
112
|
chunk_size = 64,
|
113
113
|
accept_weight_residual = True
|
114
114
|
)
|
115
115
|
|
116
|
-
seq = torch.randn(2, 256,
|
116
|
+
seq = torch.randn(2, 256, 16)
|
117
117
|
|
118
118
|
seq, state = mem(seq)
|
119
119
|
|
@@ -124,18 +124,18 @@ def test_neural_mem_chaining_with_weight_residual():
|
|
124
124
|
first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
|
125
125
|
second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
|
126
126
|
|
127
|
-
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-
|
127
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
|
128
128
|
|
129
129
|
def test_neural_mem_chaining_with_batch_size():
|
130
130
|
mem = NeuralMemory(
|
131
|
-
dim =
|
132
|
-
dim_head =
|
131
|
+
dim = 16,
|
132
|
+
dim_head = 16,
|
133
133
|
heads = 2,
|
134
134
|
chunk_size = 16,
|
135
135
|
batch_size = 64
|
136
136
|
)
|
137
137
|
|
138
|
-
seq = torch.randn(2, 112,
|
138
|
+
seq = torch.randn(2, 112, 16)
|
139
139
|
|
140
140
|
parallel_retrieved, state = mem(seq)
|
141
141
|
|
@@ -169,7 +169,7 @@ def test_mac(
|
|
169
169
|
):
|
170
170
|
transformer = MemoryAsContextTransformer(
|
171
171
|
num_tokens = 256,
|
172
|
-
dim =
|
172
|
+
dim = 16,
|
173
173
|
depth = 2,
|
174
174
|
num_persist_mem_tokens = num_persist_mem_tokens,
|
175
175
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
@@ -201,7 +201,7 @@ def test_mac_sampling(
|
|
201
201
|
):
|
202
202
|
transformer = MemoryAsContextTransformer(
|
203
203
|
num_tokens = 256,
|
204
|
-
dim =
|
204
|
+
dim = 16,
|
205
205
|
depth = 4,
|
206
206
|
segment_len = 32,
|
207
207
|
num_persist_mem_tokens = 4,
|
@@ -235,12 +235,12 @@ def test_neural_mem_inference(
|
|
235
235
|
):
|
236
236
|
|
237
237
|
mem = NeuralMemory(
|
238
|
-
dim =
|
238
|
+
dim = 16,
|
239
239
|
chunk_size = mem_chunk_size,
|
240
240
|
gated_transition = gated_transition
|
241
241
|
)
|
242
242
|
|
243
|
-
seq = torch.randn(2, seq_len,
|
243
|
+
seq = torch.randn(2, seq_len, 16)
|
244
244
|
parallel_retrieved, _ = mem(seq)
|
245
245
|
|
246
246
|
assert seq.shape == parallel_retrieved.shape
|
@@ -282,7 +282,7 @@ def test_flex(
|
|
282
282
|
pytest.skip()
|
283
283
|
|
284
284
|
attn = SegmentedAttention(
|
285
|
-
dim =
|
285
|
+
dim = 16,
|
286
286
|
segment_len = 32,
|
287
287
|
num_persist_mem_tokens = 1,
|
288
288
|
num_longterm_mem_tokens = 1,
|
@@ -290,7 +290,7 @@ def test_flex(
|
|
290
290
|
sliding = sliding
|
291
291
|
).cuda()
|
292
292
|
|
293
|
-
seq = torch.randn(1, seq_len,
|
293
|
+
seq = torch.randn(1, seq_len, 16).cuda()
|
294
294
|
|
295
295
|
out_flex, _ = attn(seq)
|
296
296
|
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
@@ -307,8 +307,8 @@ def test_assoc_scan():
|
|
307
307
|
seq_len = 128
|
308
308
|
mid_point = seq_len // 2
|
309
309
|
|
310
|
-
gates = torch.randn(2, seq_len,
|
311
|
-
inputs = torch.randn(2, seq_len,
|
310
|
+
gates = torch.randn(2, seq_len, 16).sigmoid()
|
311
|
+
inputs = torch.randn(2, seq_len, 16)
|
312
312
|
|
313
313
|
output = scan(gates, inputs)
|
314
314
|
|
@@ -36,10 +36,14 @@ class MemoryMLP(Module):
|
|
36
36
|
def __init__(
|
37
37
|
self,
|
38
38
|
dim,
|
39
|
-
depth
|
39
|
+
depth,
|
40
|
+
expansion_factor = 2.
|
40
41
|
):
|
41
42
|
super().__init__()
|
42
|
-
|
43
|
+
dim_hidden = int(dim * expansion_factor)
|
44
|
+
dims = (dim, *((dim_hidden,) * (depth - 1)), dim)
|
45
|
+
|
46
|
+
self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
|
43
47
|
|
44
48
|
self.ln = LayerNorm(dim)
|
45
49
|
|
@@ -299,7 +299,8 @@ class NeuralMemory(Module):
|
|
299
299
|
accept_weight_residual = False,
|
300
300
|
gated_transition = False,
|
301
301
|
default_model_kwargs: dict = dict(
|
302
|
-
depth = 2
|
302
|
+
depth = 2,
|
303
|
+
expansion_factor = 4.
|
303
304
|
)
|
304
305
|
):
|
305
306
|
super().__init__()
|
@@ -689,16 +690,27 @@ class NeuralMemory(Module):
|
|
689
690
|
def retrieve_memories(
|
690
691
|
self,
|
691
692
|
seq,
|
692
|
-
|
693
|
-
chunk_size = None,
|
694
|
-
need_pad = True
|
693
|
+
weights: dict[str, Tensor],
|
695
694
|
):
|
696
|
-
chunk_size =
|
695
|
+
chunk_size = self.retrieve_chunk_size
|
696
|
+
|
697
|
+
weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
|
698
|
+
|
697
699
|
batch, seq_len = seq.shape[:2]
|
698
700
|
|
699
|
-
|
701
|
+
# auto infer single token decoding, if there are only 1 set of weights and 1 token
|
702
|
+
|
703
|
+
is_one_token = seq_len == 1
|
704
|
+
is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
|
705
|
+
|
706
|
+
is_single_token_decode = is_one_token and is_one_weight
|
707
|
+
|
708
|
+
if is_single_token_decode:
|
709
|
+
chunk_size = 1
|
710
|
+
|
711
|
+
# padding related, for chunked processing
|
700
712
|
|
701
|
-
need_pad =
|
713
|
+
need_pad = chunk_size > 1 or not is_one_weight
|
702
714
|
|
703
715
|
if need_pad:
|
704
716
|
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
@@ -713,7 +725,11 @@ class NeuralMemory(Module):
|
|
713
725
|
# the parameters of the memory model stores the memories of the key / values
|
714
726
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
715
727
|
|
716
|
-
|
728
|
+
weights = TensorDict(weights)
|
729
|
+
|
730
|
+
# pre norm
|
731
|
+
|
732
|
+
seq = self.retrieve_norm(seq)
|
717
733
|
|
718
734
|
# sequence Float['b n d'] to queries
|
719
735
|
|
@@ -729,14 +745,14 @@ class NeuralMemory(Module):
|
|
729
745
|
|
730
746
|
# fetch values from memory model
|
731
747
|
|
732
|
-
if
|
733
|
-
|
748
|
+
if weights_have_expanded_shape:
|
749
|
+
weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
|
734
750
|
|
735
751
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
736
752
|
|
737
753
|
# forward functional call
|
738
754
|
|
739
|
-
values = functional_call(self.memory_model, dict(
|
755
|
+
values = functional_call(self.memory_model, dict(weights), queries)
|
740
756
|
|
741
757
|
# reconstitute batch dimension
|
742
758
|
|
@@ -884,22 +900,13 @@ class NeuralMemory(Module):
|
|
884
900
|
|
885
901
|
# retrieve
|
886
902
|
|
887
|
-
need_pad = True
|
888
|
-
retrieve_chunk_size = None
|
889
|
-
|
890
903
|
if is_single_token:
|
891
|
-
retrieve_chunk_size = 1
|
892
|
-
need_pad = False
|
893
|
-
|
894
904
|
last_update, _ = next_neural_mem_state.states
|
895
|
-
|
896
905
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
897
906
|
|
898
907
|
retrieved = self.retrieve_memories(
|
899
908
|
seq,
|
900
|
-
updates
|
901
|
-
chunk_size = retrieve_chunk_size,
|
902
|
-
need_pad = need_pad,
|
909
|
+
updates
|
903
910
|
)
|
904
911
|
|
905
912
|
return retrieved, next_neural_mem_state
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|