titans-pytorch 0.3.23__py3-none-any.whl → 0.3.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -289,6 +289,8 @@ class NeuralMemory(Module):
289
289
  self.heads = heads
290
290
 
291
291
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
292
+ self.split_kv_heads = Rearrange('b n (h u d) -> b h (n u) d', h = heads, u = num_kv_per_token)
293
+
292
294
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
293
295
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
294
296
 
@@ -522,7 +524,8 @@ class NeuralMemory(Module):
522
524
  weights: dict[str, Tensor] | None = None,
523
525
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
524
526
  seq_index = 0,
525
- prev_weights = None
527
+ prev_weights = None,
528
+ mask: Tensor | None = None,
526
529
  ):
527
530
  if self.qkv_receives_diff_views:
528
531
  _, batch, seq_len = seq.shape[:3]
@@ -596,22 +599,28 @@ class NeuralMemory(Module):
596
599
 
597
600
  # maybe multi head
598
601
 
599
- keys, values = map(self.split_heads, (keys, values))
600
-
601
- batch = keys.shape[0]
602
+ keys, values = map(self.split_kv_heads, (keys, values))
602
603
 
603
- # take care of chunking
604
+ # maybe keys rmsnorm
604
605
 
605
- keys, values = tuple(rearrange(t, 'b h (n c) (u d) -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
606
+ keys = self.k_norm(keys)
606
607
 
607
- # maybe qk rmsnorm
608
+ # take care of chunking
608
609
 
609
- keys = self.k_norm(keys)
610
+ keys, values = tuple(rearrange(t, 'b h (n c u) d -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
610
611
 
611
612
  # adaptive lr
612
613
 
613
614
  adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
614
615
 
616
+ # optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions
617
+
618
+ if exists(mask):
619
+ mask = mask[..., :round_down_seq_len]
620
+ mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size)
621
+
622
+ adaptive_lr = torch.where(mask, adaptive_lr, 0.)
623
+
615
624
  # maybe add previous layer weight
616
625
 
617
626
  assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
@@ -833,7 +842,8 @@ class NeuralMemory(Module):
833
842
  seq,
834
843
  store_seq = None,
835
844
  state: NeuralMemState | None = None,
836
- prev_weights = None
845
+ prev_weights = None,
846
+ store_mask: Tensor | None = None
837
847
  ):
838
848
  is_multi_input = self.qkv_receives_diff_views
839
849
 
@@ -910,6 +920,11 @@ class NeuralMemory(Module):
910
920
 
911
921
  store_seqs = store_seq.split(split_sizes, dim = -2)
912
922
 
923
+ if exists(store_mask):
924
+ store_masks = store_mask.split(split_sizes, dim = -1)
925
+ else:
926
+ store_masks = (None,) * len(split_sizes)
927
+
913
928
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
914
929
 
915
930
  gate = None
@@ -917,7 +932,7 @@ class NeuralMemory(Module):
917
932
  if exists(self.transition_gate):
918
933
  gate = self.transition_gate.sigmoid()
919
934
 
920
- for ind, store_seq_chunk in enumerate(store_seqs):
935
+ for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)):
921
936
  is_last = ind == (len(store_seqs) - 1)
922
937
 
923
938
  # store
@@ -927,7 +942,8 @@ class NeuralMemory(Module):
927
942
  weights,
928
943
  seq_index = seq_index,
929
944
  past_state = past_state,
930
- prev_weights = prev_weights
945
+ prev_weights = prev_weights,
946
+ mask = maybe_store_mask
931
947
  )
932
948
 
933
949
  weights = next_neural_mem_state.weights
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.23
3
+ Version: 0.3.25
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,29
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
3
  titans_pytorch/mac_transformer.py,sha256=grD327B3OCIy7d23jNUWIoUo1bIgXUqD26dXWCjdi28,25565
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=idjjW_K5qZS29isy7pI_pzGDq95vONwaPyXNLnDaj1w,31161
6
- titans_pytorch-0.3.23.dist-info/METADATA,sha256=KfG0t2AGKAWoifI81y2Bugte3EihJs3w4sjxEYompsw,6817
7
- titans_pytorch-0.3.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.23.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=uh5NbtAAzfPeZPFe7uhgnpUF6qyP0zjP0eXPIgY5pfc,31929
6
+ titans_pytorch-0.3.25.dist-info/METADATA,sha256=SZwazbaNFe1GstoF45zI_aNMpzgXAqv4mOh78gMN5-U,6817
7
+ titans_pytorch-0.3.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.25.dist-info/RECORD,,