titans-pytorch 0.0.38__tar.gz → 0.0.40__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/PKG-INFO +2 -1
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/pyproject.toml +2 -1
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/titans_pytorch/mac_transformer.py +39 -6
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/train_mac.py +1 -1
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/.gitignore +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/LICENSE +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/README.md +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/data/README.md +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/fig1.png +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/fig2.png +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/requirements.txt +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.38 → titans_pytorch-0.0.40}/train.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.40
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -43,6 +43,7 @@ Requires-Dist: ninja
|
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
|
44
44
|
Requires-Dist: tensordict
|
|
45
45
|
Requires-Dist: torch>=2.2
|
|
46
|
+
Requires-Dist: x-transformers
|
|
46
47
|
Provides-Extra: examples
|
|
47
48
|
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
|
48
49
|
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.40"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -34,6 +34,7 @@ dependencies = [
|
|
|
34
34
|
"rotary-embedding-torch",
|
|
35
35
|
"tensordict",
|
|
36
36
|
"torch>=2.2",
|
|
37
|
+
"x-transformers"
|
|
37
38
|
]
|
|
38
39
|
|
|
39
40
|
[project.urls]
|
|
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
|
9
9
|
|
|
10
|
-
from einops import repeat, rearrange, pack, unpack
|
|
10
|
+
from einops import einsum, repeat, rearrange, pack, unpack
|
|
11
11
|
from einops.layers.torch import Rearrange
|
|
12
12
|
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
|
16
16
|
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
|
19
|
+
from x_transformers.attend import Attend
|
|
19
20
|
|
|
20
21
|
# proposed neural memory
|
|
21
22
|
|
|
@@ -93,6 +94,8 @@ class SegmentedAttention(Module):
|
|
|
93
94
|
num_longterm_mem_tokens = 0,
|
|
94
95
|
dim_head = 64,
|
|
95
96
|
heads = 8,
|
|
97
|
+
accept_value_residual = False,
|
|
98
|
+
attend_kwargs: dict = dict()
|
|
96
99
|
):
|
|
97
100
|
super().__init__()
|
|
98
101
|
self.norm = nn.RMSNorm(dim)
|
|
@@ -101,9 +104,17 @@ class SegmentedAttention(Module):
|
|
|
101
104
|
|
|
102
105
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
|
103
106
|
|
|
107
|
+
self.attend = Attend(causal = True, **attend_kwargs)
|
|
108
|
+
|
|
104
109
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
|
105
110
|
self.to_out = LinearNoBias(dim_inner, dim)
|
|
106
111
|
|
|
112
|
+
self.to_learned_v_mix = nn.Sequential(
|
|
113
|
+
nn.Linear(dim, heads),
|
|
114
|
+
Rearrange('b n h -> b h n 1'),
|
|
115
|
+
nn.Sigmoid()
|
|
116
|
+
) if accept_value_residual else None
|
|
117
|
+
|
|
107
118
|
self.segment_len = segment_len
|
|
108
119
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
|
109
120
|
|
|
@@ -114,7 +125,13 @@ class SegmentedAttention(Module):
|
|
|
114
125
|
|
|
115
126
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
|
116
127
|
|
|
117
|
-
def forward(
|
|
128
|
+
def forward(
|
|
129
|
+
self,
|
|
130
|
+
seq,
|
|
131
|
+
value_residual = None
|
|
132
|
+
):
|
|
133
|
+
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
134
|
+
|
|
118
135
|
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
|
119
136
|
total_segment_len = segment_len + num_longterm_mem_tokens
|
|
120
137
|
|
|
@@ -132,6 +149,14 @@ class SegmentedAttention(Module):
|
|
|
132
149
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
|
133
150
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
134
151
|
|
|
152
|
+
# value residual
|
|
153
|
+
|
|
154
|
+
orig_v = v
|
|
155
|
+
|
|
156
|
+
if exists(self.to_learned_v_mix):
|
|
157
|
+
mix = self.to_learned_v_mix(seq)
|
|
158
|
+
v = v.lerp(value_residual, mix)
|
|
159
|
+
|
|
135
160
|
# take care of persistent memory key / values
|
|
136
161
|
|
|
137
162
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
|
@@ -145,9 +170,9 @@ class SegmentedAttention(Module):
|
|
|
145
170
|
k = cat((pmk, k), dim = -2)
|
|
146
171
|
v = cat((pmv, v), dim = -2)
|
|
147
172
|
|
|
148
|
-
#
|
|
173
|
+
# attention
|
|
149
174
|
|
|
150
|
-
out =
|
|
175
|
+
out, _ = self.attend(q, k, v)
|
|
151
176
|
|
|
152
177
|
out = self.merge_heads(out)
|
|
153
178
|
|
|
@@ -155,7 +180,7 @@ class SegmentedAttention(Module):
|
|
|
155
180
|
|
|
156
181
|
out = inverse_segment(out)
|
|
157
182
|
|
|
158
|
-
return out
|
|
183
|
+
return out, orig_v
|
|
159
184
|
|
|
160
185
|
# MAC transformer
|
|
161
186
|
|
|
@@ -206,6 +231,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
206
231
|
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
|
207
232
|
|
|
208
233
|
for layer in layers:
|
|
234
|
+
is_first = layer == 1
|
|
209
235
|
|
|
210
236
|
# neural memory
|
|
211
237
|
|
|
@@ -231,6 +257,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
231
257
|
dim_head = dim_head,
|
|
232
258
|
heads = heads,
|
|
233
259
|
segment_len = segment_len,
|
|
260
|
+
accept_value_residual = not is_first,
|
|
234
261
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
235
262
|
num_persist_mem_tokens = num_persist_mem_tokens
|
|
236
263
|
)
|
|
@@ -281,6 +308,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
281
308
|
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
|
282
309
|
x = x + pos_emb[:x.shape[-2]]
|
|
283
310
|
|
|
311
|
+
# value residual
|
|
312
|
+
|
|
313
|
+
value_residual = None
|
|
314
|
+
|
|
284
315
|
# expand and reduce streams for hyper connections
|
|
285
316
|
|
|
286
317
|
x = self.expand_streams(x)
|
|
@@ -291,7 +322,9 @@ class MemoryAsContextTransformer(Module):
|
|
|
291
322
|
x = maybe_neural_mem(x)
|
|
292
323
|
|
|
293
324
|
|
|
294
|
-
x = attn(x)
|
|
325
|
+
x, values = attn(x, value_residual = value_residual)
|
|
326
|
+
|
|
327
|
+
value_residual = default(value_residual, values)
|
|
295
328
|
|
|
296
329
|
x = ff(x)
|
|
297
330
|
|
|
@@ -24,7 +24,7 @@ SHOULD_GENERATE = False
|
|
|
24
24
|
SEQ_LEN = 512
|
|
25
25
|
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
|
27
|
-
WANDB_ONLINE =
|
|
27
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
|
29
29
|
NUM_PERSIST_MEM = 4
|
|
30
30
|
NUM_LONGTERM_MEM = 4
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|