titans-pytorch 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +2 -2
- {titans_pytorch-0.2.24.dist-info → titans_pytorch-0.2.25.dist-info}/METADATA +1 -1
- {titans_pytorch-0.2.24.dist-info → titans_pytorch-0.2.25.dist-info}/RECORD +5 -5
- {titans_pytorch-0.2.24.dist-info → titans_pytorch-0.2.25.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.24.dist-info → titans_pytorch-0.2.25.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -858,6 +858,7 @@ class NeuralMemory(Module):
|
|
858
858
|
prev_weights = prev_weights
|
859
859
|
)
|
860
860
|
|
861
|
+
weights = next_neural_mem_state.weights
|
861
862
|
seq_index = next_neural_mem_state.seq_index
|
862
863
|
past_state = next_neural_mem_state.states
|
863
864
|
|
@@ -871,8 +872,7 @@ class NeuralMemory(Module):
|
|
871
872
|
last_update, _ = past_state
|
872
873
|
|
873
874
|
if exists(gate):
|
874
|
-
|
875
|
-
weights = TensorDict({param_name: v1.lerp(v2, gate) for (param_name, v1), (_, v2) in zip(curr_weights.items(), last_update.items())})
|
875
|
+
weights = TensorDict({param_name: v1.lerp(v2, gate) for (param_name, v1), (_, v2) in zip(weights.items(), last_update.items())})
|
876
876
|
else:
|
877
877
|
weights = last_update
|
878
878
|
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,27
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
4
|
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.2.
|
7
|
-
titans_pytorch-0.2.
|
8
|
-
titans_pytorch-0.2.
|
9
|
-
titans_pytorch-0.2.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=D-xP4Qh4A1RlCyB8PPzhE48TugutFVyR8IXENWaGfr8,27563
|
6
|
+
titans_pytorch-0.2.25.dist-info/METADATA,sha256=3D6otenDwMzkguATdtZBRq-Z6imPWWv-Th3AeGSfpzU,6816
|
7
|
+
titans_pytorch-0.2.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|