titans-pytorch 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -481,6 +481,7 @@ class MemoryAsContextTransformer(Module):
481
481
  neural_memory_add_value_residual = False,
482
482
  num_longterm_mem_tokens = 0,
483
483
  num_persist_mem_tokens = 0,
484
+ neural_memory_batch_size = None,
484
485
  dim_head = 64,
485
486
  heads = 8,
486
487
  ff_mult = 4,
@@ -551,6 +552,7 @@ class MemoryAsContextTransformer(Module):
551
552
  mem = NeuralMemory(
552
553
  dim = dim,
553
554
  chunk_size = self.neural_memory_segment_len,
555
+ batch_size = neural_memory_batch_size,
554
556
  model = deepcopy(neural_memory_model),
555
557
  **neural_memory_kwargs
556
558
  )
@@ -6,7 +6,7 @@ from functools import partial
6
6
  from collections import namedtuple
7
7
 
8
8
  import torch
9
- from torch import nn, cat, Tensor
9
+ from torch import nn, cat, tensor, Tensor
10
10
  import torch.nn.functional as F
11
11
  from torch.nn import Linear, Module, Parameter, ParameterList
12
12
  from torch.func import functional_call, vmap, grad
@@ -39,7 +39,7 @@ w - num memory network weight parameters
39
39
  LinearNoBias = partial(Linear, bias = False)
40
40
 
41
41
  NeuralMemCache = namedtuple('NeuralMemCache', [
42
- 'seq',
42
+ 'seq_index',
43
43
  'weights',
44
44
  'cache_store_segment',
45
45
  'states',
@@ -63,6 +63,9 @@ def identity(t):
63
63
  def xnor(x, y):
64
64
  return not (x ^ y)
65
65
 
66
+ def divisible_by(num, den):
67
+ return (num % den) == 0
68
+
66
69
  def safe_cat(inputs, dim = -2):
67
70
  inputs = tuple(filter(exists, inputs))
68
71
 
@@ -73,6 +76,9 @@ def safe_cat(inputs, dim = -2):
73
76
 
74
77
  return cat(inputs, dim = dim)
75
78
 
79
+ def is_empty_tensor(t):
80
+ return t.numel() == 0
81
+
76
82
  def dict_get_shape(td):
77
83
  return {k: v.shape for k, v in td.items()}
78
84
 
@@ -118,7 +124,7 @@ def softclamp_max(t, max_value):
118
124
  return ((t / half_max_value).tanh() * half_max_value) + half_max_value
119
125
 
120
126
  def softclamp_grad_norm(t, max_value):
121
- if t.numel() == 0:
127
+ if is_empty_tensor(t):
122
128
  return t
123
129
 
124
130
  t, inverse = pack_one_with_inverse(t, 'bn *')
@@ -270,6 +276,7 @@ class NeuralMemory(Module):
270
276
  self,
271
277
  dim,
272
278
  chunk_size: int | tuple[int, int] = 1,
279
+ batch_size = None,
273
280
  dim_head = None,
274
281
  heads = 1,
275
282
  model: Module | None = None,
@@ -296,6 +303,13 @@ class NeuralMemory(Module):
296
303
 
297
304
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
298
305
 
306
+ # batch size
307
+
308
+ if exists(batch_size):
309
+ assert divisible_by(batch_size, self.store_chunk_size)
310
+
311
+ self.batch_size = batch_size
312
+
299
313
  # associative scan
300
314
 
301
315
  self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
@@ -460,9 +474,9 @@ class NeuralMemory(Module):
460
474
  seq,
461
475
  weights: dict[str, Tensor] | None = None,
462
476
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
463
- chunk_size = None,
477
+ seq_index = 0
464
478
  ):
465
- batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, default(chunk_size, self.store_chunk_size)
479
+ batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
466
480
 
467
481
  # curtail sequence by multiple of the chunk size
468
482
  # only a complete chunk of the sequence provides the memory for the next chunk
@@ -472,6 +486,8 @@ class NeuralMemory(Module):
472
486
 
473
487
  seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
474
488
 
489
+ next_seq_len_index = seq_index + round_down_seq_len
490
+
475
491
  # init weights if needed
476
492
  # weights of the memory network
477
493
 
@@ -568,7 +584,7 @@ class NeuralMemory(Module):
568
584
 
569
585
  if num_chunks == 0:
570
586
  updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
571
- next_store_state = NeuralMemCache(seq_len, weights, remainder, past_state, updates)
587
+ next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
572
588
 
573
589
  output = (updates, next_store_state)
574
590
 
@@ -607,7 +623,7 @@ class NeuralMemory(Module):
607
623
 
608
624
  next_state = (next_last_update, next_last_momentum)
609
625
 
610
- next_store_state = NeuralMemCache(seq_len, weights, remainder, next_state, updates)
626
+ next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
611
627
 
612
628
  # returns
613
629
 
@@ -619,9 +635,8 @@ class NeuralMemory(Module):
619
635
  self,
620
636
  seq,
621
637
  past_weights: dict[str, Tensor],
622
- chunk_size = None,
623
638
  ):
624
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
639
+ chunk_size = self.retrieve_chunk_size
625
640
  batch, seq_len = seq.shape[:2]
626
641
 
627
642
  seq = self.retrieve_norm(seq)
@@ -691,9 +706,8 @@ class NeuralMemory(Module):
691
706
  def forward_inference(
692
707
  self,
693
708
  token: Tensor,
694
- state = None,
709
+ state: NeuralMemCache | None = None,
695
710
  ):
