titans-pytorch 0.0.36__tar.gz → 0.0.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/PKG-INFO +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/pyproject.toml +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/mac_transformer.py +10 -18
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans.py +9 -6
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train_mac.py +5 -5
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.gitignore +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/LICENSE +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig1.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig2.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/requirements.txt +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train.py +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange
|
10
|
+
from einops import repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
|
|
214
214
|
if layer in neural_memory_layers:
|
215
215
|
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
216
|
|
217
|
-
mem = NeuralMemory(
|
217
|
+
mem = NeuralMemory(
|
218
|
+
dim = dim,
|
219
|
+
chunk_size = num_longterm_mem_tokens + segment_len,
|
220
|
+
**neural_memory_kwargs
|
221
|
+
)
|
222
|
+
|
218
223
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
219
224
|
|
220
225
|
self.neural_mem_layers.append(mem)
|
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
|
|
266
271
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
267
272
|
|
268
273
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
269
|
-
x =
|
274
|
+
x, mem_ps = pack((x, mems), 'b * d')
|
270
275
|
|
271
276
|
x = inverse_segment(x)
|
272
277
|
|
@@ -283,21 +288,8 @@ class MemoryAsContextTransformer(Module):
|
|
283
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
284
289
|
|
285
290
|
if exists(maybe_neural_mem):
|
286
|
-
|
287
|
-
|
288
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
289
|
-
|
290
|
-
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
291
|
-
|
292
|
-
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
293
|
-
|
294
|
-
longterm_mems = maybe_neural_mem(longterm_mems)
|
295
|
-
|
296
|
-
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
297
|
-
|
298
|
-
x = cat((longterm_mems, x), dim = -2)
|
291
|
+
x = maybe_neural_mem(x)
|
299
292
|
|
300
|
-
x = inverse_segment(x)
|
301
293
|
|
302
294
|
x = attn(x)
|
303
295
|
|
@@ -309,7 +301,7 @@ class MemoryAsContextTransformer(Module):
|
|
309
301
|
|
310
302
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
311
303
|
|
312
|
-
x = x
|
304
|
+
x, _ = unpack(x, mem_ps, 'b * d')
|
313
305
|
|
314
306
|
x = inverse_segment(x)
|
315
307
|
|
@@ -27,9 +27,7 @@ n - sequence
|
|
27
27
|
d - feature dimension
|
28
28
|
c - intra-chunk
|
29
29
|
"""
|
30
|
-
|
31
|
-
# constants
|
32
|
-
|
30
|
+
7
|
33
31
|
LinearNoBias = partial(Linear, bias = False)
|
34
32
|
|
35
33
|
# functions
|
@@ -132,7 +130,7 @@ class NeuralMemory(Module):
|
|
132
130
|
max_grad_norm: float | None = None,
|
133
131
|
use_accelerated_scan = False,
|
134
132
|
default_mlp_kwargs: dict = dict(
|
135
|
-
depth =
|
133
|
+
depth = 2
|
136
134
|
)
|
137
135
|
):
|
138
136
|
super().__init__()
|
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
|
|
390
388
|
|
391
389
|
padding = next_seq_len - curtailed_seq_len
|
392
390
|
|
393
|
-
|
391
|
+
needs_pad = padding > 0
|
392
|
+
|
393
|
+
if needs_pad:
|
394
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
394
395
|
|
395
396
|
# the parameters of the memory model stores the memories of the key / values
|
396
397
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
|
|
442
443
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
443
444
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
444
445
|
|
445
|
-
|
446
|
+
if needs_pad:
|
447
|
+
values = values[:, :-padding]
|
448
|
+
|
446
449
|
return values
|
447
450
|
|
448
451
|
def forward(
|
@@ -24,13 +24,13 @@ SHOULD_GENERATE = False
|
|
24
24
|
SEQ_LEN = 512
|
25
25
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
-
WANDB_ONLINE =
|
27
|
+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
29
29
|
NUM_PERSIST_MEM = 4
|
30
30
|
NUM_LONGTERM_MEM = 4
|
31
|
-
NEURAL_MEM_LAYERS = (
|
31
|
+
NEURAL_MEM_LAYERS = (4,)
|
32
32
|
WINDOW_SIZE = 32
|
33
|
-
RUN_NAME = 'mac - 4 longterm mems, layers (
|
33
|
+
RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
|
34
34
|
|
35
35
|
# wandb experiment tracker
|
36
36
|
|
@@ -63,10 +63,10 @@ model = MemoryAsContextTransformer(
|
|
63
63
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
64
64
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
65
65
|
neural_memory_kwargs = dict(
|
66
|
+
dim_head = 64,
|
67
|
+
heads = 4,
|
66
68
|
default_mlp_kwargs = dict(
|
67
69
|
depth = NEURAL_MEMORY_DEPTH,
|
68
|
-
dim_head = 64,
|
69
|
-
heads = 4
|
70
70
|
)
|
71
71
|
)
|
72
72
|
).cuda()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|