titans-pytorch 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +3 -0
- titans_pytorch/memory_models.py +153 -0
- titans_pytorch/neural_memory.py +13 -153
- {titans_pytorch-0.1.36.dist-info → titans_pytorch-0.1.38.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.38.dist-info/RECORD +9 -0
- titans_pytorch-0.1.36.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.36.dist-info → titans_pytorch-0.1.38.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.36.dist-info → titans_pytorch-0.1.38.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -0,0 +1,153 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
5
|
+
|
6
|
+
class MemoryMLP(Module):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
dim,
|
10
|
+
depth
|
11
|
+
):
|
12
|
+
super().__init__()
|
13
|
+
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
14
|
+
|
15
|
+
for weight in self.weights:
|
16
|
+
nn.init.xavier_uniform_(weight)
|
17
|
+
|
18
|
+
def forward(
|
19
|
+
self,
|
20
|
+
x
|
21
|
+
):
|
22
|
+
for ind, weight in enumerate(self.weights):
|
23
|
+
is_first = ind == 0
|
24
|
+
|
25
|
+
if not is_first:
|
26
|
+
x = F.silu(x)
|
27
|
+
|
28
|
+
x = x @ weight
|
29
|
+
|
30
|
+
return x
|
31
|
+
|
32
|
+
# memory mlp, but with gated residual + final projection
|
33
|
+
|
34
|
+
class GatedResidualMemoryMLP(Module):
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
dim,
|
38
|
+
depth,
|
39
|
+
expansion_factor = 2.
|
40
|
+
):
|
41
|
+
super().__init__()
|
42
|
+
dim_hidden = int(dim * expansion_factor)
|
43
|
+
|
44
|
+
self.weights = ParameterList([
|
45
|
+
ParameterList([
|
46
|
+
Parameter(torch.randn(dim, dim_hidden)),
|
47
|
+
Parameter(torch.randn(dim_hidden, dim)),
|
48
|
+
Parameter(torch.randn(dim * 2, dim)),
|
49
|
+
]) for _ in range(depth)
|
50
|
+
])
|
51
|
+
|
52
|
+
self.final_proj = Parameter(torch.randn(dim, dim))
|
53
|
+
|
54
|
+
for param in self.parameters():
|
55
|
+
nn.init.xavier_uniform_(param)
|
56
|
+
|
57
|
+
def forward(
|
58
|
+
self,
|
59
|
+
x
|
60
|
+
):
|
61
|
+
for weight1, weight2, to_gates in self.weights:
|
62
|
+
res = x
|
63
|
+
|
64
|
+
hidden = x @ weight1
|
65
|
+
hidden = F.silu(hidden)
|
66
|
+
branch_out = hidden @ weight2
|
67
|
+
|
68
|
+
# gated residual
|
69
|
+
|
70
|
+
gates = cat((branch_out, res), dim = -1) @ to_gates
|
71
|
+
x = res.lerp(branch_out, gates.sigmoid())
|
72
|
+
|
73
|
+
return x @ self.final_proj
|
74
|
+
|
75
|
+
# memory mlp with factorized weights
|
76
|
+
# so can tradeoff capacity for smaller chunk sizes
|
77
|
+
|
78
|
+
class FactorizedMemoryMLP(Module):
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
dim,
|
82
|
+
depth,
|
83
|
+
k = 32
|
84
|
+
):
|
85
|
+
super().__init__()
|
86
|
+
self.weights = ParameterList([
|
87
|
+
ParameterList([
|
88
|
+
Parameter(torch.randn(dim, k)),
|
89
|
+
Parameter(torch.randn(k, dim)),
|
90
|
+
]) for _ in range(depth)
|
91
|
+
])
|
92
|
+
|
93
|
+
for weight1, weight2 in self.weights:
|
94
|
+
nn.init.xavier_uniform_(weight1)
|
95
|
+
nn.init.xavier_uniform_(weight2)
|
96
|
+
|
97
|
+
def forward(
|
98
|
+
self,
|
99
|
+
x
|
100
|
+
):
|
101
|
+
for ind, (weight1, weight2) in enumerate(self.weights):
|
102
|
+
is_first = ind == 0
|
103
|
+
|
104
|
+
if not is_first:
|
105
|
+
x = F.silu(x)
|
106
|
+
|
107
|
+
x = x @ weight1 @ weight2
|
108
|
+
|
109
|
+
return x
|
110
|
+
|
111
|
+
# improvised attention as memory module
|
112
|
+
|
113
|
+
class MemoryAttention(Module):
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
dim,
|
117
|
+
scale = 8.,
|
118
|
+
expansion_factor = 2.
|
119
|
+
):
|
120
|
+
super().__init__()
|
121
|
+
self.scale = scale
|
122
|
+
dim_ff_hidden = int(dim * expansion_factor)
|
123
|
+
|
124
|
+
self.weights = nn.ParameterList([
|
125
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
126
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
127
|
+
nn.Parameter(torch.randn(dim, dim)), # values
|
128
|
+
nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
129
|
+
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
130
|
+
])
|
131
|
+
|
132
|
+
for weight in self.weights:
|
133
|
+
nn.init.xavier_uniform_(weight)
|
134
|
+
|
135
|
+
def forward(self, x):
|
136
|
+
wq, wk, wv, ffw1, ffw2 = self.weights
|
137
|
+
|
138
|
+
q = F.normalize(x @ wq, dim = -1)
|
139
|
+
k = F.normalize(x @ wk, dim = -1)
|
140
|
+
v = x @ wv
|
141
|
+
|
142
|
+
attn_out = F.scaled_dot_product_attention(
|
143
|
+
q, k, v,
|
144
|
+
scale = self.scale,
|
145
|
+
is_causal = True
|
146
|
+
)
|
147
|
+
|
148
|
+
x = x + attn_out
|
149
|
+
|
150
|
+
h = F.silu(x @ ffw1)
|
151
|
+
out = h @ ffw2
|
152
|
+
|
153
|
+
return out
|
titans_pytorch/neural_memory.py
CHANGED
@@ -19,6 +19,10 @@ from titans_pytorch.associative_scan import (
|
|
19
19
|
pad_at_dim
|
20
20
|
)
|
21
21
|
|
22
|
+
from titans_pytorch.memory_models import(
|
23
|
+
MemoryMLP
|
24
|
+
)
|
25
|
+
|
22
26
|
import einx
|
23
27
|
from einops import rearrange, repeat, reduce, pack, unpack
|
24
28
|
from einops.layers.torch import Rearrange, Reduce
|
@@ -169,157 +173,6 @@ class AttentionPool(Module):
|
|
169
173
|
|
170
174
|
return reduce(x * attn, 'b n c d -> b n d', 'sum')
|
171
175
|
|
172
|
-
# classes
|
173
|
-
|
174
|
-
class MemoryMLP(Module):
|
175
|
-
def __init__(
|
176
|
-
self,
|
177
|
-
dim,
|
178
|
-
depth
|
179
|
-
):
|
180
|
-
super().__init__()
|
181
|
-
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
182
|
-
|
183
|
-
for weight in self.weights:
|
184
|
-
nn.init.xavier_uniform_(weight)
|
185
|
-
|
186
|
-
def forward(
|
187
|
-
self,
|
188
|
-
x
|
189
|
-
):
|
190
|
-
for ind, weight in enumerate(self.weights):
|
191
|
-
is_first = ind == 0
|
192
|
-
|
193
|
-
if not is_first:
|
194
|
-
x = F.silu(x)
|
195
|
-
|
196
|
-
x = x @ weight
|
197
|
-
|
198
|
-
return x
|
199
|
-
|
200
|
-
# memory mlp, but with gated residual + final projection
|
201
|
-
|
202
|
-
class GatedResidualMemoryMLP(Module):
|
203
|
-
def __init__(
|
204
|
-
self,
|
205
|
-
dim,
|
206
|
-
depth,
|
207
|
-
expansion_factor = 2.
|
208
|
-
):
|
209
|
-
super().__init__()
|
210
|
-
dim_hidden = int(dim * expansion_factor)
|
211
|
-
|
212
|
-
self.weights = ParameterList([
|
213
|
-
ParameterList([
|
214
|
-
Parameter(torch.randn(dim, dim_hidden)),
|
215
|
-
Parameter(torch.randn(dim_hidden, dim)),
|
216
|
-
Parameter(torch.randn(dim * 2, dim)),
|
217
|
-
]) for _ in range(depth)
|
218
|
-
])
|
219
|
-
|
220
|
-
self.final_proj = Parameter(torch.randn(dim, dim))
|
221
|
-
|
222
|
-
for param in self.parameters():
|
223
|
-
nn.init.xavier_uniform_(param)
|
224
|
-
|
225
|
-
def forward(
|
226
|
-
self,
|
227
|
-
x
|
228
|
-
):
|
229
|
-
for weight1, weight2, to_gates in self.weights:
|
230
|
-
res = x
|
231
|
-
|
232
|
-
hidden = x @ weight1
|
233
|
-
hidden = F.silu(hidden)
|
234
|
-
branch_out = hidden @ weight2
|
235
|
-
|
236
|
-
# gated residual
|
237
|
-
|
238
|
-
gates = cat((branch_out, res), dim = -1) @ to_gates
|
239
|
-
x = res.lerp(branch_out, gates.sigmoid())
|
240
|
-
|
241
|
-
return x @ self.final_proj
|
242
|
-
|
243
|
-
# memory mlp with factorized weights
|
244
|
-
# so can tradeoff capacity for smaller chunk sizes
|
245
|
-
|
246
|
-
class FactorizedMemoryMLP(Module):
|
247
|
-
def __init__(
|
248
|
-
self,
|
249
|
-
dim,
|
250
|
-
depth,
|
251
|
-
k = 32
|
252
|
-
):
|
253
|
-
super().__init__()
|
254
|
-
self.weights = ParameterList([
|
255
|
-
ParameterList([
|
256
|
-
Parameter(torch.randn(dim, k)),
|
257
|
-
Parameter(torch.randn(k, dim)),
|
258
|
-
]) for _ in range(depth)
|
259
|
-
])
|
260
|
-
|
261
|
-
for weight1, weight2 in self.weights:
|
262
|
-
nn.init.xavier_uniform_(weight1)
|
263
|
-
nn.init.xavier_uniform_(weight2)
|
264
|
-
|
265
|
-
def forward(
|
266
|
-
self,
|
267
|
-
x
|
268
|
-
):
|
269
|
-
for ind, (weight1, weight2) in enumerate(self.weights):
|
270
|
-
is_first = ind == 0
|
271
|
-
|
272
|
-
if not is_first:
|
273
|
-
x = F.silu(x)
|
274
|
-
|
275
|
-
x = x @ weight1 @ weight2
|
276
|
-
|
277
|
-
return x
|
278
|
-
|
279
|
-
# improvised attention as memory module
|
280
|
-
|
281
|
-
class MemoryAttention(Module):
|
282
|
-
def __init__(
|
283
|
-
self,
|
284
|
-
dim,
|
285
|
-
scale = 8.,
|
286
|
-
expansion_factor = 2.
|
287
|
-
):
|
288
|
-
super().__init__()
|
289
|
-
self.scale = scale
|
290
|
-
dim_ff_hidden = int(dim * expansion_factor)
|
291
|
-
|
292
|
-
self.weights = nn.ParameterList([
|
293
|
-
nn.Parameter(torch.randn(dim, dim)), # queries
|
294
|
-
nn.Parameter(torch.randn(dim, dim)), # keys
|
295
|
-
nn.Parameter(torch.randn(dim, dim)), # values
|
296
|
-
nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
297
|
-
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
298
|
-
])
|
299
|
-
|
300
|
-
for weight in self.weights:
|
301
|
-
nn.init.xavier_uniform_(weight)
|
302
|
-
|
303
|
-
def forward(self, x):
|
304
|
-
wq, wk, wv, ffw1, ffw2 = self.weights
|
305
|
-
|
306
|
-
q = F.normalize(x @ wq, dim = -1)
|
307
|
-
k = F.normalize(x @ wk, dim = -1)
|
308
|
-
v = x @ wv
|
309
|
-
|
310
|
-
attn_out = F.scaled_dot_product_attention(
|
311
|
-
q, k, v,
|
312
|
-
scale = self.scale,
|
313
|
-
is_causal = True
|
314
|
-
)
|
315
|
-
|
316
|
-
x = x + attn_out
|
317
|
-
|
318
|
-
h = F.silu(x @ ffw1)
|
319
|
-
out = h @ ffw2
|
320
|
-
|
321
|
-
return out
|
322
|
-
|
323
176
|
# associative scan wrapper
|
324
177
|
|
325
178
|
class AssocScan(Module):
|
@@ -853,7 +706,12 @@ class NeuralMemory(Module):
|
|
853
706
|
if curr_seq_len < self.chunk_size:
|
854
707
|
empty_mem = self.init_empty_memory_embed(batch, 1)
|
855
708
|
|
856
|
-
|
709
|
+
output = empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
710
|
+
|
711
|
+
if return_values:
|
712
|
+
output = (*output, self.zero)
|
713
|
+
|
714
|
+
return output
|
857
715
|
|
858
716
|
# store if storage sequence cache hits the chunk size
|
859
717
|
|
@@ -870,6 +728,8 @@ class NeuralMemory(Module):
|
|
870
728
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
871
729
|
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
872
730
|
|
731
|
+
values = None
|
732
|
+
|
873
733
|
if store_seq_cache_len == self.chunk_size:
|
874
734
|
|
875
735
|
next_updates, next_states, values = self.store_memories(
|
@@ -917,7 +777,7 @@ class NeuralMemory(Module):
|
|
917
777
|
if seq_len < self.retrieve_chunk_size:
|
918
778
|
out = self.init_empty_memory_embed(batch, seq_len)
|
919
779
|
|
920
|
-
next_store_state = (seq_len, seq, None, None)
|
780
|
+
next_store_state = NeuralMemCache(seq_len, seq, None, None)
|
921
781
|
|
922
782
|
out = (out, next_store_state)
|
923
783
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
|
4
|
+
titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
|
5
|
+
titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
|
6
|
+
titans_pytorch-0.1.38.dist-info/METADATA,sha256=8ZmlPJotNIMGAqW8nYWJiM06MvCXJ2SKTGVKarWeOAQ,6826
|
7
|
+
titans_pytorch-0.1.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.1.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.1.38.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
|
4
|
-
titans_pytorch/neural_memory.py,sha256=U4qJvwN3otGPifQJfRBIFilduWYI22DAsTtdc3kOYFY,29009
|
5
|
-
titans_pytorch-0.1.36.dist-info/METADATA,sha256=dx1D5t2njrP-1zG4JzLIrimS6rvaN1UzFISm0wrf-l8,6826
|
6
|
-
titans_pytorch-0.1.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|