titans-pytorch 0.0.37__py3-none-any.whl → 0.0.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +10 -5
- titans_pytorch/titans.py +8 -5
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.39.dist-info}/METADATA +2 -1
- titans_pytorch-0.0.39.dist-info/RECORD +9 -0
- titans_pytorch-0.0.37.dist-info/RECORD +0 -9
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.39.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.37.dist-info → titans_pytorch-0.0.39.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange, pack, unpack
|
10
|
+
from einops import einsum, repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
16
16
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
19
|
+
from x_transformers.attend import Attend
|
19
20
|
|
20
21
|
# proposed neural memory
|
21
22
|
|
@@ -93,6 +94,7 @@ class SegmentedAttention(Module):
|
|
93
94
|
num_longterm_mem_tokens = 0,
|
94
95
|
dim_head = 64,
|
95
96
|
heads = 8,
|
97
|
+
attend_kwargs: dict = dict()
|
96
98
|
):
|
97
99
|
super().__init__()
|
98
100
|
self.norm = nn.RMSNorm(dim)
|
@@ -101,6 +103,8 @@ class SegmentedAttention(Module):
|
|
101
103
|
|
102
104
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
103
105
|
|
106
|
+
self.attend = Attend(causal = True, **attend_kwargs)
|
107
|
+
|
104
108
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
105
109
|
self.to_out = LinearNoBias(dim_inner, dim)
|
106
110
|
|
@@ -145,9 +149,9 @@ class SegmentedAttention(Module):
|
|
145
149
|
k = cat((pmk, k), dim = -2)
|
146
150
|
v = cat((pmv, v), dim = -2)
|
147
151
|
|
148
|
-
#
|
152
|
+
# attention
|
149
153
|
|
150
|
-
out =
|
154
|
+
out, _ = self.attend(q, k, v)
|
151
155
|
|
152
156
|
out = self.merge_heads(out)
|
153
157
|
|
@@ -288,7 +292,8 @@ class MemoryAsContextTransformer(Module):
|
|
288
292
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
289
293
|
|
290
294
|
if exists(maybe_neural_mem):
|
291
|
-
|
295
|
+
x = maybe_neural_mem(x)
|
296
|
+
|
292
297
|
|
293
298
|
x = attn(x)
|
294
299
|
|
@@ -300,7 +305,7 @@ class MemoryAsContextTransformer(Module):
|
|
300
305
|
|
301
306
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
302
307
|
|
303
|
-
x,
|
308
|
+
x, _ = unpack(x, mem_ps, 'b * d')
|
304
309
|
|
305
310
|
x = inverse_segment(x)
|
306
311
|
|
titans_pytorch/titans.py
CHANGED
@@ -27,9 +27,7 @@ n - sequence
|
|
27
27
|
d - feature dimension
|
28
28
|
c - intra-chunk
|
29
29
|
"""
|
30
|
-
|
31
|
-
# constants
|
32
|
-
|
30
|
+
7
|
33
31
|
LinearNoBias = partial(Linear, bias = False)
|
34
32
|
|
35
33
|
# functions
|
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
|
|
390
388
|
|
391
389
|
padding = next_seq_len - curtailed_seq_len
|
392
390
|
|
393
|
-
|
391
|
+
needs_pad = padding > 0
|
392
|
+
|
393
|
+
if needs_pad:
|
394
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
394
395
|
|
395
396
|
# the parameters of the memory model stores the memories of the key / values
|
396
397
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
|
|
442
443
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
443
444
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
444
445
|
|
445
|
-
|
446
|
+
if needs_pad:
|
447
|
+
values = values[:, :-padding]
|
448
|
+
|
446
449
|
return values
|
447
450
|
|
448
451
|
def forward(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.39
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -43,6 +43,7 @@ Requires-Dist: ninja
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
44
44
|
Requires-Dist: tensordict
|
45
45
|
Requires-Dist: torch>=2.2
|
46
|
+
Requires-Dist: x-transformers
|
46
47
|
Provides-Extra: examples
|
47
48
|
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
48
49
|
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=h58sHfufxMnSXZXyWuW-KBwzq8xwBYmFjU2XtOjUixk,8512
|
4
|
+
titans_pytorch/titans.py,sha256=bv2Ceq-_4nNb5FNx4hLd2inC93m5MmJxO2-Mbf6PKK4,14378
|
5
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
+
titans_pytorch-0.0.39.dist-info/METADATA,sha256=3KD2hmJ-uOyQ87Z3VB6JfaKtDcLBnoKA8037DpzJuPE,3968
|
7
|
+
titans_pytorch-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.39.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=JjKGEMBit_SvyAsxq5v08614YBcLVx3OkM6pf0rADsA,8400
|
4
|
-
titans_pytorch/titans.py,sha256=ALICGfc6p2bD2QkaibyIceVLvBIRKXmDm-w7RjnVOe4,14304
|
5
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.37.dist-info/METADATA,sha256=RNokG8101_tlR0BiF-AxqYLZpXqafMSiN1Rg_pZe2-o,3938
|
7
|
-
titans_pytorch-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.0.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|