titans-pytorch 0.2.28__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +14 -14
- {titans_pytorch-0.2.28.dist-info → titans_pytorch-0.3.0.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.0.dist-info/RECORD +9 -0
- titans_pytorch-0.2.28.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.28.dist-info → titans_pytorch-0.3.0.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.28.dist-info → titans_pytorch-0.3.0.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -66,12 +66,6 @@ def xnor(x, y):
|
|
66
66
|
def divisible_by(num, den):
|
67
67
|
return (num % den) == 0
|
68
68
|
|
69
|
-
def tuple_index_set(t: tuple, index, value):
|
70
|
-
klass = type(t)
|
71
|
-
t = list(t)
|
72
|
-
t[index] = value
|
73
|
-
return klass(*t)
|
74
|
-
|
75
69
|
def safe_cat(inputs, dim = -2):
|
76
70
|
inputs = tuple(filter(exists, inputs))
|
77
71
|
|
@@ -872,15 +866,20 @@ class NeuralMemory(Module):
|
|
872
866
|
last_update, last_momentum = past_state
|
873
867
|
|
874
868
|
if exists(gate):
|
875
|
-
|
876
|
-
|
877
|
-
|
869
|
+
last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())})
|
870
|
+
|
871
|
+
past_state = (last_update, last_momentum)
|
872
|
+
|
873
|
+
# set weights to the last updated weights for the last minibatch
|
878
874
|
|
879
|
-
|
880
|
-
next_neural_mem_state = tuple_index_set(next_neural_mem_state, -2, past_state)
|
881
|
-
next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, weights)
|
875
|
+
weights = last_update
|
882
876
|
|
883
|
-
|
877
|
+
next_neural_mem_state = next_neural_mem_state._replace(
|
878
|
+
weights = weights,
|
879
|
+
states = past_state,
|
880
|
+
)
|
881
|
+
|
882
|
+
next_neural_mem_state = next_neural_mem_state._replace(updates = updates)
|
884
883
|
|
885
884
|
# retrieve
|
886
885
|
|
@@ -891,7 +890,8 @@ class NeuralMemory(Module):
|
|
891
890
|
retrieve_chunk_size = 1
|
892
891
|
need_pad = False
|
893
892
|
|
894
|
-
last_update, _ =
|
893
|
+
last_update, _ = next_neural_mem_state.states
|
894
|
+
|
895
895
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
896
896
|
|
897
897
|
retrieved = self.retrieve_memories(
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=NICWfFySGeduH4-Wgn6UavWl18pddIL1JrtpT_dhXDw,27689
|
6
|
+
titans_pytorch-0.3.0.dist-info/METADATA,sha256=i-tiJAahxkvzYxHOWLiwP2z8DyEPq6mFgSQ_ThvpN9A,6815
|
7
|
+
titans_pytorch-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.0.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=vFkyxO2WK-mEnf-KP3Za4oJ7_c1KrLHonCuelofJtFs,27754
|
6
|
-
titans_pytorch-0.2.28.dist-info/METADATA,sha256=HxSE1da0QnLMbJLpsO2-2wJYm_kFUmdezy2JCUEx6aU,6816
|
7
|
-
titans_pytorch-0.2.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|