titans-pytorch 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- from titans_pytorch.titans import (
1
+ from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
65
65
 
66
66
  # proposed neural memory
67
67
 
68
- from titans_pytorch.titans import NeuralMemory
68
+ from titans_pytorch.neural_memory import NeuralMemory
69
69
 
70
70
  # constants
71
71
 
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
106
106
  zeros = ((0, 0) * dims_from_right)
107
107
  return F.pad(t, (*zeros, *pad), value = value)
108
108
 
109
- def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
109
+ def pad_and_segment_with_inverse(
110
+ seq,
111
+ segment_len,
112
+ fold_into_batch = True,
113
+ ):
110
114
  batch, seq_len = seq.shape[:2]
111
115
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
112
116
 
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
119
123
  if fold_into_batch:
120
124
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
121
125
 
122
- def inverse(out, remove_pad = True):
126
+ shape = seq.shape
127
+
128
+ def inverse(out):
129
+ unchanged_shape = out.shape == shape
130
+
123
131
  if fold_into_batch:
124
132
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
125
133
 
126
- if needs_pad and remove_pad:
134
+ if needs_pad and unchanged_shape:
127
135
  out = out[..., :-padding, :]
128
136
 
129
137
  return out
@@ -582,13 +590,8 @@ class MemoryAsContextTransformer(Module):
582
590
  self,
583
591
  seq_index
584
592
  ):
585
- total_segment_len = self.attn_window_size
586
-
587
- seq = seq_index + 1
588
- seq -= int((seq % total_segment_len) == 0)
589
- last_segment_len = round_down_multiple(seq, total_segment_len)
590
- segment_seq = seq - last_segment_len
591
- return (segment_seq - self.segment_len) > 0
593
+ total_segment_len, segment_len = self.attn_window_size, self.segment_len
594
+ return ((seq_index % total_segment_len + 1) - segment_len) > 0
592
595
 
