titans-pytorch 0.3.15__tar.gz → 0.3.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/PKG-INFO +2 -2
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/pyproject.toml +2 -2
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/tests/test_titans.py +3 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/titans_pytorch/mac_transformer.py +3 -1
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/titans_pytorch/neural_memory.py +46 -11
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/.gitignore +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/LICENSE +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/README.md +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/data/README.md +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/fig1.png +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/fig2.png +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.3.15 → titans_pytorch-0.3.20}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.20
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
41
|
+
Requires-Dist: hyper-connections>=0.1.11
|
42
42
|
Requires-Dist: ninja
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
44
44
|
Requires-Dist: tensordict
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "titans-pytorch"
|
3
|
-
version = "0.3.
|
3
|
+
version = "0.3.20"
|
4
4
|
description = "Titans"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -29,7 +29,7 @@ dependencies = [
|
|
29
29
|
"axial_positional_embedding>=0.3.10",
|
30
30
|
"einops>=0.8.0",
|
31
31
|
"einx>=0.3.0",
|
32
|
-
"hyper-connections>=0.1.
|
32
|
+
"hyper-connections>=0.1.11",
|
33
33
|
"Ninja",
|
34
34
|
"rotary-embedding-torch",
|
35
35
|
"tensordict",
|
@@ -200,6 +200,7 @@ def test_neural_mem_chaining_with_batch_size():
|
|
200
200
|
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
201
201
|
@pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
|
202
202
|
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
203
|
+
@pytest.mark.parametrize('neural_mem_qkv_receives_diff_views', (False, True))
|
203
204
|
@pytest.mark.parametrize('neural_mem_momentum', (False, True))
|
204
205
|
def test_mac(
|
205
206
|
seq_len,
|
@@ -209,6 +210,7 @@ def test_mac(
|
|
209
210
|
neural_mem_segment_len,
|
210
211
|
neural_mem_weight_residual,
|
211
212
|
neural_mem_batch_size,
|
213
|
+
neural_mem_qkv_receives_diff_views,
|
212
214
|
neural_mem_momentum
|
213
215
|
):
|
214
216
|
transformer = MemoryAsContextTransformer(
|
@@ -221,6 +223,7 @@ def test_mac(
|
|
221
223
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
222
224
|
neural_memory_segment_len = neural_mem_segment_len,
|
223
225
|
neural_memory_batch_size = neural_mem_batch_size,
|
226
|
+
neural_memory_qkv_receives_diff_views = neural_mem_qkv_receives_diff_views,
|
224
227
|
neural_mem_weight_residual = neural_mem_weight_residual,
|
225
228
|
neural_memory_kwargs = dict(
|
226
229
|
momentum = neural_mem_momentum
|
@@ -483,6 +483,7 @@ class MemoryAsContextTransformer(Module):
|
|
483
483
|
num_longterm_mem_tokens = 0,
|
484
484
|
num_persist_mem_tokens = 0,
|
485
485
|
neural_memory_batch_size = None,
|
486
|
+
neural_memory_qkv_receives_diff_views = False,
|
486
487
|
dim_head = 64,
|
487
488
|
heads = 8,
|
488
489
|
ff_mult = 4,
|
@@ -560,13 +561,14 @@ class MemoryAsContextTransformer(Module):
|
|
560
561
|
mem_hyper_conn = None
|
561
562
|
|
562
563
|
if layer in neural_memory_layers:
|
563
|
-
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
564
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
|
564
565
|
|
565
566
|
mem = NeuralMemory(
|
566
567
|
dim = dim,
|
567
568
|
chunk_size = self.neural_memory_segment_len,
|
568
569
|
batch_size = neural_memory_batch_size,
|
569
570
|
model = deepcopy(neural_memory_model),
|
571
|
+
qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
|
570
572
|
accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
571
573
|
**neural_memory_kwargs
|
572
574
|
)
|
@@ -231,6 +231,7 @@ class NeuralMemory(Module):
|
|
231
231
|
momentum_order = 1,
|
232
232
|
learned_momentum_combine = False,
|
233
233
|
learned_combine_include_zeroth = False,
|
234
|
+
qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
|
234
235
|
pre_rmsnorm = True,
|
235
236
|
post_rmsnorm = False,
|
236
237
|
qk_rmsnorm = False,
|
@@ -265,6 +266,10 @@ class NeuralMemory(Module):
|
|
265
266
|
|
266
267
|
self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
|
267
268
|
|
269
|
+
# key values receiving different views
|
270
|
+
|
271
|
+
self.qkv_receives_diff_views = qkv_receives_diff_views
|
272
|
+
|
268
273
|
# norms
|
269
274
|
|
270
275
|
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
@@ -358,7 +363,9 @@ class NeuralMemory(Module):
|
|
358
363
|
|
359
364
|
# keys and values for storing to the model
|
360
365
|
|
361
|
-
self.
|
366
|
+
self.to_keys = Sequential(LinearNoBias(dim, dim_inner), activation)
|
367
|
+
self.to_values = Sequential(LinearNoBias(dim, dim_inner), activation)
|
368
|
+
|
362
369
|
self.store_memory_loss_fn = store_memory_loss_fn
|
363
370
|
|
364
371
|
# `chunk_size` refers to chunk size used for storing to memory model weights
|
@@ -504,7 +511,14 @@ class NeuralMemory(Module):
|
|
504
511
|
seq_index = 0,
|
505
512
|
prev_weights = None
|
506
513
|
):
|
507
|
-
|
514
|
+
if self.qkv_receives_diff_views:
|
515
|
+
_, batch, seq_len = seq.shape[:3]
|
516
|
+
else:
|
517
|
+
batch, seq_len = seq.shape[:2]
|
518
|
+
|
519
|
+
# shapes and variables
|
520
|
+
|
521
|
+
heads, chunk_size = self.heads, self.store_chunk_size
|
508
522
|
|
509
523
|
# curtail sequence by multiple of the chunk size
|
510
524
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
@@ -512,7 +526,7 @@ class NeuralMemory(Module):
|
|
512
526
|
round_down_seq_len = round_down_multiple(seq_len, chunk_size)
|
513
527
|
num_chunks = round_down_seq_len // chunk_size
|
514
528
|
|
515
|
-
seq, remainder = seq[
|
529
|
+
seq, remainder = seq[..., :round_down_seq_len, :], seq[..., round_down_seq_len:, :]
|
516
530
|
|
517
531
|
next_seq_len_index = seq_index + round_down_seq_len
|
518
532
|
|
@@ -528,10 +542,19 @@ class NeuralMemory(Module):
|
|
528
542
|
|
529
543
|
weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
|
530
544
|
|
531
|
-
#
|
545
|
+
# initial norm
|
532
546
|
|
533
547
|
seq = self.store_norm(seq)
|
534
548
|
|
549
|
+
# handle keys and values coming from different sequences from hyper connection
|
550
|
+
|
551
|
+
values_seq = seq
|
552
|
+
|
553
|
+
if self.qkv_receives_diff_views:
|
554
|
+
seq, values_seq = seq
|
555
|
+
|
556
|
+
# derive learned hparams for optimization of memory network
|
557
|
+
|
535
558
|
adaptive_lr = self.to_adaptive_step(seq)
|
536
559
|
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
537
560
|
|
@@ -555,7 +578,8 @@ class NeuralMemory(Module):
|
|
555
578
|
|
556
579
|
# keys and values
|
557
580
|
|
558
|
-
keys
|
581
|
+
keys = self.to_keys(seq)
|
582
|
+
values = self.to_values(values_seq)
|
559
583
|
|
560
584
|
# maybe multi head
|
561
585
|
|
@@ -796,10 +820,23 @@ class NeuralMemory(Module):
|
|
796
820
|
state: NeuralMemState | None = None,
|
797
821
|
prev_weights = None
|
798
822
|
):
|
799
|
-
|
800
|
-
|
823
|
+
is_multi_input = self.qkv_receives_diff_views
|
824
|
+
|
825
|
+
# handle single token
|
826
|
+
|
827
|
+
if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
|
828
|
+
seq = rearrange(seq, '... b d -> ... b 1 d')
|
829
|
+
|
830
|
+
is_single_token = seq.shape[-2] == 1
|
801
831
|
|
802
|
-
|
832
|
+
# if different views for qkv, then
|
833
|
+
|
834
|
+
if is_multi_input:
|
835
|
+
retrieve_seq, seq = seq[0], seq[1:]
|
836
|
+
else:
|
837
|
+
retrieve_seq = seq
|
838
|
+
|
839
|
+
# handle previous state init
|
803
840
|
|
804
841
|
if not exists(state):
|
805
842
|
state = (0, None, None, None, None)
|
@@ -815,8 +852,6 @@ class NeuralMemory(Module):
|
|
815
852
|
if exists(cache_store_seq):
|
816
853
|
store_seq = safe_cat((cache_store_seq, store_seq))
|
817
854
|
|
818
|
-
# functions
|
819
|
-
|
820
855
|
# compute split sizes of sequence
|
821
856
|
# for now manually update weights to last update at the correct boundaries
|
822
857
|
|
@@ -916,7 +951,7 @@ class NeuralMemory(Module):
|
|
916
951
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
917
952
|
|
918
953
|
retrieved = self.retrieve_memories(
|
919
|
-
|
954
|
+
retrieve_seq,
|
920
955
|
updates
|
921
956
|
)
|
922
957
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|