titans-pytorch 0.0.30__tar.gz → 0.0.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/PKG-INFO +1 -1
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/pyproject.toml +1 -1
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/titans_pytorch/mac_transformer.py +82 -34
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/.gitignore +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/LICENSE +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/README.md +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/data/README.md +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/fig1.png +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/fig2.png +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/requirements.txt +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.30 → titans_pytorch-0.0.32}/train.py +0 -0
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
19
19
|
|
20
|
+
# proposed neural memory
|
21
|
+
|
22
|
+
from titans_pytorch.titans import NeuralMemory
|
23
|
+
|
20
24
|
# constants
|
21
25
|
|
22
26
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -29,9 +33,33 @@ def exists(v):
|
|
29
33
|
def default(v, d):
|
30
34
|
return v if exists(v) else d
|
31
35
|
|
36
|
+
def identity(t):
|
37
|
+
return t
|
38
|
+
|
32
39
|
def round_up_multiple(seq, mult):
|
33
40
|
return ceil(seq / mult) * mult
|
34
41
|
|
42
|
+
def pad_and_segment_with_inverse(seq, segment_len):
|
43
|
+
batch, seq_len = seq.shape[:2]
|
44
|
+
|
45
|
+
need_segment = seq_len >= segment_len
|
46
|
+
|
47
|
+
if not need_segment:
|
48
|
+
return seq, identity
|
49
|
+
|
50
|
+
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
51
|
+
|
52
|
+
padding = next_seq_len_mult - seq_len
|
53
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
54
|
+
|
55
|
+
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
56
|
+
|
57
|
+
def inverse(out):
|
58
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
59
|
+
return out[:, :-padding]
|
60
|
+
|
61
|
+
return seq, inverse
|
62
|
+
|
35
63
|
# feedforward and attention
|
36
64
|
|
37
65
|
class GEGLU(Module):
|
@@ -77,9 +105,6 @@ class SegmentedAttention(Module):
|
|
77
105
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
78
106
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
79
107
|
|
80
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
82
|
-
|
83
108
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
84
109
|
|
85
110
|
def forward(self, seq):
|
@@ -91,16 +116,7 @@ class SegmentedAttention(Module):
|
|
91
116
|
# auto pad to multiple
|
92
117
|
# todo - get rid of logic with flex attention
|
93
118
|
|
94
|
-
|
95
|
-
|
96
|
-
if need_segment:
|
97
|
-
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
98
|
-
padding = next_seq_len - seq_len
|
99
|
-
|
100
|
-
if padding > 0:
|
101
|
-
seq = F.pad(seq, (0, 0, 0, padding))
|
102
|
-
|
103
|
-
seq = self.segment_seq(seq)
|
119
|
+
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
|
104
120
|
|
105
121
|
# attention
|
106
122
|
|
@@ -130,10 +146,9 @@ class SegmentedAttention(Module):
|
|
130
146
|
|
131
147
|
out = self.to_out(out)
|
132
148
|
|
133
|
-
|
134
|
-
out = self.merge_seq_back(out)
|
149
|
+
out = inverse_segment(out)
|
135
150
|
|
136
|
-
return out
|
151
|
+
return out
|
137
152
|
|
138
153
|
# MAC transformer
|
139
154
|
|
@@ -150,7 +165,9 @@ class MemoryAsContextTransformer(Module):
|
|
150
165
|
dim_head = 64,
|
151
166
|
heads = 8,
|
152
167
|
ff_mult = 4,
|
153
|
-
num_residual_streams = 4
|
168
|
+
num_residual_streams = 4,
|
169
|
+
neural_memory_kwargs: dict = dict(),
|
170
|
+
neural_memory_layers: tuple[int, ...] | None = None,
|
154
171
|
):
|
155
172
|
super().__init__()
|
156
173
|
|
@@ -170,8 +187,25 @@ class MemoryAsContextTransformer(Module):
|
|
170
187
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
171
188
|
|
172
189
|
self.layers = ModuleList([])
|
190
|
+
self.neural_mem_layers = ModuleList([])
|
191
|
+
|
192
|
+
layers = tuple(range(1, depth + 1))
|
193
|
+
neural_memory_layers = set(default(neural_memory_layers, layers))
|
194
|
+
|
195
|
+
for layer in layers:
|
196
|
+
|
197
|
+
# neural memory
|
198
|
+
|
199
|
+
mem = None
|
200
|
+
|
201
|
+
if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
|
202
|
+
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
203
|
+
mem = init_hyper_conn(dim = dim, branch = mem)
|
204
|
+
|
205
|
+
self.neural_mem_layers.append(mem)
|
206
|
+
|
207
|
+
# attention and feedforward
|
173
208
|
|
174
|
-
for _ in range(depth):
|
175
209
|
attn = SegmentedAttention(
|
176
210
|
dim = dim,
|
177
211
|
dim_head = dim_head,
|
@@ -207,40 +241,54 @@ class MemoryAsContextTransformer(Module):
|
|
207
241
|
|
208
242
|
# intersperse longterm memory
|
209
243
|
|
210
|
-
|
211
|
-
|
212
|
-
if need_segment:
|
213
|
-
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
-
padding = next_seq_len - seq_len
|
215
|
-
|
216
|
-
if padding > 0:
|
217
|
-
x = F.pad(x, (0, 0, 0, padding))
|
218
|
-
|
219
|
-
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
244
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
220
245
|
|
221
246
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
|
-
x =
|
247
|
+
x = cat((mems, x), dim = -2)
|
223
248
|
|
224
|
-
|
225
|
-
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
-
x = x[:, :seq_len]
|
249
|
+
x = inverse_segment(x)
|
227
250
|
|
228
251
|
# apply axial positional embedding
|
229
252
|
# so intra and inter segment can be more easily discerned by the network
|
230
253
|
|
231
254
|
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
-
x = x + pos_emb[:
|
255
|
+
x = x + pos_emb[:x.shape[-2]]
|
233
256
|
|
234
257
|
# expand and reduce streams for hyper connections
|
235
258
|
|
236
259
|
x = self.expand_streams(x)
|
237
260
|
|
238
|
-
for attn, ff in self.layers:
|
261
|
+
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
262
|
+
|
263
|
+
if exists(maybe_neural_mem):
|
264
|
+
batch_streams = x.shape[0]
|
265
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
|
+
|
267
|
+
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
268
|
+
|
269
|
+
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
270
|
+
|
271
|
+
longterm_mems = maybe_neural_mem(longterm_mems)
|
272
|
+
|
273
|
+
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
274
|
+
|
275
|
+
x = cat((longterm_mems, x), dim = -2)
|
276
|
+
|
277
|
+
x = inverse_segment(x)
|
278
|
+
|
239
279
|
x = attn(x)
|
240
280
|
x = ff(x)
|
241
281
|
|
242
282
|
x = self.reduce_streams(x)
|
243
283
|
|
284
|
+
# excise out the memories
|
285
|
+
|
286
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
287
|
+
|
288
|
+
x = x[:, num_longterm_mem_tokens:]
|
289
|
+
|
290
|
+
x = inverse_segment(x)
|
291
|
+
|
244
292
|
# to logits
|
245
293
|
|
246
294
|
x = self.norm(x)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|