titans-pytorch 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,13 @@ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ # functions
7
+
8
+ def l2norm(t):
9
+ return F.normalize(t, dim = -1)
10
+
11
+ # memory mlp proposed in TTT
12
+
6
13
  class MemoryMLP(Module):
7
14
  def __init__(
8
15
  self,
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
19
26
  self,
20
27
  x
21
28
  ):
29
+ residual = x
30
+
22
31
  for ind, weight in enumerate(self.weights):
23
32
  is_first = ind == 0
24
33
 
25
34
  if not is_first:
26
- x = F.silu(x)
35
+ x = F.gelu(x)
27
36
 
28
37
  x = x @ weight
29
38
 
30
- return x
39
+ return x + residual
31
40
 
32
41
  # memory mlp, but with gated residual + final projection
33
42
 
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
36
45
  self,
37
46
  dim,
38
47
  depth,
39
- expansion_factor = 2.
48
+ expansion_factor = 4.
40
49
  ):
41
50
  super().__init__()
42
51
  dim_hidden = int(dim * expansion_factor)
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
58
67
  self,
59
68
  x
60
69
  ):
70
+ residual = x
71
+
61
72
  for weight1, weight2, to_gates in self.weights:
62
73
  res = x
63
74
 
64
75
  hidden = x @ weight1
65
- hidden = F.silu(hidden)
76
+ hidden = F.gelu(hidden)
66
77
  branch_out = hidden @ weight2
67
78
 
68
79
  # gated residual
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
70
81
  gates = cat((branch_out, res), dim = -1) @ to_gates
71
82
  x = res.lerp(branch_out, gates.sigmoid())
72
83
 
73
- return x @ self.final_proj
84
+ out = x @ self.final_proj
85
+
86
+ return out + residual
74
87
 
75
88
  # memory mlp with factorized weights
76
89
  # so can tradeoff capacity for smaller chunk sizes
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
98
111
  self,
99
112
  x
100
113
  ):
114
+ residual = x
115
+
101
116
  for ind, (weight1, weight2) in enumerate(self.weights):
102
117
  is_first = ind == 0
103
118
 
104
119
  if not is_first:
105
- x = F.silu(x)
120
+ x = F.gelu(x)
106
121
 
107
122
  x = x @ weight1 @ weight2
108
123
 
109
- return x
124
+ return x + residual
110
125
 
111
126
  # improvised attention as memory module
112
127
 
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
133
148
  nn.init.xavier_uniform_(weight)
134
149
 
135
150
  def forward(self, x):
151
+ residual = x
152
+
136
153
  wq, wk, wv, ffw1, ffw2 = self.weights
137
154
 
138
- q = F.normalize(x @ wq, dim = -1)
139
- k = F.normalize(x @ wk, dim = -1)
155
+ q = l2norm(x @ wq)
156
+ k = l2norm(x @ wk)
140
157
  v = x @ wv
141
158
 
142
159
  attn_out = F.scaled_dot_product_attention(
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
145
162
  is_causal = True
146
163
  )
147
164
 
148
- x = x + attn_out
165
+ # parallel attention + feedforward block
166
+ # as in PaLM + Gpt-J
149
167
 
150
- h = F.silu(x @ ffw1)
151
- out = h @ ffw2
168
+ h = F.gelu(x @ ffw1)
169
+ ff_out = h @ ffw2
152
170
 
153
- return out
171
+ return attn_out + ff_out + residual
@@ -785,10 +785,14 @@ class NeuralMemory(Module):
785
785
 
786
786
  # retrieve
787
787
 
788
- if exists(prev_layer_updates):
789
- prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
788
+ retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
789
+
790
+ if retrieve_chunk_size != 1:
791
+ if exists(prev_layer_updates):
792
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
793
+
794
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
790
795
 
791
- updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
792
796
 
793
797
  retrieved = self.retrieve_memories(
794
798
  seq,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch
59
+ ## Titans - Pytorch (wip)
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
+ titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
5
+ titans_pytorch/neural_memory.py,sha256=UNST32JulrDw3_dPljSFU3ZCLofDH-KBoFxx8j1Oii4,24733
6
+ titans_pytorch-0.2.6.dist-info/METADATA,sha256=ZsjIkdKM2zanf74Q9EOqhCreVIQsdRNnOEIRBTISwAI,6825
7
+ titans_pytorch-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.6.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
- titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
5
- titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
6
- titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
7
- titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.5.dist-info/RECORD,,