696
-
697
711
  # unpack previous state
698
712
 
699
713
  if not exists(state):
@@ -707,6 +721,8 @@ class NeuralMemory(Module):
707
721
  if token.ndim == 2:
708
722
  token = rearrange(token, 'b d -> b 1 d')
709
723
 
724
+ assert token.shape[1] == 1
725
+
710
726
  # increment the sequence cache which is at most the chunk size
711
727
 
712
728
  cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
@@ -757,32 +773,99 @@ class NeuralMemory(Module):
757
773
  self,
758
774
  seq,
759
775
  store_seq = None,
760
- mem_model_weights: dict[str, Tensor] | None = None,
761
- past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
762
- chunk_size = None,
763
- store_chunk_size = None,
764
- return_next_state = False,
776
+ state: NeuralMemCache | None = None,
765
777
  ):
766
- batch, seq_len = seq.shape[:2]
778
+ if not exists(state):
779
+ state = (0, None, None, None, None)
780
+
781
+ seq_index, weights, cache_store_seq, past_state, updates = state
782
+
783
+ assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
767
784
 
768
785
  # store
769
786
 
770
787
  store_seq = default(store_seq, seq)
771
788
 
772
- updates, next_store_state = self.store_memories(
773
- store_seq,
774
- mem_model_weights,
775
- chunk_size = store_chunk_size,
776
- )
789
+ # functions
790
+
791
+ # compute split sizes of sequence
792
+ # for now manually update weights to last update at the correct boundaries
793
+
794
+ store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
795
+
796
+ need_update_weights = exists(batch_size)
797
+
798
+ # determine split sizes and when to update
799
+
800
+ if need_update_weights:
801
+ update_after_final_store = divisible_by(seq_index + store_seq_len, batch_size)
802
+
803
+ seq_range = torch.arange(store_seq_len) + seq_index + 1
804
+ batch_boundary = divisible_by(seq_range, batch_size)
805
+
806
+ indices = seq_range[batch_boundary] - seq_index
807
+
808
+ indices = F.pad(indices, (1, 0), value = 0)
809
+
810
+ if indices[-1] != store_seq_len:
811
+ indices = F.pad(indices, (0, 1), value = store_seq_len)
812
+
813
+ split_sizes = (indices[1:] - indices[:-1]).tolist()
814
+
815
+ assert sum(split_sizes) == store_seq_len
816
+ else:
817
+ split_sizes = (store_seq_len,)
818
+ update_after_final_store = False
819
+
820
+ # accumulate updates
821
+
822
+ updates = None
823
+
824
+ def accum_updates(past_updates, future_updates):
825
+ if not exists(past_updates):
826
+ return future_updates
827
+
828
+ return TensorDict({param_name: cat((past_update[:, :-1], future_update), dim = 1) for (param_name, past_update), (_, future_update) in zip(past_updates.items(), future_updates.items())})
829
+
830
+ # loop through chunks of store sequences
831
+
832
+ store_seqs = store_seq.split(split_sizes, dim = -2)
833
+
834
+ for ind, store_seq_chunk in enumerate(store_seqs):
835
+ is_last = ind == (len(store_seqs) - 1)
836
+
837
+ # store
838
+
839
+ next_updates, next_neural_mem_state = self.store_memories(
840
+ store_seq_chunk,
841
+ weights,
842
+ seq_index = seq_index,
843
+ past_state = past_state,
844
+ )
845
+
846
+ seq_index = next_neural_mem_state.seq_index
847
+ past_state = next_neural_mem_state.states
848
+
849
+ updates = accum_updates(updates, next_updates)
850
+
851
+ if is_last and not update_after_final_store:
852
+ continue
853
+
854
+ # update weights once batch size is fulfilled
855
+
856
+ last_update, _ = past_state
857
+
858
+ weights = last_update
859
+
860
+ next_neural_mem_state = list(next_neural_mem_state)
861
+ next_neural_mem_state[1] = last_update
862
+ next_neural_mem_state = NeuralMemCache(*next_neural_mem_state)
777
863
 
778
864
  # retrieve
779
865
 
780
866
  retrieved = self.retrieve_memories(
781
867
  seq,
782
- updates,
783
- chunk_size = chunk_size,
868
+ updates
784
869
  )
785
870
 
786
- output = (retrieved, next_store_state)
787
-
788
- return output
871
+ return retrieved, next_neural_mem_state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.10
3
+ Version: 0.2.11
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=RfJ1SvQH5_4PmlB7g-13wPAqYtCCUJxfmtaL0oBrRCU,24563
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=1wX8dbGENHWk7sfz7IFF1G8KY4U5tsNh3cqSDxTUf2U,26150
6
+ titans_pytorch-0.2.11.dist-info/METADATA,sha256=CMLW5FSamLp0cPhIohOD_yXjCXoxqCPzwJrA0e83vQE,6812
7
+ titans_pytorch-0.2.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.11.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=dmS37yBN0j9OqoMCsojuIPfT1EXLN8ackRdZwPb8xDY,24463
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=kc-cV7dK3WhdqRfOCrPW91nA0F56jUK94TE1irckQ34,23487
6
- titans_pytorch-0.2.10.dist-info/METADATA,sha256=k7u9eQDNAWG3QqzGqhcdN21D6LYWWRWdd5wZFb560q0,6812
7
- titans_pytorch-0.2.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.10.dist-info/RECORD,,