titans-pytorch 0.1.31__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- from titans_pytorch.titans import (
1
+ from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
65
65
 
66
66
  # proposed neural memory
67
67
 
68
- from titans_pytorch.titans import NeuralMemory
68
+ from titans_pytorch.neural_memory import NeuralMemory
69
69
 
70
70
  # constants
71
71
 
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
106
106
  zeros = ((0, 0) * dims_from_right)
107
107
  return F.pad(t, (*zeros, *pad), value = value)
108
108
 
109
- def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
109
+ def pad_and_segment_with_inverse(
110
+ seq,
111
+ segment_len,
112
+ fold_into_batch = True,
113
+ ):
110
114
  batch, seq_len = seq.shape[:2]
111
115
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
112
116
 
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
119
123
  if fold_into_batch:
120
124
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
121
125
 
122
- def inverse(out, remove_pad = True):
126
+ shape = seq.shape
127
+
128
+ def inverse(out):
129
+ unchanged_shape = out.shape == shape
130
+
123
131
  if fold_into_batch:
124
132
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
125
133
 
126
- if needs_pad and remove_pad:
134
+ if needs_pad and unchanged_shape:
127
135
  out = out[..., :-padding, :]
128
136
 
129
137
  return out
@@ -690,7 +698,7 @@ class MemoryAsContextTransformer(Module):
690
698
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
691
699
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
692
700
 
693
- x = inverse_segment(x, remove_pad = False)
701
+ x = inverse_segment(x)
694
702
 
695
703
  # splice out unneeded tokens from padding for longterm mems
696
704
 
@@ -759,14 +767,13 @@ class MemoryAsContextTransformer(Module):
759
767
  mem_input, add_residual = mem_hyper_conn(x)
760
768
 
761
769
  if not is_inferencing:
762
- retrieved, mem_kv_aux_loss = mem(
770
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
763
771
  mem_input,
764
772
  return_aux_kv_loss = True
765
773
  )
766
774
 
767
775
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
768
776
 
769
- next_neural_mem_cache = (seq_len, None, None, None)
770
777
  else:
771
778
  retrieved, next_neural_mem_cache = mem.forward_inference(
772
779
  mem_input,
@@ -836,7 +843,7 @@ class MemoryAsContextTransformer(Module):
836
843
 
837
844
  x, _ = inverse_pack_mems(x)
838
845
 
839
- x = inverse_segment(x, remove_pad = False)
846
+ x = inverse_segment(x)
840
847
 
841
848
  x = x[:, :seq_len]
842
849
 
@@ -38,8 +38,11 @@ LinearNoBias = partial(Linear, bias = False)
38
38
  def exists(v):
39
39
  return v is not None
40
40
 
41
- def default(v, d):
42
- return v if exists(v) else d
41
+ def default(*args):
42
+ for arg in args:
43
+ if exists(arg):
44
+ return arg
45
+ return None
43
46
 
44
47
  def xnor(x, y):
45
48
  return not (x ^ y)
@@ -468,7 +471,12 @@ class NeuralMemory(Module):
468
471
  weighted_loss = loss * loss_weights
469
472
  return weighted_loss.sum(), weighted_loss.mean()
470
473
 
471
- self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
474
+ # two functions
475
+
476
+ grad_fn = grad(forward_and_loss, has_aux = True)
477
+
478
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
479
+ self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
472
480
 
473
481
  # queries for retrieving from the model
474
482
 
@@ -561,6 +569,7 @@ class NeuralMemory(Module):
561
569
  seq,
562
570
  weights: dict[str, Tensor],
563
571
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
572
+ prev_layer_updates: dict[str, Tensor] | None = None,
564
573
  return_aux_kv_loss = False,
565
574
  chunk_size = None,
566
575
  value_residual = None
@@ -583,10 +592,25 @@ class NeuralMemory(Module):
583
592
 
584
593
  seq = seq[:, :round_down_seq_len]
585
594
 
595
+ # per sample grad function
596
+
597
+ per_sample_grad_fn = self.per_sample_grad_fn
598
+
586
599
  # weights of the memory network
587
600
 
588
601
  weights = TensorDict(weights)
589
602
 
603
+ # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
604
+ # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
605
+ # improvise (or perhaps correcting to) a solution
606
+
607
+ if exists(prev_layer_updates):
608
+ prev_layer_updates = TensorDict(weights)
609
+
610
+ weights = weights + prev_layer_updates
611
+
612
+ per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
613
+
590
614
  # derive learned hparams for optimization of memory network
591
615
 
592
616
  adaptive_lr = self.to_adaptive_step(seq)
@@ -635,7 +659,7 @@ class NeuralMemory(Module):
635
659
 
636
660
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
637
661
 
638
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
662
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
639
663
 
640
664
  grads = TensorDict(grads)
641
665
 
@@ -781,6 +805,7 @@ class NeuralMemory(Module):
781
805
 
782
806
  return values[:, :seq_len]
783
807
 
808
+ @torch.no_grad()
784
809
  def forward_inference(
785
810
  self,
786
811
  token: Tensor,
@@ -854,13 +879,18 @@ class NeuralMemory(Module):
854
879
  return_aux_kv_loss = False,
855
880
  chunk_size = None,
856
881
  store_chunk_size = None,
857
- return_values = False
882
+ return_values = False,
883
+ return_next_state = False
858
884
  ):
859
885
  batch, seq_len = seq.shape[:2]
860
886
 
861
887
  if seq_len < self.retrieve_chunk_size:
862
888
  out = self.init_empty_memory_embed(batch, seq_len)
863
889
 
890
+ next_store_state = (seq_len, seq, None, None)
891
+
892
+ out = (out, next_store_state)
893
+
864
894
  if not return_aux_kv_loss:
865
895
  return out
866
896
 
@@ -870,16 +900,31 @@ class NeuralMemory(Module):
870
900
  mem_model_weights = self.init_weights()
871
901
 
872
902
  store_seq = default(store_seq, seq)
873
- store_chunk_size = default(store_chunk_size, chunk_size)
903
+
904
+ store_seq_len = store_seq.shape[-2]
905
+ store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
906
+ remainder = store_seq_len % store_chunk_size
874
907
 
875
908
  (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
876
909
 
877
910
  retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
878
911
 
879
- output = retrieved
912
+ # determine state for the storing of memories
913
+ # for transformer-xl like training with neural memory as well as inferencing with initial prompt
914
+
915
+ cache_store_seq = None
916
+
917
+ if remainder > 0:
918
+ cache_store_seq = store_seq[:, -remainder:]
919
+
920
+ updates = updates.apply(lambda t: t[:, -1:])
921
+
922
+ next_store_state = (seq_len, cache_store_seq, next_state, updates)
923
+
924
+ output = (retrieved, next_store_state)
880
925
 
881
926
  if return_values:
882
- output = (retrieved, values)
927
+ output = (*output, values)
883
928
 
884
929
  if not return_aux_kv_loss:
885
930
  return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.31
3
+ Version: 0.1.32
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -82,7 +82,7 @@ mem = NeuralMemory(
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
85
- retrieved = mem(seq)
85
+ retrieved, mem_state = mem(seq)
86
86
 
87
87
  assert seq.shape == retrieved.shape
88
88
  ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
4
+ titans_pytorch/neural_memory.py,sha256=Vfo1z1VztPDDXgFxjkiyOP29daDE7KTdnZeWXifvCJI,27456
5
+ titans_pytorch-0.1.32.dist-info/METADATA,sha256=_HPPht8nhLwH9GzLyZI-fh8JBSEoSxkENCSU2xuU_6A,6826
6
+ titans_pytorch-0.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.32.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
4
- titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
5
- titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
6
- titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.31.dist-info/RECORD,,