titans-pytorch 0.0.36__tar.gz → 0.0.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/PKG-INFO +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/pyproject.toml +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/titans_pytorch/mac_transformer.py +10 -19
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/titans_pytorch/titans.py +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/train_mac.py +5 -5
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/.gitignore +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/LICENSE +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/data/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/fig1.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/fig2.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/requirements.txt +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.37}/train.py +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange
|
10
|
+
from einops import repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
|
|
214
214
|
if layer in neural_memory_layers:
|
215
215
|
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
216
|
|
217
|
-
mem = NeuralMemory(
|
217
|
+
mem = NeuralMemory(
|
218
|
+
dim = dim,
|
219
|
+
chunk_size = num_longterm_mem_tokens + segment_len,
|
220
|
+
**neural_memory_kwargs
|
221
|
+
)
|
222
|
+
|
218
223
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
219
224
|
|
220
225
|
self.neural_mem_layers.append(mem)
|
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
|
|
266
271
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
267
272
|
|
268
273
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
269
|
-
x =
|
274
|
+
x, mem_ps = pack((x, mems), 'b * d')
|
270
275
|
|
271
276
|
x = inverse_segment(x)
|
272
277
|
|
@@ -283,21 +288,7 @@ class MemoryAsContextTransformer(Module):
|
|
283
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
284
289
|
|
285
290
|
if exists(maybe_neural_mem):
|
286
|
-
|
287
|
-
|
288
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
289
|
-
|
290
|
-
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
291
|
-
|
292
|
-
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
293
|
-
|
294
|
-
longterm_mems = maybe_neural_mem(longterm_mems)
|
295
|
-
|
296
|
-
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
297
|
-
|
298
|
-
x = cat((longterm_mems, x), dim = -2)
|
299
|
-
|
300
|
-
x = inverse_segment(x)
|
291
|
+
mems = maybe_neural_mem(mems)
|
301
292
|
|
302
293
|
x = attn(x)
|
303
294
|
|
@@ -309,7 +300,7 @@ class MemoryAsContextTransformer(Module):
|
|
309
300
|
|
310
301
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
311
302
|
|
312
|
-
x = x
|
303
|
+
x, mem = unpack(x, mem_ps, 'b * d')
|
313
304
|
|
314
305
|
x = inverse_segment(x)
|
315
306
|
|
@@ -24,13 +24,13 @@ SHOULD_GENERATE = False
|
|
24
24
|
SEQ_LEN = 512
|
25
25
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
-
WANDB_ONLINE =
|
27
|
+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
29
29
|
NUM_PERSIST_MEM = 4
|
30
30
|
NUM_LONGTERM_MEM = 4
|
31
|
-
NEURAL_MEM_LAYERS = (
|
31
|
+
NEURAL_MEM_LAYERS = (4,)
|
32
32
|
WINDOW_SIZE = 32
|
33
|
-
RUN_NAME = 'mac - 4 longterm mems, layers (
|
33
|
+
RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
|
34
34
|
|
35
35
|
# wandb experiment tracker
|
36
36
|
|
@@ -63,10 +63,10 @@ model = MemoryAsContextTransformer(
|
|
63
63
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
64
64
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
65
65
|
neural_memory_kwargs = dict(
|
66
|
+
dim_head = 64,
|
67
|
+
heads = 4,
|
66
68
|
default_mlp_kwargs = dict(
|
67
69
|
depth = NEURAL_MEMORY_DEPTH,
|
68
|
-
dim_head = 64,
|
69
|
-
heads = 4
|
70
70
|
)
|
71
71
|
)
|
72
72
|
).cuda()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|