titans-pytorch 0.2.17__tar.gz → 0.2.19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/PKG-INFO +1 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/pyproject.toml +1 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/tests/test_titans.py +6 -2
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/titans_pytorch/mac_transformer.py +7 -1
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/titans_pytorch/neural_memory.py +25 -3
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/.gitignore +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/LICENSE +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/README.md +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/data/README.md +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/fig1.png +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/fig2.png +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.2.17 → titans_pytorch-0.2.19}/train_mac.py +0 -0
@@ -105,7 +105,8 @@ def test_neural_mem_chaining_with_weight_residual():
|
|
105
105
|
dim = 384,
|
106
106
|
dim_head = 64,
|
107
107
|
heads = 2,
|
108
|
-
chunk_size = 64
|
108
|
+
chunk_size = 64,
|
109
|
+
accept_weight_residual = True
|
109
110
|
)
|
110
111
|
|
111
112
|
seq = torch.randn(2, 256, 384)
|
@@ -149,6 +150,7 @@ def test_neural_mem_chaining_with_batch_size():
|
|
149
150
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
150
151
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
151
152
|
@pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
|
153
|
+
@pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
|
152
154
|
@pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
|
153
155
|
def test_mac(
|
154
156
|
seq_len,
|
@@ -156,7 +158,8 @@ def test_mac(
|
|
156
158
|
num_longterm_mem_tokens,
|
157
159
|
neural_mem_gate_attn_output,
|
158
160
|
neural_mem_segment_len,
|
159
|
-
|
161
|
+
neural_mem_weight_residual,
|
162
|
+
neural_mem_batch_size,
|
160
163
|
):
|
161
164
|
transformer = MemoryAsContextTransformer(
|
162
165
|
num_tokens = 256,
|
@@ -168,6 +171,7 @@ def test_mac(
|
|
168
171
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output,
|
169
172
|
neural_memory_segment_len = neural_mem_segment_len,
|
170
173
|
neural_memory_batch_size = neural_mem_batch_size,
|
174
|
+
neural_mem_weight_residual = neural_mem_weight_residual
|
171
175
|
)
|
172
176
|
|
173
177
|
x = torch.randint(0, 256, (1, seq_len))
|
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
neural_mem_weight_residual = False
|
494
|
+
neural_mem_weight_residual = False,
|
495
495
|
):
|
496
496
|
super().__init__()
|
497
497
|
|
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
|
|
525
525
|
|
526
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
527
527
|
|
528
|
+
# weight residual related
|
529
|
+
|
528
530
|
self.neural_mem_weight_residual = neural_mem_weight_residual
|
531
|
+
is_first_neural_mem = True
|
529
532
|
|
530
533
|
# mem, attn, and feedforward layers
|
531
534
|
|
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
|
|
557
560
|
chunk_size = self.neural_memory_segment_len,
|
558
561
|
batch_size = neural_memory_batch_size,
|
559
562
|
model = deepcopy(neural_memory_model),
|
563
|
+
learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
560
564
|
**neural_memory_kwargs
|
561
565
|
)
|
562
566
|
|
567
|
+
is_first_neural_mem = False
|
568
|
+
|
563
569
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
564
570
|
|
565
571
|
self.layers.append(ModuleList([
|
@@ -66,6 +66,12 @@ def xnor(x, y):
|
|
66
66
|
def divisible_by(num, den):
|
67
67
|
return (num % den) == 0
|
68
68
|
|
69
|
+
def tuple_index_set(t: tuple, index, value):
|
70
|
+
klass = type(t)
|
71
|
+
t = list(t)
|
72
|
+
t[index] = value
|
73
|
+
return klass(*t)
|
74
|
+
|
69
75
|
def safe_cat(inputs, dim = -2):
|
70
76
|
inputs = tuple(filter(exists, inputs))
|
71
77
|
|
@@ -296,6 +302,7 @@ class NeuralMemory(Module):
|
|
296
302
|
init_adaptive_step_bias = None,
|
297
303
|
init_momentum_bias = None,
|
298
304
|
init_decay_bias = None,
|
305
|
+
accept_weight_residual = False,
|
299
306
|
default_model_kwargs: dict = dict(
|
300
307
|
depth = 2
|
301
308
|
)
|
@@ -438,6 +445,14 @@ class NeuralMemory(Module):
|
|
438
445
|
|
439
446
|
self.max_mem_layer_modulation = max_mem_layer_modulation
|
440
447
|
|
448
|
+
# learned weight residual
|
449
|
+
|
450
|
+
self.to_learned_weight_residual_mix = Sequential(
|
451
|
+
nn.Linear(dim, heads),
|
452
|
+
Rearrange('b n h -> b h n'),
|
453
|
+
nn.Sigmoid()
|
454
|
+
) if accept_weight_residual else None
|
455
|
+
|
441
456
|
# allow for softclamp the gradient norms for storing memories
|
442
457
|
|
443
458
|
self.max_grad_norm = max_grad_norm
|
@@ -563,6 +578,8 @@ class NeuralMemory(Module):
|
|
563
578
|
|
564
579
|
# maybe add previous layer weight
|
565
580
|
|
581
|
+
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
582
|
+
|
566
583
|
if exists(prev_weights):
|
567
584
|
|
568
585
|
start_index = math.ceil(seq_index / chunk_size)
|
@@ -570,6 +587,11 @@ class NeuralMemory(Module):
|
|
570
587
|
|
571
588
|
prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
|
572
589
|
|
590
|
+
if exists(self.to_learned_weight_residual_mix):
|
591
|
+
mix = self.to_learned_weight_residual_mix(chunked_seq)
|
592
|
+
mix = rearrange(mix, 'b h n -> (b h) n')
|
593
|
+
prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
|
594
|
+
|
573
595
|
weights_for_surprise = weights_for_surprise + prev_weights
|
574
596
|
|
575
597
|
# flatten batch and time if surprise depends on previous layer memory model
|
@@ -837,9 +859,9 @@ class NeuralMemory(Module):
|
|
837
859
|
|
838
860
|
weights = last_update
|
839
861
|
|
840
|
-
next_neural_mem_state =
|
841
|
-
|
842
|
-
|
862
|
+
next_neural_mem_state = tuple_index_set(next_neural_mem_state, 1, last_update)
|
863
|
+
|
864
|
+
next_neural_mem_state = tuple_index_set(next_neural_mem_state, -1, updates)
|
843
865
|
|
844
866
|
# retrieve
|
845
867
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|