titans-pytorch 0.0.27__tar.gz → 0.0.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/PKG-INFO +3 -1
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/pyproject.toml +3 -1
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/tests/test_titans.py +16 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/titans_pytorch/mac_transformer.py +29 -2
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/.gitignore +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/LICENSE +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/README.md +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/data/README.md +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/fig1.png +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/fig2.png +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/requirements.txt +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.27 → titans_pytorch-0.0.29}/train.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.29
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "titans-pytorch"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.29"
|
4
4
|
description = "Titans"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -26,10 +26,12 @@ classifiers=[
|
|
26
26
|
|
27
27
|
dependencies = [
|
28
28
|
"accelerated-scan>=0.2.0",
|
29
|
+
"axial_positional_embedding>=0.3.5",
|
29
30
|
"einx>=0.3.0",
|
30
31
|
"einops>=0.8.0",
|
31
32
|
"hyper-connections>=0.1.8",
|
32
33
|
"Ninja",
|
34
|
+
"rotary-embedding-torch",
|
33
35
|
"tensordict",
|
34
36
|
"torch>=2.2",
|
35
37
|
]
|
@@ -33,3 +33,19 @@ def test_titans_attn_memory():
|
|
33
33
|
retrieved = mem(seq)
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
|
+
|
37
|
+
def test_mac():
|
38
|
+
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
|
+
|
40
|
+
transformer = MemoryAsContextTransformer(
|
41
|
+
num_tokens = 256,
|
42
|
+
dim = 256,
|
43
|
+
depth = 2,
|
44
|
+
num_persist_mem_tokens = 16,
|
45
|
+
segment_len = 128,
|
46
|
+
)
|
47
|
+
|
48
|
+
x = torch.randint(0, 256, (1, 1023))
|
49
|
+
|
50
|
+
logits = transformer(x)
|
51
|
+
assert logits.shape == (1, 1023, 256)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
@@ -10,8 +10,14 @@ from torch.nn import Module, ModuleList, Linear
|
|
10
10
|
from einops import repeat
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
+
|
13
14
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
14
15
|
|
16
|
+
# absolute and relative positions
|
17
|
+
|
18
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
19
|
+
from rotary_embedding_torch import RotaryEmbedding
|
20
|
+
|
15
21
|
# constants
|
16
22
|
|
17
23
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -25,7 +31,7 @@ def default(v, d):
|
|
25
31
|
return v if exists(v) else d
|
26
32
|
|
27
33
|
def round_up_multiple(seq, mult):
|
28
|
-
return
|
34
|
+
return ceil(seq / mult) * mult
|
29
35
|
|
30
36
|
# feedforward and attention
|
31
37
|
|
@@ -58,6 +64,8 @@ class SegmentedAttention(Module):
|
|
58
64
|
|
59
65
|
dim_inner = dim_head * heads
|
60
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
61
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
62
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
63
71
|
|
@@ -99,6 +107,12 @@ class SegmentedAttention(Module):
|
|
99
107
|
|
100
108
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
109
|
|
110
|
+
# relative positions
|
111
|
+
|
112
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
113
|
+
|
114
|
+
# persistent memory
|
115
|
+
|
102
116
|
k = cat((pmk, k), dim = -2)
|
103
117
|
v = cat((pmv, v), dim = -2)
|
104
118
|
|
@@ -133,6 +147,9 @@ class MemoryAsContextTransformer(Module):
|
|
133
147
|
):
|
134
148
|
super().__init__()
|
135
149
|
|
150
|
+
self.segment_len = segment_len
|
151
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
|
+
|
136
153
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
137
154
|
|
138
155
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
@@ -160,9 +177,19 @@ class MemoryAsContextTransformer(Module):
|
|
160
177
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
161
178
|
|
162
179
|
def forward(self, x):
|
180
|
+
seq_len, segment_len = x.shape[-1], self.segment_len
|
181
|
+
windows = ceil(seq_len / segment_len)
|
163
182
|
|
164
183
|
x = self.token_emb(x)
|
165
184
|
|
185
|
+
# apply axial positional embedding
|
186
|
+
# so intra and inter segment can be more easily discerned by the network
|
187
|
+
|
188
|
+
pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
|
189
|
+
x = x + pos_emb[:seq_len]
|
190
|
+
|
191
|
+
# expand and reduce streams for hyper connections
|
192
|
+
|
166
193
|
x = self.expand_streams(x)
|
167
194
|
|
168
195
|
for attn, ff in self.layers:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|