titans-pytorch 0.2.28__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +16 -15
- {titans_pytorch-0.2.28.dist-info → titans_pytorch-0.3.1.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.1.dist-info/RECORD +9 -0
- titans_pytorch-0.2.28.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.28.dist-info → titans_pytorch-0.3.1.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.28.dist-info → titans_pytorch-0.3.1.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -66,12 +66,6 @@ def xnor(x, y):
|
|
66
66
|
def divisible_by(num, den):
|
67
67
|
return (num % den) == 0
|
68
68
|
|
69
|
-
def tuple_index_set(t: tuple, index, value):
|
70
|
-
klass = type(t)
|
71
|
-
t = list(t)
|
72
|
-
t[index] = value
|
73
|
-
return klass(*t)
|
74
|
-
|
75
69
|
def safe_cat(inputs, dim = -2):
|
76
70
|
inputs = tuple(filter(exists, inputs))
|
77
71
|
|
@@ -658,13 +652,14 @@ class NeuralMemory(Module):
|
|
658
652
|
next_last_update = TensorDict()
|
659
653
|
next_last_momentum = TensorDict()
|
660
654
|
|
661
|
-
for (param_name, surprise), (_, last_update)
|
655
|
+
for (param_name, surprise), (_, last_update) in zip(surprises.items(), past_last_update.items()):
|
662
656
|
|
663
657
|
update = surprise
|
664
658
|
|
665
659
|
# derive momentum with associative scan - eq (10)
|
666
660
|
|
667
661
|
if has_momentum:
|
662
|
+
last_momentum = past_last_momentum[param_name]
|
668
663
|
update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
|
669
664
|
momentum = update
|
670
665
|
next_last_momentum[param_name] = momentum[:, -1]
|
@@ -872,15 +867,20 @@ class NeuralMemory(Module):
|
|
872
867
|
last_update, last_momentum = past_state
|
873
868
|
|
874
869
|
if exists(gate):
|
875
|
-
|
876
|
-
|
877
|
-
|
870
|
+
last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())})
|
871
|
+
|
872
|
+
past_state = (last_update, last_momentum)
|
873
|
+
|
874
|
+
# set weights to the last updated weights for the last minibatch
|
878
875
|
|
879
|
-
|
880
|
-
next_neural_mem_state = tuple_index_set(next_neural_mem_state, -2, past_state)
|
881
|
-
next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, weights)
|
876
|
+
weights = last_update
|
882
877
|
|
883
|
-
|
878
|
+
next_neural_mem_state = next_neural_mem_state._replace(
|
879
|
+
weights = weights,
|
880
|
+
states = past_state,
|
881
|
+
)
|
882
|
+
|
883
|
+
next_neural_mem_state = next_neural_mem_state._replace(updates = updates)
|
884
884
|
|
885
885
|
# retrieve
|
886
886
|
|
@@ -891,7 +891,8 @@ class NeuralMemory(Module):
|
|
891
891
|
retrieve_chunk_size = 1
|
892
892
|
need_pad = False
|
893
893
|
|
894
|
-
last_update, _ =
|
894
|
+
last_update, _ = next_neural_mem_state.states
|
895
|
+
|
895
896
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
896
897
|
|
897
898
|
retrieved = self.retrieve_memories(
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=K1z7wtv366Y6-eEyXMFZ_j7D2frPl5RxfSgxzFYoFMc,27704
|
6
|
+
titans_pytorch-0.3.1.dist-info/METADATA,sha256=ZAxucKq2DZBtW-BI_O2sUQ5RXy11a7eu48yPpwnanpw,6815
|
7
|
+
titans_pytorch-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.1.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=vFkyxO2WK-mEnf-KP3Za4oJ7_c1KrLHonCuelofJtFs,27754
|
6
|
-
titans_pytorch-0.2.28.dist-info/METADATA,sha256=HxSE1da0QnLMbJLpsO2-2wJYm_kFUmdezy2JCUEx6aU,6816
|
7
|
-
titans_pytorch-0.2.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|