titans-pytorch 0.2.27__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/PKG-INFO +1 -1
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/pyproject.toml +1 -1
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/tests/test_titans.py +4 -1
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/titans_pytorch/neural_memory.py +15 -13
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/.gitignore +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/LICENSE +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/README.md +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/data/README.md +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/fig1.png +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/fig2.png +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.2.27 → titans_pytorch-0.3.0}/train_mac.py +0 -0
@@ -220,16 +220,19 @@ def test_mac_sampling(
|
|
220
220
|
@pytest.mark.parametrize('seq_len', (2, 64, 256))
|
221
221
|
@pytest.mark.parametrize('prompt_len', (0, 65))
|
222
222
|
@pytest.mark.parametrize('mem_chunk_size', (2, 32, 64))
|
223
|
+
@pytest.mark.parametrize('gated_transition', (False, True))
|
223
224
|
@torch_default_dtype(torch.float64)
|
224
225
|
def test_neural_mem_inference(
|
225
226
|
seq_len,
|
226
227
|
prompt_len,
|
227
|
-
mem_chunk_size
|
228
|
+
mem_chunk_size,
|
229
|
+
gated_transition
|
228
230
|
):
|
229
231
|
|
230
232
|
mem = NeuralMemory(
|
231
233
|
dim = 384,
|
232
234
|
chunk_size = mem_chunk_size,
|
235
|
+
gated_transition = gated_transition
|
233
236
|
)
|
234
237
|
|
235
238
|
seq = torch.randn(2, seq_len, 384)
|
@@ -66,12 +66,6 @@ def xnor(x, y):
|
|
66
66
|
def divisible_by(num, den):
|
67
67
|
return (num % den) == 0
|
68
68
|
|
69
|
-
def tuple_index_set(t: tuple, index, value):
|
70
|
-
klass = type(t)
|
71
|
-
t = list(t)
|
72
|
-
t[index] = value
|
73
|
-
return klass(*t)
|
74
|
-
|
75
69
|
def safe_cat(inputs, dim = -2):
|
76
70
|
inputs = tuple(filter(exists, inputs))
|
77
71
|
|
@@ -869,16 +863,23 @@ class NeuralMemory(Module):
|
|
869
863
|
|
870
864
|
# update weights once batch size is fulfilled
|
871
865
|
|
872
|
-
last_update,
|
866
|
+
last_update, last_momentum = past_state
|
873
867
|
|
874
868
|
if exists(gate):
|
875
|
-
|
876
|
-
|
877
|
-
|
869
|
+
last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())})
|
870
|
+
|
871
|
+
past_state = (last_update, last_momentum)
|
872
|
+
|
873
|
+
# set weights to the last updated weights for the last minibatch
|
878
874
|
|
879
|
-
|
875
|
+
weights = last_update
|
880
876
|
|
881
|
-
|
877
|
+
next_neural_mem_state = next_neural_mem_state._replace(
|
878
|
+
weights = weights,
|
879
|
+
states = past_state,
|
880
|
+
)
|
881
|
+
|
882
|
+
next_neural_mem_state = next_neural_mem_state._replace(updates = updates)
|
882
883
|
|
883
884
|
# retrieve
|
884
885
|
|
@@ -889,7 +890,8 @@ class NeuralMemory(Module):
|
|
889
890
|
retrieve_chunk_size = 1
|
890
891
|
need_pad = False
|
891
892
|
|
892
|
-
last_update, _ =
|
893
|
+
last_update, _ = next_neural_mem_state.states
|
894
|
+
|
893
895
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
894
896
|
|
895
897
|
retrieved = self.retrieve_memories(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|