titans-pytorch 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +11 -0
- titans_pytorch/neural_memory.py +10 -1
- {titans_pytorch-0.2.15.dist-info → titans_pytorch-0.2.16.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.16.dist-info/RECORD +9 -0
- titans_pytorch-0.2.15.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.15.dist-info → titans_pytorch-0.2.16.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.15.dist-info → titans_pytorch-0.2.16.dist-info}/licenses/LICENSE +0 -0
@@ -491,6 +491,7 @@ class MemoryAsContextTransformer(Module):
|
|
491
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
|
+
neural_mem_weight_residual = False
|
494
495
|
):
|
495
496
|
super().__init__()
|
496
497
|
|
@@ -524,6 +525,8 @@ class MemoryAsContextTransformer(Module):
|
|
524
525
|
|
525
526
|
neural_memory_layers = default(neural_memory_layers, layers)
|
526
527
|
|
528
|
+
self.neural_mem_weight_residual = neural_mem_weight_residual
|
529
|
+
|
527
530
|
# mem, attn, and feedforward layers
|
528
531
|
|
529
532
|
for layer in layers:
|
@@ -739,6 +742,10 @@ class MemoryAsContextTransformer(Module):
|
|
739
742
|
|
740
743
|
value_residual = None
|
741
744
|
|
745
|
+
# neural mem weight residual
|
746
|
+
|
747
|
+
mem_weight_residual = None
|
748
|
+
|
742
749
|
# when inferencing, only do one token at a time
|
743
750
|
|
744
751
|
if is_inferencing:
|
@@ -764,8 +771,12 @@ class MemoryAsContextTransformer(Module):
|
|
764
771
|
retrieved, next_neural_mem_cache = mem.forward(
|
765
772
|
mem_input,
|
766
773
|
state = next(neural_mem_caches, None),
|
774
|
+
prev_weights = mem_weight_residual
|
767
775
|
)
|
768
776
|
|
777
|
+
if self.neural_mem_weight_residual:
|
778
|
+
mem_weight_residual = next_neural_mem_cache.updates
|
779
|
+
|
769
780
|
if self.gate_attn_output:
|
770
781
|
attn_out_gates = retrieved.sigmoid()
|
771
782
|
else:
|
titans_pytorch/neural_memory.py
CHANGED
@@ -494,7 +494,8 @@ class NeuralMemory(Module):
|
|
494
494
|
seq,
|
495
495
|
weights: dict[str, Tensor] | None = None,
|
496
496
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
497
|
-
seq_index = 0
|
497
|
+
seq_index = 0,
|
498
|
+
prev_weights = None
|
498
499
|
):
|
499
500
|
batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
|
500
501
|
|
@@ -560,6 +561,12 @@ class NeuralMemory(Module):
|
|
560
561
|
|
561
562
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
562
563
|
|
564
|
+
# maybe add previous layer weight
|
565
|
+
|
566
|
+
if exists(prev_weights):
|
567
|
+
prev_weights = prev_weights.apply(lambda t: t[:, -1:])
|
568
|
+
weights_for_surprise = weights_for_surprise + prev_weights
|
569
|
+
|
563
570
|
# flatten batch and time if surprise depends on previous layer memory model
|
564
571
|
|
565
572
|
weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
|
@@ -732,6 +739,7 @@ class NeuralMemory(Module):
|
|
732
739
|
seq,
|
733
740
|
store_seq = None,
|
734
741
|
state: NeuralMemCache | None = None,
|
742
|
+
prev_weights = None
|
735
743
|
):
|
736
744
|
if seq.ndim == 2:
|
737
745
|
seq = rearrange(seq, 'b d -> b 1 d')
|
@@ -807,6 +815,7 @@ class NeuralMemory(Module):
|
|
807
815
|
weights,
|
808
816
|
seq_index = seq_index,
|
809
817
|
past_state = past_state,
|
818
|
+
prev_weights = prev_weights
|
810
819
|
)
|
811
820
|
|
812
821
|
seq_index = next_neural_mem_state.seq_index
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=QhdSPgWntfWILMJ1t0xLKgvZfPZWu9vhzZWaesftaPg,24724
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=pOoDhYrUgQ5EaUyiEvwDu7mfcaH6Jqqod5NwIFLbD9U,25798
|
6
|
+
titans_pytorch-0.2.16.dist-info/METADATA,sha256=TxyjTuJmP0o2NhrHmlzJCU3JivOA1rTY-xQp3Ir_igY,6812
|
7
|
+
titans_pytorch-0.2.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.16.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Udu-9mtPy9sDeDyXKo95YMel3ELv5quJXINW-JG-hdk,24357
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=iu9lnRrqWmtFw3QYyJlS7mOP2zI2HJFuhs3TyfkKV3o,25482
|
6
|
-
titans_pytorch-0.2.15.dist-info/METADATA,sha256=vOb0Tt6-egnqtNXMfrJVibHwm8VuWQMlPw3C7Y_L4Wg,6812
|
7
|
-
titans_pytorch-0.2.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|