593
596
  def seq_len_with_longterm_mem(
594
597
  self,
@@ -597,7 +600,7 @@ class MemoryAsContextTransformer(Module):
597
600
  assert seq_len > 0
598
601
 
599
602
  segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
600
- return ceil(seq_len / segment_len) * num_mem + seq_len
603
+ return ((seq_len - 1) // segment_len) * num_mem + seq_len
601
604
 
602
605
  @torch.no_grad()
603
606
  def sample(
@@ -695,7 +698,7 @@ class MemoryAsContextTransformer(Module):
695
698
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
696
699
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
697
700
 
698
- x = inverse_segment(x, remove_pad = False)
701
+ x = inverse_segment(x)
699
702
 
700
703
  # splice out unneeded tokens from padding for longterm mems
701
704
 
@@ -723,9 +726,9 @@ class MemoryAsContextTransformer(Module):
723
726
  is_inferencing = exists(cache)
724
727
 
725
728
  if not exists(cache):
726
- cache = (None, None)
729
+ cache = (seq_len_with_mem - 1, None, None)
727
730
 
728
- kv_caches, neural_mem_caches = cache
731
+ inference_seq_index, kv_caches, neural_mem_caches = cache
729
732
 
730
733
  kv_caches = iter(default(kv_caches, []))
731
734
  neural_mem_caches = iter(default(neural_mem_caches, []))
@@ -744,7 +747,8 @@ class MemoryAsContextTransformer(Module):
744
747
  # when inferencing, only do one token at a time
745
748
 
746
749
  if is_inferencing:
747
- x = x[:, -1:]
750
+ ind = inference_seq_index
751
+ x = x[:, ind:(ind + 1)]
748
752
 
749
753
  # expand and reduce streams for hyper connections
750
754
 
@@ -763,14 +767,13 @@ class MemoryAsContextTransformer(Module):
763
767
  mem_input, add_residual = mem_hyper_conn(x)
764
768
 
765
769
  if not is_inferencing:
766
- retrieved, mem_kv_aux_loss = mem(
770
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
767
771
  mem_input,
768
772
  return_aux_kv_loss = True
769
773
  )
770
774
 
771
775
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
772
776
 
773
- next_neural_mem_cache = (seq_len, None, None, None)
774
777
  else:
775
778
  retrieved, next_neural_mem_cache = mem.forward_inference(
776
779
  mem_input,
@@ -817,6 +820,17 @@ class MemoryAsContextTransformer(Module):
817
820
  if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
818
821
  next_kv_caches = next_kv_caches[..., 0:0, :]
819
822
 
823
+ next_cache = (
824
+ inference_seq_index + 1,
825
+ next_kv_caches,
826
+ next_neural_mem_caches
827
+ )
828
+
829
+ is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
830
+
831
+ if is_inferencing and is_longterm_mem:
832
+ return None, next_cache
833
+
820
834
  # hyper connection reducing of streams
821
835
 
822
836
  x = self.reduce_streams(x)
@@ -843,7 +857,7 @@ class MemoryAsContextTransformer(Module):
843
857
  if not return_cache:
844
858
  return logits
845
859
 
846
- return logits, (next_kv_caches, next_neural_mem_caches)
860
+ return logits, next_cache
847
861
 
848
862
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
849
863
 
@@ -38,8 +38,11 @@ LinearNoBias = partial(Linear, bias = False)
38
38
  def exists(v):
39
39
  return v is not None
40
40
 
41
- def default(v, d):
42
- return v if exists(v) else d
41
+ def default(*args):
42
+ for arg in args:
43
+ if exists(arg):
44
+ return arg
45
+ return None
43
46
 
44
47
  def xnor(x, y):
45
48
  return not (x ^ y)
@@ -468,7 +471,12 @@ class NeuralMemory(Module):
468
471
  weighted_loss = loss * loss_weights
469
472
  return weighted_loss.sum(), weighted_loss.mean()
470
473
 
471
- self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
474
+ # two functions
475
+
476
+ grad_fn = grad(forward_and_loss, has_aux = True)
477
+
478
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
479
+ self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
472
480
 
473
481
  # queries for retrieving from the model
474
482
 
@@ -561,6 +569,7 @@ class NeuralMemory(Module):
561
569
  seq,
562
570
  weights: dict[str, Tensor],
563
571
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
572
+ prev_layer_updates: dict[str, Tensor] | None = None,
564
573
  return_aux_kv_loss = False,
565
574
  chunk_size = None,
566
575
  value_residual = None
@@ -583,10 +592,25 @@ class NeuralMemory(Module):
583
592
 
584
593
  seq = seq[:, :round_down_seq_len]
585
594
 
595
+ # per sample grad function
596
+
597
+ per_sample_grad_fn = self.per_sample_grad_fn
598
+
586
599
  # weights of the memory network
587
600
 
588
601
  weights = TensorDict(weights)
589
602
 
603
+ # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
604
+ # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
605
+ # improvise (or perhaps correcting to) a solution
606
+
607
+ if exists(prev_layer_updates):
608
+ prev_layer_updates = TensorDict(weights)
609
+
610
+ weights = weights + prev_layer_updates
611
+
612
+ per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
613
+
590
614
  # derive learned hparams for optimization of memory network
591
615
 
592
616
  adaptive_lr = self.to_adaptive_step(seq)
@@ -635,7 +659,7 @@ class NeuralMemory(Module):
635
659
 
636
660
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
637
661
 
638
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
662
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
639
663
 
640
664
  grads = TensorDict(grads)
641
665
 
@@ -781,6 +805,7 @@ class NeuralMemory(Module):
781
805
 
782
806
  return values[:, :seq_len]
783
807
 
808
+ @torch.no_grad()
784
809
  def forward_inference(
785
810
  self,
786
811
  token: Tensor,
@@ -854,13 +879,18 @@ class NeuralMemory(Module):
854
879
  return_aux_kv_loss = False,
855
880
  chunk_size = None,
856
881
  store_chunk_size = None,
857
- return_values = False
882
+ return_values = False,
883
+ return_next_state = False
858
884
  ):
859
885
  batch, seq_len = seq.shape[:2]
860
886
 
861
887
  if seq_len < self.retrieve_chunk_size:
862
888
  out = self.init_empty_memory_embed(batch, seq_len)
863
889
 
890
+ next_store_state = (seq_len, seq, None, None)
891
+
892
+ out = (out, next_store_state)
893
+
864
894
  if not return_aux_kv_loss:
865
895
  return out
866
896
 
@@ -870,16 +900,31 @@ class NeuralMemory(Module):
870
900
  mem_model_weights = self.init_weights()
871
901
 
872
902
  store_seq = default(store_seq, seq)
873
- store_chunk_size = default(store_chunk_size, chunk_size)
903
+
904
+ store_seq_len = store_seq.shape[-2]
905
+ store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
906
+ remainder = store_seq_len % store_chunk_size
874
907
 
875
908
  (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
876
909
 
877
910
  retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
878
911
 
879
- output = retrieved
912
+ # determine state for the storing of memories
913
+ # for transformer-xl like training with neural memory as well as inferencing with initial prompt
914
+
915
+ cache_store_seq = None
916
+
917
+ if remainder > 0:
918
+ cache_store_seq = store_seq[:, -remainder:]
919
+
920
+ updates = updates.apply(lambda t: t[:, -1:])
921
+
922
+ next_store_state = (seq_len, cache_store_seq, next_state, updates)
923
+
924
+ output = (retrieved, next_store_state)
880
925
 
881
926
  if return_values:
882
- output = (retrieved, values)
927
+ output = (*output, values)
883
928
 
884
929
  if not return_aux_kv_loss:
885
930
  return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.30
3
+ Version: 0.1.32
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -82,7 +82,7 @@ mem = NeuralMemory(
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
85
- retrieved = mem(seq)
85
+ retrieved, mem_state = mem(seq)
86
86
 
87
87
  assert seq.shape == retrieved.shape
88
88
  ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
4
+ titans_pytorch/neural_memory.py,sha256=Vfo1z1VztPDDXgFxjkiyOP29daDE7KTdnZeWXifvCJI,27456
5
+ titans_pytorch-0.1.32.dist-info/METADATA,sha256=_HPPht8nhLwH9GzLyZI-fh8JBSEoSxkENCSU2xuU_6A,6826
6
+ titans_pytorch-0.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.32.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
4
- titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
5
- titans_pytorch-0.1.30.dist-info/METADATA,sha256=o5flkZ0hNhZE06bSKVEFpbrkhuWB9putcaL_MZ0sJHA,6815
6
- titans_pytorch-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.30.dist-info/RECORD,,