titans-pytorch 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- from titans_pytorch.titans import (
1
+ from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
65
65
 
66
66
  # proposed neural memory
67
67
 
68
- from titans_pytorch.titans import NeuralMemory
68
+ from titans_pytorch.neural_memory import NeuralMemory
69
69
 
70
70
  # constants
71
71
 
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
106
106
  zeros = ((0, 0) * dims_from_right)
107
107
  return F.pad(t, (*zeros, *pad), value = value)
108
108
 
109
- def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
109
+ def pad_and_segment_with_inverse(
110
+ seq,
111
+ segment_len,
112
+ fold_into_batch = True,
113
+ ):
110
114
  batch, seq_len = seq.shape[:2]
111
115
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
112
116
 
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
119
123
  if fold_into_batch:
120
124
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
121
125
 
122
- def inverse(out, remove_pad = True):
126
+ shape = seq.shape
127
+
128
+ def inverse(out):
129
+ unchanged_shape = out.shape == shape
130
+
123
131
  if fold_into_batch:
124
132
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
125
133
 
126
- if needs_pad and remove_pad:
134
+ if needs_pad and unchanged_shape:
127
135
  out = out[..., :-padding, :]
128
136
 
129
137
  return out
@@ -690,7 +698,7 @@ class MemoryAsContextTransformer(Module):
690
698
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
691
699
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
692
700
 
693
- x = inverse_segment(x, remove_pad = False)
701
+ x = inverse_segment(x)
694
702
 
695
703
  # splice out unneeded tokens from padding for longterm mems
696
704
 
@@ -759,14 +767,13 @@ class MemoryAsContextTransformer(Module):
759
767
  mem_input, add_residual = mem_hyper_conn(x)
760
768
 
761
769
  if not is_inferencing:
762
- retrieved, mem_kv_aux_loss = mem(
770
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
763
771
  mem_input,
764
772
  return_aux_kv_loss = True
765
773
  )
766
774
 
767
775
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
768
776
 
769
- next_neural_mem_cache = (seq_len, None, None, None)
770
777
  else:
771
778
  retrieved, next_neural_mem_cache = mem.forward_inference(
772
779
  mem_input,
@@ -836,7 +843,7 @@ class MemoryAsContextTransformer(Module):
836
843
 
837
844
  x, _ = inverse_pack_mems(x)
838
845
 
839
- x = inverse_segment(x, remove_pad = False)
846
+ x = inverse_segment(x)
840
847
 
841
848
  x = x[:, :seq_len]
842
849
 
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from collections import namedtuple
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, Tensor
@@ -33,13 +34,18 @@ w - num memory network weight parameters
33
34
 
34
35
  LinearNoBias = partial(Linear, bias = False)
35
36
 
37
+ NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
38
+
36
39
  # functions
37
40
 
38
41
  def exists(v):
39
42
  return v is not None
40
43
 
41
- def default(v, d):
42
- return v if exists(v) else d
44
+ def default(*args):
45
+ for arg in args:
46
+ if exists(arg):
47
+ return arg
48
+ return None
43
49
 
44
50
  def xnor(x, y):
45
51
  return not (x ^ y)
@@ -468,7 +474,12 @@ class NeuralMemory(Module):
468
474
  weighted_loss = loss * loss_weights
469
475
  return weighted_loss.sum(), weighted_loss.mean()
470
476
 
471
- self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
477
+ # two functions
478
+
479
+ grad_fn = grad(forward_and_loss, has_aux = True)
480
+
481
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
482
+ self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
472
483
 
473
484
  # queries for retrieving from the model
474
485
 
@@ -561,6 +572,7 @@ class NeuralMemory(Module):
561
572
  seq,
562
573
  weights: dict[str, Tensor],
563
574
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
575
+ prev_layer_updates: dict[str, Tensor] | None = None,
564
576
  return_aux_kv_loss = False,
565
577
  chunk_size = None,
566
578
  value_residual = None
@@ -583,10 +595,25 @@ class NeuralMemory(Module):
583
595
 
584
596
  seq = seq[:, :round_down_seq_len]
585
597
 
598
+ # per sample grad function
599
+
600
+ per_sample_grad_fn = self.per_sample_grad_fn
601
+
586
602
  # weights of the memory network
587
603
 
588
604
  weights = TensorDict(weights)
589
605
 
606
+ # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
607
+ # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
608
+ # improvise (or perhaps correcting to) a solution
609
+
610
+ if exists(prev_layer_updates):
611
+ prev_layer_updates = TensorDict(prev_layer_updates)
612
+
613
+ weights = weights + prev_layer_updates
614
+
615
+ per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
616
+
590
617
  # derive learned hparams for optimization of memory network
591
618
 
592
619
  adaptive_lr = self.to_adaptive_step(seq)
@@ -633,9 +660,14 @@ class NeuralMemory(Module):
633
660
 
634
661
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
635
662
 
663
+ # flatten batch and time if surprise depends on previous layer memory model
664
+
665
+ if exists(prev_layer_updates):
666
+ weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
667
+
636
668
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
637
669
 
638
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
670
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
639
671
 
640
672
  grads = TensorDict(grads)
641
673
 
@@ -713,7 +745,8 @@ class NeuralMemory(Module):
713
745
  self,
714
746
  seq,
715
747
  past_weights: dict[str, Tensor],
716
- chunk_size = None
748
+ chunk_size = None,
749
+ prev_layer_updates: dict[str, Tensor] | None = None
717
750
  ):
718
751
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
719
752
  batch, seq_len = seq.shape[:2]
@@ -736,6 +769,9 @@ class NeuralMemory(Module):
736
769
 
737
770
  curr_weights = TensorDict(past_weights)
738
771
 
772
+ if exists(prev_layer_updates):
773
+ curr_weights = curr_weights + TensorDict(prev_layer_updates)
774
+
739
775
  # sequence Float['b n d'] to queries
740
776
 
741
777
  queries = self.to_queries(seq)
@@ -781,6 +817,7 @@ class NeuralMemory(Module):
781
817
 
782
818
  return values[:, :seq_len]
783
819
 
820
+ @torch.no_grad()
784
821
  def forward_inference(
785
822
  self,
786
823
  token: Tensor,
@@ -813,7 +850,7 @@ class NeuralMemory(Module):
813
850
  if curr_seq_len < self.chunk_size:
814
851
  empty_mem = self.init_empty_memory_embed(batch, 1)
815
852
 
816
- return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
853
+ return empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
817
854
 
818
855
  # store if storage sequence cache hits the chunk size
819
856
 
@@ -823,6 +860,8 @@ class NeuralMemory(Module):
823
860
  if not exists(updates):
824
861
  updates = weights.clone().zero_()
825
862
  updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
863
+ else:
864
+ updates = updates.apply(lambda t: t[:, -1:])
826
865
 
827
866
  if store_seq_cache_len == self.chunk_size:
828
867
 
@@ -841,7 +880,7 @@ class NeuralMemory(Module):
841
880
 
842
881
  # next state tuple
843
882
 
844
- next_state = (curr_seq_len, cache_store_seq, next_states, updates)
883
+ next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
845
884
 
846
885
  return retrieved, next_state
847
886
 
@@ -854,13 +893,19 @@ class NeuralMemory(Module):
854
893
  return_aux_kv_loss = False,
855
894
  chunk_size = None,
856
895
  store_chunk_size = None,
857
- return_values = False
896
+ return_values = False,
897
+ return_next_state = False,
898
+ prev_layer_updates: dict[str, Tensor] | None = None
858
899
  ):
859
900
  batch, seq_len = seq.shape[:2]
860
901
 
861
902
  if seq_len < self.retrieve_chunk_size:
862
903
  out = self.init_empty_memory_embed(batch, seq_len)
863
904
 
905
+ next_store_state = (seq_len, seq, None, None)
906
+
907
+ out = (out, next_store_state)
908
+
864
909
  if not return_aux_kv_loss:
865
910
  return out
866
911
 
@@ -869,17 +914,45 @@ class NeuralMemory(Module):
869
914
  if not exists(mem_model_weights):
870
915
  mem_model_weights = self.init_weights()
871
916
 
917
+ # store
918
+
872
919
  store_seq = default(store_seq, seq)
873
- store_chunk_size = default(store_chunk_size, chunk_size)
874
920
 
875
- (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
921
+ store_seq_len = store_seq.shape[-2]
922
+ store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
923
+ remainder = store_seq_len % store_chunk_size
924
+
925
+ (updates, next_state, values), aux_kv_recon_loss = self.store_memories(
926
+ store_seq,
927
+ mem_model_weights,
928
+ chunk_size = store_chunk_size,
929
+ prev_layer_updates = prev_layer_updates,
930
+ return_aux_kv_loss = True
931
+ )
932
+
933
+ # retrieve
934
+
935
+ retrieved = self.retrieve_memories(
936
+ seq,
937
+ mem_model_weights + updates,
938
+ chunk_size = chunk_size,
939
+ prev_layer_updates = prev_layer_updates
940
+ )
941
+
942
+ # determine state for the storing of memories
943
+ # for transformer-xl like training with neural memory as well as inferencing with initial prompt
944
+
945
+ cache_store_seq = None
946
+
947
+ if remainder > 0:
948
+ cache_store_seq = store_seq[:, -remainder:]
876
949
 
877
- retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
950
+ next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
878
951
 
879
- output = retrieved
952
+ output = (retrieved, next_store_state)
880
953
 
881
954
  if return_values:
882
- output = (retrieved, values)
955
+ output = (*output, values)
883
956
 
884
957
  if not return_aux_kv_loss:
885
958
  return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.31
3
+ Version: 0.1.33
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -82,7 +82,7 @@ mem = NeuralMemory(
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
85
- retrieved = mem(seq)
85
+ retrieved, mem_state = mem(seq)
86
86
 
87
87
  assert seq.shape == retrieved.shape
88
88
  ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
4
+ titans_pytorch/neural_memory.py,sha256=9dXpSaQYomc-ur-nEwej1nG9M5NqS0c3LBBP9jUIMPU,28352
5
+ titans_pytorch-0.1.33.dist-info/METADATA,sha256=A9BBoe0Sas2kxUcUi7w_RFl8-SIF1TLzPIRGuZlauFM,6826
6
+ titans_pytorch-0.1.33.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.33.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
4
- titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
5
- titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
6
- titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.31.dist-info/RECORD,,