titans-pytorch 0.2.14__tar.gz → 0.2.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.14
3
+ Version: 0.2.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.14"
3
+ version = "0.2.15"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -148,7 +148,7 @@ def test_mac(
148
148
  assert logits.shape == (1, seq_len, 256)
149
149
 
150
150
  @pytest.mark.parametrize('sliding', (False, True))
151
- @pytest.mark.parametrize('mem_layers', (()))
151
+ @pytest.mark.parametrize('mem_layers', ((), None))
152
152
  @pytest.mark.parametrize('longterm_mems', (0, 4, 16))
153
153
  @pytest.mark.parametrize('prompt_len', (4, 16))
154
154
  @torch_default_dtype(torch.float64)
@@ -190,7 +190,6 @@ def test_neural_mem_inference(
190
190
  prompt_len,
191
191
  mem_chunk_size
192
192
  ):
193
- pytest.skip()
194
193
 
195
194
  mem = NeuralMemory(
196
195
  dim = 384,
@@ -218,7 +217,7 @@ def test_neural_mem_inference(
218
217
 
219
218
  for token in seq.unbind(dim = 1):
220
219
 
221
- one_retrieved, state = mem.forward_inference(
220
+ one_retrieved, state = mem.forward(
222
221
  token,
223
222
  state = state,
224
223
  )
@@ -761,16 +761,10 @@ class MemoryAsContextTransformer(Module):
761
761
 
762
762
  mem_input, add_residual = mem_hyper_conn(x)
763
763
 
764
- if not is_inferencing:
765
- retrieved, next_neural_mem_cache = mem(
766
- mem_input
767
- )
768
-
769
- else:
770
- (retrieved, next_neural_mem_cache) = mem.forward_inference(
771
- mem_input,
772
- state = next(neural_mem_caches, None),
773
- )
764
+ retrieved, next_neural_mem_cache = mem.forward(
765
+ mem_input,
766
+ state = next(neural_mem_caches, None),
767
+ )
774
768
 
775
769
  if self.gate_attn_output:
776
770
  attn_out_gates = retrieved.sigmoid()
@@ -655,15 +655,19 @@ class NeuralMemory(Module):
655
655
  self,
656
656
  seq,
657
657
  past_weights: dict[str, Tensor],
658
+ chunk_size = None,
659
+ need_pad = True
658
660
  ):
659
- chunk_size = self.retrieve_chunk_size
661
+ chunk_size = default(chunk_size, self.retrieve_chunk_size)
660
662
  batch, seq_len = seq.shape[:2]
661
663
 
662
664
  seq = self.retrieve_norm(seq)
663
665
 
664
- needs_pad = chunk_size > 1
666
+ need_pad = need_pad or chunk_size > 1
667
+
668
+ if need_pad:
669
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
665
670
 
666
- seq = pad_at_dim(seq, (1, 0), dim = 1)
667
671
  seq_len_plus_one = seq.shape[-2]
668
672
 
669
673
  next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
@@ -718,76 +722,10 @@ class NeuralMemory(Module):
718
722
 
719
723
  # restore, pad with empty memory embed
720
724
 
721
- values = values[:, 1:(seq_len + 1)]
722
-
723
- return values
724
-
725
- @torch.no_grad()
726
- def forward_inference(
727
- self,
728
- token: Tensor,
729
- state: NeuralMemCache | None = None,
730
- ):
731
- # unpack previous state
732
-
733
- if not exists(state):
734
- state = (0, None, None, None, None)
735
-
736
- seq_index, weights, cache_store_seq, past_states, updates = state
737
-
738
- curr_seq_len = seq_index + 1
739
- batch = token.shape[0]
725
+ if need_pad:
726
+ values = values[:, 1:]
740
727
 
741
- if token.ndim == 2:
742
- token = rearrange(token, 'b d -> b 1 d')
743
-
744
- assert token.shape[1] == 1
745
-
746
- # increment the sequence cache which is at most the chunk size
747
-
748
- cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
749
-
750
- # early return empty memory, when no memories are stored for steps < first chunk size
751
-
752
- if curr_seq_len < self.chunk_size:
753
- retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
754
-
755
- output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
756
-
757
- return output
758
-
759
- # store if storage sequence cache hits the chunk size
760
-
761
- next_states = past_states
762
- store_seq_cache_len = cache_store_seq.shape[-2]
763
-
764
- if not exists(updates):
765
- updates = weights.clone().zero_()
766
- updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
767
- else:
768
- updates = updates.apply(lambda t: t[:, -1:])
769
-
770
- if store_seq_cache_len == self.chunk_size:
771
-
772
- next_updates, store_state = self.store_memories(
773
- cache_store_seq,
774
- weights,
775
- past_state = past_states,
776
- )
777
-
778
- updates = next_updates
779
- cache_store_seq = None
780
- next_states = store_state.states
781
-
782
- # retrieve
783
-
784
- retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
785
-
786
- # next state tuple
787
-
788
- next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
789
-
790
- return retrieved, next_store_state
728
+ return values[:, :seq_len]
791
729
 
792
730
  def forward(
793
731
  self,
@@ -795,17 +733,25 @@ class NeuralMemory(Module):
795
733
  store_seq = None,
796
734
  state: NeuralMemCache | None = None,
797
735
  ):
736
+ if seq.ndim == 2:
737
+ seq = rearrange(seq, 'b d -> b 1 d')
738
+
739
+ is_single_token = seq.shape[1] == 1
740
+
798
741
  if not exists(state):
799
742
  state = (0, None, None, None, None)
800
743
 
801
744
  seq_index, weights, cache_store_seq, past_state, updates = state
802
745
 
803
- assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
804
-
805
746
  # store
806
747
 
807
748
  store_seq = default(store_seq, seq)
808
749
 
750
+ # take care of cache
751
+
752
+ if exists(cache_store_seq):
753
+ store_seq = safe_cat((cache_store_seq, store_seq))
754
+
809
755
  # functions
810
756
 
811
757
  # compute split sizes of sequence
@@ -883,9 +829,21 @@ class NeuralMemory(Module):
883
829
 
884
830
  # retrieve
885
831
 
832
+ need_pad = True
833
+ retrieve_chunk_size = None
834
+
835
+ if is_single_token:
836
+ retrieve_chunk_size = 1
837
+ need_pad = False
838
+
839
+ last_update, _ = past_state
840
+ updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
841
+
886
842
  retrieved = self.retrieve_memories(
887
843
  seq,
888
- updates
844
+ updates,
845
+ chunk_size = retrieve_chunk_size,
846
+ need_pad = need_pad,
889
847
  )
890
848
 
891
849
  return retrieved, next_neural_mem_state
@@ -87,7 +87,6 @@ model = MemoryAsContextTransformer(
87
87
  neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
88
88
  neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
89
89
  neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
90
- default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
91
90
  use_flex_attn = USE_FLEX_ATTN,
92
91
  sliding_window_attn = SLIDING_WINDOWS,
93
92
  neural_memory_model = MemoryMLP(
@@ -100,6 +99,7 @@ model = MemoryAsContextTransformer(
100
99
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
101
100
  qk_rmsnorm = NEURAL_MEM_QK_NORM,
102
101
  momentum = NEURAL_MEM_MOMENTUM,
102
+ default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
103
103
  use_accelerated_scan = USE_ACCELERATED_SCAN,
104
104
  per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
105
105
  )
File without changes