titans-pytorch 0.0.31__py3-none-any.whl → 0.0.32__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/mac_transformer.py +45 -5
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/licenses/LICENSE +0 -0
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
19
19
|
|
20
|
+
# proposed neural memory
|
21
|
+
|
22
|
+
from titans_pytorch.titans import NeuralMemory
|
23
|
+
|
20
24
|
# constants
|
21
25
|
|
22
26
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -161,7 +165,9 @@ class MemoryAsContextTransformer(Module):
|
|
161
165
|
dim_head = 64,
|
162
166
|
heads = 8,
|
163
167
|
ff_mult = 4,
|
164
|
-
num_residual_streams = 4
|
168
|
+
num_residual_streams = 4,
|
169
|
+
neural_memory_kwargs: dict = dict(),
|
170
|
+
neural_memory_layers: tuple[int, ...] | None = None,
|
165
171
|
):
|
166
172
|
super().__init__()
|
167
173
|
|
@@ -181,8 +187,25 @@ class MemoryAsContextTransformer(Module):
|
|
181
187
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
182
188
|
|
183
189
|
self.layers = ModuleList([])
|
190
|
+
self.neural_mem_layers = ModuleList([])
|
191
|
+
|
192
|
+
layers = tuple(range(1, depth + 1))
|
193
|
+
neural_memory_layers = set(default(neural_memory_layers, layers))
|
194
|
+
|
195
|
+
for layer in layers:
|
196
|
+
|
197
|
+
# neural memory
|
198
|
+
|
199
|
+
mem = None
|
200
|
+
|
201
|
+
if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
|
202
|
+
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
203
|
+
mem = init_hyper_conn(dim = dim, branch = mem)
|
204
|
+
|
205
|
+
self.neural_mem_layers.append(mem)
|
206
|
+
|
207
|
+
# attention and feedforward
|
184
208
|
|
185
|
-
for _ in range(depth):
|
186
209
|
attn = SegmentedAttention(
|
187
210
|
dim = dim,
|
188
211
|
dim_head = dim_head,
|
@@ -221,7 +244,7 @@ class MemoryAsContextTransformer(Module):
|
|
221
244
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
222
245
|
|
223
246
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
224
|
-
x =
|
247
|
+
x = cat((mems, x), dim = -2)
|
225
248
|
|
226
249
|
x = inverse_segment(x)
|
227
250
|
|
@@ -235,7 +258,24 @@ class MemoryAsContextTransformer(Module):
|
|
235
258
|
|
236
259
|
x = self.expand_streams(x)
|
237
260
|
|
238
|
-
for attn, ff in self.layers:
|
261
|
+
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
262
|
+
|
263
|
+
if exists(maybe_neural_mem):
|
264
|
+
batch_streams = x.shape[0]
|
265
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
|
+
|
267
|
+
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
268
|
+
|
269
|
+
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
270
|
+
|
271
|
+
longterm_mems = maybe_neural_mem(longterm_mems)
|
272
|
+
|
273
|
+
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
274
|
+
|
275
|
+
x = cat((longterm_mems, x), dim = -2)
|
276
|
+
|
277
|
+
x = inverse_segment(x)
|
278
|
+
|
239
279
|
x = attn(x)
|
240
280
|
x = ff(x)
|
241
281
|
|
@@ -245,7 +285,7 @@ class MemoryAsContextTransformer(Module):
|
|
245
285
|
|
246
286
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
247
287
|
|
248
|
-
x = x[:,
|
288
|
+
x = x[:, num_longterm_mem_tokens:]
|
249
289
|
|
250
290
|
x = inverse_segment(x)
|
251
291
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
|
7
|
+
titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|