titans-pytorch 0.2.4__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch
59
+ ## Titans - Pytorch (wip)
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  <img src="./fig1.png" width="400px"></img>
4
4
 
5
- ## Titans - Pytorch
5
+ ## Titans - Pytorch (wip)
6
6
 
7
7
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
8
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.4"
3
+ version = "0.2.6"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
488
488
  neural_memory_model: Module | None = None,
489
489
  neural_memory_kwargs: dict = dict(),
490
490
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 0.,
491
+ aux_kv_recon_loss_weight = 1.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
494
  weight_tie_memory_model = False,
@@ -1,8 +1,15 @@
1
1
  import torch
2
- from torch import nn
2
+ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ # functions
7
+
8
+ def l2norm(t):
9
+ return F.normalize(t, dim = -1)
10
+
11
+ # memory mlp proposed in TTT
12
+
6
13
  class MemoryMLP(Module):
7
14
  def __init__(
8
15
  self,
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
19
26
  self,
20
27
  x
21
28
  ):
29
+ residual = x
30
+
22
31
  for ind, weight in enumerate(self.weights):
23
32
  is_first = ind == 0
24
33
 
25
34
  if not is_first:
26
- x = F.silu(x)
35
+ x = F.gelu(x)
27
36
 
28
37
  x = x @ weight
29
38
 
30
- return x
39
+ return x + residual
31
40
 
32
41
  # memory mlp, but with gated residual + final projection
33
42
 
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
36
45
  self,
37
46
  dim,
38
47
  depth,
39
- expansion_factor = 2.
48
+ expansion_factor = 4.
40
49
  ):
41
50
  super().__init__()
42
51
  dim_hidden = int(dim * expansion_factor)
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
58
67
  self,
59
68
  x
60
69
  ):
70
+ residual = x
71
+
61
72
  for weight1, weight2, to_gates in self.weights:
62
73
  res = x
63
74
 
64
75
  hidden = x @ weight1
65
- hidden = F.silu(hidden)
76
+ hidden = F.gelu(hidden)
66
77
  branch_out = hidden @ weight2
67
78
 
68
79
  # gated residual
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
70
81
  gates = cat((branch_out, res), dim = -1) @ to_gates
71
82
  x = res.lerp(branch_out, gates.sigmoid())
72
83
 
73
- return x @ self.final_proj
84
+ out = x @ self.final_proj
85
+
86
+ return out + residual
74
87
 
75
88
  # memory mlp with factorized weights
76
89
  # so can tradeoff capacity for smaller chunk sizes
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
98
111
  self,
99
112
  x
100
113
  ):
114
+ residual = x
115
+
101
116
  for ind, (weight1, weight2) in enumerate(self.weights):
102
117
  is_first = ind == 0
103
118
 
104
119
  if not is_first:
105
- x = F.silu(x)
120
+ x = F.gelu(x)
106
121
 
107
122
  x = x @ weight1 @ weight2
108
123
 
109
- return x
124
+ return x + residual
110
125
 
111
126
  # improvised attention as memory module
112
127
 
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
133
148
  nn.init.xavier_uniform_(weight)
134
149
 
135
150
  def forward(self, x):
151
+ residual = x
152
+
136
153
  wq, wk, wv, ffw1, ffw2 = self.weights
137
154
 
138
- q = F.normalize(x @ wq, dim = -1)
139
- k = F.normalize(x @ wk, dim = -1)
155
+ q = l2norm(x @ wq)
156
+ k = l2norm(x @ wk)
140
157
  v = x @ wv
141
158
 
142
159
  attn_out = F.scaled_dot_product_attention(
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
145
162
  is_causal = True
146
163
  )
147
164
 
148
- x = x + attn_out
165
+ # parallel attention + feedforward block
166
+ # as in PaLM + Gpt-J
149
167
 
150
- h = F.silu(x @ ffw1)
151
- out = h @ ffw2
168
+ h = F.gelu(x @ ffw1)
169
+ ff_out = h @ ffw2
152
170
 
153
- return out
171
+ return attn_out + ff_out + residual
@@ -304,13 +304,26 @@ class NeuralMemory(Module):
304
304
  nn.Sigmoid()
305
305
  ) if heads > 1 else None
306
306
 
307
- # memory mlp
307
+ # memory model
308
308
 
309
309
  if not exists(model):
310
310
  model = MemoryMLP(dim_head, **default_model_kwargs)
311
311
 
312
+ # validate memory model
313
+
312
314
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
313
315
 
316
+ test_shape = (3, 2, dim_head)
317
+
318
+ with torch.no_grad():
319
+ try:
320
+ test_input = torch.randn(test_shape)
321
+ mem_model_output = model(test_input)
322
+ except:
323
+ raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
324
+
325
+ assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
326
+
314
327
  # the memory is the weights of the model
315
328
 
316
329
  self.memory_model = model
@@ -772,10 +785,14 @@ class NeuralMemory(Module):
772
785
 
773
786
  # retrieve
774
787
 
775
- if exists(prev_layer_updates):
776
- prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
788
+ retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
789
+
790
+ if retrieve_chunk_size != 1:
791
+ if exists(prev_layer_updates):
792
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
793
+
794
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
777
795
 
778
- updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
779
796
 
780
797
  retrieved = self.retrieve_memories(
781
798
  seq,
@@ -30,17 +30,18 @@ SEQ_LEN = 512
30
30
  NEURAL_MEMORY_DEPTH = 2
31
31
  NUM_PERSIST_MEM = 4
32
32
  NUM_LONGTERM_MEM = 4
33
- NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
33
+ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
34
34
  NEURAL_MEM_GATE_ATTN_OUTPUT = False
35
35
  NEURAL_MEM_MOMENTUM = True
36
- NEURAL_MEM_QK_NORM = False
36
+ NEURAL_MEM_QK_NORM = True
37
37
  WINDOW_SIZE = 32
38
- NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
38
+ NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
39
39
  SLIDING_WINDOWS = True
40
- WEIGHT_TIE_MEMORY_MODEL = True # set to have memory MLP shared across layers
41
- STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
40
+ WEIGHT_TIE_MEMORY_MODEL = False # set to have memory MLP shared across layers
41
+ PREV_MEM_UPDATE_FOR_WEIGHTS = True,
42
+ STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
42
43
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
43
- KV_RECON_LOSS_WEIGHT = 0.
44
+ KV_RECON_LOSS_WEIGHT = 1.
44
45
 
45
46
  # experiment related
46
47
 
@@ -90,6 +91,7 @@ model = MemoryAsContextTransformer(
90
91
  use_flex_attn = USE_FLEX_ATTN,
91
92
  sliding_window_attn = SLIDING_WINDOWS,
92
93
  weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
94
+ prev_neural_mem_update_for_weights = PREV_MEM_UPDATE_FOR_WEIGHTS,
93
95
  neural_memory_model = MemoryMLP(
94
96
  dim = 64,
95
97
  depth = NEURAL_MEMORY_DEPTH
File without changes
File without changes
File without changes