titans-pytorch 0.2.28__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/PKG-INFO +1 -1
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/pyproject.toml +1 -1
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/tests/test_titans.py +10 -2
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/titans_pytorch/neural_memory.py +16 -15
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/.gitignore +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/LICENSE +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/README.md +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/data/README.md +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/fig1.png +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/fig2.png +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.2.28 → titans_pytorch-0.3.1}/train_mac.py +0 -0
@@ -156,6 +156,7 @@ def test_neural_mem_chaining_with_batch_size():
|
|
156
156
|
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
157
157
|
@pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
|
158
158
|
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
159
|
+
@pytest.mark.parametrize('neural_mem_momentum', (False, True))
|
159
160
|
def test_mac(
|
160
161
|
seq_len,
|
161
162
|
num_persist_mem_tokens,
|
@@ -164,6 +165,7 @@ def test_mac(
|
|
164
165
|
neural_mem_segment_len,
|
165
166
|
neural_mem_weight_residual,
|
166
167
|
neural_mem_batch_size,
|
168
|
+
neural_mem_momentum
|
167
169
|
):
|
168
170
|
transformer = MemoryAsContextTransformer(
|
169
171
|
num_tokens = 256,
|
@@ -175,7 +177,10 @@ def test_mac(
|
|
175
177
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
176
178
|
neural_memory_segment_len = neural_mem_segment_len,
|
177
179
|
neural_memory_batch_size = neural_mem_batch_size,
|
178
|
-
neural_mem_weight_residual = neural_mem_weight_residual
|
180
|
+
neural_mem_weight_residual = neural_mem_weight_residual,
|
181
|
+
neural_memory_kwargs = dict(
|
182
|
+
momentum = neural_mem_momentum
|
183
|
+
)
|
179
184
|
)
|
180
185
|
|
181
186
|
x = torch.randint(0, 256, (1, seq_len))
|
@@ -220,16 +225,19 @@ def test_mac_sampling(
|
|
220
225
|
@pytest.mark.parametrize('seq_len', (2, 64, 256))
|
221
226
|
@pytest.mark.parametrize('prompt_len', (0, 65))
|
222
227
|
@pytest.mark.parametrize('mem_chunk_size', (2, 32, 64))
|
228
|
+
@pytest.mark.parametrize('gated_transition', (False, True))
|
223
229
|
@torch_default_dtype(torch.float64)
|
224
230
|
def test_neural_mem_inference(
|
225
231
|
seq_len,
|
226
232
|
prompt_len,
|
227
|
-
mem_chunk_size
|
233
|
+
mem_chunk_size,
|
234
|
+
gated_transition
|
228
235
|
):
|
229
236
|
|
230
237
|
mem = NeuralMemory(
|
231
238
|
dim = 384,
|
232
239
|
chunk_size = mem_chunk_size,
|
240
|
+
gated_transition = gated_transition
|
233
241
|
)
|
234
242
|
|
235
243
|
seq = torch.randn(2, seq_len, 384)
|
@@ -66,12 +66,6 @@ def xnor(x, y):
|
|
66
66
|
def divisible_by(num, den):
|
67
67
|
return (num % den) == 0
|
68
68
|
|
69
|
-
def tuple_index_set(t: tuple, index, value):
|
70
|
-
klass = type(t)
|
71
|
-
t = list(t)
|
72
|
-
t[index] = value
|
73
|
-
return klass(*t)
|
74
|
-
|
75
69
|
def safe_cat(inputs, dim = -2):
|
76
70
|
inputs = tuple(filter(exists, inputs))
|
77
71
|
|
@@ -658,13 +652,14 @@ class NeuralMemory(Module):
|
|
658
652
|
next_last_update = TensorDict()
|
659
653
|
next_last_momentum = TensorDict()
|
660
654
|
|
661
|
-
for (param_name, surprise), (_, last_update)
|
655
|
+
for (param_name, surprise), (_, last_update) in zip(surprises.items(), past_last_update.items()):
|
662
656
|
|
663
657
|
update = surprise
|
664
658
|
|
665
659
|
# derive momentum with associative scan - eq (10)
|
666
660
|
|
667
661
|
if has_momentum:
|
662
|
+
last_momentum = past_last_momentum[param_name]
|
668
663
|
update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
|
669
664
|
momentum = update
|
670
665
|
next_last_momentum[param_name] = momentum[:, -1]
|
@@ -872,15 +867,20 @@ class NeuralMemory(Module):
|
|
872
867
|
last_update, last_momentum = past_state
|
873
868
|
|
874
869
|
if exists(gate):
|
875
|
-
|
876
|
-
|
877
|
-
|
870
|
+
last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())})
|
871
|
+
|
872
|
+
past_state = (last_update, last_momentum)
|
873
|
+
|
874
|
+
# set weights to the last updated weights for the last minibatch
|
878
875
|
|
879
|
-
|
880
|
-
next_neural_mem_state = tuple_index_set(next_neural_mem_state, -2, past_state)
|
881
|
-
next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, weights)
|
876
|
+
weights = last_update
|
882
877
|
|
883
|
-
|
878
|
+
next_neural_mem_state = next_neural_mem_state._replace(
|
879
|
+
weights = weights,
|
880
|
+
states = past_state,
|
881
|
+
)
|
882
|
+
|
883
|
+
next_neural_mem_state = next_neural_mem_state._replace(updates = updates)
|
884
884
|
|
885
885
|
# retrieve
|
886
886
|
|
@@ -891,7 +891,8 @@ class NeuralMemory(Module):
|
|
891
891
|
retrieve_chunk_size = 1
|
892
892
|
need_pad = False
|
893
893
|
|
894
|
-
last_update, _ =
|
894
|
+
last_update, _ = next_neural_mem_state.states
|
895
|
+
|
895
896
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
896
897
|
|
897
898
|
retrieved = self.retrieve_memories(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|