titans-pytorch 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/mac_transformer.py +29 -2
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/METADATA +3 -1
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
@@ -10,8 +10,14 @@ from torch.nn import Module, ModuleList, Linear
|
|
10
10
|
from einops import repeat
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
+
|
13
14
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
14
15
|
|
16
|
+
# absolute and relative positions
|
17
|
+
|
18
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
19
|
+
from rotary_embedding_torch import RotaryEmbedding
|
20
|
+
|
15
21
|
# constants
|
16
22
|
|
17
23
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -25,7 +31,7 @@ def default(v, d):
|
|
25
31
|
return v if exists(v) else d
|
26
32
|
|
27
33
|
def round_up_multiple(seq, mult):
|
28
|
-
return
|
34
|
+
return ceil(seq / mult) * mult
|
29
35
|
|
30
36
|
# feedforward and attention
|
31
37
|
|
@@ -58,6 +64,8 @@ class SegmentedAttention(Module):
|
|
58
64
|
|
59
65
|
dim_inner = dim_head * heads
|
60
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
61
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
62
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
63
71
|
|
@@ -99,6 +107,12 @@ class SegmentedAttention(Module):
|
|
99
107
|
|
100
108
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
109
|
|
110
|
+
# relative positions
|
111
|
+
|
112
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
113
|
+
|
114
|
+
# persistent memory
|
115
|
+
|
102
116
|
k = cat((pmk, k), dim = -2)
|
103
117
|
v = cat((pmv, v), dim = -2)
|
104
118
|
|
@@ -133,6 +147,9 @@ class MemoryAsContextTransformer(Module):
|
|
133
147
|
):
|
134
148
|
super().__init__()
|
135
149
|
|
150
|
+
self.segment_len = segment_len
|
151
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
|
+
|
136
153
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
137
154
|
|
138
155
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
@@ -160,9 +177,19 @@ class MemoryAsContextTransformer(Module):
|
|
160
177
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
161
178
|
|
162
179
|
def forward(self, x):
|
180
|
+
seq_len, segment_len = x.shape[-1], self.segment_len
|
181
|
+
windows = ceil(seq_len / segment_len)
|
163
182
|
|
164
183
|
x = self.token_emb(x)
|
165
184
|
|
185
|
+
# apply axial positional embedding
|
186
|
+
# so intra and inter segment can be more easily discerned by the network
|
187
|
+
|
188
|
+
pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
|
189
|
+
x = x + pos_emb[:seq_len]
|
190
|
+
|
191
|
+
# expand and reduce streams for hyper connections
|
192
|
+
|
166
193
|
x = self.expand_streams(x)
|
167
194
|
|
168
195
|
for attn, ff in self.layers:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.29
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=JBrJah7gfQDPizYRcBvpUKinrd2I9KMB997f3RIR8TA,5568
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.29.dist-info/METADATA,sha256=EhS4E9SAoqzDa0PIjZpQmUSYAo5IS-XePofWlZZnIS0,3938
|
7
|
+
titans_pytorch-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.29.dist-info/RECORD,,
|
File without changes
|
File without changes
|