titans-pytorch 0.2.15__tar.gz → 0.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.15
3
+ Version: 0.2.16
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.15"
3
+ version = "0.2.16"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -491,6 +491,7 @@ class MemoryAsContextTransformer(Module):
491
491
  neural_memory_layers: tuple[int, ...] | None = None,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
+ neural_mem_weight_residual = False
494
495
  ):
495
496
  super().__init__()
496
497
 
@@ -524,6 +525,8 @@ class MemoryAsContextTransformer(Module):
524
525
 
525
526
  neural_memory_layers = default(neural_memory_layers, layers)
526
527
 
528
+ self.neural_mem_weight_residual = neural_mem_weight_residual
529
+
527
530
  # mem, attn, and feedforward layers
528
531
 
529
532
  for layer in layers:
@@ -739,6 +742,10 @@ class MemoryAsContextTransformer(Module):
739
742
 
740
743
  value_residual = None
741
744
 
745
+ # neural mem weight residual
746
+
747
+ mem_weight_residual = None
748
+
742
749
  # when inferencing, only do one token at a time
743
750
 
744
751
  if is_inferencing:
@@ -764,8 +771,12 @@ class MemoryAsContextTransformer(Module):
764
771
  retrieved, next_neural_mem_cache = mem.forward(
765
772
  mem_input,
766
773
  state = next(neural_mem_caches, None),
774
+ prev_weights = mem_weight_residual
767
775
  )
768
776
 
777
+ if self.neural_mem_weight_residual:
778
+ mem_weight_residual = next_neural_mem_cache.updates
779
+
769
780
  if self.gate_attn_output:
770
781
  attn_out_gates = retrieved.sigmoid()
771
782
  else:
@@ -494,7 +494,8 @@ class NeuralMemory(Module):
494
494
  seq,
495
495
  weights: dict[str, Tensor] | None = None,
496
496
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
497
- seq_index = 0
497
+ seq_index = 0,
498
+ prev_weights = None
498
499
  ):
499
500
  batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
500
501
 
@@ -560,6 +561,12 @@ class NeuralMemory(Module):
560
561
 
561
562
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
562
563
 
564
+ # maybe add previous layer weight
565
+
566
+ if exists(prev_weights):
567
+ prev_weights = prev_weights.apply(lambda t: t[:, -1:])
568
+ weights_for_surprise = weights_for_surprise + prev_weights
569
+
563
570
  # flatten batch and time if surprise depends on previous layer memory model
564
571
 
565
572
  weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
@@ -732,6 +739,7 @@ class NeuralMemory(Module):
732
739
  seq,
733
740
  store_seq = None,
734
741
  state: NeuralMemCache | None = None,
742
+ prev_weights = None
735
743
  ):
736
744
  if seq.ndim == 2:
737
745
  seq = rearrange(seq, 'b d -> b 1 d')
@@ -807,6 +815,7 @@ class NeuralMemory(Module):
807
815
  weights,
808
816
  seq_index = seq_index,
809
817
  past_state = past_state,
818
+ prev_weights = prev_weights
810
819
  )
811
820
 
812
821
  seq_index = next_neural_mem_state.seq_index
@@ -36,8 +36,9 @@ NEURAL_MEM_MOMENTUM = True
36
36
  NEURAL_MEM_QK_NORM = True
37
37
  NEURAL_MEM_MAX_LR = 1e-1
38
38
  WINDOW_SIZE = 32
39
- NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
39
+ NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
40
40
  NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
41
+ NEURAL_MEM_WEIGHT_RESIDUAL = True
41
42
  SLIDING_WINDOWS = True
42
43
  STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
43
44
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
@@ -87,6 +88,7 @@ model = MemoryAsContextTransformer(
87
88
  neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
88
89
  neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
89
90
  neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
91
+ neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
90
92
  use_flex_attn = USE_FLEX_ATTN,
91
93
  sliding_window_attn = SLIDING_WINDOWS,
92
94
  neural_memory_model = MemoryMLP(
File without changes