titans-pytorch 0.2.1__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -536,11 +536,6 @@ class MemoryAsContextTransformer(Module):
536
536
  self.weight_tie_memory_model = weight_tie_memory_model
537
537
  self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
538
538
 
539
- # value residual learning for neural memory
540
-
541
- is_first_mem = True
542
- self.mem_add_value_residual = neural_memory_add_value_residual
543
-
544
539
  # mem, attn, and feedforward layers
545
540
 
546
541
  for layer in layers:
@@ -570,12 +565,9 @@ class MemoryAsContextTransformer(Module):
570
565
  dim = dim,
571
566
  chunk_size = self.neural_memory_segment_len,
572
567
  model = maybe_copy(neural_memory_model),
573
- accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
574
568
  **neural_memory_kwargs
575
569
  )
576
570
 
577
- is_first_mem = False
578
-
579
571
  ff = FeedForward(dim = dim, mult = ff_mult)
580
572
 
581
573
  self.layers.append(ModuleList([
@@ -765,8 +757,6 @@ class MemoryAsContextTransformer(Module):
765
757
 
766
758
  value_residual = None
767
759
 
768
- mem_value_residual = None
769
-
770
760
  # aux losses
771
761
 
772
762
  kv_recon_losses = self.zero
@@ -794,28 +784,21 @@ class MemoryAsContextTransformer(Module):
794
784
  mem_input, add_residual = mem_hyper_conn(x)
795
785
 
796
786
  if not is_inferencing:
797
- (retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
787
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
798
788
  mem_input,
799
789
  return_aux_kv_loss = True,
800
- return_values = True,
801
- value_residual = mem_value_residual,
802
790
  prev_layer_updates = neural_memory_updates
803
791
  )
804
792
 
805
793
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
806
794
 
807
795
  else:
808
- (retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
796
+ (retrieved, next_neural_mem_cache) = mem.forward_inference(
809
797
  mem_input,
810
798
  state = next(neural_mem_caches, None),
811
- return_values = True,
812
- value_residual = mem_value_residual,
813
799
  prev_layer_updates = neural_memory_updates
814
800
  )
815
801
 
816
- if self.mem_add_value_residual:
817
- mem_value_residual = next_mem_value_residual
818
-
819
802
  if prev_neural_mem_update_for_weights:
820
803
  neural_memory_updates = next_neural_mem_cache.updates
821
804
 
@@ -67,6 +67,9 @@ def safe_cat(inputs, dim = -2):
67
67
  def identity(t):
68
68
  return t
69
69
 
70
+ def dict_get_shape(td):
71
+ return {k: v.shape for k, v in td.items()}
72
+
70
73
  def pair(v):
71
74
  return (v, v) if not isinstance(v, tuple) else v
72
75
 
@@ -258,7 +261,6 @@ class NeuralMemory(Module):
258
261
  pre_rmsnorm = True,
259
262
  post_rmsnorm = True,
260
263
  qk_rmsnorm = False,
261
- accept_value_residual = False,
262
264
  max_grad_norm: float | None = None,
263
265
  use_accelerated_scan = False,
264
266
  activation: Module | None = None,
@@ -315,6 +317,8 @@ class NeuralMemory(Module):
315
317
 
316
318
  self.num_memory_parameter_tensors = len(set(model.parameters()))
317
319
 
320
+ self.init_weight_shape = dict_get_shape(dict(model.named_parameters()))
321
+
318
322
  # the chunk size within the paper where adaptive step, momentum, weight decay are shared
319
323
 
320
324
  self.chunk_size = chunk_size
@@ -343,19 +347,6 @@ class NeuralMemory(Module):
343
347
  self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
344
348
  self.store_memory_loss_fn = store_memory_loss_fn
345
349
 
346
- # value residual learning
347
-
348
- self.learned_value_residual = Sequential(
349
- LinearNoBias(dim, heads),
350
- Rearrange('b n h -> b h n 1'),
351
- nn.Sigmoid()
352
- ) if accept_value_residual else None
353
-
354
- # empty memory embed
355
-
356
- self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
357
- nn.init.normal_(self.empty_memory_embed, std = 0.02)
358
-
359
350
  # `chunk_size` refers to chunk size used for storing to memory model weights
360
351
 
361
352
  chunk_size = self.store_chunk_size
@@ -417,9 +408,6 @@ class NeuralMemory(Module):
417
408
  weights = TensorDict(dict(self.memory_model.named_parameters()))
418
409
  return weights
419
410
 
420
- def init_empty_memory_embed(self, batch, seq_len):
421
- return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
422
-
423
411
  def store_memories(
424
412
  self,
425
413
  seq,
@@ -428,10 +416,7 @@ class NeuralMemory(Module):
428
416
  prev_layer_updates: dict[str, Tensor] | None = None,
429
417
  return_aux_kv_loss = False,
430
418
  chunk_size = None,
431
- value_residual = None
432
419
  ):
433
- assert xnor(exists(value_residual), exists(self.learned_value_residual))
434
-
435
420
  seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
436
421
 
437
422
  # handle edge case
@@ -446,7 +431,7 @@ class NeuralMemory(Module):
446
431
 
447
432
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
448
433
 
449
- seq = seq[:, :round_down_seq_len]
434
+ seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
450
435
 
451
436
  # per sample grad function
452
437
 
@@ -499,14 +484,6 @@ class NeuralMemory(Module):
499
484
 
500
485
  keys = self.k_norm(keys)
501
486
 
502
- # maybe value residual learning
503
-
504
- orig_values = values
505
-
506
- if exists(self.learned_value_residual):
507
- mix = self.learned_value_residual(seq)
508
- values = values.lerp(value_residual, mix)
509
-
510
487
  # take care of chunking
511
488
 
512
489
  keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
@@ -581,13 +558,15 @@ class NeuralMemory(Module):
581
558
  if has_momentum:
582
559
  next_momentum[param_name] = inverse_pack(momentum)
583
560
 
584
- # compute next states for inference, or titans-xl like training
561
+ # determine next state for the storing of memories
585
562
 
586
563
  next_state = (next_last_update, next_last_momentum)
587
564
 
565
+ next_store_state = NeuralMemCache(seq_len, remainder, next_state, updates)
566
+
588
567
  # returns
589
568
 
590
- output = (updates, next_state, orig_values)
569
+ output = (updates, next_store_state)
591
570
 
592
571
  if not return_aux_kv_loss:
593
572
  return output
@@ -606,16 +585,18 @@ class NeuralMemory(Module):
606
585
 
607
586
  seq = self.retrieve_norm(seq)
608
587
 
609
- if seq_len < chunk_size:
610
- return self.init_empty_memory_embed(batch, seq_len)
588
+ assert seq_len >= chunk_size, 'must be handled outside of retrieve'
589
+
590
+ needs_pad = chunk_size > 1
611
591
 
612
- seq = seq[:, (chunk_size - 1):]
613
- curtailed_seq_len = seq.shape[-2]
592
+ if needs_pad:
593
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
594
+ seq_len_plus_one = seq.shape[-2]
614
595
 
615
- next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
596
+ next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
616
597
 
617
- padding = next_seq_len - curtailed_seq_len
618
- seq = pad_at_dim(seq, (0, padding), dim = 1)
598
+ padding = next_seq_len - seq_len_plus_one
599
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
619
600
 
620
601
  # the parameters of the memory model stores the memories of the key / values
621
602
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -639,7 +620,9 @@ class NeuralMemory(Module):
639
620
 
640
621
  # fetch values from memory model
641
622
 
642
- curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
623
+ if dict_get_shape(curr_weights) != self.init_weight_shape:
624
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
625
+
643
626
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
644
627
 
645
628
  # forward functional call
@@ -665,10 +648,10 @@ class NeuralMemory(Module):
665
648
 
666
649
  # restore, pad with empty memory embed
667
650
 
668
- empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
669
- values = torch.cat((empty_memory_embeds, values), dim = -2)
651
+ if needs_pad:
652
+ values = values[:, 1:(seq_len + 1)]
670
653
 
671
- return values[:, :seq_len]
654
+ return values
672
655
 
673
656
  @torch.no_grad()
674
657
  def forward_inference(
@@ -676,8 +659,6 @@ class NeuralMemory(Module):
676
659
  token: Tensor,
677
660
  state = None,
678
661
  prev_layer_updates: dict[str, Tensor] | None = None,
679
- return_values = False,
680
- value_residual = None,
681
662
  ):
682
663
 
683
664
  # unpack previous state
@@ -704,12 +685,9 @@ class NeuralMemory(Module):
704
685
  # early return empty memory, when no memories are stored for steps < first chunk size
705
686
 
706
687
  if curr_seq_len < self.chunk_size:
707
- empty_mem = self.init_empty_memory_embed(batch, 1)
708
-
709
- output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
688
+ retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
710
689
 
711
- if return_values:
712
- output = (*output, self.zero)
690
+ output = retrieve, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
713
691
 
714
692
  return output
715
693
 
@@ -728,20 +706,18 @@ class NeuralMemory(Module):
728
706
  prev_layer_updates = TensorDict(prev_layer_updates)
729
707
  prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
730
708
 
731
- values = None
732
-
733
709
  if store_seq_cache_len == self.chunk_size:
734
710
 
735
- next_updates, next_states, values = self.store_memories(
711
+ next_updates, store_state = self.store_memories(
736
712
  cache_store_seq,
737
713
  weights,
738
714
  past_state = past_states,
739
715
  prev_layer_updates = prev_layer_updates,
740
- value_residual = value_residual
741
716
  )
742
717
 
743
718
  updates = next_updates
744
719
  cache_store_seq = None
720
+ next_states = store_state.states
745
721
 
746
722
  # retrieve
747
723
 
@@ -749,14 +725,9 @@ class NeuralMemory(Module):
749
725
 
750
726
  # next state tuple
751
727
 
752
- next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
753
-
754
- output = (retrieved, next_state)
755
-
756
- if return_values:
757
- output = (*output, values)
728
+ next_store_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
758
729
 
759
- return output
730
+ return retrieved, next_store_state
760
731
 
761
732
  def forward(
762
733
  self,
@@ -767,50 +738,45 @@ class NeuralMemory(Module):
767
738
  return_aux_kv_loss = False,
768
739
  chunk_size = None,
769
740
  store_chunk_size = None,
770
- return_values = False,
771
- value_residual = None,
772
741
  return_next_state = False,
773
742
  prev_layer_updates: dict[str, Tensor] | None = None
774
743
  ):
775
744
  batch, seq_len = seq.shape[:2]
776
745
 
746
+ if not exists(mem_model_weights):
747
+ mem_model_weights = self.init_weights()
748
+
777
749
  if seq_len < self.retrieve_chunk_size:
778
- out = self.init_empty_memory_embed(batch, seq_len)
750
+ retrieved = self.retrieve_memories(seq, mem_model_weights, chunk_size = 1)
779
751
 
780
752
  next_store_state = NeuralMemCache(seq_len, seq, None, None)
781
753
 
782
- out = (out, next_store_state)
783
-
784
- if return_values:
785
- out = (*out, self.zero)
754
+ out = (retrieved, next_store_state)
786
755
 
787
756
  if not return_aux_kv_loss:
788
757
  return out
789
758
 
790
759
  return out, self.zero
791
760
 
792
- if not exists(mem_model_weights):
793
- mem_model_weights = self.init_weights()
794
-
795
761
  # store
796
762
 
797
763
  store_seq = default(store_seq, seq)
798
764
 
799
- store_seq_len = store_seq.shape[-2]
800
- store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
801
- remainder = store_seq_len % store_chunk_size
802
-
803
- (updates, next_state, values), aux_kv_recon_loss = self.store_memories(
765
+ (updates, next_store_state), aux_kv_recon_loss = self.store_memories(
804
766
  store_seq,
805
767
  mem_model_weights,
806
768
  chunk_size = store_chunk_size,
807
769
  prev_layer_updates = prev_layer_updates,
808
- value_residual = value_residual,
809
770
  return_aux_kv_loss = True
810
771
  )
811
772
 
812
773
  # retrieve
813
774
 
775
+ if exists(prev_layer_updates):
776
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
777
+
778
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
779
+
814
780
  retrieved = self.retrieve_memories(
815
781
  seq,
816
782
  mem_model_weights + updates,
@@ -818,21 +784,8 @@ class NeuralMemory(Module):
818
784
  prev_layer_updates = prev_layer_updates
819
785
  )
820
786
 
821
- # determine state for the storing of memories
822
- # for transformer-xl like training with neural memory as well as inferencing with initial prompt
823
-
824
- cache_store_seq = None
825
-
826
- if remainder > 0:
827
- cache_store_seq = store_seq[:, -remainder:]
828
-
829
- next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
830
-
831
787
  output = (retrieved, next_store_state)
832
788
 
833
- if return_values:
834
- output = (*output, values)
835
-
836
789
  if not return_aux_kv_loss:
837
790
  return output
838
791
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.1
3
+ Version: 0.2.4
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=g-Rx8zwTUbMv-XBYWPe9abFVVSUFLxOn_yVQ-wWvG5M,26039
4
+ titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
+ titans_pytorch/neural_memory.py,sha256=3ykFukUDp3dW1QwDmS3jZ2wFysiZE2ippcOoMFall34,24143
6
+ titans_pytorch-0.2.4.dist-info/METADATA,sha256=2yY3d58zPQ1uyvnTX4Dml7a2dd2jRu3TR5NhBpPNmdY,6819
7
+ titans_pytorch-0.2.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.4.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=kqW90mpbFf1ZJ_mMkd6v9EQ5J__TwKMPy5cjHJF_26A,26742
4
- titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
- titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
6
- titans_pytorch-0.2.1.dist-info/METADATA,sha256=HPdcQb4SlT-eLFzOYLMwGInEKegL4M4yIpKWt1a6DTs,6819
7
- titans_pytorch-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.1.dist-info/RECORD,,