titans-pytorch 0.2.15__tar.gz → 0.2.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/PKG-INFO +1 -1
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/pyproject.toml +1 -1
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/tests/test_titans.py +28 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/titans_pytorch/mac_transformer.py +11 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/titans_pytorch/neural_memory.py +15 -1
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/train_mac.py +3 -1
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/.gitignore +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/LICENSE +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/README.md +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/data/README.md +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/fig1.png +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/fig2.png +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.15 → titans_pytorch-0.2.17}/titans_pytorch/memory_models.py +0 -0
@@ -93,6 +93,34 @@ def test_neural_mem_chaining_chunks():
|
|
93
93
|
|
94
94
|
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
|
95
95
|
|
96
|
+
def test_neural_mem_chaining_with_weight_residual():
|
97
|
+
mem = NeuralMemory(
|
98
|
+
dim = 384,
|
99
|
+
dim_head = 64,
|
100
|
+
heads = 2,
|
101
|
+
chunk_size = 64
|
102
|
+
)
|
103
|
+
|
104
|
+
mem2 = NeuralMemory(
|
105
|
+
dim = 384,
|
106
|
+
dim_head = 64,
|
107
|
+
heads = 2,
|
108
|
+
chunk_size = 64
|
109
|
+
)
|
110
|
+
|
111
|
+
seq = torch.randn(2, 256, 384)
|
112
|
+
|
113
|
+
seq, state = mem(seq)
|
114
|
+
|
115
|
+
parallel_retrieved, _ = mem2(seq, prev_weights = state.updates)
|
116
|
+
|
117
|
+
seq_first, seq_second = seq[:, :128], seq[:, 128:]
|
118
|
+
|
119
|
+
first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
|
120
|
+
second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
|
121
|
+
|
122
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
|
123
|
+
|
96
124
|
def test_neural_mem_chaining_with_batch_size():
|
97
125
|
mem = NeuralMemory(
|
98
126
|
dim = 384,
|
@@ -491,6 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
+
neural_mem_weight_residual = False
|
494
495
|
):
|
495
496
|
super().__init__()
|
496
497
|
|
@@ -524,6 +525,8 @@ class MemoryAsContextTransformer(Module):
|
|
524
525
|
|
525
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
526
527
|
|
528
|
+
self.neural_mem_weight_residual = neural_mem_weight_residual
|
529
|
+
|
527
530
|
# mem, attn, and feedforward layers
|
528
531
|
|
529
532
|
for layer in layers:
|
@@ -739,6 +742,10 @@ class MemoryAsContextTransformer(Module):
|
|
739
742
|
|
740
743
|
value_residual = None
|
741
744
|
|
745
|
+
# neural mem weight residual
|
746
|
+
|
747
|
+
mem_weight_residual = None
|
748
|
+
|
742
749
|
# when inferencing, only do one token at a time
|
743
750
|
|
744
751
|
if is_inferencing:
|
@@ -764,8 +771,12 @@ class MemoryAsContextTransformer(Module):
|
|
764
771
|
retrieved, next_neural_mem_cache = mem.forward(
|
765
772
|
mem_input,
|
766
773
|
state = next(neural_mem_caches, None),
|
774
|
+
prev_weights = mem_weight_residual
|
767
775
|
)
|
768
776
|
|
777
|
+
if self.neural_mem_weight_residual:
|
778
|
+
mem_weight_residual = next_neural_mem_cache.updates
|
779
|
+
|
769
780
|
if self.gate_attn_output:
|
770
781
|
attn_out_gates = retrieved.sigmoid()
|
771
782
|
else:
|
@@ -494,7 +494,8 @@ class NeuralMemory(Module):
|
|
494
494
|
seq,
|
495
495
|
weights: dict[str, Tensor] | None = None,
|
496
496
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
497
|
-
seq_index = 0
|
497
|
+
seq_index = 0,
|
498
|
+
prev_weights = None
|
498
499
|
):
|
499
500
|
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
|
500
501
|
|
@@ -560,6 +561,17 @@ class NeuralMemory(Module):
|
|
560
561
|
|
561
562
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
562
563
|
|
564
|
+
# maybe add previous layer weight
|
565
|
+
|
566
|
+
if exists(prev_weights):
|
567
|
+
|
568
|
+
start_index = math.ceil(seq_index / chunk_size)
|
569
|
+
end_index = start_index + num_chunks
|
570
|
+
|
571
|
+
prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
|
572
|
+
|
573
|
+
weights_for_surprise = weights_for_surprise + prev_weights
|
574
|
+
|
563
575
|
# flatten batch and time if surprise depends on previous layer memory model
|
564
576
|
|
565
577
|
weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
|
@@ -732,6 +744,7 @@ class NeuralMemory(Module):
|
|
732
744
|
seq,
|
733
745
|
store_seq = None,
|
734
746
|
state: NeuralMemCache | None = None,
|
747
|
+
prev_weights = None
|
735
748
|
):
|
736
749
|
if seq.ndim == 2:
|
737
750
|
seq = rearrange(seq, 'b d -> b 1 d')
|
@@ -807,6 +820,7 @@ class NeuralMemory(Module):
|
|
807
820
|
weights,
|
808
821
|
seq_index = seq_index,
|
809
822
|
past_state = past_state,
|
823
|
+
prev_weights = prev_weights
|
810
824
|
)
|
811
825
|
|
812
826
|
seq_index = next_neural_mem_state.seq_index
|
@@ -36,8 +36,9 @@ NEURAL_MEM_MOMENTUM = True
|
|
36
36
|
NEURAL_MEM_QK_NORM = True
|
37
37
|
NEURAL_MEM_MAX_LR = 1e-1
|
38
38
|
WINDOW_SIZE = 32
|
39
|
-
NEURAL_MEM_SEGMENT_LEN =
|
39
|
+
NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
|
40
40
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
41
|
+
NEURAL_MEM_WEIGHT_RESIDUAL = False
|
41
42
|
SLIDING_WINDOWS = True
|
42
43
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
43
44
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
@@ -87,6 +88,7 @@ model = MemoryAsContextTransformer(
|
|
87
88
|
neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
|
88
89
|
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
89
90
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
91
|
+
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
|
90
92
|
use_flex_attn = USE_FLEX_ATTN,
|
91
93
|
sliding_window_attn = SLIDING_WINDOWS,
|
92
94
|
neural_memory_model = MemoryMLP(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|