titans-pytorch 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +3 -0
- titans_pytorch/memory_models.py +153 -0
- titans_pytorch/neural_memory.py +10 -152
- {titans_pytorch-0.1.35.dist-info → titans_pytorch-0.1.37.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.37.dist-info/RECORD +9 -0
- titans_pytorch-0.1.35.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.35.dist-info → titans_pytorch-0.1.37.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.35.dist-info → titans_pytorch-0.1.37.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -0,0 +1,153 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
5
|
+
|
6
|
+
class MemoryMLP(Module):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
dim,
|
10
|
+
depth
|
11
|
+
):
|
12
|
+
super().__init__()
|
13
|
+
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
14
|
+
|
15
|
+
for weight in self.weights:
|
16
|
+
nn.init.xavier_uniform_(weight)
|
17
|
+
|
18
|
+
def forward(
|
19
|
+
self,
|
20
|
+
x
|
21
|
+
):
|
22
|
+
for ind, weight in enumerate(self.weights):
|
23
|
+
is_first = ind == 0
|
24
|
+
|
25
|
+
if not is_first:
|
26
|
+
x = F.silu(x)
|
27
|
+
|
28
|
+
x = x @ weight
|
29
|
+
|
30
|
+
return x
|
31
|
+
|
32
|
+
# memory mlp, but with gated residual + final projection
|
33
|
+
|
34
|
+
class GatedResidualMemoryMLP(Module):
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
dim,
|
38
|
+
depth,
|
39
|
+
expansion_factor = 2.
|
40
|
+
):
|
41
|
+
super().__init__()
|
42
|
+
dim_hidden = int(dim * expansion_factor)
|
43
|
+
|
44
|
+
self.weights = ParameterList([
|
45
|
+
ParameterList([
|
46
|
+
Parameter(torch.randn(dim, dim_hidden)),
|
47
|
+
Parameter(torch.randn(dim_hidden, dim)),
|
48
|
+
Parameter(torch.randn(dim * 2, dim)),
|
49
|
+
]) for _ in range(depth)
|
50
|
+
])
|
51
|
+
|
52
|
+
self.final_proj = Parameter(torch.randn(dim, dim))
|
53
|
+
|
54
|
+
for param in self.parameters():
|
55
|
+
nn.init.xavier_uniform_(param)
|
56
|
+
|
57
|
+
def forward(
|
58
|
+
self,
|
59
|
+
x
|
60
|
+
):
|
61
|
+
for weight1, weight2, to_gates in self.weights:
|
62
|
+
res = x
|
63
|
+
|
64
|
+
hidden = x @ weight1
|
65
|
+
hidden = F.silu(hidden)
|
66
|
+
branch_out = hidden @ weight2
|
67
|
+
|
68
|
+
# gated residual
|
69
|
+
|
70
|
+
gates = cat((branch_out, res), dim = -1) @ to_gates
|
71
|
+
x = res.lerp(branch_out, gates.sigmoid())
|
72
|
+
|
73
|
+
return x @ self.final_proj
|
74
|
+
|
75
|
+
# memory mlp with factorized weights
|
76
|
+
# so can tradeoff capacity for smaller chunk sizes
|
77
|
+
|
78
|
+
class FactorizedMemoryMLP(Module):
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
dim,
|
82
|
+
depth,
|
83
|
+
k = 32
|
84
|
+
):
|
85
|
+
super().__init__()
|
86
|
+
self.weights = ParameterList([
|
87
|
+
ParameterList([
|
88
|
+
Parameter(torch.randn(dim, k)),
|
89
|
+
Parameter(torch.randn(k, dim)),
|
90
|
+
]) for _ in range(depth)
|
91
|
+
])
|
92
|
+
|
93
|
+
for weight1, weight2 in self.weights:
|
94
|
+
nn.init.xavier_uniform_(weight1)
|
95
|
+
nn.init.xavier_uniform_(weight2)
|
96
|
+
|
97
|
+
def forward(
|
98
|
+
self,
|
99
|
+
x
|
100
|
+
):
|
101
|
+
for ind, (weight1, weight2) in enumerate(self.weights):
|
102
|
+
is_first = ind == 0
|
103
|
+
|
104
|
+
if not is_first:
|
105
|
+
x = F.silu(x)
|
106
|
+
|
107
|
+
x = x @ weight1 @ weight2
|
108
|
+
|
109
|
+
return x
|
110
|
+
|
111
|
+
# improvised attention as memory module
|
112
|
+
|
113
|
+
class MemoryAttention(Module):
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
dim,
|
117
|
+
scale = 8.,
|
118
|
+
expansion_factor = 2.
|
119
|
+
):
|
120
|
+
super().__init__()
|
121
|
+
self.scale = scale
|
122
|
+
dim_ff_hidden = int(dim * expansion_factor)
|
123
|
+
|
124
|
+
self.weights = nn.ParameterList([
|
125
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
126
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
127
|
+
nn.Parameter(torch.randn(dim, dim)), # values
|
128
|
+
nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
129
|
+
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
130
|
+
])
|
131
|
+
|
132
|
+
for weight in self.weights:
|
133
|
+
nn.init.xavier_uniform_(weight)
|
134
|
+
|
135
|
+
def forward(self, x):
|
136
|
+
wq, wk, wv, ffw1, ffw2 = self.weights
|
137
|
+
|
138
|
+
q = F.normalize(x @ wq, dim = -1)
|
139
|
+
k = F.normalize(x @ wk, dim = -1)
|
140
|
+
v = x @ wv
|
141
|
+
|
142
|
+
attn_out = F.scaled_dot_product_attention(
|
143
|
+
q, k, v,
|
144
|
+
scale = self.scale,
|
145
|
+
is_causal = True
|
146
|
+
)
|
147
|
+
|
148
|
+
x = x + attn_out
|
149
|
+
|
150
|
+
h = F.silu(x @ ffw1)
|
151
|
+
out = h @ ffw2
|
152
|
+
|
153
|
+
return out
|
titans_pytorch/neural_memory.py
CHANGED
@@ -19,6 +19,10 @@ from titans_pytorch.associative_scan import (
|
|
19
19
|
pad_at_dim
|
20
20
|
)
|
21
21
|
|
22
|
+
from titans_pytorch.memory_models import(
|
23
|
+
MemoryMLP
|
24
|
+
)
|
25
|
+
|
22
26
|
import einx
|
23
27
|
from einops import rearrange, repeat, reduce, pack, unpack
|
24
28
|
from einops.layers.torch import Rearrange, Reduce
|
@@ -169,157 +173,6 @@ class AttentionPool(Module):
|
|
169
173
|
|
170
174
|
return reduce(x * attn, 'b n c d -> b n d', 'sum')
|
171
175
|
|
172
|
-
# classes
|
173
|
-
|
174
|
-
class MemoryMLP(Module):
|
175
|
-
def __init__(
|
176
|
-
self,
|
177
|
-
dim,
|
178
|
-
depth
|
179
|
-
):
|
180
|
-
super().__init__()
|
181
|
-
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
182
|
-
|
183
|
-
for weight in self.weights:
|
184
|
-
nn.init.xavier_uniform_(weight)
|
185
|
-
|
186
|
-
def forward(
|
187
|
-
self,
|
188
|
-
x
|
189
|
-
):
|
190
|
-
for ind, weight in enumerate(self.weights):
|
191
|
-
is_first = ind == 0
|
192
|
-
|
193
|
-
if not is_first:
|
194
|
-
x = F.silu(x)
|
195
|
-
|
196
|
-
x = x @ weight
|
197
|
-
|
198
|
-
return x
|
199
|
-
|
200
|
-
# memory mlp, but with gated residual + final projection
|
201
|
-
|
202
|
-
class GatedResidualMemoryMLP(Module):
|
203
|
-
def __init__(
|
204
|
-
self,
|
205
|
-
dim,
|
206
|
-
depth,
|
207
|
-
expansion_factor = 2.
|
208
|
-
):
|
209
|
-
super().__init__()
|
210
|
-
dim_hidden = int(dim * expansion_factor)
|
211
|
-
|
212
|
-
self.weights = ParameterList([
|
213
|
-
ParameterList([
|
214
|
-
Parameter(torch.randn(dim, dim_hidden)),
|
215
|
-
Parameter(torch.randn(dim_hidden, dim)),
|
216
|
-
Parameter(torch.randn(dim * 2, dim)),
|
217
|
-
]) for _ in range(depth)
|
218
|
-
])
|
219
|
-
|
220
|
-
self.final_proj = Parameter(torch.randn(dim, dim))
|
221
|
-
|
222
|
-
for param in self.parameters():
|
223
|
-
nn.init.xavier_uniform_(param)
|
224
|
-
|
225
|
-
def forward(
|
226
|
-
self,
|
227
|
-
x
|
228
|
-
):
|
229
|
-
for weight1, weight2, to_gates in self.weights:
|
230
|
-
res = x
|
231
|
-
|
232
|
-
hidden = x @ weight1
|
233
|
-
hidden = F.silu(hidden)
|
234
|
-
branch_out = hidden @ weight2
|
235
|
-
|
236
|
-
# gated residual
|
237
|
-
|
238
|
-
gates = cat((branch_out, res), dim = -1) @ to_gates
|
239
|
-
x = res.lerp(branch_out, gates.sigmoid())
|
240
|
-
|
241
|
-
return x @ self.final_proj
|
242
|
-
|
243
|
-
# memory mlp with factorized weights
|
244
|
-
# so can tradeoff capacity for smaller chunk sizes
|
245
|
-
|
246
|
-
class FactorizedMemoryMLP(Module):
|
247
|
-
def __init__(
|
248
|
-
self,
|
249
|
-
dim,
|
250
|
-
depth,
|
251
|
-
k = 32
|
252
|
-
):
|
253
|
-
super().__init__()
|
254
|
-
self.weights = ParameterList([
|
255
|
-
ParameterList([
|
256
|
-
Parameter(torch.randn(dim, k)),
|
257
|
-
Parameter(torch.randn(k, dim)),
|
258
|
-
]) for _ in range(depth)
|
259
|
-
])
|
260
|
-
|
261
|
-
for weight1, weight2 in self.weights:
|
262
|
-
nn.init.xavier_uniform_(weight1)
|
263
|
-
nn.init.xavier_uniform_(weight2)
|
264
|
-
|
265
|
-
def forward(
|
266
|
-
self,
|
267
|
-
x
|
268
|
-
):
|
269
|
-
for ind, (weight1, weight2) in enumerate(self.weights):
|
270
|
-
is_first = ind == 0
|
271
|
-
|
272
|
-
if not is_first:
|
273
|
-
x = F.silu(x)
|
274
|
-
|
275
|
-
x = x @ weight1 @ weight2
|
276
|
-
|
277
|
-
return x
|
278
|
-
|
279
|
-
# improvised attention as memory module
|
280
|
-
|
281
|
-
class MemoryAttention(Module):
|
282
|
-
def __init__(
|
283
|
-
self,
|
284
|
-
dim,
|
285
|
-
scale = 8.,
|
286
|
-
expansion_factor = 2.
|
287
|
-
):
|
288
|
-
super().__init__()
|
289
|
-
self.scale = scale
|
290
|
-
dim_ff_hidden = int(dim * expansion_factor)
|
291
|
-
|
292
|
-
self.weights = nn.ParameterList([
|
293
|
-
nn.Parameter(torch.randn(dim, dim)), # queries
|
294
|
-
nn.Parameter(torch.randn(dim, dim)), # keys
|
295
|
-
nn.Parameter(torch.randn(dim, dim)), # values
|
296
|
-
nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
297
|
-
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
298
|
-
])
|
299
|
-
|
300
|
-
for weight in self.weights:
|
301
|
-
nn.init.xavier_uniform_(weight)
|
302
|
-
|
303
|
-
def forward(self, x):
|
304
|
-
wq, wk, wv, ffw1, ffw2 = self.weights
|
305
|
-
|
306
|
-
q = F.normalize(x @ wq, dim = -1)
|
307
|
-
k = F.normalize(x @ wk, dim = -1)
|
308
|
-
v = x @ wv
|
309
|
-
|
310
|
-
attn_out = F.scaled_dot_product_attention(
|
311
|
-
q, k, v,
|
312
|
-
scale = self.scale,
|
313
|
-
is_causal = True
|
314
|
-
)
|
315
|
-
|
316
|
-
x = x + attn_out
|
317
|
-
|
318
|
-
h = F.silu(x @ ffw1)
|
319
|
-
out = h @ ffw2
|
320
|
-
|
321
|
-
return out
|
322
|
-
|
323
176
|
# associative scan wrapper
|
324
177
|
|
325
178
|
class AssocScan(Module):
|
@@ -870,6 +723,8 @@ class NeuralMemory(Module):
|
|
870
723
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
871
724
|
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
872
725
|
|
726
|
+
values = None
|
727
|
+
|
873
728
|
if store_seq_cache_len == self.chunk_size:
|
874
729
|
|
875
730
|
next_updates, next_states, values = self.store_memories(
|
@@ -917,10 +772,13 @@ class NeuralMemory(Module):
|
|
917
772
|
if seq_len < self.retrieve_chunk_size:
|
918
773
|
out = self.init_empty_memory_embed(batch, seq_len)
|
919
774
|
|
920
|
-
next_store_state = (seq_len, seq, None, None)
|
775
|
+
next_store_state = NeuralMemCache(seq_len, seq, None, None)
|
921
776
|
|
922
777
|
out = (out, next_store_state)
|
923
778
|
|
779
|
+
if return_values:
|
780
|
+
out = (*out, self.zero)
|
781
|
+
|
924
782
|
if not return_aux_kv_loss:
|
925
783
|
return out
|
926
784
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
|
4
|
+
titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
|
5
|
+
titans_pytorch/neural_memory.py,sha256=B6nTdsq7Tp6lfmUwQakOtODImwZzMCDkDfIv5CIKlbQ,25453
|
6
|
+
titans_pytorch-0.1.37.dist-info/METADATA,sha256=BdkDj71kq320M_vbZOGPrRBR91ycgqBJaUNm9m7tA5U,6826
|
7
|
+
titans_pytorch-0.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.1.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.1.37.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
|
4
|
-
titans_pytorch/neural_memory.py,sha256=wFOLFe3ViXiQfNvUiAGJ6BfiaDhr0BYDRDnLNMHWQhU,28938
|
5
|
-
titans_pytorch-0.1.35.dist-info/METADATA,sha256=5e5qPt4hAOhxhDWqdjutjJuUmht44zYq_KqgagKjqxE,6826
|
6
|
-
titans_pytorch-0.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.35.dist-info/RECORD,,
|
File without changes
|
File without changes
|