titans-pytorch 0.2.16__tar.gz → 0.2.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/PKG-INFO +1 -1
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/pyproject.toml +1 -1
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/tests/test_titans.py +32 -1
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/titans_pytorch/mac_transformer.py +7 -1
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/titans_pytorch/neural_memory.py +23 -1
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/train_mac.py +1 -1
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/.gitignore +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/LICENSE +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/README.md +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/data/README.md +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/fig1.png +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/fig2.png +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.16 → titans_pytorch-0.2.18}/titans_pytorch/memory_models.py +0 -0
@@ -93,6 +93,34 @@ def test_neural_mem_chaining_chunks():
|
|
93
93
|
|
94
94
|
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
|
95
95
|
|
96
|
+
def test_neural_mem_chaining_with_weight_residual():
|
97
|
+
mem = NeuralMemory(
|
98
|
+
dim = 384,
|
99
|
+
dim_head = 64,
|
100
|
+
heads = 2,
|
101
|
+
chunk_size = 64
|
102
|
+
)
|
103
|
+
|
104
|
+
mem2 = NeuralMemory(
|
105
|
+
dim = 384,
|
106
|
+
dim_head = 64,
|
107
|
+
heads = 2,
|
108
|
+
chunk_size = 64
|
109
|
+
)
|
110
|
+
|
111
|
+
seq = torch.randn(2, 256, 384)
|
112
|
+
|
113
|
+
seq, state = mem(seq)
|
114
|
+
|
115
|
+
parallel_retrieved, _ = mem2(seq, prev_weights = state.updates)
|
116
|
+
|
117
|
+
seq_first, seq_second = seq[:, :128], seq[:, 128:]
|
118
|
+
|
119
|
+
first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
|
120
|
+
second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
|
121
|
+
|
122
|
+
assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
|
123
|
+
|
96
124
|
def test_neural_mem_chaining_with_batch_size():
|
97
125
|
mem = NeuralMemory(
|
98
126
|
dim = 384,
|
@@ -121,6 +149,7 @@ def test_neural_mem_chaining_with_batch_size():
|
|
121
149
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
122
150
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
123
151
|
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
152
|
+
@pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
|
124
153
|
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
125
154
|
def test_mac(
|
126
155
|
seq_len,
|
@@ -128,7 +157,8 @@ def test_mac(
|
|
128
157
|
num_longterm_mem_tokens,
|
129
158
|
neural_mem_gate_attn_output,
|
130
159
|
neural_mem_segment_len,
|
131
|
-
|
160
|
+
neural_mem_weight_residual,
|
161
|
+
neural_mem_batch_size,
|
132
162
|
):
|
133
163
|
transformer = MemoryAsContextTransformer(
|
134
164
|
num_tokens = 256,
|
@@ -140,6 +170,7 @@ def test_mac(
|
|
140
170
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
141
171
|
neural_memory_segment_len = neural_mem_segment_len,
|
142
172
|
neural_memory_batch_size = neural_mem_batch_size,
|
173
|
+
neural_mem_weight_residual = neural_mem_weight_residual
|
143
174
|
)
|
144
175
|
|
145
176
|
x = torch.randint(0, 256, (1, seq_len))
|
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
neural_mem_weight_residual = False
|
494
|
+
neural_mem_weight_residual = False,
|
495
495
|
):
|
496
496
|
super().__init__()
|
497
497
|
|
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
|
|
525
525
|
|
526
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
527
527
|
|
528
|
+
# weight residual related
|
529
|
+
|
528
530
|
self.neural_mem_weight_residual = neural_mem_weight_residual
|
531
|
+
is_first_neural_mem = True
|
529
532
|
|
530
533
|
# mem, attn, and feedforward layers
|
531
534
|
|
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
|
|
557
560
|
chunk_size = self.neural_memory_segment_len,
|
558
561
|
batch_size = neural_memory_batch_size,
|
559
562
|
model = deepcopy(neural_memory_model),
|
563
|
+
learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
560
564
|
**neural_memory_kwargs
|
561
565
|
)
|
562
566
|
|
567
|
+
is_first_neural_mem = False
|
568
|
+
|
563
569
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
564
570
|
|
565
571
|
self.layers.append(ModuleList([
|
@@ -296,6 +296,7 @@ class NeuralMemory(Module):
|
|
296
296
|
init_adaptive_step_bias = None,
|
297
297
|
init_momentum_bias = None,
|
298
298
|
init_decay_bias = None,
|
299
|
+
learned_weight_residual = False,
|
299
300
|
default_model_kwargs: dict = dict(
|
300
301
|
depth = 2
|
301
302
|
)
|
@@ -438,6 +439,14 @@ class NeuralMemory(Module):
|
|
438
439
|
|
439
440
|
self.max_mem_layer_modulation = max_mem_layer_modulation
|
440
441
|
|
442
|
+
# learned weight residual
|
443
|
+
|
444
|
+
self.to_learned_weight_residual_mix = Sequential(
|
445
|
+
nn.Linear(dim, heads),
|
446
|
+
Rearrange('b n h -> b h n'),
|
447
|
+
nn.Sigmoid()
|
448
|
+
) if learned_weight_residual else None
|
449
|
+
|
441
450
|
# allow for softclamp the gradient norms for storing memories
|
442
451
|
|
443
452
|
self.max_grad_norm = max_grad_norm
|
@@ -563,8 +572,21 @@ class NeuralMemory(Module):
|
|
563
572
|
|
564
573
|
# maybe add previous layer weight
|
565
574
|
|
575
|
+
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
576
|
+
|
566
577
|
if exists(prev_weights):
|
567
|
-
|
578
|
+
|
579
|
+
start_index = math.ceil(seq_index / chunk_size)
|
580
|
+
end_index = start_index + num_chunks
|
581
|
+
|
582
|
+
prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
|
583
|
+
|
584
|
+
if exists(self.to_learned_weight_residual_mix):
|
585
|
+
mix = self.to_learned_weight_residual_mix(chunked_seq)
|
586
|
+
mix = rearrange(mix, 'b h n -> (b h) n')
|
587
|
+
|
588
|
+
prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
|
589
|
+
|
568
590
|
weights_for_surprise = weights_for_surprise + prev_weights
|
569
591
|
|
570
592
|
# flatten batch and time if surprise depends on previous layer memory model
|
@@ -38,7 +38,7 @@ NEURAL_MEM_MAX_LR = 1e-1
|
|
38
38
|
WINDOW_SIZE = 32
|
39
39
|
NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
|
40
40
|
NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
|
41
|
-
NEURAL_MEM_WEIGHT_RESIDUAL =
|
41
|
+
NEURAL_MEM_WEIGHT_RESIDUAL = False
|
42
42
|
SLIDING_WINDOWS = True
|
43
43
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
44
44
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|