titans-pytorch 0.1.38__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.38
3
+ Version: 0.2.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch (wip)
59
+ ## Titans - Pytorch
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  <img src="./fig1.png" width="400px"></img>
4
4
 
5
- ## Titans - Pytorch (wip)
5
+ ## Titans - Pytorch
6
6
 
7
7
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
8
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.38"
3
+ version = "0.2.1"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -184,9 +184,10 @@ def test_mac(
184
184
  assert logits.shape == (1, seq_len, 256)
185
185
 
186
186
  @pytest.mark.parametrize('sliding', (False, True))
187
- @pytest.mark.parametrize('mem_layers', (()))
187
+ @pytest.mark.parametrize('mem_layers', ((), None))
188
188
  @pytest.mark.parametrize('longterm_mems', (0, 4, 16))
189
- @pytest.mark.parametrize('prompt_len', (0, 4, 16))
189
+ @pytest.mark.parametrize('prompt_len', (4, 16))
190
+ @torch_default_dtype(torch.float64)
190
191
  def test_mac_sampling(
191
192
  sliding,
192
193
  mem_layers,
@@ -111,6 +111,7 @@ def pad_and_segment_with_inverse(
111
111
  seq,
112
112
  segment_len,
113
113
  fold_into_batch = True,
114
+ inverse_remove_pad = True
114
115
  ):
115
116
  batch, seq_len = seq.shape[:2]
116
117
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
@@ -124,15 +125,12 @@ def pad_and_segment_with_inverse(
124
125
  if fold_into_batch:
125
126
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
126
127
 
127
- shape = seq.shape
128
-
129
128
  def inverse(out):
130
- unchanged_shape = out.shape == shape
131
129
 
132
130
  if fold_into_batch:
133
131
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
134
132
 
135
- if needs_pad and unchanged_shape:
133
+ if needs_pad and inverse_remove_pad:
136
134
  out = out[..., :-padding, :]
137
135
 
138
136
  return out
@@ -493,7 +491,8 @@ class MemoryAsContextTransformer(Module):
493
491
  aux_kv_recon_loss_weight = 0.,
494
492
  use_flex_attn = False,
495
493
  sliding_window_attn = False,
496
- weight_tie_memory_model = False
494
+ weight_tie_memory_model = False,
495
+ prev_neural_mem_update_for_weights = None
497
496
  ):
498
497
  super().__init__()
499
498
 
@@ -535,6 +534,7 @@ class MemoryAsContextTransformer(Module):
535
534
  assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
536
535
 
537
536
  self.weight_tie_memory_model = weight_tie_memory_model
537
+ self.prev_neural_mem_update_for_weights = default(prev_neural_mem_update_for_weights, weight_tie_memory_model)
538
538
 
539
539
  # value residual learning for neural memory
540
540
 
@@ -704,7 +704,7 @@ class MemoryAsContextTransformer(Module):
704
704
 
705
705
  # math
706
706
 
707
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
707
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, prev_neural_mem_update_for_weights = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.prev_neural_mem_update_for_weights
708
708
 
709
709
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
710
710
 
@@ -714,7 +714,7 @@ class MemoryAsContextTransformer(Module):
714
714
 
715
715
  # intersperse longterm memory
716
716
 
717
- x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
717
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len, inverse_remove_pad = False)
718
718
 
719
719
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
720
720
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
@@ -816,7 +816,7 @@ class MemoryAsContextTransformer(Module):
816
816
  if self.mem_add_value_residual:
817
817
  mem_value_residual = next_mem_value_residual
818
818
 
819
- if weight_tie_memory_model:
819
+ if prev_neural_mem_update_for_weights:
820
820
  neural_memory_updates = next_neural_mem_cache.updates
821
821
 
822
822
  if self.gate_attn_output:
@@ -856,7 +856,9 @@ class MemoryAsContextTransformer(Module):
856
856
 
857
857
  next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
858
858
 
859
- if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
859
+ kv_cache_length = next_kv_caches.shape[-2]
860
+
861
+ if not self.sliding_window_attn and divisible_by(kv_cache_length, attn_window_size):
860
862
  next_kv_caches = next_kv_caches[..., 0:0, :]
861
863
 
862
864
  next_cache = (
@@ -878,7 +880,7 @@ class MemoryAsContextTransformer(Module):
878
880
 
879
881
  if not is_inferencing:
880
882
 
881
- x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size)
883
+ x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size, inverse_remove_pad = False)
882
884
 
883
885
  x, _ = inverse_pack_mems(x)
884
886
 
@@ -53,6 +53,7 @@ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
53
53
 
54
54
  USE_ACCELERATED_SCAN = True
55
55
  USE_FLEX_ATTN = True
56
+ USE_FAST_INFERENCE = False
56
57
 
57
58
  # wandb experiment tracker
58
59
 
@@ -163,6 +164,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
163
164
  prime = decode_tokens(inp)
164
165
  print(f'%s \n\n %s', (prime, '*' * 100))
165
166
 
166
- sample = model.sample(inp[None, ...], GENERATE_LENGTH)
167
+ sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
167
168
  output_str = decode_tokens(sample[0])
168
169
  print(output_str)
File without changes
File without changes
File without changes