titans-pytorch 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +7 -1
- titans_pytorch/neural_memory.py +17 -0
- {titans_pytorch-0.2.17.dist-info → titans_pytorch-0.2.18.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.18.dist-info/RECORD +9 -0
- titans_pytorch-0.2.17.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.17.dist-info → titans_pytorch-0.2.18.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.17.dist-info → titans_pytorch-0.2.18.dist-info}/licenses/LICENSE +0 -0
@@ -491,7 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
-
neural_mem_weight_residual = False
|
494
|
+
neural_mem_weight_residual = False,
|
495
495
|
):
|
496
496
|
super().__init__()
|
497
497
|
|
@@ -525,7 +525,10 @@ class MemoryAsContextTransformer(Module):
|
|
525
525
|
|
526
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
527
527
|
|
528
|
+
# weight residual related
|
529
|
+
|
528
530
|
self.neural_mem_weight_residual = neural_mem_weight_residual
|
531
|
+
is_first_neural_mem = True
|
529
532
|
|
530
533
|
# mem, attn, and feedforward layers
|
531
534
|
|
@@ -557,9 +560,12 @@ class MemoryAsContextTransformer(Module):
|
|
557
560
|
chunk_size = self.neural_memory_segment_len,
|
558
561
|
batch_size = neural_memory_batch_size,
|
559
562
|
model = deepcopy(neural_memory_model),
|
563
|
+
learned_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
560
564
|
**neural_memory_kwargs
|
561
565
|
)
|
562
566
|
|
567
|
+
is_first_neural_mem = False
|
568
|
+
|
563
569
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
564
570
|
|
565
571
|
self.layers.append(ModuleList([
|
titans_pytorch/neural_memory.py
CHANGED
@@ -296,6 +296,7 @@ class NeuralMemory(Module):
|
|
296
296
|
init_adaptive_step_bias = None,
|
297
297
|
init_momentum_bias = None,
|
298
298
|
init_decay_bias = None,
|
299
|
+
learned_weight_residual = False,
|
299
300
|
default_model_kwargs: dict = dict(
|
300
301
|
depth = 2
|
301
302
|
)
|
@@ -438,6 +439,14 @@ class NeuralMemory(Module):
|
|
438
439
|
|
439
440
|
self.max_mem_layer_modulation = max_mem_layer_modulation
|
440
441
|
|
442
|
+
# learned weight residual
|
443
|
+
|
444
|
+
self.to_learned_weight_residual_mix = Sequential(
|
445
|
+
nn.Linear(dim, heads),
|
446
|
+
Rearrange('b n h -> b h n'),
|
447
|
+
nn.Sigmoid()
|
448
|
+
) if learned_weight_residual else None
|
449
|
+
|
441
450
|
# allow for softclamp the gradient norms for storing memories
|
442
451
|
|
443
452
|
self.max_grad_norm = max_grad_norm
|
@@ -563,6 +572,8 @@ class NeuralMemory(Module):
|
|
563
572
|
|
564
573
|
# maybe add previous layer weight
|
565
574
|
|
575
|
+
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
576
|
+
|
566
577
|
if exists(prev_weights):
|
567
578
|
|
568
579
|
start_index = math.ceil(seq_index / chunk_size)
|
@@ -570,6 +581,12 @@ class NeuralMemory(Module):
|
|
570
581
|
|
571
582
|
prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
|
572
583
|
|
584
|
+
if exists(self.to_learned_weight_residual_mix):
|
585
|
+
mix = self.to_learned_weight_residual_mix(chunked_seq)
|
586
|
+
mix = rearrange(mix, 'b h n -> (b h) n')
|
587
|
+
|
588
|
+
prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
|
589
|
+
|
573
590
|
weights_for_surprise = weights_for_surprise + prev_weights
|
574
591
|
|
575
592
|
# flatten batch and time if surprise depends on previous layer memory model
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=J9MHFViWWehBOFOKx7ry_X2k8nAXaAZFFeCwGtudZyk,24942
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=A74v2g4OUXJIqXLu1Rk7FDsTlhveJIHsiamsWN9e7cI,26602
|
6
|
+
titans_pytorch-0.2.18.dist-info/METADATA,sha256=QFJI4OoZZSdgZd6FJInAG5-z2ubpvO5I0aCRNKwhDnQ,6812
|
7
|
+
titans_pytorch-0.2.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.18.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=QhdSPgWntfWILMJ1t0xLKgvZfPZWu9vhzZWaesftaPg,24724
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=M7vB6AgrphY3_cdELyWXF8qS9cjYUvv1NVCUqQOlA8M,25928
|
6
|
-
titans_pytorch-0.2.17.dist-info/METADATA,sha256=8oBe_7SkPmHvHVGXpoRvGWMTldfl2pzXIztOaM_qGrI,6812
|
7
|
-
titans_pytorch-0.2.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|