titans-pytorch 0.1.23__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -510,10 +510,7 @@ class MemoryAsContextTransformer(Module):
510
510
 
511
511
  layers = tuple(range(1, depth + 1))
512
512
 
513
- if not exists(neural_memory_layers):
514
- neural_memory_layers = layers if has_longterm_mems else ()
515
-
516
- assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
513
+ neural_memory_layers = default(neural_memory_layers, layers)
517
514
 
518
515
  # mem, attn, and feedforward layers
519
516
 
@@ -535,9 +532,10 @@ class MemoryAsContextTransformer(Module):
535
532
  )
536
533
 
537
534
  mem = None
535
+ mem_hyper_conn = None
538
536
 
539
537
  if layer in neural_memory_layers:
540
- assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
538
+ mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
541
539
 
542
540
  mem = NeuralMemory(
543
541
  dim = dim,
@@ -545,10 +543,12 @@ class MemoryAsContextTransformer(Module):
545
543
  **neural_memory_kwargs
546
544
  )
547
545
 
546
+
548
547
  ff = FeedForward(dim = dim, mult = ff_mult)
549
548
 
550
549
  self.layers.append(ModuleList([
551
- init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
550
+ mem_hyper_conn,
551
+ mem,
552
552
  init_hyper_conn(dim = dim, branch = attn),
553
553
  init_hyper_conn(dim = dim, branch = ff)
554
554
  ]))
@@ -691,8 +691,18 @@ class MemoryAsContextTransformer(Module):
691
691
  # kv caching
692
692
 
693
693
  is_inferencing = exists(cache)
694
- cache = iter(default(cache, []))
694
+ assert not (is_inferencing and self.num_longterm_mem_tokens > 0)
695
+
696
+ if not exists(cache):
697
+ cache = (None, None)
698
+
699
+ kv_caches, neural_mem_caches = cache
700
+
701
+ kv_caches = iter(default(kv_caches, []))
702
+ neural_mem_caches = iter(default(neural_mem_caches, []))
703
+
695
704
  next_kv_caches = []
705
+ next_neural_mem_caches = []
696
706
 
697
707
  # value residual
698
708
 
@@ -711,21 +721,37 @@ class MemoryAsContextTransformer(Module):
711
721
 
712
722
  x = self.expand_streams(x)
713
723
 
714
- for mem, attn, ff in self.layers:
724
+ for mem_hyper_conn, mem, attn, ff in self.layers:
715
725
 
716
726
  retrieved = None
717
727
  attn_out_gates = None
728
+ next_neural_mem_cache = None
718
729
 
719
730
  # maybe neural memory
720
731
 
721
732
  if exists(mem):
722
- retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
723
- kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
733
+
734
+ mem_input, add_residual = mem_hyper_conn(x)
735
+
736
+ if not is_inferencing:
737
+ retrieved, mem_kv_aux_loss = mem(
738
+ mem_input,
739
+ return_aux_kv_loss = True
740
+ )
741
+
742
+ kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
743
+
744
+ next_neural_mem_cache = (seq_len, None, None, None)
745
+ else:
746
+ retrieved, next_neural_mem_cache = mem.forward_inference(
747
+ mem_input,
748
+ state = next(neural_mem_caches, None)
749
+ )
724
750
 
725
751
  if self.gate_attn_output:
726
752
  attn_out_gates = retrieved.sigmoid()
727
753
  else:
728
- seq = retrieved
754
+ x = add_residual(retrieved)
729
755
 
730
756
  # attention
731
757
 
@@ -735,12 +761,15 @@ class MemoryAsContextTransformer(Module):
735
761
  disable_flex_attn = disable_flex_attn,
736
762
  flex_attn_fn = flex_attn_fn,
737
763
  output_gating = attn_out_gates,
738
- cache = next(cache, None)
764
+ cache = next(kv_caches, None)
739
765
  )
740
766
 
741
767
  value_residual = default(value_residual, values)
742
768
 
