titans-pytorch 0.3.23__tar.gz → 0.3.25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/PKG-INFO +1 -1
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/pyproject.toml +1 -1
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/tests/test_titans.py +13 -2
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/titans_pytorch/neural_memory.py +27 -11
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/.gitignore +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/LICENSE +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/README.md +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/data/README.md +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/fig1.png +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/fig2.png +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.3.23 → titans_pytorch-0.3.25}/train_mac.py +0 -0
@@ -29,10 +29,12 @@ def torch_default_dtype(dtype):
|
|
29
29
|
@pytest.mark.parametrize('chunk_size, attn_pool_chunks', ((64, True), (64, False), (1, False)))
|
30
30
|
@pytest.mark.parametrize('momentum', (False, True))
|
31
31
|
@pytest.mark.parametrize('qk_rmsnorm', (False, True))
|
32
|
+
@pytest.mark.parametrize('heads', (1, 4))
|
32
33
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
33
34
|
@pytest.mark.parametrize('num_kv_per_token', (1, 2))
|
34
35
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
35
36
|
@pytest.mark.parametrize('per_head_learned_parameters', (False, True))
|
37
|
+
@pytest.mark.parametrize('test_store_mask', (False, True))
|
36
38
|
def test_titans(
|
37
39
|
seq_len,
|
38
40
|
silu,
|
@@ -40,10 +42,12 @@ def test_titans(
|
|
40
42
|
chunk_size,
|
41
43
|
momentum,
|
42
44
|
qk_rmsnorm,
|
45
|
+
heads,
|
43
46
|
max_grad_norm,
|
44
47
|
num_kv_per_token,
|
45
48
|
per_parameter_lr_modulation,
|
46
|
-
per_head_learned_parameters
|
49
|
+
per_head_learned_parameters,
|
50
|
+
test_store_mask
|
47
51
|
):
|
48
52
|
mem = NeuralMemory(
|
49
53
|
dim = 16,
|
@@ -54,12 +58,19 @@ def test_titans(
|
|
54
58
|
num_kv_per_token = num_kv_per_token,
|
55
59
|
momentum = momentum,
|
56
60
|
qk_rmsnorm = qk_rmsnorm,
|
61
|
+
heads = heads,
|
57
62
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
58
63
|
per_head_learned_parameters = per_head_learned_parameters
|
59
64
|
)
|
60
65
|
|
61
66
|
seq = torch.randn(2, seq_len, 16)
|
62
|
-
|
67
|
+
|
68
|
+
store_mask = None
|
69
|
+
|
70
|
+
if test_store_mask:
|
71
|
+
store_mask = torch.randint(0, 2, (2, seq_len)).bool()
|
72
|
+
|
73
|
+
retrieved, _ = mem(seq, store_mask = store_mask)
|
63
74
|
|
64
75
|
assert seq.shape == retrieved.shape
|
65
76
|
|
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
|
|
289
289
|
self.heads = heads
|
290
290
|
|
291
291
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
292
|
+
self.split_kv_heads = Rearrange('b n (h u d) -> b h (n u) d', h = heads, u = num_kv_per_token)
|
293
|
+
|
292
294
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
293
295
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
294
296
|
|
@@ -522,7 +524,8 @@ class NeuralMemory(Module):
|
|
522
524
|
weights: dict[str, Tensor] | None = None,
|
523
525
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
524
526
|
seq_index = 0,
|
525
|
-
prev_weights = None
|
527
|
+
prev_weights = None,
|
528
|
+
mask: Tensor | None = None,
|
526
529
|
):
|
527
530
|
if self.qkv_receives_diff_views:
|
528
531
|
_, batch, seq_len = seq.shape[:3]
|
@@ -596,22 +599,28 @@ class NeuralMemory(Module):
|
|
596
599
|
|
597
600
|
# maybe multi head
|
598
601
|
|
599
|
-
keys, values = map(self.
|
600
|
-
|
601
|
-
batch = keys.shape[0]
|
602
|
+
keys, values = map(self.split_kv_heads, (keys, values))
|
602
603
|
|
603
|
-
#
|
604
|
+
# maybe keys rmsnorm
|
604
605
|
|
605
|
-
keys
|
606
|
+
keys = self.k_norm(keys)
|
606
607
|
|
607
|
-
#
|
608
|
+
# take care of chunking
|
608
609
|
|
609
|
-
keys =
|
610
|
+
keys, values = tuple(rearrange(t, 'b h (n c u) d -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
|
610
611
|
|
611
612
|
# adaptive lr
|
612
613
|
|
613
614
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
|
614
615
|
|
616
|
+
# optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions
|
617
|
+
|
618
|
+
if exists(mask):
|
619
|
+
mask = mask[..., :round_down_seq_len]
|
620
|
+
mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size)
|
621
|
+
|
622
|
+
adaptive_lr = torch.where(mask, adaptive_lr, 0.)
|
623
|
+
|
615
624
|
# maybe add previous layer weight
|
616
625
|
|
617
626
|
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
@@ -833,7 +842,8 @@ class NeuralMemory(Module):
|
|
833
842
|
seq,
|
834
843
|
store_seq = None,
|
835
844
|
state: NeuralMemState | None = None,
|
836
|
-
prev_weights = None
|
845
|
+
prev_weights = None,
|
846
|
+
store_mask: Tensor | None = None
|
837
847
|
):
|
838
848
|
is_multi_input = self.qkv_receives_diff_views
|
839
849
|
|
@@ -910,6 +920,11 @@ class NeuralMemory(Module):
|
|
910
920
|
|
911
921
|
store_seqs = store_seq.split(split_sizes, dim = -2)
|
912
922
|
|
923
|
+
if exists(store_mask):
|
924
|
+
store_masks = store_mask.split(split_sizes, dim = -1)
|
925
|
+
else:
|
926
|
+
store_masks = (None,) * len(split_sizes)
|
927
|
+
|
913
928
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
914
929
|
|
915
930
|
gate = None
|
@@ -917,7 +932,7 @@ class NeuralMemory(Module):
|
|
917
932
|
if exists(self.transition_gate):
|
918
933
|
gate = self.transition_gate.sigmoid()
|
919
934
|
|
920
|
-
for ind, store_seq_chunk in enumerate(store_seqs):
|
935
|
+
for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)):
|
921
936
|
is_last = ind == (len(store_seqs) - 1)
|
922
937
|
|
923
938
|
# store
|
@@ -927,7 +942,8 @@ class NeuralMemory(Module):
|
|
927
942
|
weights,
|
928
943
|
seq_index = seq_index,
|
929
944
|
past_state = past_state,
|
930
|
-
prev_weights = prev_weights
|
945
|
+
prev_weights = prev_weights,
|
946
|
+
mask = maybe_store_mask
|
931
947
|
)
|
932
948
|
|
933
949
|
weights = next_neural_mem_state.weights
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|