titans-pytorch 0.2.17__tar.gz → 0.2.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/PKG-INFO +1 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/pyproject.toml +1 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/tests/test_titans.py +4 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/titans_pytorch/mac_transformer.py +7 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/titans_pytorch/neural_memory.py +17 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/.gitignore +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/LICENSE +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/README.md +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/data/README.md +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/fig1.png +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/fig2.png +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.18}/train_mac.py +0 -0
@@ -149,6 +149,7 @@ def test_neural_mem_chaining_with_batch_size():
|
|
149
149
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
150
150
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
151
151
|
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
152
|
+
@pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
|
152
153
|
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
153
154
|
def test_mac(
|
154
155
|
seq_len,
|
@@ -156,7 +157,8 @@ def test_mac(
|
|
156
157
|
num_longterm_mem_tokens,
|
157
158
|
neural_mem_gate_attn_output,
|
158
159
|
neural_mem_segment_len,
|
159
|
-
|
160
|
+
neural_mem_weight_residual,
|
161
|
+
neural_mem_batch_size,
|
160
162
|
):
|
161
163
|
transformer = MemoryAsContextTransformer(
|
162
164
|
num_tokens = 256,
|
@@ -168,6 +170,7 @@ def test_mac(
|
|
168
170
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
169
171
|
neural_memory_segment_len = neural_mem_segment_len,
|
170
172
|
neural_memory_batch_size = neural_mem_batch_size,
|
173
|
+
neural_mem_weight_residual = neural_mem_weight_residual
|
171
174
|
)
|
172
175
|
|
173
176
|
x = torch.randint(0, 256, (1, seq_len))
|
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
neural_mem_weight_residual = False
|
494
|
+
neural_mem_weight_residual = False,
|
495
495
|
):
|
496
496
|
super().__init__()
|
497
497
|
|
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
|
|
525
525
|
|
526
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
527
527
|
|
528
|
+
# weight residual related
|
529
|
+
|
528
530
|
self.neural_mem_weight_residual = neural_mem_weight_residual
|
531
|
+
is_first_neural_mem = True
|
529
532
|
|
530
533
|
# mem, attn, and feedforward layers
|
531
534
|
|
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
|
|
557
560
|
chunk_size = self.neural_memory_segment_len,
|
558
561
|
batch_size = neural_memory_batch_size,
|
559
562
|
model = deepcopy(neural_memory_model),
|
563
|
+
learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
560
564
|
**neural_memory_kwargs
|
561
565
|
)
|
562
566
|
|
567
|
+
is_first_neural_mem = False
|
568
|
+
|
563
569
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
564
570
|
|
565
571
|
self.layers.append(ModuleList([
|
@@ -296,6 +296,7 @@ class NeuralMemory(Module):
|
|
296
296
|
init_adaptive_step_bias = None,
|
297
297
|
init_momentum_bias = None,
|
298
298
|
init_decay_bias = None,
|
299
|
+
learned_weight_residual = False,
|
299
300
|
default_model_kwargs: dict = dict(
|
300
301
|
depth = 2
|
301
302
|
)
|
@@ -438,6 +439,14 @@ class NeuralMemory(Module):
|
|
438
439
|
|
439
440
|
self.max_mem_layer_modulation = max_mem_layer_modulation
|
440
441
|
|
442
|
+
# learned weight residual
|
443
|
+
|
444
|
+
self.to_learned_weight_residual_mix = Sequential(
|
445
|
+
nn.Linear(dim, heads),
|
446
|
+
Rearrange('b n h -> b h n'),
|
447
|
+
nn.Sigmoid()
|
448
|
+
) if learned_weight_residual else None
|
449
|
+
|
441
450
|
# allow for softclamp the gradient norms for storing memories
|
442
451
|
|
443
452
|
self.max_grad_norm = max_grad_norm
|
@@ -563,6 +572,8 @@ class NeuralMemory(Module):
|
|
563
572
|
|
564
573
|
# maybe add previous layer weight
|
565
574
|
|
575
|
+
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
576
|
+
|
566
577
|
if exists(prev_weights):
|
567
578
|
|
568
579
|
start_index = math.ceil(seq_index / chunk_size)
|
@@ -570,6 +581,12 @@ class NeuralMemory(Module):
|
|
570
581
|
|
571
582
|
prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
|
572
583
|
|
584
|
+
if exists(self.to_learned_weight_residual_mix):
|
585
|
+
mix = self.to_learned_weight_residual_mix(chunked_seq)
|
586
|
+
mix = rearrange(mix, 'b h n -> (b h) n')
|
587
|
+
|
588
|
+
prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
|
589
|
+
|
573
590
|
weights_for_surprise = weights_for_surprise + prev_weights
|
574
591
|
|
575
592
|
# flatten batch and time if surprise depends on previous layer memory model
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|