769
+ # caches
770
+
743
771
  next_kv_caches.append(next_kv_cache)
772
+ next_neural_mem_caches.append(next_neural_mem_cache)
744
773
 
745
774
  # feedforward
746
775
 
@@ -775,7 +804,7 @@ class MemoryAsContextTransformer(Module):
775
804
  if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
776
805
  next_kv_caches = next_kv_caches[..., 0:0, :]
777
806
 
778
- return logits, next_kv_caches
807
+ return logits, (next_kv_caches, next_neural_mem_caches)
779
808
 
780
809
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
781
810
 
titans_pytorch/titans.py CHANGED
@@ -324,15 +324,26 @@ class AssocScan(Module):
324
324
  super().__init__()
325
325
  self.use_accelerated = use_accelerated
326
326
 
327
- def forward(self, gates, inputs, prev = None):
327
+ def forward(
328
+ self,
329
+ gates,
330
+ inputs,
331
+ prev = None,
332
+ remove_prev = None
333
+ ):
334
+ remove_prev = default(remove_prev, exists(prev))
328
335
 
329
336
  if exists(prev):
330
337
  inputs, _ = pack([prev, inputs], 'b * d')
331
338
  gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
332
339
 
333
340
  if not self.use_accelerated:
334
- _, outputs = associative_scan(binary_operator, (gates, inputs))
335
- return outputs
341
+ _, out = associative_scan(binary_operator, (gates, inputs))
342
+
343
+ if remove_prev:
344
+ out = out[:, 1:]
345
+
346
+ return out
336
347
 
337
348
  from accelerated_scan.triton import scan as triton_scan
338
349
  from accelerated_scan.warp import scan as warp_scan
@@ -355,7 +366,12 @@ class AssocScan(Module):
355
366
  outputs = rearrange(outputs, 'b d n -> b n d')
356
367
  return outputs
357
368
 
358
- return accelerate_scan_fn(gates, inputs)
369
+ out = accelerate_scan_fn(gates, inputs)
370
+
371
+ if remove_prev:
372
+ out = out[:, 1:]
373
+
374
+ return out
359
375
 
360
376
  # main neural memory
361
377
 
@@ -384,7 +400,6 @@ class NeuralMemory(Module):
384
400
  post_rmsnorm = True,
385
401
  qk_rmsnorm = False,
386
402
  accept_value_residual = False,
387
- learned_mem_model_weights = True,
388
403
  max_grad_norm: float | None = None,
389
404
  use_accelerated_scan = False,
390
405
  activation: Module | None = None,
@@ -432,9 +447,6 @@ class NeuralMemory(Module):
432
447
  if not exists(model):
433
448
  model = MemoryMLP(dim_head, **default_model_kwargs)
434
449
 
435
- if not learned_mem_model_weights:
436
- model.requires_grad_(False)
437
-
438
450
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
439
451
 
440
452
  # the memory is the weights of the model
@@ -536,16 +548,9 @@ class NeuralMemory(Module):
536
548
 
537
549
  self.register_buffer('zero', torch.tensor(0.), persistent = False)
538
550
 
539
- def init_weights_and_momentum(self, zero_weights = False):
540
- params = TensorDict(dict(self.memory_model.named_parameters()))
541
-
542
- init_weights = params
543
- init_momentum = params.clone().zero_()
544
-
545
- if zero_weights:
546
- init_weights = params.clone().zero_()
547
-
548
- return init_weights, init_momentum
551
+ def init_weights(self):
552
+ weights = TensorDict(dict(self.memory_model.named_parameters()))
553
+ return weights
549
554
 
550
555
  def init_empty_memory_embed(self, batch, seq_len):
551
556
  return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
