titans-pytorch 0.2.12__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -761,16 +761,10 @@ class MemoryAsContextTransformer(Module):
761
761
 
762
762
  mem_input, add_residual = mem_hyper_conn(x)
763
763
 
764
- if not is_inferencing:
765
- retrieved, next_neural_mem_cache = mem(
766
- mem_input
767
- )
768
-
769
- else:
770
- (retrieved, next_neural_mem_cache) = mem.forward_inference(
771
- mem_input,
772
- state = next(neural_mem_caches, None),
773
- )
764
+ retrieved, next_neural_mem_cache = mem.forward(
765
+ mem_input,
766
+ state = next(neural_mem_caches, None),
767
+ )
774
768
 
775
769
  if self.gate_attn_output:
776
770
  attn_out_gates = retrieved.sigmoid()
@@ -284,7 +284,7 @@ class NeuralMemory(Module):
284
284
  adaptive_step_transform: Callable | None = None,
285
285
  default_step_transform_max_lr = 1.,
286
286
  per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
287
- max_mem_layer_modulation = 1e1, # max of 10.
287
+ max_mem_layer_modulation = 1., # max of 10.
288
288
  attn_pool_chunks = False,
289
289
  momentum = True,
290
290
  pre_rmsnorm = True,
@@ -293,6 +293,9 @@ class NeuralMemory(Module):
293
293
  max_grad_norm: float | None = None,
294
294
  use_accelerated_scan = False,
295
295
  activation: Module | None = None,
296
+ init_adaptive_step_bias = None,
297
+ init_momentum_bias = None,
298
+ init_decay_bias = None,
296
299
  default_model_kwargs: dict = dict(
297
300
  depth = 2
298
301
  )
@@ -411,12 +414,12 @@ class NeuralMemory(Module):
411
414
  # learned adaptive learning rate and momentum
412
415
 
413
416
  self.to_momentum = Sequential(
414
- LinearNoBias(dim, heads),
417
+ nn.Linear(dim, heads),
415
418
  Rearrange('b n h -> (b h) n 1')
416
419
  ) if momentum else None
417
420
 
418
421
  self.to_adaptive_step = Sequential(
419
- LinearNoBias(dim, heads),
422
+ nn.Linear(dim, heads),
420
423
  Rearrange('b n h -> (b h) n')
421
424
  )
422
425
 
@@ -428,7 +431,7 @@ class NeuralMemory(Module):
428
431
  # per layer learning rate modulation
429
432
 
430
433
  self.to_layer_modulation = Sequential(
431
- LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
434
+ nn.Linear(dim, heads * self.num_memory_parameter_tensors),
432
435
  Rearrange('b n (h w) -> w (b h) n', h = heads),
433
436
  nn.Sigmoid()
434
437
  ) if per_parameter_lr_modulation else None
@@ -442,10 +445,27 @@ class NeuralMemory(Module):
442
445
  # weight decay factor
443
446
 
444
447
  self.to_decay_factor = Sequential(
445
- LinearNoBias(dim, heads),
448
+ nn.Linear(dim, heads),
446
449
  Rearrange('b n h -> (b h) n 1')
447
450
  )
448
451
 
452
+ # inits
453
+
454
+ if exists(init_adaptive_step_bias):
455
+ linear = self.to_adaptive_step[0]
456
+ nn.init.zeros_(linear.weight)
457
+ nn.init.constant_(linear.bias, init_adaptive_step_bias)
458
+
459
+ if exists(init_momentum_bias):
460
+ linear = self.to_momentum[0]
461
+ nn.init.zeros_(linear.weight)
462
+ nn.init.constant_(linear.bias, init_momentum_bias)
463
+
464
+ if exists(init_decay_bias):
465
+ linear = self.to_decay_factor[0]
466
+ nn.init.zeros_(linear.weight)
467
+ nn.init.constant_(linear.bias, init_decay_bias)
468
+
449
469
  # maybe use accelerated scan
450
470
 
451
471
  self.use_accelerated_scan = use_accelerated_scan
@@ -635,15 +655,19 @@ class NeuralMemory(Module):
635
655
  self,
636
656
  seq,
637
657
  past_weights: dict[str, Tensor],
658
+ chunk_size = None,
659
+ need_pad = True
638
660
  ):
639
- chunk_size = self.retrieve_chunk_size
661
+ chunk_size = default(chunk_size, self.retrieve_chunk_size)
640
662
  batch, seq_len = seq.shape[:2]
641
663
 
642
664
  seq = self.retrieve_norm(seq)
643
665
 
644
- needs_pad = chunk_size > 1
666
+ need_pad = need_pad or chunk_size > 1
667
+
668
+ if need_pad:
669
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
645
670
 
646
- seq = pad_at_dim(seq, (1, 0), dim = 1)
647
671
  seq_len_plus_one = seq.shape[-2]
