titans-pytorch 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
62
62
  # hyper connections / attend from x-transformers, which handles different queries and key lengths better
63
63
 
64
64
  from x_transformers.attend import Attend
65
+
65
66
  from hyper_connections import get_init_and_expand_reduce_stream_functions
66
67
 
67
68
  # proposed neural memory
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
515
516
 
516
517
  # hyper conection
517
518
 
518
- init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
519
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
519
520
 
520
521
  self.layers = ModuleList([])
521
522
 
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
553
554
  mem_hyper_conn = None
554
555
 
555
556
  if layer in neural_memory_layers:
556
- mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
558
 
558
559
  mem = NeuralMemory(
559
560
  dim = dim,
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
571
572
  self.layers.append(ModuleList([
572
573
  mem_hyper_conn,
573
574
  mem,
574
- init_hyper_conn(dim = dim, branch = attn),
575
- init_hyper_conn(dim = dim, branch = ff)
575
+ init_hyper_conn(branch = attn),
576
+ init_hyper_conn(branch = ff)
576
577
  ]))
577
578
 
578
579
  self.norm = nn.RMSNorm(dim)
@@ -39,7 +39,7 @@ o - momentum orders
39
39
 
40
40
  LinearNoBias = partial(Linear, bias = False)
41
41
 
42
- NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ NeuralMemState = namedtuple('NeuralMemState', [
43
43
  'seq_index',
44
44
  'weights',
45
45
  'cache_store_segment',
@@ -629,7 +629,7 @@ class NeuralMemory(Module):
629
629
 
630
630
  if num_chunks == 0:
631
631
  updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
632
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
632
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
633
633
 
634
634
  output = (updates, next_store_state)
635
635
 
@@ -682,13 +682,11 @@ class NeuralMemory(Module):
682
682
 
683
683
  next_state = (next_last_update, next_last_momentum)
684
684
 
685
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
685
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
686
686
 
687
- # returns
687
+ # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
688
688
 
689
- output = (updates, next_store_state)
690
-
691
- return output
689
+ return updates, next_store_state
692
690
 
693
691
  def retrieve_memories(
694
692
  self,
@@ -785,7 +783,7 @@ class NeuralMemory(Module):
785
783
  self,
786
784
  seq,
787
785
  store_seq = None,
788
- state: NeuralMemCache | None = None,
786
+ state: NeuralMemState | None = None,
789
787
  prev_weights = None
790
788
  ):
791
789
  if seq.ndim == 2:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.10
3
+ Version: 0.3.11
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.9
41
+ Requires-Dist: hyper-connections>=0.1.10
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=7YglrQaDpKS2hbpBBwx7PmqhJdjyvFEPZDt_QXmnUMM,28878
6
+ titans_pytorch-0.3.11.dist-info/METADATA,sha256=xAEvavDiCj__5Bl_5UXaG__BycdUB2DzHOud-nwsn1c,6817
7
+ titans_pytorch-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.11.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=BeOnq41gjZeq-XJFjkHE44F9dLzsg9mm36EBYZ4wHMA,28814
6
- titans_pytorch-0.3.10.dist-info/METADATA,sha256=sA_Dx_x5RMcpz5-vUPDHuz__tHYfKzs4W_BgY4CHPdk,6816
7
- titans_pytorch-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.10.dist-info/RECORD,,