titans-pytorch 0.0.29__tar.gz → 0.0.30__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/PKG-INFO +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/pyproject.toml +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/tests/test_titans.py +8 -2
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/mac_transformer.py +57 -27
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/.gitignore +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/LICENSE +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/data/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/fig1.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/fig2.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/requirements.txt +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/train.py +0 -0
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
|
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
36
|
|
37
|
-
|
37
|
+
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
38
|
+
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
39
|
+
def test_mac(
|
40
|
+
num_persist_mem_tokens,
|
41
|
+
num_longterm_mem_tokens
|
42
|
+
):
|
38
43
|
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
44
|
|
40
45
|
transformer = MemoryAsContextTransformer(
|
41
46
|
num_tokens = 256,
|
42
47
|
dim = 256,
|
43
48
|
depth = 2,
|
44
|
-
num_persist_mem_tokens =
|
49
|
+
num_persist_mem_tokens = num_persist_mem_tokens,
|
50
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
45
51
|
segment_len = 128,
|
46
52
|
)
|
47
53
|
|
@@ -7,10 +7,9 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat
|
10
|
+
from einops import repeat, rearrange
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
-
|
14
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
15
14
|
|
16
15
|
# absolute and relative positions
|
@@ -55,7 +54,8 @@ class SegmentedAttention(Module):
|
|
55
54
|
self,
|
56
55
|
dim,
|
57
56
|
segment_len,
|
58
|
-
num_persist_mem_tokens,
|
57
|
+
num_persist_mem_tokens = 0,
|
58
|
+
num_longterm_mem_tokens = 0,
|
59
59
|
dim_head = 64,
|
60
60
|
heads = 8,
|
61
61
|
):
|
@@ -70,25 +70,31 @@ class SegmentedAttention(Module):
|
|
70
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
71
71
|
|
72
72
|
self.segment_len = segment_len
|
73
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
74
|
+
|
75
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
73
76
|
|
74
77
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
75
78
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
76
79
|
|
77
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n =
|
78
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n =
|
80
|
+
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
+
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
79
82
|
|
80
83
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
81
84
|
|
82
85
|
def forward(self, seq):
|
86
|
+
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
87
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
88
|
+
|
83
89
|
batch, seq_len = seq.shape[:2]
|
84
90
|
|
85
91
|
# auto pad to multiple
|
86
92
|
# todo - get rid of logic with flex attention
|
87
93
|
|
88
|
-
need_segment = seq_len >=
|
94
|
+
need_segment = seq_len >= total_segment_len
|
89
95
|
|
90
96
|
if need_segment:
|
91
|
-
next_seq_len = round_up_multiple(seq_len,
|
97
|
+
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
92
98
|
padding = next_seq_len - seq_len
|
93
99
|
|
94
100
|
if padding > 0:
|
@@ -139,7 +145,8 @@ class MemoryAsContextTransformer(Module):
|
|
139
145
|
dim,
|
140
146
|
depth,
|
141
147
|
segment_len,
|
142
|
-
|
148
|
+
num_longterm_mem_tokens = 0,
|
149
|
+
num_persist_mem_tokens = 0,
|
143
150
|
dim_head = 64,
|
144
151
|
heads = 8,
|
145
152
|
ff_mult = 4,
|
@@ -147,10 +154,18 @@ class MemoryAsContextTransformer(Module):
|
|
147
154
|
):
|
148
155
|
super().__init__()
|
149
156
|
|
150
|
-
self.
|
157
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
158
|
+
|
151
159
|
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
160
|
|
153
|
-
|
161
|
+
# long term mem tokens
|
162
|
+
|
163
|
+
self.segment_len = segment_len
|
164
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
165
|
+
|
166
|
+
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
167
|
+
|
168
|
+
# hyper conection
|
154
169
|
|
155
170
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
156
171
|
|
@@ -162,6 +177,7 @@ class MemoryAsContextTransformer(Module):
|
|
162
177
|
dim_head = dim_head,
|
163
178
|
heads = heads,
|
164
179
|
segment_len = segment_len,
|
180
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
165
181
|
num_persist_mem_tokens = num_persist_mem_tokens
|
166
182
|
)
|
167
183
|
|
@@ -177,15 +193,42 @@ class MemoryAsContextTransformer(Module):
|
|
177
193
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
178
194
|
|
179
195
|
def forward(self, x):
|
180
|
-
|
196
|
+
|
197
|
+
# math
|
198
|
+
|
199
|
+
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
200
|
+
|
181
201
|
windows = ceil(seq_len / segment_len)
|
202
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
203
|
+
|
204
|
+
# token embedding
|
182
205
|
|
183
206
|
x = self.token_emb(x)
|
184
207
|
|
208
|
+
# intersperse longterm memory
|
209
|
+
|
210
|
+
need_segment = seq_len >= segment_len
|
211
|
+
|
212
|
+
if need_segment:
|
213
|
+
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
+
padding = next_seq_len - seq_len
|
215
|
+
|
216
|
+
if padding > 0:
|
217
|
+
x = F.pad(x, (0, 0, 0, padding))
|
218
|
+
|
219
|
+
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
220
|
+
|
221
|
+
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
|
+
x = torch.cat((mems, x), dim = -2)
|
223
|
+
|
224
|
+
if need_segment:
|
225
|
+
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
+
x = x[:, :seq_len]
|
227
|
+
|
185
228
|
# apply axial positional embedding
|
186
229
|
# so intra and inter segment can be more easily discerned by the network
|
187
230
|
|
188
|
-
pos_emb = self.axial_pos_emb((windows,
|
231
|
+
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
189
232
|
x = x + pos_emb[:seq_len]
|
190
233
|
|
191
234
|
# expand and reduce streams for hyper connections
|
@@ -198,21 +241,8 @@ class MemoryAsContextTransformer(Module):
|
|
198
241
|
|
199
242
|
x = self.reduce_streams(x)
|
200
243
|
|
244
|
+
# to logits
|
245
|
+
|
201
246
|
x = self.norm(x)
|
202
247
|
|
203
248
|
return self.to_logits(x)
|
204
|
-
|
205
|
-
# main
|
206
|
-
|
207
|
-
if __name__ == '__main__':
|
208
|
-
transformer = MemoryAsContextTransformer(
|
209
|
-
num_tokens = 256,
|
210
|
-
dim = 256,
|
211
|
-
depth = 2,
|
212
|
-
num_persist_mem_tokens = 16,
|
213
|
-
segment_len = 128,
|
214
|
-
)
|
215
|
-
|
216
|
-
x = torch.randint(0, 256, (1, 1023))
|
217
|
-
|
218
|
-
logits = transformer(x)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|