titans-pytorch 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/mac_transformer.py +10 -19
- titans_pytorch/titans.py +1 -1
- {titans_pytorch-0.0.36.dist-info → titans_pytorch-0.0.37.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.37.dist-info/RECORD +9 -0
- titans_pytorch-0.0.36.dist-info/RECORD +0 -9
- {titans_pytorch-0.0.36.dist-info → titans_pytorch-0.0.37.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.36.dist-info → titans_pytorch-0.0.37.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange
|
10
|
+
from einops import repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
|
|
214
214
|
if layer in neural_memory_layers:
|
215
215
|
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
216
|
|
217
|
-
mem = NeuralMemory(
|
217
|
+
mem = NeuralMemory(
|
218
|
+
dim = dim,
|
219
|
+
chunk_size = num_longterm_mem_tokens + segment_len,
|
220
|
+
**neural_memory_kwargs
|
221
|
+
)
|
222
|
+
|
218
223
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
219
224
|
|
220
225
|
self.neural_mem_layers.append(mem)
|
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
|
|
266
271
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
267
272
|
|
268
273
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
269
|
-
x =
|
274
|
+
x, mem_ps = pack((x, mems), 'b * d')
|
270
275
|
|
271
276
|
x = inverse_segment(x)
|
272
277
|
|
@@ -283,21 +288,7 @@ class MemoryAsContextTransformer(Module):
|
|
283
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
284
289
|
|
285
290
|
if exists(maybe_neural_mem):
|
286
|
-
|
287
|
-
|
288
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
289
|
-
|
290
|
-
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
291
|
-
|
292
|
-
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
293
|
-
|
294
|
-
longterm_mems = maybe_neural_mem(longterm_mems)
|
295
|
-
|
296
|
-
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
297
|
-
|
298
|
-
x = cat((longterm_mems, x), dim = -2)
|
299
|
-
|
300
|
-
x = inverse_segment(x)
|
291
|
+
mems = maybe_neural_mem(mems)
|
301
292
|
|
302
293
|
x = attn(x)
|
303
294
|
|
@@ -309,7 +300,7 @@ class MemoryAsContextTransformer(Module):
|
|
309
300
|
|
310
301
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
311
302
|
|
312
|
-
x = x
|
303
|
+
x, mem = unpack(x, mem_ps, 'b * d')
|
313
304
|
|
314
305
|
x = inverse_segment(x)
|
315
306
|
|
titans_pytorch/titans.py
CHANGED
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=JjKGEMBit_SvyAsxq5v08614YBcLVx3OkM6pf0rADsA,8400
|
4
|
+
titans_pytorch/titans.py,sha256=ALICGfc6p2bD2QkaibyIceVLvBIRKXmDm-w7RjnVOe4,14304
|
5
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
+
titans_pytorch-0.0.37.dist-info/METADATA,sha256=RNokG8101_tlR0BiF-AxqYLZpXqafMSiN1Rg_pZe2-o,3938
|
7
|
+
titans_pytorch-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.37.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=xXQ9GvtvNArYidV1OOhUhCJ0pIxniElTLnL0_eIZtEE,8821
|
4
|
-
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.36.dist-info/METADATA,sha256=7Bum0wO6e6BsB7TShLBZWALyStcs1LLaEv5vvnVlQ9c,3938
|
7
|
-
titans_pytorch-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.0.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|