titans-pytorch 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +34 -4
- titans_pytorch/neural_memory.py +2 -3
- {titans_pytorch-0.2.7.dist-info → titans_pytorch-0.2.9.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.9.dist-info/RECORD +9 -0
- titans_pytorch-0.2.7.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.7.dist-info → titans_pytorch-0.2.9.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.7.dist-info → titans_pytorch-0.2.9.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
@@ -3,11 +3,33 @@ from torch import nn, cat
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
5
5
|
|
6
|
+
from einops import rearrange
|
7
|
+
|
6
8
|
# functions
|
7
9
|
|
8
10
|
def l2norm(t):
|
9
11
|
return F.normalize(t, dim = -1)
|
10
12
|
|
13
|
+
# norms
|
14
|
+
|
15
|
+
class LayerNorm(Module):
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
dim
|
19
|
+
):
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
23
|
+
self.gamma = Parameter(torch.zeros(dim))
|
24
|
+
|
25
|
+
def forward(self, x):
|
26
|
+
gamma = self.gamma
|
27
|
+
|
28
|
+
if gamma.ndim == 2:
|
29
|
+
gamma = rearrange(gamma, 'b d -> b 1 d')
|
30
|
+
|
31
|
+
return self.ln(x) * (gamma + 1.)
|
32
|
+
|
11
33
|
# memory mlp proposed in TTT
|
12
34
|
|
13
35
|
class MemoryMLP(Module):
|
@@ -19,6 +41,8 @@ class MemoryMLP(Module):
|
|
19
41
|
super().__init__()
|
20
42
|
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
21
43
|
|
44
|
+
self.ln = LayerNorm(dim)
|
45
|
+
|
22
46
|
for weight in self.weights:
|
23
47
|
nn.init.xavier_uniform_(weight)
|
24
48
|
|
@@ -36,7 +60,7 @@ class MemoryMLP(Module):
|
|
36
60
|
|
37
61
|
x = x @ weight
|
38
62
|
|
39
|
-
return x + residual
|
63
|
+
return self.ln(x) + residual
|
40
64
|
|
41
65
|
# memory mlp, but with gated residual + final projection
|
42
66
|
|
@@ -60,6 +84,8 @@ class GatedResidualMemoryMLP(Module):
|
|
60
84
|
|
61
85
|
self.final_proj = Parameter(torch.randn(dim, dim))
|
62
86
|
|
87
|
+
self.ln = LayerNorm(dim)
|
88
|
+
|
63
89
|
for param in self.parameters():
|
64
90
|
nn.init.xavier_uniform_(param)
|
65
91
|
|
@@ -83,7 +109,7 @@ class GatedResidualMemoryMLP(Module):
|
|
83
109
|
|
84
110
|
out = x @ self.final_proj
|
85
111
|
|
86
|
-
return out + residual
|
112
|
+
return self.ln(out) + residual
|
87
113
|
|
88
114
|
# memory mlp with factorized weights
|
89
115
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -103,6 +129,8 @@ class FactorizedMemoryMLP(Module):
|
|
103
129
|
]) for _ in range(depth)
|
104
130
|
])
|
105
131
|
|
132
|
+
self.ln = LayerNorm(dim)
|
133
|
+
|
106
134
|
for weight1, weight2 in self.weights:
|
107
135
|
nn.init.xavier_uniform_(weight1)
|
108
136
|
nn.init.xavier_uniform_(weight2)
|
@@ -121,7 +149,7 @@ class FactorizedMemoryMLP(Module):
|
|
121
149
|
|
122
150
|
x = x @ weight1 @ weight2
|
123
151
|
|
124
|
-
return x + residual
|
152
|
+
return self.ln(x) + residual
|
125
153
|
|
126
154
|
# improvised attention as memory module
|
127
155
|
|
@@ -144,6 +172,8 @@ class MemoryAttention(Module):
|
|
144
172
|
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
145
173
|
])
|
146
174
|
|
175
|
+
self.ln = LayerNorm(dim)
|
176
|
+
|
147
177
|
for weight in self.weights:
|
148
178
|
nn.init.xavier_uniform_(weight)
|
149
179
|
|
@@ -168,4 +198,4 @@ class MemoryAttention(Module):
|
|
168
198
|
h = F.gelu(x @ ffw1)
|
169
199
|
ff_out = h @ ffw2
|
170
200
|
|
171
|
-
return attn_out + ff_out + residual
|
201
|
+
return self.ln(attn_out + ff_out) + residual
|
titans_pytorch/neural_memory.py
CHANGED
@@ -742,7 +742,7 @@ class NeuralMemory(Module):
|
|
742
742
|
|
743
743
|
# retrieve
|
744
744
|
|
745
|
-
retrieved = self.retrieve_memories(token,
|
745
|
+
retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
|
746
746
|
|
747
747
|
# next state tuple
|
748
748
|
|
@@ -801,10 +801,9 @@ class NeuralMemory(Module):
|
|
801
801
|
|
802
802
|
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
803
803
|
|
804
|
-
|
805
804
|
retrieved = self.retrieve_memories(
|
806
805
|
seq,
|
807
|
-
|
806
|
+
updates,
|
808
807
|
chunk_size = chunk_size,
|
809
808
|
prev_layer_updates = prev_layer_updates
|
810
809
|
)
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=YVbKl7DYKFWUgCawDTxXIEgJAcl7nq5OaZytmovIl8Q,24899
|
6
|
+
titans_pytorch-0.2.9.dist-info/METADATA,sha256=fSFt54zXLKB5gRhLTJd9551O0pF2qcYNlR7039yJiD0,6811
|
7
|
+
titans_pytorch-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.9.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
-
titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
|
5
|
-
titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
|
6
|
-
titans_pytorch-0.2.7.dist-info/METADATA,sha256=ndFb28pAe8xWmNU6oncV8VJDDPImo3aCuBv0d0JylIs,6811
|
7
|
-
titans_pytorch-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|