titans-pytorch 0.3.24__py3-none-any.whl → 0.3.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +20 -4
- {titans_pytorch-0.3.24.dist-info → titans_pytorch-0.3.25.dist-info}/METADATA +1 -1
- {titans_pytorch-0.3.24.dist-info → titans_pytorch-0.3.25.dist-info}/RECORD +5 -5
- {titans_pytorch-0.3.24.dist-info → titans_pytorch-0.3.25.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.24.dist-info → titans_pytorch-0.3.25.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -524,7 +524,8 @@ class NeuralMemory(Module):
|
|
524
524
|
weights: dict[str, Tensor] | None = None,
|
525
525
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
526
526
|
seq_index = 0,
|
527
|
-
prev_weights = None
|
527
|
+
prev_weights = None,
|
528
|
+
mask: Tensor | None = None,
|
528
529
|
):
|
529
530
|
if self.qkv_receives_diff_views:
|
530
531
|
_, batch, seq_len = seq.shape[:3]
|
@@ -612,6 +613,14 @@ class NeuralMemory(Module):
|
|
612
613
|
|
613
614
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
|
614
615
|
|
616
|
+
# optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions
|
617
|
+
|
618
|
+
if exists(mask):
|
619
|
+
mask = mask[..., :round_down_seq_len]
|
620
|
+
mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size)
|
621
|
+
|
622
|
+
adaptive_lr = torch.where(mask, adaptive_lr, 0.)
|
623
|
+
|
615
624
|
# maybe add previous layer weight
|
616
625
|
|
617
626
|
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
@@ -833,7 +842,8 @@ class NeuralMemory(Module):
|
|
833
842
|
seq,
|
834
843
|
store_seq = None,
|
835
844
|
state: NeuralMemState | None = None,
|
836
|
-
prev_weights = None
|
845
|
+
prev_weights = None,
|
846
|
+
store_mask: Tensor | None = None
|
837
847
|
):
|
838
848
|
is_multi_input = self.qkv_receives_diff_views
|
839
849
|
|
@@ -910,6 +920,11 @@ class NeuralMemory(Module):
|
|
910
920
|
|
911
921
|
store_seqs = store_seq.split(split_sizes, dim = -2)
|
912
922
|
|
923
|
+
if exists(store_mask):
|
924
|
+
store_masks = store_mask.split(split_sizes, dim = -1)
|
925
|
+
else:
|
926
|
+
store_masks = (None,) * len(split_sizes)
|
927
|
+
|
913
928
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
914
929
|
|
915
930
|
gate = None
|
@@ -917,7 +932,7 @@ class NeuralMemory(Module):
|
|
917
932
|
if exists(self.transition_gate):
|
918
933
|
gate = self.transition_gate.sigmoid()
|
919
934
|
|
920
|
-
for ind, store_seq_chunk in enumerate(store_seqs):
|
935
|
+
for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)):
|
921
936
|
is_last = ind == (len(store_seqs) - 1)
|
922
937
|
|
923
938
|
# store
|
@@ -927,7 +942,8 @@ class NeuralMemory(Module):
|
|
927
942
|
weights,
|
928
943
|
seq_index = seq_index,
|
929
944
|
past_state = past_state,
|
930
|
-
prev_weights = prev_weights
|
945
|
+
prev_weights = prev_weights,
|
946
|
+
mask = maybe_store_mask
|
931
947
|
)
|
932
948
|
|
933
949
|
weights = next_neural_mem_state.weights
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,29
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=grD327B3OCIy7d23jNUWIoUo1bIgXUqD26dXWCjdi28,25565
|
4
4
|
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.3.
|
7
|
-
titans_pytorch-0.3.
|
8
|
-
titans_pytorch-0.3.
|
9
|
-
titans_pytorch-0.3.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=uh5NbtAAzfPeZPFe7uhgnpUF6qyP0zjP0eXPIgY5pfc,31929
|
6
|
+
titans_pytorch-0.3.25.dist-info/METADATA,sha256=SZwazbaNFe1GstoF45zI_aNMpzgXAqv4mOh78gMN5-U,6817
|
7
|
+
titans_pytorch-0.3.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.25.dist-info/RECORD,,
|
File without changes
|
File without changes
|