titans-pytorch 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +29 -2
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/METADATA +3 -1
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.29.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
@@ -10,8 +10,14 @@ from torch.nn import Module, ModuleList, Linear
|
|
10
10
|
from einops import repeat
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
+
|
13
14
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
14
15
|
|
16
|
+
# absolute and relative positions
|
17
|
+
|
18
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
19
|
+
from rotary_embedding_torch import RotaryEmbedding
|
20
|
+
|
15
21
|
# constants
|
16
22
|
|
17
23
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -25,7 +31,7 @@ def default(v, d):
|
|
25
31
|
return v if exists(v) else d
|
26
32
|
|
27
33
|
def round_up_multiple(seq, mult):
|
28
|
-
return
|
34
|
+
return ceil(seq / mult) * mult
|
29
35
|
|
30
36
|
# feedforward and attention
|
31
37
|
|
@@ -58,6 +64,8 @@ class SegmentedAttention(Module):
|
|
58
64
|
|
59
65
|
dim_inner = dim_head * heads
|
60
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
61
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
62
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
63
71
|
|
@@ -99,6 +107,12 @@ class SegmentedAttention(Module):
|
|
99
107
|
|
100
108
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
109
|
|
110
|
+
# relative positions
|
111
|
+
|
112
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
113
|
+
|
114
|
+
# persistent memory
|
115
|
+
|
102
116
|
k = cat((pmk, k), dim = -2)
|
103
117
|
v = cat((pmv, v), dim = -2)
|
104
118
|
|
@@ -133,6 +147,9 @@ class MemoryAsContextTransformer(Module):
|
|
133
147
|
):
|
134
148
|
super().__init__()
|
135
149
|
|
150
|
+
self.segment_len = segment_len
|
151
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
|
+
|
136
153
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
137
154
|
|
138
155
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
@@ -160,9 +177,19 @@ class MemoryAsContextTransformer(Module):
|
|
160
177
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
161
178
|
|
162
179
|
def forward(self, x):
|
180
|
+
seq_len, segment_len = x.shape[-1], self.segment_len
|
181
|
+
windows = ceil(seq_len / segment_len)
|
163
182
|
|
164
183
|
x = self.token_emb(x)
|
165
184
|
|
185
|
+
# apply axial positional embedding
|
186
|
+
# so intra and inter segment can be more easily discerned by the network
|
187
|
+
|
188
|
+
pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
|
189
|
+
x = x + pos_emb[:seq_len]
|
190
|
+
|
191
|
+
# expand and reduce streams for hyper connections
|
192
|
+
|
166
193
|
x = self.expand_streams(x)
|
167
194
|
|
168
195
|
for attn, ff in self.layers:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.29
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=JBrJah7gfQDPizYRcBvpUKinrd2I9KMB997f3RIR8TA,5568
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.29.dist-info/METADATA,sha256=EhS4E9SAoqzDa0PIjZpQmUSYAo5IS-XePofWlZZnIS0,3938
|
7
|
+
titans_pytorch-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.29.dist-info/RECORD,,
|
File without changes
|
File without changes
|