648
672
 
649
673
  next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
@@ -698,76 +722,10 @@ class NeuralMemory(Module):
698
722
 
699
723
  # restore, pad with empty memory embed
700
724
 
701
- values = values[:, 1:(seq_len + 1)]
702
-
703
- return values
704
-
705
- @torch.no_grad()
706
- def forward_inference(
707
- self,
708
- token: Tensor,
709
- state: NeuralMemCache | None = None,
710
- ):
711
- # unpack previous state
712
-
713
- if not exists(state):
714
- state = (0, None, None, None, None)
715
-
716
- seq_index, weights, cache_store_seq, past_states, updates = state
717
-
718
- curr_seq_len = seq_index + 1
719
- batch = token.shape[0]
720
-
721
- if token.ndim == 2:
722
- token = rearrange(token, 'b d -> b 1 d')
723
-
724
- assert token.shape[1] == 1
725
-
726
- # increment the sequence cache which is at most the chunk size
727
-
728
- cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
729
-
730
- # early return empty memory, when no memories are stored for steps < first chunk size
731
-
732
- if curr_seq_len < self.chunk_size:
733
- retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
734
-
735
- output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
725
+ if need_pad:
726
+ values = values[:, 1:]
736
727
 
737
- return output
738
-
739
- # store if storage sequence cache hits the chunk size
740
-
741
- next_states = past_states
742
- store_seq_cache_len = cache_store_seq.shape[-2]
743
-
744
- if not exists(updates):
745
- updates = weights.clone().zero_()
746
- updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
747
- else:
748
- updates = updates.apply(lambda t: t[:, -1:])
749
-
750
- if store_seq_cache_len == self.chunk_size:
751
-
752
- next_updates, store_state = self.store_memories(
753
- cache_store_seq,
754
- weights,
755
- past_state = past_states,
756
- )
757
-
758
- updates = next_updates
759
- cache_store_seq = None
760
- next_states = store_state.states
761
-
762
- # retrieve
763
-
764
- retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
765
-
766
- # next state tuple
767
-
768
- next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
769
-
770
- return retrieved, next_store_state
728
+ return values[:, :seq_len]
771
729
 
772
730
  def forward(
773
731
  self,
@@ -775,17 +733,25 @@ class NeuralMemory(Module):
775
733
  store_seq = None,
776
734
  state: NeuralMemCache | None = None,
777
735
  ):
736
+ if seq.ndim == 2:
737
+ seq = rearrange(seq, 'b d -> b 1 d')
738
+
739
+ is_single_token = seq.shape[1] == 1
740
+
778
741
  if not exists(state):
779
742
  state = (0, None, None, None, None)
780
743
 
781
744
  seq_index, weights, cache_store_seq, past_state, updates = state
782
745
 
783
- assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
784
-
785
746
  # store
786
747
 
787
748
  store_seq = default(store_seq, seq)
788
749
 
750
+ # take care of cache
751
+
752
+ if exists(cache_store_seq):
753
+ store_seq = safe_cat((cache_store_seq, store_seq))
754
+
789
755
  # functions
790
756
 
791
757
  # compute split sizes of sequence
@@ -863,9 +829,21 @@ class NeuralMemory(Module):
863
829
 
864
830
  # retrieve
865
831
 
832
+ need_pad = True
833
+ retrieve_chunk_size = None
834
+
835
+ if is_single_token:
836
+ retrieve_chunk_size = 1
837
+ need_pad = False
838
+
839
+ last_update, _ = past_state
840
+ updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
841
+
866
842
  retrieved = self.retrieve_memories(
867
843
  seq,
868
- updates
844
+ updates,
845
+ chunk_size = retrieve_chunk_size,
846
+ need_pad = need_pad,
869
847
  )
870
848
 
871
849
  return retrieved, next_neural_mem_state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.12
3
+ Version: 0.2.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Udu-9mtPy9sDeDyXKo95YMel3ELv5quJXINW-JG-hdk,24357
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=iu9lnRrqWmtFw3QYyJlS7mOP2zI2HJFuhs3TyfkKV3o,25482
6
+ titans_pytorch-0.2.15.dist-info/METADATA,sha256=vOb0Tt6-egnqtNXMfrJVibHwm8VuWQMlPw3C7Y_L4Wg,6812
7
+ titans_pytorch-0.2.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.15.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RfJ1SvQH5_4PmlB7g-13wPAqYtCCUJxfmtaL0oBrRCU,24563
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=PbPjpm7BZIkl1z4Zyt3OcWW6MxWKAnTlmDvJthcmlp4,26151
6
- titans_pytorch-0.2.12.dist-info/METADATA,sha256=xZvhE0thLfasfF7DHO2FcwNo-bD0T5_R01pvL4sWd4A,6812
7
- titans_pytorch-0.2.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.12.dist-info/RECORD,,