titans-pytorch 0.0.29__tar.gz → 0.0.31__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/PKG-INFO +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/pyproject.toml +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/tests/test_titans.py +8 -2
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/mac_transformer.py +76 -38
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/.gitignore +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/LICENSE +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/data/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/fig1.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/fig2.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/requirements.txt +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/train.py +0 -0
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
|
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
36
|
|
37
|
-
|
37
|
+
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
38
|
+
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
39
|
+
def test_mac(
|
40
|
+
num_persist_mem_tokens,
|
41
|
+
num_longterm_mem_tokens
|
42
|
+
):
|
38
43
|
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
44
|
|
40
45
|
transformer = MemoryAsContextTransformer(
|
41
46
|
num_tokens = 256,
|
42
47
|
dim = 256,
|
43
48
|
depth = 2,
|
44
|
-
num_persist_mem_tokens =
|
49
|
+
num_persist_mem_tokens = num_persist_mem_tokens,
|
50
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
45
51
|
segment_len = 128,
|
46
52
|
)
|
47
53
|
|
@@ -7,10 +7,9 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat
|
10
|
+
from einops import repeat, rearrange
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
-
|
14
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
15
14
|
|
16
15
|
# absolute and relative positions
|
@@ -30,9 +29,33 @@ def exists(v):
|
|
30
29
|
def default(v, d):
|
31
30
|
return v if exists(v) else d
|
32
31
|
|
32
|
+
def identity(t):
|
33
|
+
return t
|
34
|
+
|
33
35
|
def round_up_multiple(seq, mult):
|
34
36
|
return ceil(seq / mult) * mult
|
35
37
|
|
38
|
+
def pad_and_segment_with_inverse(seq, segment_len):
|
39
|
+
batch, seq_len = seq.shape[:2]
|
40
|
+
|
41
|
+
need_segment = seq_len >= segment_len
|
42
|
+
|
43
|
+
if not need_segment:
|
44
|
+
return seq, identity
|
45
|
+
|
46
|
+
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
47
|
+
|
48
|
+
padding = next_seq_len_mult - seq_len
|
49
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
50
|
+
|
51
|
+
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
52
|
+
|
53
|
+
def inverse(out):
|
54
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
55
|
+
return out[:, :-padding]
|
56
|
+
|
57
|
+
return seq, inverse
|
58
|
+
|
36
59
|
# feedforward and attention
|
37
60
|
|
38
61
|
class GEGLU(Module):
|
@@ -55,7 +78,8 @@ class SegmentedAttention(Module):
|
|
55
78
|
self,
|
56
79
|
dim,
|
57
80
|
segment_len,
|
58
|
-
num_persist_mem_tokens,
|
81
|
+
num_persist_mem_tokens = 0,
|
82
|
+
num_longterm_mem_tokens = 0,
|
59
83
|
dim_head = 64,
|
60
84
|
heads = 8,
|
61
85
|
):
|
@@ -70,31 +94,25 @@ class SegmentedAttention(Module):
|
|
70
94
|
self.to_out = LinearNoBias(dim_inner, dim)
|
71
95
|
|
72
96
|
self.segment_len = segment_len
|
97
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
98
|
+
|
99
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
73
100
|
|
74
101
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
75
102
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
76
103
|
|
77
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
78
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
79
|
-
|
80
104
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
81
105
|
|
82
106
|
def forward(self, seq):
|
107
|
+
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
108
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
109
|
+
|
83
110
|
batch, seq_len = seq.shape[:2]
|
84
111
|
|
85
112
|
# auto pad to multiple
|
86
113
|
# todo - get rid of logic with flex attention
|
87
114
|
|
88
|
-
|
89
|
-
|
90
|
-
if need_segment:
|
91
|
-
next_seq_len = round_up_multiple(seq_len, self.segment_len)
|
92
|
-
padding = next_seq_len - seq_len
|
93
|
-
|
94
|
-
if padding > 0:
|
95
|
-
seq = F.pad(seq, (0, 0, 0, padding))
|
96
|
-
|
97
|
-
seq = self.segment_seq(seq)
|
115
|
+
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
|
98
116
|
|
99
117
|
# attention
|
100
118
|
|
@@ -124,10 +142,9 @@ class SegmentedAttention(Module):
|
|
124
142
|
|
125
143
|
out = self.to_out(out)
|
126
144
|
|
127
|
-
|
128
|
-
out = self.merge_seq_back(out)
|
145
|
+
out = inverse_segment(out)
|
129
146
|
|
130
|
-
return out
|
147
|
+
return out
|
131
148
|
|
132
149
|
# MAC transformer
|
133
150
|
|
@@ -139,7 +156,8 @@ class MemoryAsContextTransformer(Module):
|
|
139
156
|
dim,
|
140
157
|
depth,
|
141
158
|
segment_len,
|
142
|
-
|
159
|
+
num_longterm_mem_tokens = 0,
|
160
|
+
num_persist_mem_tokens = 0,
|
143
161
|
dim_head = 64,
|
144
162
|
heads = 8,
|
145
163
|
ff_mult = 4,
|
@@ -147,10 +165,18 @@ class MemoryAsContextTransformer(Module):
|
|
147
165
|
):
|
148
166
|
super().__init__()
|
149
167
|
|
150
|
-
self.
|
168
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
169
|
+
|
151
170
|
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
171
|
|
153
|
-
|
172
|
+
# long term mem tokens
|
173
|
+
|
174
|
+
self.segment_len = segment_len
|
175
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
176
|
+
|
177
|
+
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
178
|
+
|
179
|
+
# hyper conection
|
154
180
|
|
155
181
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
156
182
|
|
@@ -162,6 +188,7 @@ class MemoryAsContextTransformer(Module):
|
|
162
188
|
dim_head = dim_head,
|
163
189
|
heads = heads,
|
164
190
|
segment_len = segment_len,
|
191
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
165
192
|
num_persist_mem_tokens = num_persist_mem_tokens
|
166
193
|
)
|
167
194
|
|
@@ -177,16 +204,32 @@ class MemoryAsContextTransformer(Module):
|
|
177
204
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
178
205
|
|
179
206
|
def forward(self, x):
|
180
|
-
|
207
|
+
|
208
|
+
# math
|
209
|
+
|
210
|
+
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
211
|
+
|
181
212
|
windows = ceil(seq_len / segment_len)
|
213
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
214
|
+
|
215
|
+
# token embedding
|
182
216
|
|
183
217
|
x = self.token_emb(x)
|
184
218
|
|
219
|
+
# intersperse longterm memory
|
220
|
+
|
221
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
222
|
+
|
223
|
+
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
224
|
+
x = torch.cat((mems, x), dim = -2)
|
225
|
+
|
226
|
+
x = inverse_segment(x)
|
227
|
+
|
185
228
|
# apply axial positional embedding
|
186
229
|
# so intra and inter segment can be more easily discerned by the network
|
187
230
|
|
188
|
-
pos_emb = self.axial_pos_emb((windows,
|
189
|
-
x = x + pos_emb[:
|
231
|
+
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
+
x = x + pos_emb[:x.shape[-2]]
|
190
233
|
|
191
234
|
# expand and reduce streams for hyper connections
|
192
235
|
|
@@ -198,21 +241,16 @@ class MemoryAsContextTransformer(Module):
|
|
198
241
|
|
199
242
|
x = self.reduce_streams(x)
|
200
243
|
|
201
|
-
|
244
|
+
# excise out the memories
|
202
245
|
|
203
|
-
|
246
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
204
247
|
|
205
|
-
|
248
|
+
x = x[:, self.num_longterm_mem_tokens:]
|
206
249
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
dim = 256,
|
211
|
-
depth = 2,
|
212
|
-
num_persist_mem_tokens = 16,
|
213
|
-
segment_len = 128,
|
214
|
-
)
|
250
|
+
x = inverse_segment(x)
|
251
|
+
|
252
|
+
# to logits
|
215
253
|
|
216
|
-
|
254
|
+
x = self.norm(x)
|
217
255
|
|
218
|
-
|
256
|
+
return self.to_logits(x)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|