titans-pytorch 0.0.35__tar.gz → 0.0.37__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/PKG-INFO +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/pyproject.toml +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/mac_transformer.py +20 -21
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/titans.py +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/train_mac.py +12 -8
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/.gitignore +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/LICENSE +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/README.md +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/data/README.md +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/fig1.png +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/fig2.png +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/requirements.txt +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.37}/train.py +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange
|
10
|
+
from einops import repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
|
|
185
185
|
# long term mem tokens
|
186
186
|
|
187
187
|
self.segment_len = segment_len
|
188
|
+
|
188
189
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
190
|
+
has_longterm_mems = num_longterm_mem_tokens > 0
|
189
191
|
|
190
192
|
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
191
193
|
|
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
|
|
197
199
|
self.neural_mem_layers = ModuleList([])
|
198
200
|
|
199
201
|
layers = tuple(range(1, depth + 1))
|
200
|
-
|
202
|
+
|
203
|
+
if not exists(neural_memory_layers):
|
204
|
+
neural_memory_layers = layers if has_longterm_mems else ()
|
205
|
+
|
206
|
+
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
201
207
|
|
202
208
|
for layer in layers:
|
203
209
|
|
@@ -205,8 +211,15 @@ class MemoryAsContextTransformer(Module):
|
|
205
211
|
|
206
212
|
mem = None
|
207
213
|
|
208
|
-
if
|
209
|
-
|
214
|
+
if layer in neural_memory_layers:
|
215
|
+
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
|
+
|
217
|
+
mem = NeuralMemory(
|
218
|
+
dim = dim,
|
219
|
+
chunk_size = num_longterm_mem_tokens + segment_len,
|
220
|
+
**neural_memory_kwargs
|
221
|
+
)
|
222
|
+
|
210
223
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
211
224
|
|
212
225
|
self.neural_mem_layers.append(mem)
|
@@ -258,7 +271,7 @@ class MemoryAsContextTransformer(Module):
|
|
258
271
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
259
272
|
|
260
273
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
261
|
-
x =
|
274
|
+
x, mem_ps = pack((x, mems), 'b * d')
|
262
275
|
|
263
276
|
x = inverse_segment(x)
|
264
277
|
|
@@ -275,21 +288,7 @@ class MemoryAsContextTransformer(Module):
|
|
275
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
276
289
|
|
277
290
|
if exists(maybe_neural_mem):
|
278
|
-
|
279
|
-
|
280
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
281
|
-
|
282
|
-
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
283
|
-
|
284
|
-
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
285
|
-
|
286
|
-
longterm_mems = maybe_neural_mem(longterm_mems)
|
287
|
-
|
288
|
-
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
289
|
-
|
290
|
-
x = cat((longterm_mems, x), dim = -2)
|
291
|
-
|
292
|
-
x = inverse_segment(x)
|
291
|
+
mems = maybe_neural_mem(mems)
|
293
292
|
|
294
293
|
x = attn(x)
|
295
294
|
|
@@ -301,7 +300,7 @@ class MemoryAsContextTransformer(Module):
|
|
301
300
|
|
302
301
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
303
302
|
|
304
|
-
x = x
|
303
|
+
x, mem = unpack(x, mem_ps, 'b * d')
|
305
304
|
|
306
305
|
x = inverse_segment(x)
|
307
306
|
|
@@ -24,11 +24,13 @@ SHOULD_GENERATE = False
|
|
24
24
|
SEQ_LEN = 512
|
25
25
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
-
WANDB_ONLINE =
|
28
|
-
GLOBAL_LAYERS = (2, 4)
|
27
|
+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
|
29
28
|
NEURAL_MEMORY_DEPTH = 2
|
30
|
-
|
31
|
-
|
29
|
+
NUM_PERSIST_MEM = 4
|
30
|
+
NUM_LONGTERM_MEM = 4
|
31
|
+
NEURAL_MEM_LAYERS = (4,)
|
32
|
+
WINDOW_SIZE = 32
|
33
|
+
RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
|
32
34
|
|
33
35
|
# wandb experiment tracker
|
34
36
|
|
@@ -57,12 +59,14 @@ model = MemoryAsContextTransformer(
|
|
57
59
|
dim = 384,
|
58
60
|
depth = 8,
|
59
61
|
segment_len = WINDOW_SIZE,
|
60
|
-
num_persist_mem_tokens =
|
61
|
-
num_longterm_mem_tokens =
|
62
|
-
neural_memory_layers =
|
62
|
+
num_persist_mem_tokens = NUM_PERSIST_MEM,
|
63
|
+
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
64
|
+
neural_memory_layers = NEURAL_MEM_LAYERS,
|
63
65
|
neural_memory_kwargs = dict(
|
66
|
+
dim_head = 64,
|
67
|
+
heads = 4,
|
64
68
|
default_mlp_kwargs = dict(
|
65
|
-
depth = NEURAL_MEMORY_DEPTH
|
69
|
+
depth = NEURAL_MEMORY_DEPTH,
|
66
70
|
)
|
67
71
|
)
|
68
72
|
).cuda()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|