titans-pytorch 0.0.27__py3-none-any.whl → 0.0.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +81 -24
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/METADATA +3 -1
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/RECORD +5 -5
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.27.dist-info → titans_pytorch-0.0.30.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
@@ -7,11 +7,16 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat
|
10
|
+
from einops import repeat, rearrange
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
14
14
|
|
15
|
+
# absolute and relative positions
|
16
|
+
|
17
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
|
+
from rotary_embedding_torch import RotaryEmbedding
|
19
|
+
|
15
20
|
# constants
|
16
21
|
|
17
22
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -25,7 +30,7 @@ def default(v, d):
|
|
25
30
|
return v if exists(v) else d
|
26
31
|
|
27
32
|
def round_up_multiple(seq, mult):
|
28
|
-
return
|
33
|
+
return ceil(seq / mult) * mult
|
29
34
|
|
30
35
|
# feedforward and attention
|
31
36
|
|
@@ -49,7 +54,8 @@ class SegmentedAttention(Module):
|
|
49
54
|
self,
|
50
55
|
dim,
|
51
56
|
segment_len,
|
52
|
-
num_persist_mem_tokens,
|
57
|
+
num_persist_mem_tokens = 0,
|
58
|
+
num_longterm_mem_tokens = 0,
|
53
59
|
dim_head = 64,
|
54
60
|
heads = 8,
|
55
61
|
):
|
@@ -58,29 +64,37 @@ class SegmentedAttention(Module):
|
|
58
64
|
|
59
65
|
dim_inner = dim_head * heads
|
60
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
61
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
62
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
63
71
|
|
64
72
|
self.segment_len = segment_len
|
73
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
74
|
+
|
75
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
65
76
|
|
66
77
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
67
78
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
68
79
|
|
69
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n =
|
70
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n =
|
80
|
+
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
+
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
71
82
|
|
72
83
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
73
84
|
|
74
85
|
def forward(self, seq):
|
86
|
+
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
87
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
88
|
+
|
75
89
|
batch, seq_len = seq.shape[:2]
|
76
90
|
|
77
91
|
# auto pad to multiple
|
78
92
|
# todo - get rid of logic with flex attention
|
79
93
|
|
80
|
-
need_segment = seq_len >=
|
94
|
+
need_segment = seq_len >= total_segment_len
|
81
95
|
|
82
96
|
if need_segment:
|
83
|
-
next_seq_len = round_up_multiple(seq_len,
|
97
|
+
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
84
98
|
padding = next_seq_len - seq_len
|
85
99
|
|
86
100
|
if padding > 0:
|
@@ -99,6 +113,12 @@ class SegmentedAttention(Module):
|
|
99
113
|
|
100
114
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
115
|
|
116
|
+
# relative positions
|
117
|
+
|
118
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
119
|
+
|
120
|
+
# persistent memory
|
121
|
+
|
102
122
|
k = cat((pmk, k), dim = -2)
|
103
123
|
v = cat((pmv, v), dim = -2)
|
104
124
|
|
@@ -125,7 +145,8 @@ class MemoryAsContextTransformer(Module):
|
|
125
145
|
dim,
|
126
146
|
depth,
|
127
147
|
segment_len,
|
128
|
-
|
148
|
+
num_longterm_mem_tokens = 0,
|
149
|
+
num_persist_mem_tokens = 0,
|
129
150
|
dim_head = 64,
|
130
151
|
heads = 8,
|
131
152
|
ff_mult = 4,
|
@@ -135,6 +156,17 @@ class MemoryAsContextTransformer(Module):
|
|
135
156
|
|
136
157
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
137
158
|
|
159
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
160
|
+
|
161
|
+
# long term mem tokens
|
162
|
+
|
163
|
+
self.segment_len = segment_len
|
164
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
165
|
+
|
166
|
+
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
167
|
+
|
168
|
+
# hyper conection
|
169
|
+
|
138
170
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
139
171
|
|
140
172
|
self.layers = ModuleList([])
|
@@ -145,6 +177,7 @@ class MemoryAsContextTransformer(Module):
|
|
145
177
|
dim_head = dim_head,
|
146
178
|
heads = heads,
|
147
179
|
segment_len = segment_len,
|
180
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
148
181
|
num_persist_mem_tokens = num_persist_mem_tokens
|
149
182
|
)
|
150
183
|
|
@@ -161,8 +194,45 @@ class MemoryAsContextTransformer(Module):
|
|
161
194
|
|
162
195
|
def forward(self, x):
|
163
196
|
|
197
|
+
# math
|
198
|
+
|
199
|
+
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
200
|
+
|
201
|
+
windows = ceil(seq_len / segment_len)
|
202
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
203
|
+
|
204
|
+
# token embedding
|
205
|
+
|
164
206
|
x = self.token_emb(x)
|
165
207
|
|
208
|
+
# intersperse longterm memory
|
209
|
+
|
210
|
+
need_segment = seq_len >= segment_len
|
211
|
+
|
212
|
+
if need_segment:
|
213
|
+
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
+
padding = next_seq_len - seq_len
|
215
|
+
|
216
|
+
if padding > 0:
|
217
|
+
x = F.pad(x, (0, 0, 0, padding))
|
218
|
+
|
219
|
+
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
220
|
+
|
221
|
+
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
|
+
x = torch.cat((mems, x), dim = -2)
|
223
|
+
|
224
|
+
if need_segment:
|
225
|
+
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
+
x = x[:, :seq_len]
|
227
|
+
|
228
|
+
# apply axial positional embedding
|
229
|
+
# so intra and inter segment can be more easily discerned by the network
|
230
|
+
|
231
|
+
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
+
x = x + pos_emb[:seq_len]
|
233
|
+
|
234
|
+
# expand and reduce streams for hyper connections
|
235
|
+
|
166
236
|
x = self.expand_streams(x)
|
167
237
|
|
168
238
|
for attn, ff in self.layers:
|
@@ -171,21 +241,8 @@ class MemoryAsContextTransformer(Module):
|
|
171
241
|
|
172
242
|
x = self.reduce_streams(x)
|
173
243
|
|
244
|
+
# to logits
|
245
|
+
|
174
246
|
x = self.norm(x)
|
175
247
|
|
176
248
|
return self.to_logits(x)
|
177
|
-
|
178
|
-
# main
|
179
|
-
|
180
|
-
if __name__ == '__main__':
|
181
|
-
transformer = MemoryAsContextTransformer(
|
182
|
-
num_tokens = 256,
|
183
|
-
dim = 256,
|
184
|
-
depth = 2,
|
185
|
-
num_persist_mem_tokens = 16,
|
186
|
-
segment_len = 128,
|
187
|
-
)
|
188
|
-
|
189
|
-
x = torch.randint(0, 256, (1, 1023))
|
190
|
-
|
191
|
-
logits = transformer(x)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.30
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,9 +1,9 @@
|
|
1
1
|
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
2
|
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=XDLc9NWXpVMza03XjU0lkw5lRvtJ25ReTPKNoGslOOk,6773
|
4
4
|
titans_pytorch/titans.py,sha256=Kx_tl_QkeDccvkMwPZ0xQ_saYjZfbKzDNPTTSHNWYcA,14304
|
5
5
|
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
6
|
-
titans_pytorch-0.0.
|
7
|
-
titans_pytorch-0.0.
|
8
|
-
titans_pytorch-0.0.
|
9
|
-
titans_pytorch-0.0.
|
6
|
+
titans_pytorch-0.0.30.dist-info/METADATA,sha256=7H6WPsgfBE9ByEUP7r6C-cfBX6K13yNCaVhZMzwUvf8,3938
|
7
|
+
titans_pytorch-0.0.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.0.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.0.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|