titans-pytorch 0.3.19__py3-none-any.whl → 0.3.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +3 -3
- titans_pytorch/neural_memory.py +21 -13
- {titans_pytorch-0.3.19.dist-info → titans_pytorch-0.3.20.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.20.dist-info/RECORD +9 -0
- titans_pytorch-0.3.19.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.19.dist-info → titans_pytorch-0.3.20.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.19.dist-info → titans_pytorch-0.3.20.dist-info}/licenses/LICENSE +0 -0
@@ -483,7 +483,7 @@ class MemoryAsContextTransformer(Module):
|
|
483
483
|
num_longterm_mem_tokens = 0,
|
484
484
|
num_persist_mem_tokens = 0,
|
485
485
|
neural_memory_batch_size = None,
|
486
|
-
|
486
|
+
neural_memory_qkv_receives_diff_views = False,
|
487
487
|
dim_head = 64,
|
488
488
|
heads = 8,
|
489
489
|
ff_mult = 4,
|
@@ -561,14 +561,14 @@ class MemoryAsContextTransformer(Module):
|
|
561
561
|
mem_hyper_conn = None
|
562
562
|
|
563
563
|
if layer in neural_memory_layers:
|
564
|
-
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views =
|
564
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
|
565
565
|
|
566
566
|
mem = NeuralMemory(
|
567
567
|
dim = dim,
|
568
568
|
chunk_size = self.neural_memory_segment_len,
|
569
569
|
batch_size = neural_memory_batch_size,
|
570
570
|
model = deepcopy(neural_memory_model),
|
571
|
-
|
571
|
+
qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
|
572
572
|
accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
573
573
|
**neural_memory_kwargs
|
574
574
|
)
|
titans_pytorch/neural_memory.py
CHANGED
@@ -231,7 +231,7 @@ class NeuralMemory(Module):
|
|
231
231
|
momentum_order = 1,
|
232
232
|
learned_momentum_combine = False,
|
233
233
|
learned_combine_include_zeroth = False,
|
234
|
-
|
234
|
+
qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
|
235
235
|
pre_rmsnorm = True,
|
236
236
|
post_rmsnorm = False,
|
237
237
|
qk_rmsnorm = False,
|
@@ -268,7 +268,7 @@ class NeuralMemory(Module):
|
|
268
268
|
|
269
269
|
# key values receiving different views
|
270
270
|
|
271
|
-
self.
|
271
|
+
self.qkv_receives_diff_views = qkv_receives_diff_views
|
272
272
|
|
273
273
|
# norms
|
274
274
|
|
@@ -511,7 +511,7 @@ class NeuralMemory(Module):
|
|
511
511
|
seq_index = 0,
|
512
512
|
prev_weights = None
|
513
513
|
):
|
514
|
-
if self.
|
514
|
+
if self.qkv_receives_diff_views:
|
515
515
|
_, batch, seq_len = seq.shape[:3]
|
516
516
|
else:
|
517
517
|
batch, seq_len = seq.shape[:2]
|
@@ -550,7 +550,7 @@ class NeuralMemory(Module):
|
|
550
550
|
|
551
551
|
values_seq = seq
|
552
552
|
|
553
|
-
if self.
|
553
|
+
if self.qkv_receives_diff_views:
|
554
554
|
seq, values_seq = seq
|
555
555
|
|
556
556
|
# derive learned hparams for optimization of memory network
|
@@ -820,10 +820,23 @@ class NeuralMemory(Module):
|
|
820
820
|
state: NeuralMemState | None = None,
|
821
821
|
prev_weights = None
|
822
822
|
):
|
823
|
-
|
824
|
-
seq = rearrange(seq, 'b d -> b 1 d')
|
823
|
+
is_multi_input = self.qkv_receives_diff_views
|
825
824
|
|
826
|
-
|
825
|
+
# handle single token
|
826
|
+
|
827
|
+
if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
|
828
|
+
seq = rearrange(seq, '... b d -> ... b 1 d')
|
829
|
+
|
830
|
+
is_single_token = seq.shape[-2] == 1
|
831
|
+
|
832
|
+
# if different views for qkv, then
|
833
|
+
|
834
|
+
if is_multi_input:
|
835
|
+
retrieve_seq, seq = seq[0], seq[1:]
|
836
|
+
else:
|
837
|
+
retrieve_seq = seq
|
838
|
+
|
839
|
+
# handle previous state init
|
827
840
|
|
828
841
|
if not exists(state):
|
829
842
|
state = (0, None, None, None, None)
|
@@ -839,8 +852,6 @@ class NeuralMemory(Module):
|
|
839
852
|
if exists(cache_store_seq):
|
840
853
|
store_seq = safe_cat((cache_store_seq, store_seq))
|
841
854
|
|
842
|
-
# functions
|
843
|
-
|
844
855
|
# compute split sizes of sequence
|
845
856
|
# for now manually update weights to last update at the correct boundaries
|
846
857
|
|
@@ -939,11 +950,8 @@ class NeuralMemory(Module):
|
|
939
950
|
last_update, _ = next_neural_mem_state.states
|
940
951
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
941
952
|
|
942
|
-
if self.kv_receives_diff_views:
|
943
|
-
seq = seq[0]
|
944
|
-
|
945
953
|
retrieved = self.retrieve_memories(
|
946
|
-
|
954
|
+
retrieve_seq,
|
947
955
|
updates
|
948
956
|
)
|
949
957
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=M45tLp_CzYVm9R33zkUHmWdzIuaNxNYQLCTAtyMechg,25294
|
4
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
+
titans_pytorch/neural_memory.py,sha256=hLMDbLosxgQZvvfuJM3X7w23npcoL9AWBQfbtQHOSiA,30637
|
6
|
+
titans_pytorch-0.3.20.dist-info/METADATA,sha256=JILf0r0bT4KXAZD4V2nmYRIr5Y7bGGjjNlpfIuT0UNs,6817
|
7
|
+
titans_pytorch-0.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.20.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=4N8WcoqPYNhMGGQAZjpm-djVsLnU7VADH_l06qFPuOk,25290
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=KKj8r3oRoDYf-vFQFnI4Rd4DfMH1f5QFs4vdcS35og8,30374
|
6
|
-
titans_pytorch-0.3.19.dist-info/METADATA,sha256=5h4f5gsO1emX5LEwe8cgpH35Rtjydz2UTlpK4DKSntI,6817
|
7
|
-
titans_pytorch-0.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|