@@ -553,7 +558,8 @@ class NeuralMemory(Module):
553
558
  def store_memories(
554
559
  self,
555
560
  seq,
556
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
561
+ weights: dict[str, Tensor],
562
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
557
563
  return_aux_kv_loss = False,
558
564
  chunk_size = None,
559
565
  value_residual = None
@@ -565,8 +571,7 @@ class NeuralMemory(Module):
565
571
  # handle edge case
566
572
 
567
573
  if seq_len < chunk_size:
568
- past_weight, _ = past_state
569
- return TensorDict(past_weight).clone().zero_(), self.zero
574
+ return TensorDict(weights).clone().zero_(), self.zero
570
575
 
571
576
  seq = self.store_norm(seq)
572
577
 
@@ -577,10 +582,9 @@ class NeuralMemory(Module):
577
582
 
578
583
  seq = seq[:, :round_down_seq_len]
579
584
 
580
- # get the weights of the memory network
585
+ # weights of the memory network
581
586
 
582
- past_state = tuple(TensorDict(d) for d in past_state)
583
- curr_weights, past_momentum = past_state
587
+ weights = TensorDict(weights)
584
588
 
585
589
  # derive learned hparams for optimization of memory network
586
590
 
@@ -630,7 +634,7 @@ class NeuralMemory(Module):
630
634
 
631
635
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
632
636
 
633
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
637
+ grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
634
638
 
635
639
  grads = TensorDict(grads)
636
640
 
@@ -652,12 +656,23 @@ class NeuralMemory(Module):
652
656
 
653
657
  surprises = grads.apply(lambda t: -t)
654
658
 
659
+ # past states
660
+
661
+ if not exists(past_state):
662
+ empty_dict = {key: None for key in weights.keys()}
663
+ past_state = (empty_dict, empty_dict)
664
+
665
+ past_last_update, past_last_momentum = past_state
666
+
655
667
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
656
668
 
657
669
  next_momentum = TensorDict() if has_momentum else None
658
670
  updates = TensorDict()
659
671
 
660
- for param_name, surprise in surprises.items():
672
+ next_last_update = TensorDict()
673
+ next_last_momentum = TensorDict()
674
+
675
+ for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
661
676
 
662
677
  surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
663
678
 
@@ -666,23 +681,27 @@ class NeuralMemory(Module):
666
681
  # derive momentum with associative scan - eq (10)
667
682
 
668
683
  if has_momentum:
669
- update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
684
+ update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
670
685
  momentum = update
686
+ next_last_momentum[param_name] = momentum[:, -1]
671
687
 
672
688
  # use associative scan again for learned forgetting (weight decay) - eq (13)
673
689
 
674
- update = self.assoc_scan(1. - decay_factor, update)
690
+ update = self.assoc_scan(1. - decay_factor, update, prev = last_update)
691
+ next_last_update[param_name] = update[:, -1]
675
692
 
676
693
  updates[param_name] = inverse_pack(update)
677
694
 
678
695
  if has_momentum:
679
696
  next_momentum[param_name] = inverse_pack(momentum)
680
697
 
681
- # compute the next weight per batch
698
+ # compute next states for inference, or titans-xl like training
682
699
 
683
- last_update = updates.apply(lambda t: t[:, -1])
700
+ next_state = (next_last_update, next_last_momentum)
684
701
 
685
- output = (updates, orig_values)
702
+ # returns
703
+
704
+ output = (updates, next_state, orig_values)
686
705
 
687
706
  if not return_aux_kv_loss:
688
707
  return output
@@ -764,21 +783,25 @@ class NeuralMemory(Module):
764
783
  def forward_inference(
765
784
  self,
766
785
  token: Tensor,
767
- seq_index = None, # the index of the token in the sequence, starts at 0
768
- mem_model_state = None,
769
- cache_store_seq = None
786
+ state = None,
770
787
  ):
771
- seq_index = default(seq_index, 0)
788
+
789
+ # unpack previous state
790
+
791
+ if not exists(state):
792
+ state = (0, None, None, None)
793
+
794
+ seq_index, cache_store_seq, past_states, updates = state
795
+
772
796
  curr_seq_len = seq_index + 1
773
797
  batch = token.shape[0]
774
798
 
775
799
  if token.ndim == 2:
776
800
  token = rearrange(token, 'b d -> b 1 d')
777
801
 
778
- # init memory model if needed
802
+ # get memory model weights
779
803
 
780
- if not exists(mem_model_state):
781
- mem_model_state = self.init_weights_and_momentum()
804
+ weights = self.init_weights()
782
805
 
783
806
  # increment the sequence cache which is at most the chunk size
784
807
 
@@ -789,32 +812,43 @@ class NeuralMemory(Module):
789
812
  if curr_seq_len < self.chunk_size:
790
813
  empty_mem = self.init_empty_memory_embed(batch, 1)
791
814
 
792
- return empty_mem, cache_store_seq, mem_model_state
815
+ return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
793
816
 
794
817
  # store if storage sequence cache hits the chunk size
795
818
 
819
+ next_states = past_states
796
820
  store_seq_cache_len = cache_store_seq.shape[-2]
797
821
 
822
+ if not exists(updates):
823
+ updates = weights.clone().zero_()
824
+ updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
825
+
798
826
  if store_seq_cache_len == self.chunk_size:
799
- updates, _ = self.store_memories(cache_store_seq, mem_model_state)
800
827
 
801
- past_weights, past_momentum = mem_model_state
802
- mem_model_state = (past_weights + updates, past_momentum)
828
+ next_updates, next_states, _ = self.store_memories(
829
+ cache_store_seq,
830
+ weights,
831
+ past_state = past_states
832
+ )
803
833
 
834
+ updates = next_updates
804
835
  cache_store_seq = None
805
836
 
806
837
  # retrieve
807
838
 
808
- past_weights, _ = mem_model_state
839
+ retrieved = self.retrieve_memories(token, updates + weights, chunk_size = 1)
840
+
841
+ # next state tuple
809
842
 
810
- retrieved = self.retrieve_memories(token, past_weights, chunk_size = 1)
843
+ next_state = (curr_seq_len, cache_store_seq, next_states, updates)
811
844
 
812
- return retrieved, cache_store_seq, mem_model_state
845
+ return retrieved, next_state
813
846
 
814
847
  def forward(
815
848
  self,
816
849
  seq,
817
850
  store_seq = None,
851
+ mem_model_weights: dict[str, Tensor] | None = None,
818
852
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
819
853
  return_aux_kv_loss = False,
820
854
  chunk_size = None,
@@ -831,20 +865,15 @@ class NeuralMemory(Module):
831
865
 
832
866
  return out, self.zero
833
867
 
834
- if exists(past_state):
835
- past_state = tuple(TensorDict(d) for d in past_state)
836
-
837
- if not exists(past_state):
838
- past_state = self.init_weights_and_momentum()
868
+ if not exists(mem_model_weights):
869
+ mem_model_weights = self.init_weights()
839
870
 
840
871
  store_seq = default(store_seq, seq)
841
872
  store_chunk_size = default(store_chunk_size, chunk_size)
842
873
 
843
- (updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
844
-
845
- past_weights, _ = past_state
874
+ (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
846
875
 
847
- retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
876
+ retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
848
877
 
849
878
  output = retrieved
850
879
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.23
3
+ Version: 0.1.27
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.9
38
+ Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.9
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Staf9hRQ44QAL23bSGh4VSB8NeGtMri-JdiZdgirJiU,23587
4
+ titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
5
+ titans_pytorch-0.1.27.dist-info/METADATA,sha256=AZ5-_d9o_khm6jaky1zoKyXB1hDQNifbS061v_b4McQ,6815
6
+ titans_pytorch-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.27.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
- titans_pytorch/titans.py,sha256=WbagKMYDs-3NoW2j_pAyHEnvR9QzH3A9WntHuV_FKOo,25109
5
- titans_pytorch-0.1.23.dist-info/METADATA,sha256=H7QbLscawNObHGeoTbnKbf-NOqkMqWCu4yWeZJ0yKMA,6814
6
- titans_pytorch-0.1.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.23.dist-info/RECORD,,