titans-pytorch 0.2.27__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +15 -13
- {titans_pytorch-0.2.27.dist-info → titans_pytorch-0.3.0.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.0.dist-info/RECORD +9 -0
- titans_pytorch-0.2.27.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.27.dist-info → titans_pytorch-0.3.0.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.27.dist-info → titans_pytorch-0.3.0.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -66,12 +66,6 @@ def xnor(x, y):
|
|
66
66
|
def divisible_by(num, den):
|
67
67
|
return (num % den) == 0
|
68
68
|
|
69
|
-
def tuple_index_set(t: tuple, index, value):
|
70
|
-
klass = type(t)
|
71
|
-
t = list(t)
|
72
|
-
t[index] = value
|
73
|
-
return klass(*t)
|
74
|
-
|
75
69
|
def safe_cat(inputs, dim = -2):
|
76
70
|
inputs = tuple(filter(exists, inputs))
|
77
71
|
|
@@ -869,16 +863,23 @@ class NeuralMemory(Module):
|
|
869
863
|
|
870
864
|
# update weights once batch size is fulfilled
|
871
865
|
|
872
|
-
last_update,
|
866
|
+
last_update, last_momentum = past_state
|
873
867
|
|
874
868
|
if exists(gate):
|
875
|
-
|
876
|
-
|
877
|
-
|
869
|
+
last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())})
|
870
|
+
|
871
|
+
past_state = (last_update, last_momentum)
|
872
|
+
|
873
|
+
# set weights to the last updated weights for the last minibatch
|
878
874
|
|
879
|
-
|
875
|
+
weights = last_update
|
880
876
|
|
881
|
-
|
877
|
+
next_neural_mem_state = next_neural_mem_state._replace(
|
878
|
+
weights = weights,
|
879
|
+
states = past_state,
|
880
|
+
)
|
881
|
+
|
882
|
+
next_neural_mem_state = next_neural_mem_state._replace(updates = updates)
|
882
883
|
|
883
884
|
# retrieve
|
884
885
|
|
@@ -889,7 +890,8 @@ class NeuralMemory(Module):
|
|
889
890
|
retrieve_chunk_size = 1
|
890
891
|
need_pad = False
|
891
892
|
|
892
|
-
last_update, _ =
|
893
|
+
last_update, _ = next_neural_mem_state.states
|
894
|
+
|
893
895
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
894
896
|
|
895
897
|
retrieved = self.retrieve_memories(
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=NICWfFySGeduH4-Wgn6UavWl18pddIL1JrtpT_dhXDw,27689
|
6
|
+
titans_pytorch-0.3.0.dist-info/METADATA,sha256=i-tiJAahxkvzYxHOWLiwP2z8DyEPq6mFgSQ_ThvpN9A,6815
|
7
|
+
titans_pytorch-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.0.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=nAO54CwSd2m0NlrRM0zWbWmpaX9dqn2Xyh_qfVsTQXI,27601
|
6
|
-
titans_pytorch-0.2.27.dist-info/METADATA,sha256=ok3KssBVk0w-cKRw7JZOGJowVa8A88CRVBo0MVEV3Sc,6816
|
7
|
-
titans_pytorch-0.2.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|