titans-pytorch 0.0.35__tar.gz → 0.0.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/PKG-INFO +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/pyproject.toml +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/mac_transformer.py +20 -21
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/titans.py +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/train_mac.py +12 -8
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.gitignore +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/LICENSE +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/README.md +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/data/README.md +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/fig1.png +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/fig2.png +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/requirements.txt +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/train.py +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange
|
10
|
+
from einops import repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
|
|
185
185
|
# long term mem tokens
|
186
186
|
|
187
187
|
self.segment_len = segment_len
|
188
|
+
|
188
189
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
190
|
+
has_longterm_mems = num_longterm_mem_tokens > 0
|
189
191
|
|
190
192
|
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
191
193
|
|
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
|
|
197
199
|
self.neural_mem_layers = ModuleList([])
|
198
200
|
|
199
201
|
layers = tuple(range(1, depth + 1))
|
200
|
-
|
202
|
+
|
203
|
+
if not exists(neural_memory_layers):
|
204
|
+
neural_memory_layers = layers if has_longterm_mems else ()
|
205
|
+
|
206
|
+
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
201
207
|
|
202
208
|
for layer in layers:
|
203
209
|
|
@@ -205,8 +211,15 @@ class MemoryAsContextTransformer(Module):
|
|
205
211
|
|
206
212
|
mem = None
|
207
213
|
|
208
|
-
if
|
209
|
-
|
214
|
+
if layer in neural_memory_layers:
|
215
|
+
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
|
+
|
217
|
+
mem = NeuralMemory(
|
218
|
+
dim = dim,
|
219
|
+
chunk_size = num_longterm_mem_tokens + segment_len,
|
220
|
+
**neural_memory_kwargs
|
221
|
+
)
|
222
|
+
|
210
223
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
211
224
|
|
212
225
|
self.neural_mem_layers.append(mem)
|
@@ -258,7 +271,7 @@ class MemoryAsContextTransformer(Module):
|
|
258
271
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
259
272
|
|
260
273
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
261
|
-
x =
|
274
|
+
x, mem_ps = pack((x, mems), 'b * d')
|
262
275
|
|
263
276
|
x = inverse_segment(x)
|
264
277
|
|
@@ -275,21 +288,7 @@ class MemoryAsContextTransformer(Module):
|
|
275
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
276
289
|
|
277
290
|
if exists(maybe_neural_mem):
|
278
|
-
|
279
|
-
|
280
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
281
|
-
|
282
|
-
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
283
|
-
|
284
|
-
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
285
|
-
|
286
|
-
longterm_mems = maybe_neural_mem(longterm_mems)
|
287
|
-
|
288
|
-
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
289
|
-
|
290
|
-
x = cat((longterm_mems, x), dim = -2)
|
291
|
-
|
292
|
-
x = inverse_segment(x)
|
291
|
+
mems = maybe_neural_mem(mems)
|
293
292
|
|
294
293
|
x = attn(x)
|
295
294
|
|
@@ -301,7 +300,7 @@ class MemoryAsContextTransformer(Module):
|
|
301
300
|
|
302
301
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
303
302
|
|
304
|
-
x = x
|
303
|
+
x, mem = unpack(x, mem_ps, 'b * d')
|
305
304
|
|
306
305
|
x = inverse_segment(x)
|
307
306
|
|
@@ -24,11 +24,13 @@ SHOULD_GENERATE = False
|
|
24
24
|
SEQ_LEN = 512
|
25
25
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
-
WANDB_ONLINE =
|
28
|
-
GLOBAL_LAYERS = (2, 4)
|
27
|
+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
|
29
28
|
NEURAL_MEMORY_DEPTH = 2
|
30
|
-
|
31
|
-
|
29
|
+
NUM_PERSIST_MEM = 4
|
30
|
+
NUM_LONGTERM_MEM = 4
|
31
|
+
NEURAL_MEM_LAYERS = (4,)
|
32
|
+
WINDOW_SIZE = 32
|
33
|
+
RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
|
32
34
|
|
33
35
|
# wandb experiment tracker
|
34
36
|
|
@@ -57,12 +59,14 @@ model = MemoryAsContextTransformer(
|
|
57
59
|
dim = 384,
|
58
60
|
depth = 8,
|
59
61
|
segment_len = WINDOW_SIZE,
|
60
|
-
num_persist_mem_tokens =
|
61
|
-
num_longterm_mem_tokens =
|
62
|
-
neural_memory_layers =
|
62
|
+
num_persist_mem_tokens = NUM_PERSIST_MEM,
|
63
|
+
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
64
|
+
neural_memory_layers = NEURAL_MEM_LAYERS,
|
63
65
|
neural_memory_kwargs = dict(
|
66
|
+
dim_head = 64,
|
67
|
+
heads = 4,
|
64
68
|
default_mlp_kwargs = dict(
|
65
|
-
depth = NEURAL_MEMORY_DEPTH
|
69
|
+
depth = NEURAL_MEMORY_DEPTH,
|
66
70
|
)
|
67
71
|
)
|
68
72
|
).cuda()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|