titans-pytorch 0.0.36__tar.gz → 0.0.38__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/PKG-INFO +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/pyproject.toml +1 -1
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/mac_transformer.py +10 -18
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans.py +9 -6
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train_mac.py +5 -5
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/.gitignore +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/LICENSE +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/README.md +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig1.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/fig2.png +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/requirements.txt +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.36 → titans_pytorch-0.0.38}/train.py +0 -0
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange
|
10
|
+
from einops import repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -214,7 +214,12 @@ class MemoryAsContextTransformer(Module):
|
|
214
214
|
if layer in neural_memory_layers:
|
215
215
|
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
216
|
|
217
|
-
mem = NeuralMemory(
|
217
|
+
mem = NeuralMemory(
|
218
|
+
dim = dim,
|
219
|
+
chunk_size = num_longterm_mem_tokens + segment_len,
|
220
|
+
**neural_memory_kwargs
|
221
|
+
)
|
222
|
+
|
218
223
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
219
224
|
|
220
225
|
self.neural_mem_layers.append(mem)
|
@@ -266,7 +271,7 @@ class MemoryAsContextTransformer(Module):
|
|
266
271
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
267
272
|
|
268
273
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
269
|
-
x =
|
274
|
+
x, mem_ps = pack((x, mems), 'b * d')
|
270
275
|
|
271
276
|
x = inverse_segment(x)
|
272
277
|
|
@@ -283,21 +288,8 @@ class MemoryAsContextTransformer(Module):
|
|
283
288
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
284
289
|
|
285
290
|
if exists(maybe_neural_mem):
|
286
|
-
|
287
|
-
|
288
|
-
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
289
|
-
|
290
|
-
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
291
|
-
|
292
|
-
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
293
|
-
|
294
|
-
longterm_mems = maybe_neural_mem(longterm_mems)
|
295
|
-
|
296
|
-
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
297
|
-
|
298
|
-
x = cat((longterm_mems, x), dim = -2)
|
291
|
+
x = maybe_neural_mem(x)
|
299
292
|
|
300
|
-
x = inverse_segment(x)
|
301
293
|
|
302
294
|
x = attn(x)
|
303
295
|
|
@@ -309,7 +301,7 @@ class MemoryAsContextTransformer(Module):
|
|
309
301
|
|
310
302
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
311
303
|
|
312
|
-
x = x
|
304
|
+
x, _ = unpack(x, mem_ps, 'b * d')
|
313
305
|
|
314
306
|
x = inverse_segment(x)
|
315
307
|
|
@@ -27,9 +27,7 @@ n - sequence
|
|
27
27
|
d - feature dimension
|
28
28
|
c - intra-chunk
|
29
29
|
"""
|
30
|
-
|
31
|
-
# constants
|
32
|
-
|
30
|
+
7
|
33
31
|
LinearNoBias = partial(Linear, bias = False)
|
34
32
|
|
35
33
|
# functions
|
@@ -132,7 +130,7 @@ class NeuralMemory(Module):
|
|
132
130
|
max_grad_norm: float | None = None,
|
133
131
|
use_accelerated_scan = False,
|
134
132
|
default_mlp_kwargs: dict = dict(
|
135
|
-
depth =
|
133
|
+
depth = 2
|
136
134
|
)
|
137
135
|
):
|
138
136
|
super().__init__()
|
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
|
|
390
388
|
|
391
389
|
padding = next_seq_len - curtailed_seq_len
|
392
390
|
|
393
|
-
|
391
|
+
needs_pad = padding > 0
|
392
|
+
|
393
|
+
if needs_pad:
|
394
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
394
395
|
|
395
396
|
# the parameters of the memory model stores the memories of the key / values
|
396
397
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
|
|
442
443
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
443
444
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
444
445
|
|
445
|
-
|
446
|
+
if needs_pad:
|
447
|
+
values = values[:, :-padding]
|
448
|
+
|
446
449
|
return values
|
447
450
|
|
448
451
|
def forward(
|
@@ -24,13 +24,13 @@ SHOULD_GENERATE = False
|
|
24
24
|
SEQ_LEN = 512
|
25
25
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
-
WANDB_ONLINE =
|
27
|
+
WANDB_ONLINE = True # turn this on to pipe experiment to cloud
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
29
29
|
NUM_PERSIST_MEM = 4
|
30
30
|
NUM_LONGTERM_MEM = 4
|
31
|
-
NEURAL_MEM_LAYERS = (
|
31
|
+
NEURAL_MEM_LAYERS = (4,)
|
32
32
|
WINDOW_SIZE = 32
|
33
|
-
RUN_NAME = 'mac - 4 longterm mems, layers (
|
33
|
+
RUN_NAME = 'mac - 4 longterm mems, layers (4,)'
|
34
34
|
|
35
35
|
# wandb experiment tracker
|
36
36
|
|
@@ -63,10 +63,10 @@ model = MemoryAsContextTransformer(
|
|
63
63
|
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
64
64
|
neural_memory_layers = NEURAL_MEM_LAYERS,
|
65
65
|
neural_memory_kwargs = dict(
|
66
|
+
dim_head = 64,
|
67
|
+
heads = 4,
|
66
68
|
default_mlp_kwargs = dict(
|
67
69
|
depth = NEURAL_MEMORY_DEPTH,
|
68
|
-
dim_head = 64,
|
69
|
-
heads = 4
|
70
70
|
)
|
71
71
|
)
|
72
72
|
).cuda()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|