titans-pytorch 0.3.10__tar.gz → 0.3.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.10
3
+ Version: 0.3.11
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.9
41
+ Requires-Dist: hyper-connections>=0.1.10
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.10"
3
+ version = "0.3.11"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "axial_positional_embedding>=0.3.10",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
- "hyper-connections>=0.1.9",
32
+ "hyper-connections>=0.1.10",
33
33
  "Ninja",
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
62
62
  # hyper connections / attend from x-transformers, which handles different queries and key lengths better
63
63
 
64
64
  from x_transformers.attend import Attend
65
+
65
66
  from hyper_connections import get_init_and_expand_reduce_stream_functions
66
67
 
67
68
  # proposed neural memory
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
515
516
 
516
517
  # hyper conection
517
518
 
518
- init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
519
+ init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
519
520
 
520
521
  self.layers = ModuleList([])
521
522
 
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
553
554
  mem_hyper_conn = None
554
555
 
555
556
  if layer in neural_memory_layers:
556
- mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
557
558
 
558
559
  mem = NeuralMemory(
559
560
  dim = dim,
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
571
572
  self.layers.append(ModuleList([
572
573
  mem_hyper_conn,
573
574
  mem,
574
- init_hyper_conn(dim = dim, branch = attn),
575
- init_hyper_conn(dim = dim, branch = ff)
575
+ init_hyper_conn(branch = attn),
576
+ init_hyper_conn(branch = ff)
576
577
  ]))
577
578
 
578
579
  self.norm = nn.RMSNorm(dim)
@@ -39,7 +39,7 @@ o - momentum orders
39
39
 
40
40
  LinearNoBias = partial(Linear, bias = False)
41
41
 
42
- NeuralMemCache = namedtuple('NeuralMemCache', [
42
+ NeuralMemState = namedtuple('NeuralMemState', [
43
43
  'seq_index',
44
44
  'weights',
45
45
  'cache_store_segment',
@@ -629,7 +629,7 @@ class NeuralMemory(Module):
629
629
 
630
630
  if num_chunks == 0:
631
631
  updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
632
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, past_state, updates)
632
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
633
633
 
634
634
  output = (updates, next_store_state)
635
635
 
@@ -682,13 +682,11 @@ class NeuralMemory(Module):
682
682
 
683
683
  next_state = (next_last_update, next_last_momentum)
684
684
 
685
- next_store_state = NeuralMemCache(next_seq_len_index, weights, remainder, next_state, updates)
685
+ next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
686
686
 
687
- # returns
687
+ # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
688
688
 
689
- output = (updates, next_store_state)
690
-
691
- return output
689
+ return updates, next_store_state
692
690
 
693
691
  def retrieve_memories(
694
692
  self,
@@ -785,7 +783,7 @@ class NeuralMemory(Module):
785
783
  self,
786
784
  seq,
787
785
  store_seq = None,
788
- state: NeuralMemCache | None = None,
786
+ state: NeuralMemState | None = None,
789
787
  prev_weights = None
790
788
  ):
791
789
  if seq.ndim == 2:
@@ -37,6 +37,7 @@ NUM_LONGTERM_MEM = 4
37
37
  NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
38
38
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
39
39
  NEURAL_MEM_MOMENTUM = True
40
+ NEURAL_MEM_MOMENTUM_ORDER = 1
40
41
  NEURAL_MEM_QK_NORM = True
41
42
  NEURAL_MEM_MAX_LR = 1e-1
42
43
  USE_MEM_ATTENTION_MODEL = False
@@ -115,6 +116,7 @@ model = MemoryAsContextTransformer(
115
116
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
116
117
  qk_rmsnorm = NEURAL_MEM_QK_NORM,
117
118
  momentum = NEURAL_MEM_MOMENTUM,
119
+ momentum_order = NEURAL_MEM_MOMENTUM_ORDER,
118
120
  default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
119
121
  use_accelerated_scan = USE_ACCELERATED_SCAN,
120
122
  per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
File without changes