titans-pytorch 0.2.12__tar.gz → 0.2.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.12
3
+ Version: 0.2.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.12"
3
+ version = "0.2.15"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -148,7 +148,7 @@ def test_mac(
148
148
  assert logits.shape == (1, seq_len, 256)
149
149
 
150
150
  @pytest.mark.parametrize('sliding', (False, True))
151
- @pytest.mark.parametrize('mem_layers', (()))
151
+ @pytest.mark.parametrize('mem_layers', ((), None))
152
152
  @pytest.mark.parametrize('longterm_mems', (0, 4, 16))
153
153
  @pytest.mark.parametrize('prompt_len', (4, 16))
154
154
  @torch_default_dtype(torch.float64)
@@ -190,7 +190,6 @@ def test_neural_mem_inference(
190
190
  prompt_len,
191
191
  mem_chunk_size
192
192
  ):
193
- pytest.skip()
194
193
 
195
194
  mem = NeuralMemory(
196
195
  dim = 384,
@@ -218,7 +217,7 @@ def test_neural_mem_inference(
218
217
 
219
218
  for token in seq.unbind(dim = 1):
220
219
 
221
- one_retrieved, state = mem.forward_inference(
220
+ one_retrieved, state = mem.forward(
222
221
  token,
223
222
  state = state,
224
223
  )
@@ -761,16 +761,10 @@ class MemoryAsContextTransformer(Module):
761
761
 
762
762
  mem_input, add_residual = mem_hyper_conn(x)
763
763
 
764
- if not is_inferencing:
765
- retrieved, next_neural_mem_cache = mem(
766
- mem_input
767
- )
768
-
769
- else:
770
- (retrieved, next_neural_mem_cache) = mem.forward_inference(
771
- mem_input,
772
- state = next(neural_mem_caches, None),
773
- )
764
+ retrieved, next_neural_mem_cache = mem.forward(
765
+ mem_input,
766
+ state = next(neural_mem_caches, None),
767
+ )
774
768
 
775
769
  if self.gate_attn_output:
776
770
  attn_out_gates = retrieved.sigmoid()
@@ -284,7 +284,7 @@ class NeuralMemory(Module):
284
284
  adaptive_step_transform: Callable | None = None,
285
285
  default_step_transform_max_lr = 1.,
286
286
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
287
- max_mem_layer_modulation = 1e1, # max of 10.
287
+ max_mem_layer_modulation = 1., # max of 10.
288
288
  attn_pool_chunks = False,
289
289
  momentum = True,
290
290
  pre_rmsnorm = True,
@@ -293,6 +293,9 @@ class NeuralMemory(Module):
293
293
  max_grad_norm: float | None = None,
294
294
  use_accelerated_scan = False,
295
295
  activation: Module | None = None,
296
+ init_adaptive_step_bias = None,
297
+ init_momentum_bias = None,
298
+ init_decay_bias = None,
296
299
  default_model_kwargs: dict = dict(
297
300
  depth = 2
298
301
  )
@@ -411,12 +414,12 @@ class NeuralMemory(Module):
411
414
  # learned adaptive learning rate and momentum
412
415
 
413
416
  self.to_momentum = Sequential(
414
- LinearNoBias(dim, heads),
417
+ nn.Linear(dim, heads),
415
418
  Rearrange('b n h -> (b h) n 1')
416
419
  ) if momentum else None
417
420
 
418
421
  self.to_adaptive_step = Sequential(
419
- LinearNoBias(dim, heads),
422
+ nn.Linear(dim, heads),
420
423
  Rearrange('b n h -> (b h) n')
421
424
  )
422
425
 
@@ -428,7 +431,7 @@ class NeuralMemory(Module):
428
431
  # per layer learning rate modulation
429
432
 
430
433
  self.to_layer_modulation = Sequential(
431
- LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
434
+ nn.Linear(dim, heads * self.num_memory_parameter_tensors),
432
435
  Rearrange('b n (h w) -> w (b h) n', h = heads),
433
436
  nn.Sigmoid()
434
437
  ) if per_parameter_lr_modulation else None
@@ -442,10 +445,27 @@ class NeuralMemory(Module):
442
445
  # weight decay factor
443
446
 
444
447
  self.to_decay_factor = Sequential(
445
- LinearNoBias(dim, heads),
448
+ nn.Linear(dim, heads),
446
449
  Rearrange('b n h -> (b h) n 1')
447
450
  )
448
451
 
452
+ # inits
453
+
454
+ if exists(init_adaptive_step_bias):
455
+ linear = self.to_adaptive_step[0]
456
+ nn.init.zeros_(linear.weight)
457
+ nn.init.constant_(linear.bias, init_adaptive_step_bias)
458
+
459
+ if exists(init_momentum_bias):
460
+ linear = self.to_momentum[0]
461
+ nn.init.zeros_(linear.weight)
462
+ nn.init.constant_(linear.bias, init_momentum_bias)
463
+
464
+ if exists(init_decay_bias):
465
+ linear = self.to_decay_factor[0]
466
+ nn.init.zeros_(linear.weight)
467
+ nn.init.constant_(linear.bias, init_decay_bias)
468
+
449
469
  # maybe use accelerated scan
450
470
 
451
471
  self.use_accelerated_scan = use_accelerated_scan
@@ -635,15 +655,19 @@ class NeuralMemory(Module):
635
655
  self,
636
656
  seq,
637
657
  past_weights: dict[str, Tensor],
658
+ chunk_size = None,
659
+ need_pad = True
638
660
  ):
639
- chunk_size = self.retrieve_chunk_size
661
+ chunk_size = default(chunk_size, self.retrieve_chunk_size)
640
662
  batch, seq_len = seq.shape[:2]
641
663
 
642
664
  seq = self.retrieve_norm(seq)
643
665
 
644
- needs_pad = chunk_size > 1
666
+ need_pad = need_pad or chunk_size > 1
667
+
668
+ if need_pad:
669
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
645
670
 
646
- seq = pad_at_dim(seq, (1, 0), dim = 1)
647
671
  seq_len_plus_one = seq.shape[-2]
648
672
 
649
673
  next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
@@ -698,76 +722,10 @@ class NeuralMemory(Module):
698
722
 
699
723
  # restore, pad with empty memory embed
700
724
 
701
- values = values[:, 1:(seq_len + 1)]
702
-
703
- return values
704
-
705
- @torch.no_grad()
706
- def forward_inference(
707
- self,
708
- token: Tensor,
709
- state: NeuralMemCache | None = None,
710
- ):
711
- # unpack previous state
712
-
713
- if not exists(state):
714
- state = (0, None, None, None, None)
715
-
716
- seq_index, weights, cache_store_seq, past_states, updates = state
717
-
718
- curr_seq_len = seq_index + 1
719
- batch = token.shape[0]
720
-
721
- if token.ndim == 2:
722
- token = rearrange(token, 'b d -> b 1 d')
723
-
724
- assert token.shape[1] == 1
725
-
726
- # increment the sequence cache which is at most the chunk size
727
-
728
- cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
729
-
730
- # early return empty memory, when no memories are stored for steps < first chunk size
731
-
732
- if curr_seq_len < self.chunk_size:
733
- retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
734
-
735
- output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
725
+ if need_pad:
726
+ values = values[:, 1:]
736
727
 
737
- return output
738
-
739
- # store if storage sequence cache hits the chunk size
740
-
741
- next_states = past_states
742
- store_seq_cache_len = cache_store_seq.shape[-2]
743
-
744
- if not exists(updates):
745
- updates = weights.clone().zero_()
746
- updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
747
- else:
748
- updates = updates.apply(lambda t: t[:, -1:])
749
-
750
- if store_seq_cache_len == self.chunk_size:
751
-
752
- next_updates, store_state = self.store_memories(
753
- cache_store_seq,
754
- weights,
755
- past_state = past_states,
756
- )
757
-
758
- updates = next_updates
759
- cache_store_seq = None
760
- next_states = store_state.states
761
-
762
- # retrieve
763
-
764
- retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
765
-
766
- # next state tuple
767
-
768
- next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
769
-
770
- return retrieved, next_store_state
728
+ return values[:, :seq_len]
771
729
 
772
730
  def forward(
773
731
  self,
@@ -775,17 +733,25 @@ class NeuralMemory(Module):
775
733
  store_seq = None,
776
734
  state: NeuralMemCache | None = None,
777
735
  ):
736
+ if seq.ndim == 2:
737
+ seq = rearrange(seq, 'b d -> b 1 d')
738
+
739
+ is_single_token = seq.shape[1] == 1
740
+
778
741
  if not exists(state):
779
742
  state = (0, None, None, None, None)
780
743
 
781
744
  seq_index, weights, cache_store_seq, past_state, updates = state
782
745
 
783
- assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
784
-
785
746
  # store
786
747
 
787
748
  store_seq = default(store_seq, seq)
788
749
 
750
+ # take care of cache
751
+
752
+ if exists(cache_store_seq):
753
+ store_seq = safe_cat((cache_store_seq, store_seq))
754
+
789
755
  # functions
790
756
 
791
757
  # compute split sizes of sequence
@@ -863,9 +829,21 @@ class NeuralMemory(Module):
863
829
 
864
830
  # retrieve
865
831
 
832
+ need_pad = True
833
+ retrieve_chunk_size = None
834
+
835
+ if is_single_token:
836
+ retrieve_chunk_size = 1
837
+ need_pad = False
838
+
839
+ last_update, _ = past_state
840
+ updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
841
+
866
842
  retrieved = self.retrieve_memories(
867
843
  seq,
868
- updates
844
+ updates,
845
+ chunk_size = retrieve_chunk_size,
846
+ need_pad = need_pad,
869
847
  )
870
848
 
871
849
  return retrieved, next_neural_mem_state
@@ -34,6 +34,7 @@ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural mem
34
34
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
35
35
  NEURAL_MEM_MOMENTUM = True
36
36
  NEURAL_MEM_QK_NORM = True
37
+ NEURAL_MEM_MAX_LR = 1e-1
37
38
  WINDOW_SIZE = 32
38
39
  NEURAL_MEM_SEGMENT_LEN = 2 # set smaller for more granularity for learning rate / momentum etc
39
40
  NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
@@ -98,6 +99,7 @@ model = MemoryAsContextTransformer(
98
99
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
99
100
  qk_rmsnorm = NEURAL_MEM_QK_NORM,
100
101
  momentum = NEURAL_MEM_MOMENTUM,
102
+ default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
101
103
  use_accelerated_scan = USE_ACCELERATED_SCAN,
102
104
  per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
103
105
  )
File without changes