titans-pytorch 0.0.27__py3-none-any.whl → 0.0.30__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- titans_pytorch/mac_transformer.py +81 -24
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/METADATA +3 -1
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
@@ -7,11 +7,16 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat
|
10
|
+
from einops import repeat, rearrange
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
14
14
|
|
15
|
+
# absolute and relative positions
|
16
|
+
|
17
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
|
+
from rotary_embedding_torch import RotaryEmbedding
|
19
|
+
|
15
20
|
# constants
|
16
21
|
|
17
22
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -25,7 +30,7 @@ def default(v, d):
|
|
25
30
|
return v if exists(v) else d
|
26
31
|
|
27
32
|
def round_up_multiple(seq, mult):
|
28
|
-
return
|
33
|
+
return ceil(seq / mult) * mult
|
29
34
|
|
30
35
|
# feedforward and attention
|
31
36
|
|
@@ -49,7 +54,8 @@ class SegmentedAttention(Module):
|
|
49
54
|
self,
|
50
55
|
dim,
|
51
56
|
segment_len,
|
52
|
-
num_persist_mem_tokens,
|
57
|
+
num_persist_mem_tokens = 0,
|
58
|
+
num_longterm_mem_tokens = 0,
|
53
59
|
dim_head = 64,
|
54
60
|
heads = 8,
|
55
61
|
):
|
@@ -58,29 +64,37 @@ class SegmentedAttention(Module):
|
|
58
64
|
|
59
65
|
dim_inner = dim_head * heads
|
60
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
61
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
62
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
63
71
|
|
64
72
|
self.segment_len = segment_len
|
73
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
74
|
+
|
75
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
65
76
|
|
66
77
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
67
78
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
68
79
|
|
69
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n =
|
70
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n =
|
80
|
+
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
+
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
71
82
|
|
72
83
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
73
84
|
|
74
85
|
def forward(self, seq):
|
86
|
+
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
87
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
88
|
+
|
75
89
|
batch, seq_len = seq.shape[:2]
|
76
90
|
|
77
91
|
# auto pad to multiple
|
78
92
|
# todo - get rid of logic with flex attention
|
79
93
|
|
80
|
-
need_segment = seq_len >=
|
94
|
+
need_segment = seq_len >= total_segment_len
|
81
95
|
|
82
96
|
if need_segment:
|
83
|
-
next_seq_len = round_up_multiple(seq_len,
|
97
|
+
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
84
98
|
padding = next_seq_len - seq_len
|
85
99
|
|
86
100
|
if padding > 0:
|
@@ -99,6 +113,12 @@ class SegmentedAttention(Module):
|
|
99
113
|
|
100
114
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
115
|
|
116
|
+
# relative positions
|
117
|
+
|
118
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
119
|
+
|
120
|
+
# persistent memory
|
121
|
+
|
102
122
|
k = cat((pmk, k), dim = -2)
|
103
123
|
v = cat((pmv, v), dim = -2)
|
104
124
|
|
@@ -125,7 +145,8 @@ class MemoryAsContextTransformer(Module):
|
|
125
145
|
dim,
|
126
146
|
depth,
|
127
147
|
segment_len,
|
128
|
-
|
148
|
+
num_longterm_mem_tokens = 0,
|
149
|
+
num_persist_mem_tokens = 0,
|
129
150
|
dim_head = 64,
|
130
151
|
heads = 8,
|
131
152
|
ff_mult = 4,
|
@@ -135,6 +156,17 @@ class MemoryAsContextTransformer(Module):
|
|
135
156
|
|
136
157
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
137
158
|
|
159
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
160
|
+
|
161
|
+
# long term mem tokens
|
162
|
+
|
163
|
+
self.segment_len = segment_len
|
164
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
165
|
+
|
166
|
+
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
167
|
+
|
168
|
+
# hyper conection
|
169
|
+
|
138
170
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
139
171
|
|
140
172
|
self.layers = ModuleList([])
|
@@ -145,6 +177,7 @@ class MemoryAsContextTransformer(Module):
|
|
145
177
|
dim_head = dim_head,
|
146
178
|
heads = heads,
|
147
179
|
segment_len = segment_len,
|
180
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
148
181
|
num_persist_mem_tokens = num_persist_mem_tokens
|
149
182
|
)
|
150
183
|
|
@@ -161,8 +194,45 @@ class MemoryAsContextTransformer(Module):
|
|
161
194
|
|
162
195
|
def forward(self, x):
|
163
196
|
|
197
|
+
# math
|
198
|
+
|
199
|
+
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
200
|
+
|
201
|
+
windows = ceil(seq_len / segment_len)
|
202
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
203
|
+
|
204
|
+
# token embedding
|
205
|
+
|
164
206
|
x = self.token_emb(x)
|
165
207
|
|
208
|
+
# intersperse longterm memory
|
209
|
+
|
210
|
+
need_segment = seq_len >= segment_len
|
211
|
+
|
212
|
+
if need_segment:
|
213
|
+
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
+
padding = next_seq_len - seq_len
|
215
|
+
|
216
|
+
if padding > 0:
|
217
|
+
x = F.pad(x, (0, 0, 0, padding))
|
218
|
+
|
219
|
+
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
220
|
+
|
221
|
+
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
|
+
x = torch.cat((mems, x), dim = -2)
|
223
|
+
|
224
|
+
if need_segment:
|
225
|
+
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
+
x = x[:, :seq_len]
|
227
|
+
|
228
|
+
# apply axial positional embedding
|
229
|
+
# so intra and inter segment can be more easily discerned by the network
|
230
|
+
|
231
|
+
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
+
x = x + pos_emb[:seq_len]
|
233
|
+
|
234
|
+
# expand and reduce streams for hyper connections
|
235
|
+
|
166
236
|
x = self.expand_streams(x)
|
167
237
|
|
168
238
|
for attn, ff in self.layers:
|
@@ -171,21 +241,8 @@ class MemoryAsContextTransformer(Module):
|
|
171
241
|
|
172
242
|
x = self.reduce_streams(x)
|
173
243
|
|
244
|
+
# to logits
|
245
|
+
|
174
246
|
x = self.norm(x)
|
175
247
|
|
176
248
|
return self.to_logits(x)
|
177
|
-
|
178
|
-
# main
|
179
|
-
|
180
|
-
if __name__ == '__main__':
|
181
|
-
transformer = MemoryAsContextTransformer(
|
182
|
-
num_tokens = 256,
|
183
|
-
dim = 256,
|
184
|
-
depth = 2,
|
185
|
-
num_persist_mem_tokens = 16,
|
186
|
-
segment_len = 128,
|
187
|
-
)
|
188
|
-
|
189
|
-
x = torch.randint(0, 256, (1, 1023))
|
190
|
-
|
191
|
-
logits = transformer(x)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.30
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=XDLc9NWXpVMza03XjU0lkw5lRvtJ25ReTPKNoGslOOk,6773
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.30.dist-info/METADATA,sha256=7H6WPsgfBE9ByEUP7r6C-cfBX6K13yNCaVhZMzwUvf8,3938
|
7
|
+
titans_pytorch-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|