titans-pytorch 0.0.31__py3-none-any.whl → 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +45 -5
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/METADATA +1 -1
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.31.dist-info → titans_pytorch-0.0.32.dist-info}/licenses/LICENSE +0 -0
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
19
19
|
|
20
|
+
# proposed neural memory
|
21
|
+
|
22
|
+
from titans_pytorch.titans import NeuralMemory
|
23
|
+
|
20
24
|
# constants
|
21
25
|
|
22
26
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -161,7 +165,9 @@ class MemoryAsContextTransformer(Module):
|
|
161
165
|
dim_head = 64,
|
162
166
|
heads = 8,
|
163
167
|
ff_mult = 4,
|
164
|
-
num_residual_streams = 4
|
168
|
+
num_residual_streams = 4,
|
169
|
+
neural_memory_kwargs: dict = dict(),
|
170
|
+
neural_memory_layers: tuple[int, ...] | None = None,
|
165
171
|
):
|
166
172
|
super().__init__()
|
167
173
|
|
@@ -181,8 +187,25 @@ class MemoryAsContextTransformer(Module):
|
|
181
187
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
182
188
|
|
183
189
|
self.layers = ModuleList([])
|
190
|
+
self.neural_mem_layers = ModuleList([])
|
191
|
+
|
192
|
+
layers = tuple(range(1, depth + 1))
|
193
|
+
neural_memory_layers = set(default(neural_memory_layers, layers))
|
194
|
+
|
195
|
+
for layer in layers:
|
196
|
+
|
197
|
+
# neural memory
|
198
|
+
|
199
|
+
mem = None
|
200
|
+
|
201
|
+
if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
|
202
|
+
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
203
|
+
mem = init_hyper_conn(dim = dim, branch = mem)
|
204
|
+
|
205
|
+
self.neural_mem_layers.append(mem)
|
206
|
+
|
207
|
+
# attention and feedforward
|
184
208
|
|
185
|
-
for _ in range(depth):
|
186
209
|
attn = SegmentedAttention(
|
187
210
|
dim = dim,
|
188
211
|
dim_head = dim_head,
|
@@ -221,7 +244,7 @@ class MemoryAsContextTransformer(Module):
|
|
221
244
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
222
245
|
|
223
246
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
224
|
-
x =
|
247
|
+
x = cat((mems, x), dim = -2)
|
225
248
|
|
226
249
|
x = inverse_segment(x)
|
227
250
|
|
@@ -235,7 +258,24 @@ class MemoryAsContextTransformer(Module):
|
|
235
258
|
|
236
259
|
x = self.expand_streams(x)
|
237
260
|
|
238
|
-
for attn, ff in self.layers:
|
261
|
+
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
262
|
+
|
263
|
+
if exists(maybe_neural_mem):
|
264
|
+
batch_streams = x.shape[0]
|
265
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
|
+
|
267
|
+
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
268
|
+
|
269
|
+
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
270
|
+
|
271
|
+
longterm_mems = maybe_neural_mem(longterm_mems)
|
272
|
+
|
273
|
+
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
274
|
+
|
275
|
+
x = cat((longterm_mems, x), dim = -2)
|
276
|
+
|
277
|
+
x = inverse_segment(x)
|
278
|
+
|
239
279
|
x = attn(x)
|
240
280
|
x = ff(x)
|
241
281
|
|
@@ -245,7 +285,7 @@ class MemoryAsContextTransformer(Module):
|
|
245
285
|
|
246
286
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
247
287
|
|
248
|
-
x = x[:,
|
288
|
+
x = x[:, num_longterm_mem_tokens:]
|
249
289
|
|
250
290
|
x = inverse_segment(x)
|
251
291
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=ohZWVhMBtpm0Iz3w5g7pD3WXSXrvhwzZvfRplwhe1Qo,8149
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.32.dist-info/METADATA,sha256=9X9nWfgIVS-9XIeLHQY53HXSMA6rMemPfyVC2bRrJOQ,3938
|
7
|
+
titans_pytorch-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|