titans-pytorch 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,13 @@ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ # functions
7
+
8
+ def l2norm(t):
9
+ return F.normalize(t, dim = -1)
10
+
11
+ # memory mlp proposed in TTT
12
+
6
13
  class MemoryMLP(Module):
7
14
  def __init__(
8
15
  self,
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
19
26
  self,
20
27
  x
21
28
  ):
29
+ residual = x
30
+
22
31
  for ind, weight in enumerate(self.weights):
23
32
  is_first = ind == 0
24
33
 
25
34
  if not is_first:
26
- x = F.silu(x)
35
+ x = F.gelu(x)
27
36
 
28
37
  x = x @ weight
29
38
 
30
- return x
39
+ return x + residual
31
40
 
32
41
  # memory mlp, but with gated residual + final projection
33
42
 
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
36
45
  self,
37
46
  dim,
38
47
  depth,
39
- expansion_factor = 2.
48
+ expansion_factor = 4.
40
49
  ):
41
50
  super().__init__()
42
51
  dim_hidden = int(dim * expansion_factor)
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
58
67
  self,
59
68
  x
60
69
  ):
70
+ residual = x
71
+
61
72
  for weight1, weight2, to_gates in self.weights:
62
73
  res = x
63
74
 
64
75
  hidden = x @ weight1
65
- hidden = F.silu(hidden)
76
+ hidden = F.gelu(hidden)
66
77
  branch_out = hidden @ weight2
67
78
 
68
79
  # gated residual
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
70
81
  gates = cat((branch_out, res), dim = -1) @ to_gates
71
82
  x = res.lerp(branch_out, gates.sigmoid())
72
83
 
73
- return x @ self.final_proj
84
+ out = x @ self.final_proj
85
+
86
+ return out + residual
74
87
 
75
88
  # memory mlp with factorized weights
76
89
  # so can tradeoff capacity for smaller chunk sizes
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
98
111
  self,
99
112
  x
100
113
  ):
114
+ residual = x
115
+
101
116
  for ind, (weight1, weight2) in enumerate(self.weights):
102
117
  is_first = ind == 0
103
118
 
104
119
  if not is_first:
105
- x = F.silu(x)
120
+ x = F.gelu(x)
106
121
 
107
122
  x = x @ weight1 @ weight2
108
123
 
109
- return x
124
+ return x + residual
110
125
 
111
126
  # improvised attention as memory module
112
127
 
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
133
148
  nn.init.xavier_uniform_(weight)
134
149
 
135
150
  def forward(self, x):
151
+ residual = x
152
+
136
153
  wq, wk, wv, ffw1, ffw2 = self.weights
137
154
 
138
- q = F.normalize(x @ wq, dim = -1)
139
- k = F.normalize(x @ wk, dim = -1)
155
+ q = l2norm(x @ wq)
156
+ k = l2norm(x @ wk)
140
157
  v = x @ wv
141
158
 
142
159
  attn_out = F.scaled_dot_product_attention(
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
145
162
  is_causal = True
146
163
  )
147
164
 
148
- x = x + attn_out
165
+ # parallel attention + feedforward block
166
+ # as in PaLM + Gpt-J
149
167
 
150
- h = F.silu(x @ ffw1)
151
- out = h @ ffw2
168
+ h = F.gelu(x @ ffw1)
169
+ ff_out = h @ ffw2
152
170
 
153
- return out
171
+ return attn_out + ff_out + residual
@@ -51,6 +51,9 @@ def default(*args):
51
51
  return arg
52
52
  return None
53
53
 
54
+ def identity(t):
55
+ return t
56
+
54
57
  def xnor(x, y):
55
58
  return not (x ^ y)
56
59
 
@@ -64,9 +67,6 @@ def safe_cat(inputs, dim = -2):
64
67
 
65
68
  return cat(inputs, dim = dim)
66
69
 
67
- def identity(t):
68
- return t
69
-
70
70
  def dict_get_shape(td):
71
71
  return {k: v.shape for k, v in td.items()}
72
72
 
@@ -454,14 +454,14 @@ class NeuralMemory(Module):
454
454
 
455
455
  weights = TensorDict(weights)
456
456
 
457
- # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
458
- # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
459
- # improvise (or perhaps correcting to) a solution
457
+ # allow for neural memory of a previous layer to influence surprise of current layer
458
+
459
+ weights_for_surprise = weights
460
460
 
461
461
  if exists(prev_layer_updates):
462
462
  prev_layer_updates = TensorDict(prev_layer_updates)
463
463
 
464
- weights = weights + prev_layer_updates
464
+ weights_for_surprise = weights_for_surprise + prev_layer_updates
465
465
 
466
466
  per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
467
467
 
@@ -506,11 +506,11 @@ class NeuralMemory(Module):
506
506
  # flatten batch and time if surprise depends on previous layer memory model
507
507
 
508
508
  if exists(prev_layer_updates):
509
- weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
509
+ weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
510
510
 
511
511
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
512
512
 
513
- grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
513
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
514
514
 
515
515
  grads = TensorDict(grads)
516
516
 
@@ -536,7 +536,15 @@ class NeuralMemory(Module):
536
536
 
537
537
  if not exists(past_state):
538
538
  empty_dict = {key: None for key in weights.keys()}
539
- past_state = (empty_dict, empty_dict)
539
+
540
+ # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
541
+
542
+ minibatch_init_weight = weights
543
+
544
+ if dict_get_shape(weights) == self.init_weight_shape:
545
+ minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
546
+
547
+ past_state = (minibatch_init_weight, empty_dict)
540
548
 
541
549
  past_last_update, past_last_momentum = past_state
542
550
 
@@ -734,7 +742,7 @@ class NeuralMemory(Module):
734
742
 
735
743
  # retrieve
736
744
 
737
- retrieved = self.retrieve_memories(token, updates + weights, chunk_size = 1)
745
+ retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
738
746
 
739
747
  # next state tuple
740
748
 
@@ -785,14 +793,18 @@ class NeuralMemory(Module):
785
793
 
786
794
  # retrieve
787
795
 
788
- if exists(prev_layer_updates):
789
- prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
796
+ retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
797
+
798
+ if retrieve_chunk_size != 1:
799
+ if exists(prev_layer_updates):
800
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
801
+
802
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
790
803
 
791
- updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
792
804
 
793
805
  retrieved = self.retrieve_memories(
794
806
  seq,
795
- mem_model_weights + updates,
807
+ mem_model_weights,
796
808
  chunk_size = chunk_size,
797
809
  prev_layer_updates = prev_layer_updates
798
810
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -27,7 +27,7 @@ License: MIT License
27
27
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
28
  SOFTWARE.
29
29
  License-File: LICENSE
30
- Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
30
+ Keywords: artificial intelligence,deep learning,linear attention,memory,test time training
31
31
  Classifier: Development Status :: 4 - Beta
32
32
  Classifier: Intended Audience :: Developers
33
33
  Classifier: License :: OSI Approved :: MIT License
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch
59
+ ## Titans - Pytorch (wip)
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
+ titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
5
+ titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
6
+ titans_pytorch-0.2.7.dist-info/METADATA,sha256=ndFb28pAe8xWmNU6oncV8VJDDPImo3aCuBv0d0JylIs,6811
7
+ titans_pytorch-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.7.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
- titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
5
- titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
6
- titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
7
- titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.5.dist-info/RECORD,,