titans-pytorch 0.0.39__tar.gz → 0.0.41__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/PKG-INFO +1 -1
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/pyproject.toml +1 -1
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/titans_pytorch/mac_transformer.py +36 -4
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/train_mac.py +1 -1
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/.gitignore +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/LICENSE +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/README.md +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/data/README.md +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/fig1.png +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/fig2.png +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/requirements.txt +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.39 → titans_pytorch-0.0.41}/train.py +0 -0
|
@@ -94,6 +94,7 @@ class SegmentedAttention(Module):
|
|
|
94
94
|
num_longterm_mem_tokens = 0,
|
|
95
95
|
dim_head = 64,
|
|
96
96
|
heads = 8,
|
|
97
|
+
accept_value_residual = False,
|
|
97
98
|
attend_kwargs: dict = dict()
|
|
98
99
|
):
|
|
99
100
|
super().__init__()
|
|
@@ -108,6 +109,12 @@ class SegmentedAttention(Module):
|
|
|
108
109
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
|
109
110
|
self.to_out = LinearNoBias(dim_inner, dim)
|
|
110
111
|
|
|
112
|
+
self.to_learned_v_mix = nn.Sequential(
|
|
113
|
+
nn.Linear(dim, heads),
|
|
114
|
+
Rearrange('b n h -> b h n 1'),
|
|
115
|
+
nn.Sigmoid()
|
|
116
|
+
) if accept_value_residual else None
|
|
117
|
+
|
|
111
118
|
self.segment_len = segment_len
|
|
112
119
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
|
113
120
|
|
|
@@ -118,7 +125,13 @@ class SegmentedAttention(Module):
|
|
|
118
125
|
|
|
119
126
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
|
120
127
|
|
|
121
|
-
def forward(
|
|
128
|
+
def forward(
|
|
129
|
+
self,
|
|
130
|
+
seq,
|
|
131
|
+
value_residual = None
|
|
132
|
+
):
|
|
133
|
+
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
134
|
+
|
|
122
135
|
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
|
123
136
|
total_segment_len = segment_len + num_longterm_mem_tokens
|
|
124
137
|
|
|
@@ -136,6 +149,14 @@ class SegmentedAttention(Module):
|
|
|
136
149
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
|
137
150
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
138
151
|
|
|
152
|
+
# value residual
|
|
153
|
+
|
|
154
|
+
orig_v = v
|
|
155
|
+
|
|
156
|
+
if exists(self.to_learned_v_mix):
|
|
157
|
+
mix = self.to_learned_v_mix(seq)
|
|
158
|
+
v = v.lerp(value_residual, mix)
|
|
159
|
+
|
|
139
160
|
# take care of persistent memory key / values
|
|
140
161
|
|
|
141
162
|
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
|
@@ -159,7 +180,7 @@ class SegmentedAttention(Module):
|
|
|
159
180
|
|
|
160
181
|
out = inverse_segment(out)
|
|
161
182
|
|
|
162
|
-
return out
|
|
183
|
+
return out, orig_v
|
|
163
184
|
|
|
164
185
|
# MAC transformer
|
|
165
186
|
|
|
@@ -171,6 +192,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
171
192
|
dim,
|
|
172
193
|
depth,
|
|
173
194
|
segment_len,
|
|
195
|
+
neural_memory_segment_len = None,
|
|
174
196
|
num_longterm_mem_tokens = 0,
|
|
175
197
|
num_persist_mem_tokens = 0,
|
|
176
198
|
dim_head = 64,
|
|
@@ -200,7 +222,9 @@ class MemoryAsContextTransformer(Module):
|
|
|
200
222
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
|
201
223
|
|
|
202
224
|
self.layers = ModuleList([])
|
|
225
|
+
|
|
203
226
|
self.neural_mem_layers = ModuleList([])
|
|
227
|
+
neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
204
228
|
|
|
205
229
|
layers = tuple(range(1, depth + 1))
|
|
206
230
|
|
|
@@ -210,6 +234,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
210
234
|
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
|
211
235
|
|
|
212
236
|
for layer in layers:
|
|
237
|
+
is_first = layer == 1
|
|
213
238
|
|
|
214
239
|
# neural memory
|
|
215
240
|
|
|
@@ -220,7 +245,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
220
245
|
|
|
221
246
|
mem = NeuralMemory(
|
|
222
247
|
dim = dim,
|
|
223
|
-
chunk_size =
|
|
248
|
+
chunk_size = neural_memory_segment_len,
|
|
224
249
|
**neural_memory_kwargs
|
|
225
250
|
)
|
|
226
251
|
|
|
@@ -235,6 +260,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
235
260
|
dim_head = dim_head,
|
|
236
261
|
heads = heads,
|
|
237
262
|
segment_len = segment_len,
|
|
263
|
+
accept_value_residual = not is_first,
|
|
238
264
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
239
265
|
num_persist_mem_tokens = num_persist_mem_tokens
|
|
240
266
|
)
|
|
@@ -285,6 +311,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
285
311
|
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
|
286
312
|
x = x + pos_emb[:x.shape[-2]]
|
|
287
313
|
|
|
314
|
+
# value residual
|
|
315
|
+
|
|
316
|
+
value_residual = None
|
|
317
|
+
|
|
288
318
|
# expand and reduce streams for hyper connections
|
|
289
319
|
|
|
290
320
|
x = self.expand_streams(x)
|
|
@@ -295,7 +325,9 @@ class MemoryAsContextTransformer(Module):
|
|
|
295
325
|
x = maybe_neural_mem(x)
|
|
296
326
|
|
|
297
327
|
|
|
298
|
-
x = attn(x)
|
|
328
|
+
x, values = attn(x, value_residual = value_residual)
|
|
329
|
+
|
|
330
|
+
value_residual = default(value_residual, values)
|
|
299
331
|
|
|
300
332
|
x = ff(x)
|
|
301
333
|
|
|
@@ -24,7 +24,7 @@ SHOULD_GENERATE = False
|
|
|
24
24
|
SEQ_LEN = 512
|
|
25
25
|
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
|
27
|
-
WANDB_ONLINE =
|
|
27
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
|
29
29
|
NUM_PERSIST_MEM = 4
|
|
30
30
|
NUM_LONGTERM_MEM = 4
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|