titans-pytorch 0.3.19__tar.gz → 0.3.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/PKG-INFO +1 -1
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/pyproject.toml +1 -1
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/tests/test_titans.py +3 -3
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/titans_pytorch/mac_transformer.py +5 -3
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/titans_pytorch/neural_memory.py +21 -13
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/.gitignore +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/LICENSE +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/README.md +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/data/README.md +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/fig1.png +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/fig2.png +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.3.19 → titans_pytorch-0.3.21}/train_mac.py +0 -0
@@ -200,7 +200,7 @@ def test_neural_mem_chaining_with_batch_size():
|
|
200
200
|
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
201
201
|
@pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
|
202
202
|
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
203
|
-
@pytest.mark.parametrize('
|
203
|
+
@pytest.mark.parametrize('neural_mem_qkv_receives_diff_views', (False, True))
|
204
204
|
@pytest.mark.parametrize('neural_mem_momentum', (False, True))
|
205
205
|
def test_mac(
|
206
206
|
seq_len,
|
@@ -210,7 +210,7 @@ def test_mac(
|
|
210
210
|
neural_mem_segment_len,
|
211
211
|
neural_mem_weight_residual,
|
212
212
|
neural_mem_batch_size,
|
213
|
-
|
213
|
+
neural_mem_qkv_receives_diff_views,
|
214
214
|
neural_mem_momentum
|
215
215
|
):
|
216
216
|
transformer = MemoryAsContextTransformer(
|
@@ -223,7 +223,7 @@ def test_mac(
|
|
223
223
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
224
224
|
neural_memory_segment_len = neural_mem_segment_len,
|
225
225
|
neural_memory_batch_size = neural_mem_batch_size,
|
226
|
-
|
226
|
+
neural_memory_qkv_receives_diff_views = neural_mem_qkv_receives_diff_views,
|
227
227
|
neural_mem_weight_residual = neural_mem_weight_residual,
|
228
228
|
neural_memory_kwargs = dict(
|
229
229
|
momentum = neural_mem_momentum
|
@@ -483,7 +483,7 @@ class MemoryAsContextTransformer(Module):
|
|
483
483
|
num_longterm_mem_tokens = 0,
|
484
484
|
num_persist_mem_tokens = 0,
|
485
485
|
neural_memory_batch_size = None,
|
486
|
-
|
486
|
+
neural_memory_qkv_receives_diff_views = False,
|
487
487
|
dim_head = 64,
|
488
488
|
heads = 8,
|
489
489
|
ff_mult = 4,
|
@@ -523,6 +523,8 @@ class MemoryAsContextTransformer(Module):
|
|
523
523
|
|
524
524
|
# hyper conection
|
525
525
|
|
526
|
+
assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
|
527
|
+
|
526
528
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
|
527
529
|
|
528
530
|
self.layers = ModuleList([])
|
@@ -561,14 +563,14 @@ class MemoryAsContextTransformer(Module):
|
|
561
563
|
mem_hyper_conn = None
|
562
564
|
|
563
565
|
if layer in neural_memory_layers:
|
564
|
-
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views =
|
566
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
|
565
567
|
|
566
568
|
mem = NeuralMemory(
|
567
569
|
dim = dim,
|
568
570
|
chunk_size = self.neural_memory_segment_len,
|
569
571
|
batch_size = neural_memory_batch_size,
|
570
572
|
model = deepcopy(neural_memory_model),
|
571
|
-
|
573
|
+
qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
|
572
574
|
accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
573
575
|
**neural_memory_kwargs
|
574
576
|
)
|
@@ -231,7 +231,7 @@ class NeuralMemory(Module):
|
|
231
231
|
momentum_order = 1,
|
232
232
|
learned_momentum_combine = False,
|
233
233
|
learned_combine_include_zeroth = False,
|
234
|
-
|
234
|
+
qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
|
235
235
|
pre_rmsnorm = True,
|
236
236
|
post_rmsnorm = False,
|
237
237
|
qk_rmsnorm = False,
|
@@ -268,7 +268,7 @@ class NeuralMemory(Module):
|
|
268
268
|
|
269
269
|
# key values receiving different views
|
270
270
|
|
271
|
-
self.
|
271
|
+
self.qkv_receives_diff_views = qkv_receives_diff_views
|
272
272
|
|
273
273
|
# norms
|
274
274
|
|
@@ -511,7 +511,7 @@ class NeuralMemory(Module):
|
|
511
511
|
seq_index = 0,
|
512
512
|
prev_weights = None
|
513
513
|
):
|
514
|
-
if self.
|
514
|
+
if self.qkv_receives_diff_views:
|
515
515
|
_, batch, seq_len = seq.shape[:3]
|
516
516
|
else:
|
517
517
|
batch, seq_len = seq.shape[:2]
|
@@ -550,7 +550,7 @@ class NeuralMemory(Module):
|
|
550
550
|
|
551
551
|
values_seq = seq
|
552
552
|
|
553
|
-
if self.
|
553
|
+
if self.qkv_receives_diff_views:
|
554
554
|
seq, values_seq = seq
|
555
555
|
|
556
556
|
# derive learned hparams for optimization of memory network
|
@@ -820,10 +820,23 @@ class NeuralMemory(Module):
|
|
820
820
|
state: NeuralMemState | None = None,
|
821
821
|
prev_weights = None
|
822
822
|
):
|
823
|
-
|
824
|
-
seq = rearrange(seq, 'b d -> b 1 d')
|
823
|
+
is_multi_input = self.qkv_receives_diff_views
|
825
824
|
|
826
|
-
|
825
|
+
# handle single token
|
826
|
+
|
827
|
+
if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
|
828
|
+
seq = rearrange(seq, '... b d -> ... b 1 d')
|
829
|
+
|
830
|
+
is_single_token = seq.shape[-2] == 1
|
831
|
+
|
832
|
+
# if different views for qkv, then
|
833
|
+
|
834
|
+
if is_multi_input:
|
835
|
+
retrieve_seq, seq = seq[0], seq[1:]
|
836
|
+
else:
|
837
|
+
retrieve_seq = seq
|
838
|
+
|
839
|
+
# handle previous state init
|
827
840
|
|
828
841
|
if not exists(state):
|
829
842
|
state = (0, None, None, None, None)
|
@@ -839,8 +852,6 @@ class NeuralMemory(Module):
|
|
839
852
|
if exists(cache_store_seq):
|
840
853
|
store_seq = safe_cat((cache_store_seq, store_seq))
|
841
854
|
|
842
|
-
# functions
|
843
|
-
|
844
855
|
# compute split sizes of sequence
|
845
856
|
# for now manually update weights to last update at the correct boundaries
|
846
857
|
|
@@ -939,11 +950,8 @@ class NeuralMemory(Module):
|
|
939
950
|
last_update, _ = next_neural_mem_state.states
|
940
951
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
941
952
|
|
942
|
-
if self.kv_receives_diff_views:
|
943
|
-
seq = seq[0]
|
944
|
-
|
945
953
|
retrieved = self.retrieve_memories(
|
946
|
-
|
954
|
+
retrieve_seq,
|
947
955
|
updates
|
948
956
|
)
|
949
957
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|