titans-pytorch 0.1.29__tar.gz → 0.1.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.29
3
+ Version: 0.1.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.29"
3
+ version = "0.1.31"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -124,18 +124,22 @@ def test_mac(
124
124
  assert logits.shape == (1, seq_len, 256)
125
125
 
126
126
  @pytest.mark.parametrize('sliding', (False, True))
127
- @pytest.mark.parametrize('mem_layers', ((), None, (4,)))
127
+ @pytest.mark.parametrize('mem_layers', (()))
128
+ @pytest.mark.parametrize('longterm_mems', (0, 4, 16))
129
+ @pytest.mark.parametrize('prompt_len', (0, 4, 16))
128
130
  def test_mac_sampling(
129
131
  sliding,
130
- mem_layers
132
+ mem_layers,
133
+ longterm_mems,
134
+ prompt_len
131
135
  ):
132
136
  transformer = MemoryAsContextTransformer(
133
137
  num_tokens = 256,
134
138
  dim = 256,
135
- depth = 2,
139
+ depth = 4,
136
140
  segment_len = 32,
137
141
  num_persist_mem_tokens = 4,
138
- num_longterm_mem_tokens = 0,
142
+ num_longterm_mem_tokens = longterm_mems,
139
143
  sliding_window_attn = sliding,
140
144
  neural_memory_layers = mem_layers,
141
145
  neural_mem_gate_attn_output = False
@@ -145,8 +149,10 @@ def test_mac_sampling(
145
149
 
146
150
  # after much training
147
151
 
148
- sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
149
- sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
152
+ prompt = ids[:, :prompt_len]
153
+
154
+ sampled = transformer.sample(prompt, 53, use_cache = False, temperature = 0.)
155
+ sampled_with_cache = transformer.sample(prompt, 53, use_cache = True, temperature = 0.)
150
156
 
151
157
  assert torch.allclose(sampled, sampled_with_cache)
152
158
 
@@ -582,13 +582,8 @@ class MemoryAsContextTransformer(Module):
582
582
  self,
583
583
  seq_index
584
584
  ):
585
- total_segment_len = self.attn_window_size
586
-
587
- seq = seq_index + 1
588
- seq -= int((seq % total_segment_len) == 0)
589
- last_segment_len = round_down_multiple(seq, total_segment_len)
590
- segment_seq = seq - last_segment_len
591
- return (segment_seq - self.segment_len) > 0
585
+ total_segment_len, segment_len = self.attn_window_size, self.segment_len
586
+ return ((seq_index % total_segment_len + 1) - segment_len) > 0
592
587
 
593
588
  def seq_len_with_longterm_mem(
594
589
  self,
@@ -597,7 +592,7 @@ class MemoryAsContextTransformer(Module):
597
592
  assert seq_len > 0
598
593
 
599
594
  segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
600
- return ceil(seq_len / segment_len) * num_mem + seq_len
595
+ return ((seq_len - 1) // segment_len) * num_mem + seq_len
601
596
 
602
597
  @torch.no_grad()
603
598
  def sample(
@@ -723,9 +718,9 @@ class MemoryAsContextTransformer(Module):
723
718
  is_inferencing = exists(cache)
724
719
 
725
720
  if not exists(cache):
726
- cache = (None, None)
721
+ cache = (seq_len_with_mem - 1, None, None)
727
722
 
728
- kv_caches, neural_mem_caches = cache
723
+ inference_seq_index, kv_caches, neural_mem_caches = cache
729
724
 
730
725
  kv_caches = iter(default(kv_caches, []))
731
726
  neural_mem_caches = iter(default(neural_mem_caches, []))
@@ -744,7 +739,8 @@ class MemoryAsContextTransformer(Module):
744
739
  # when inferencing, only do one token at a time
745
740
 
746
741
  if is_inferencing:
747
- x = x[:, -1:]
742
+ ind = inference_seq_index
743
+ x = x[:, ind:(ind + 1)]
748
744
 
749
745
  # expand and reduce streams for hyper connections
750
746
 
@@ -817,6 +813,17 @@ class MemoryAsContextTransformer(Module):
817
813
  if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
818
814
  next_kv_caches = next_kv_caches[..., 0:0, :]
819
815
 
816
+ next_cache = (
817
+ inference_seq_index + 1,
818
+ next_kv_caches,
819
+ next_neural_mem_caches
820
+ )
821
+
822
+ is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
823
+
824
+ if is_inferencing and is_longterm_mem:
825
+ return None, next_cache
826
+
820
827
  # hyper connection reducing of streams
821
828
 
822
829
  x = self.reduce_streams(x)
@@ -829,7 +836,7 @@ class MemoryAsContextTransformer(Module):
829
836
 
830
837
  x, _ = inverse_pack_mems(x)
831
838
 
832
- x = inverse_segment(x)
839
+ x = inverse_segment(x, remove_pad = False)
833
840
 
834
841
  x = x[:, :seq_len]
835
842
 
@@ -843,7 +850,7 @@ class MemoryAsContextTransformer(Module):
843
850
  if not return_cache:
844
851
  return logits
845
852
 
846
- return logits, (next_kv_caches, next_neural_mem_caches)
853
+ return logits, next_cache
847
854
 
848
855
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
849
856
 
@@ -409,6 +409,7 @@ class NeuralMemory(Module):
409
409
  ):
410
410
  super().__init__()
411
411
  dim_head = default(dim_head, dim)
412
+ assert not (heads == 1 and dim_head != dim)
412
413
 
413
414
  self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
414
415
 
@@ -566,7 +567,7 @@ class NeuralMemory(Module):
566
567
  ):
567
568
  assert xnor(exists(value_residual), exists(self.learned_value_residual))
568
569
 
569
- seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
570
+ seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
570
571
 
571
572
  # handle edge case
572
573
 
@@ -645,7 +646,7 @@ class NeuralMemory(Module):
645
646
 
646
647
  # restore batch and sequence dimension
647
648
 
648
- grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
649
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch * heads))
649
650
 
650
651
  # maybe per layer modulation
651
652
 
File without changes