titans-pytorch 0.1.1__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/PKG-INFO +1 -1
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/pyproject.toml +1 -1
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/tests/test_titans.py +10 -3
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/titans_pytorch/mac_transformer.py +5 -2
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/titans_pytorch/titans.py +43 -4
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/.gitignore +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/LICENSE +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/README.md +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/data/README.md +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/fig1.png +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/fig2.png +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.1 → titans_pytorch-0.1.5}/train_mac.py +0 -0
|
@@ -11,12 +11,14 @@ def exists(v):
|
|
|
11
11
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
12
12
|
@pytest.mark.parametrize('silu', (False, True))
|
|
13
13
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
|
14
|
+
@pytest.mark.parametrize('attn_pool_chunks', (False, True))
|
|
14
15
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
15
16
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
|
16
17
|
def test_titans(
|
|
17
18
|
seq_len,
|
|
18
19
|
silu,
|
|
19
20
|
learned_mem_model_weights,
|
|
21
|
+
attn_pool_chunks,
|
|
20
22
|
max_grad_norm,
|
|
21
23
|
per_parameter_lr_modulation
|
|
22
24
|
):
|
|
@@ -24,6 +26,7 @@ def test_titans(
|
|
|
24
26
|
dim = 384,
|
|
25
27
|
chunk_size = 64,
|
|
26
28
|
activation = nn.SiLU() if silu else None,
|
|
29
|
+
attn_pool_chunks = attn_pool_chunks,
|
|
27
30
|
max_grad_norm = max_grad_norm,
|
|
28
31
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
|
29
32
|
learned_mem_model_weights = learned_mem_model_weights
|
|
@@ -50,10 +53,12 @@ def test_titans_attn_memory():
|
|
|
50
53
|
|
|
51
54
|
assert seq.shape == retrieved.shape
|
|
52
55
|
|
|
56
|
+
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
53
57
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
|
54
58
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
|
55
59
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
|
56
60
|
def test_mac(
|
|
61
|
+
seq_len,
|
|
57
62
|
num_persist_mem_tokens,
|
|
58
63
|
num_longterm_mem_tokens,
|
|
59
64
|
neural_mem_gate_attn_output
|
|
@@ -70,13 +75,15 @@ def test_mac(
|
|
|
70
75
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output
|
|
71
76
|
)
|
|
72
77
|
|
|
73
|
-
x = torch.randint(0, 256, (1,
|
|
78
|
+
x = torch.randint(0, 256, (1, seq_len))
|
|
74
79
|
|
|
75
80
|
logits = transformer(x)
|
|
76
|
-
assert logits.shape == (1,
|
|
81
|
+
assert logits.shape == (1, seq_len, 256)
|
|
77
82
|
|
|
83
|
+
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
78
84
|
@pytest.mark.parametrize('sliding', (True, False))
|
|
79
85
|
def test_flex(
|
|
86
|
+
seq_len,
|
|
80
87
|
sliding
|
|
81
88
|
):
|
|
82
89
|
if not (torch.cuda.is_available() and exists(flex_attention)):
|
|
@@ -91,7 +98,7 @@ def test_flex(
|
|
|
91
98
|
sliding = sliding
|
|
92
99
|
).cuda()
|
|
93
100
|
|
|
94
|
-
seq = torch.randn(1,
|
|
101
|
+
seq = torch.randn(1, seq_len, 512).cuda()
|
|
95
102
|
|
|
96
103
|
out_flex, _ = attn(seq)
|
|
97
104
|
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
|
@@ -528,7 +528,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
528
528
|
filter_fn: Callable = min_p_filter,
|
|
529
529
|
filter_kwargs: dict = dict(
|
|
530
530
|
min_p = 0.1,
|
|
531
|
-
)
|
|
531
|
+
),
|
|
532
|
+
show_progress = True
|
|
532
533
|
):
|
|
533
534
|
was_training = self.training
|
|
534
535
|
self.eval()
|
|
@@ -536,7 +537,9 @@ class MemoryAsContextTransformer(Module):
|
|
|
536
537
|
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
537
538
|
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
538
539
|
|
|
539
|
-
|
|
540
|
+
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
541
|
+
|
|
542
|
+
for _ in iter_wrap(range(sample_num_times)):
|
|
540
543
|
logits = self.forward(out, disable_flex_attn = True)
|
|
541
544
|
logits = logits[:, -1]
|
|
542
545
|
|
|
@@ -18,7 +18,7 @@ from titans_pytorch.associative_scan import (
|
|
|
18
18
|
)
|
|
19
19
|
|
|
20
20
|
import einx
|
|
21
|
-
from einops import rearrange, repeat, pack, unpack
|
|
21
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
|
22
22
|
from einops.layers.torch import Rearrange, Reduce
|
|
23
23
|
|
|
24
24
|
"""
|
|
@@ -95,6 +95,37 @@ class MultiheadRMSNorm(Module):
|
|
|
95
95
|
def forward(self, x):
|
|
96
96
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
97
97
|
|
|
98
|
+
# attention pool
|
|
99
|
+
|
|
100
|
+
class AttentionPool(Module):
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
dim,
|
|
104
|
+
chunk_size
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
108
|
+
"""
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
|
|
111
|
+
self.to_attn_logits = nn.Linear(dim, dim)
|
|
112
|
+
|
|
113
|
+
# default to average pool
|
|
114
|
+
|
|
115
|
+
nn.init.zeros_(self.to_attn_logits.weight)
|
|
116
|
+
nn.init.zeros_(self.to_attn_logits.bias)
|
|
117
|
+
|
|
118
|
+
def forward(
|
|
119
|
+
self,
|
|
120
|
+
x
|
|
121
|
+
):
|
|
122
|
+
x = self.split_chunks(x)
|
|
123
|
+
attn_logits = self.to_attn_logits(x)
|
|
124
|
+
|
|
125
|
+
attn = attn_logits.softmax(dim = -2)
|
|
126
|
+
|
|
127
|
+
return reduce(x * attn, 'b n c d -> b n d', 'sum')
|
|
128
|
+
|
|
98
129
|
# classes
|
|
99
130
|
|
|
100
131
|
class MemoryMLP(Module):
|
|
@@ -224,6 +255,7 @@ class NeuralMemory(Module):
|
|
|
224
255
|
default_step_transform_max_lr = 1e-2,
|
|
225
256
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
226
257
|
max_mem_layer_modulation = 1e1, # max of 10.
|
|
258
|
+
attn_pool_chunks = False,
|
|
227
259
|
pre_rmsnorm = True,
|
|
228
260
|
post_rmsnorm = True,
|
|
229
261
|
learned_mem_model_weights = True,
|
|
@@ -304,10 +336,17 @@ class NeuralMemory(Module):
|
|
|
304
336
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
|
305
337
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
306
338
|
|
|
339
|
+
# whether to use averaging of chunks, or attention pooling
|
|
340
|
+
|
|
341
|
+
if not attn_pool_chunks:
|
|
342
|
+
chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
|
|
343
|
+
else:
|
|
344
|
+
chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
|
|
345
|
+
|
|
307
346
|
# learned adaptive learning rate and momentum
|
|
308
347
|
|
|
309
348
|
self.to_momentum = Sequential(
|
|
310
|
-
|
|
349
|
+
chunk_reduce_module,
|
|
311
350
|
LinearNoBias(dim, heads),
|
|
312
351
|
Rearrange('b n h -> (b h) n 1')
|
|
313
352
|
)
|
|
@@ -325,7 +364,7 @@ class NeuralMemory(Module):
|
|
|
325
364
|
# per layer learning rate modulation
|
|
326
365
|
|
|
327
366
|
self.to_layer_modulation = Sequential(
|
|
328
|
-
|
|
367
|
+
chunk_reduce_module,
|
|
329
368
|
LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
|
|
330
369
|
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
|
331
370
|
nn.Sigmoid()
|
|
@@ -340,7 +379,7 @@ class NeuralMemory(Module):
|
|
|
340
379
|
# weight decay factor
|
|
341
380
|
|
|
342
381
|
self.to_decay_factor = Sequential(
|
|
343
|
-
|
|
382
|
+
chunk_reduce_module,
|
|
344
383
|
LinearNoBias(dim, heads),
|
|
345
384
|
Rearrange('b n h -> (b h) n 1')
|
|
346
385
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|