titans-pytorch 0.2.4__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/PKG-INFO +2 -2
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/README.md +1 -1
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/pyproject.toml +1 -1
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/titans_pytorch/mac_transformer.py +1 -1
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/titans_pytorch/memory_models.py +32 -14
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/titans_pytorch/neural_memory.py +21 -4
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/train_mac.py +8 -6
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/.gitignore +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/LICENSE +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/data/README.md +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/data/enwik8.gz +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/fig1.png +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/fig2.png +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/tests/test_titans.py +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.2.4 → titans_pytorch-0.2.6}/titans_pytorch/associative_scan.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.6
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch (wip)
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
<img src="./fig1.png" width="400px"></img>
|
4
4
|
|
5
|
-
## Titans - Pytorch
|
5
|
+
## Titans - Pytorch (wip)
|
6
6
|
|
7
7
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
8
|
|
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
|
|
488
488
|
neural_memory_model: Module | None = None,
|
489
489
|
neural_memory_kwargs: dict = dict(),
|
490
490
|
neural_memory_layers: tuple[int, ...] | None = None,
|
491
|
-
aux_kv_recon_loss_weight =
|
491
|
+
aux_kv_recon_loss_weight = 1.,
|
492
492
|
use_flex_attn = False,
|
493
493
|
sliding_window_attn = False,
|
494
494
|
weight_tie_memory_model = False,
|
@@ -1,8 +1,15 @@
|
|
1
1
|
import torch
|
2
|
-
from torch import nn
|
2
|
+
from torch import nn, cat
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
5
5
|
|
6
|
+
# functions
|
7
|
+
|
8
|
+
def l2norm(t):
|
9
|
+
return F.normalize(t, dim = -1)
|
10
|
+
|
11
|
+
# memory mlp proposed in TTT
|
12
|
+
|
6
13
|
class MemoryMLP(Module):
|
7
14
|
def __init__(
|
8
15
|
self,
|
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
|
|
19
26
|
self,
|
20
27
|
x
|
21
28
|
):
|
29
|
+
residual = x
|
30
|
+
|
22
31
|
for ind, weight in enumerate(self.weights):
|
23
32
|
is_first = ind == 0
|
24
33
|
|
25
34
|
if not is_first:
|
26
|
-
x = F.
|
35
|
+
x = F.gelu(x)
|
27
36
|
|
28
37
|
x = x @ weight
|
29
38
|
|
30
|
-
return x
|
39
|
+
return x + residual
|
31
40
|
|
32
41
|
# memory mlp, but with gated residual + final projection
|
33
42
|
|
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
|
|
36
45
|
self,
|
37
46
|
dim,
|
38
47
|
depth,
|
39
|
-
expansion_factor =
|
48
|
+
expansion_factor = 4.
|
40
49
|
):
|
41
50
|
super().__init__()
|
42
51
|
dim_hidden = int(dim * expansion_factor)
|
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
|
|
58
67
|
self,
|
59
68
|
x
|
60
69
|
):
|
70
|
+
residual = x
|
71
|
+
|
61
72
|
for weight1, weight2, to_gates in self.weights:
|
62
73
|
res = x
|
63
74
|
|
64
75
|
hidden = x @ weight1
|
65
|
-
hidden = F.
|
76
|
+
hidden = F.gelu(hidden)
|
66
77
|
branch_out = hidden @ weight2
|
67
78
|
|
68
79
|
# gated residual
|
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
|
|
70
81
|
gates = cat((branch_out, res), dim = -1) @ to_gates
|
71
82
|
x = res.lerp(branch_out, gates.sigmoid())
|
72
83
|
|
73
|
-
|
84
|
+
out = x @ self.final_proj
|
85
|
+
|
86
|
+
return out + residual
|
74
87
|
|
75
88
|
# memory mlp with factorized weights
|
76
89
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
|
|
98
111
|
self,
|
99
112
|
x
|
100
113
|
):
|
114
|
+
residual = x
|
115
|
+
|
101
116
|
for ind, (weight1, weight2) in enumerate(self.weights):
|
102
117
|
is_first = ind == 0
|
103
118
|
|
104
119
|
if not is_first:
|
105
|
-
x = F.
|
120
|
+
x = F.gelu(x)
|
106
121
|
|
107
122
|
x = x @ weight1 @ weight2
|
108
123
|
|
109
|
-
return x
|
124
|
+
return x + residual
|
110
125
|
|
111
126
|
# improvised attention as memory module
|
112
127
|
|
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
|
|
133
148
|
nn.init.xavier_uniform_(weight)
|
134
149
|
|
135
150
|
def forward(self, x):
|
151
|
+
residual = x
|
152
|
+
|
136
153
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
137
154
|
|
138
|
-
q =
|
139
|
-
k =
|
155
|
+
q = l2norm(x @ wq)
|
156
|
+
k = l2norm(x @ wk)
|
140
157
|
v = x @ wv
|
141
158
|
|
142
159
|
attn_out = F.scaled_dot_product_attention(
|
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
|
|
145
162
|
is_causal = True
|
146
163
|
)
|
147
164
|
|
148
|
-
|
165
|
+
# parallel attention + feedforward block
|
166
|
+
# as in PaLM + Gpt-J
|
149
167
|
|
150
|
-
h = F.
|
151
|
-
|
168
|
+
h = F.gelu(x @ ffw1)
|
169
|
+
ff_out = h @ ffw2
|
152
170
|
|
153
|
-
return
|
171
|
+
return attn_out + ff_out + residual
|
@@ -304,13 +304,26 @@ class NeuralMemory(Module):
|
|
304
304
|
nn.Sigmoid()
|
305
305
|
) if heads > 1 else None
|
306
306
|
|
307
|
-
# memory
|
307
|
+
# memory model
|
308
308
|
|
309
309
|
if not exists(model):
|
310
310
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
311
311
|
|
312
|
+
# validate memory model
|
313
|
+
|
312
314
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
313
315
|
|
316
|
+
test_shape = (3, 2, dim_head)
|
317
|
+
|
318
|
+
with torch.no_grad():
|
319
|
+
try:
|
320
|
+
test_input = torch.randn(test_shape)
|
321
|
+
mem_model_output = model(test_input)
|
322
|
+
except:
|
323
|
+
raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
|
324
|
+
|
325
|
+
assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
|
326
|
+
|
314
327
|
# the memory is the weights of the model
|
315
328
|
|
316
329
|
self.memory_model = model
|
@@ -772,10 +785,14 @@ class NeuralMemory(Module):
|
|
772
785
|
|
773
786
|
# retrieve
|
774
787
|
|
775
|
-
|
776
|
-
|
788
|
+
retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
789
|
+
|
790
|
+
if retrieve_chunk_size != 1:
|
791
|
+
if exists(prev_layer_updates):
|
792
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
793
|
+
|
794
|
+
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
777
795
|
|
778
|
-
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
779
796
|
|
780
797
|
retrieved = self.retrieve_memories(
|
781
798
|
seq,
|
@@ -30,17 +30,18 @@ SEQ_LEN = 512
|
|
30
30
|
NEURAL_MEMORY_DEPTH = 2
|
31
31
|
NUM_PERSIST_MEM = 4
|
32
32
|
NUM_LONGTERM_MEM = 4
|
33
|
-
NEURAL_MEM_LAYERS = (2, 4, 6)
|
33
|
+
NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
|
34
34
|
NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
35
35
|
NEURAL_MEM_MOMENTUM = True
|
36
|
-
NEURAL_MEM_QK_NORM =
|
36
|
+
NEURAL_MEM_QK_NORM = True
|
37
37
|
WINDOW_SIZE = 32
|
38
|
-
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2
|
38
|
+
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
39
39
|
SLIDING_WINDOWS = True
|
40
|
-
WEIGHT_TIE_MEMORY_MODEL =
|
41
|
-
|
40
|
+
WEIGHT_TIE_MEMORY_MODEL = False # set to have memory MLP shared across layers
|
41
|
+
PREV_MEM_UPDATE_FOR_WEIGHTS = True,
|
42
|
+
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
42
43
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
43
|
-
KV_RECON_LOSS_WEIGHT =
|
44
|
+
KV_RECON_LOSS_WEIGHT = 1.
|
44
45
|
|
45
46
|
# experiment related
|
46
47
|
|
@@ -90,6 +91,7 @@ model = MemoryAsContextTransformer(
|
|
90
91
|
use_flex_attn = USE_FLEX_ATTN,
|
91
92
|
sliding_window_attn = SLIDING_WINDOWS,
|
92
93
|
weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
|
94
|
+
prev_neural_mem_update_for_weights = PREV_MEM_UPDATE_FOR_WEIGHTS,
|
93
95
|
neural_memory_model = MemoryMLP(
|
94
96
|
dim = 64,
|
95
97
|
depth = NEURAL_MEMORY_DEPTH
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|