titans-pytorch 0.0.42__py3-none-any.whl → 0.0.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +3 -1
- titans_pytorch/mac_transformer.py +11 -9
- titans_pytorch/titans.py +2 -9
- {titans_pytorch-0.0.42.dist-info → titans_pytorch-0.0.44.dist-info}/METADATA +21 -1
- titans_pytorch-0.0.44.dist-info/RECORD +8 -0
- titans_pytorch-0.0.42.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.42.dist-info → titans_pytorch-0.0.44.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.42.dist-info → titans_pytorch-0.0.44.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
|
@@ -224,7 +224,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
224
224
|
self.layers = ModuleList([])
|
|
225
225
|
|
|
226
226
|
self.neural_mem_layers = ModuleList([])
|
|
227
|
-
neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
227
|
+
self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
|
|
228
228
|
|
|
229
229
|
layers = tuple(range(1, depth + 1))
|
|
230
230
|
|
|
@@ -245,7 +245,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
245
245
|
|
|
246
246
|
mem = NeuralMemory(
|
|
247
247
|
dim = dim,
|
|
248
|
-
chunk_size = neural_memory_segment_len,
|
|
248
|
+
chunk_size = self.neural_memory_segment_len,
|
|
249
249
|
**neural_memory_kwargs
|
|
250
250
|
)
|
|
251
251
|
|
|
@@ -287,10 +287,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
287
287
|
|
|
288
288
|
# math
|
|
289
289
|
|
|
290
|
-
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
|
291
|
-
|
|
292
|
-
windows = ceil(seq_len / segment_len)
|
|
293
|
-
total_segment_len = segment_len + num_longterm_mem_tokens
|
|
290
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
|
|
294
291
|
|
|
295
292
|
# token embedding
|
|
296
293
|
|
|
@@ -305,11 +302,16 @@ class MemoryAsContextTransformer(Module):
|
|
|
305
302
|
|
|
306
303
|
x = inverse_segment(x)
|
|
307
304
|
|
|
305
|
+
seq_len_with_mem = x.shape[-2]
|
|
306
|
+
|
|
308
307
|
# apply axial positional embedding
|
|
309
308
|
# so intra and inter segment can be more easily discerned by the network
|
|
310
309
|
|
|
311
|
-
|
|
312
|
-
|
|
310
|
+
neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len)
|
|
311
|
+
|
|
312
|
+
pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)
|
|
313
|
+
|
|
314
|
+
x = x + pos_emb[:seq_len_with_mem]
|
|
313
315
|
|
|
314
316
|
# value residual
|
|
315
317
|
|
|
@@ -334,7 +336,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
334
336
|
|
|
335
337
|
# excise out the memories
|
|
336
338
|
|
|
337
|
-
x, inverse_segment = pad_and_segment_with_inverse(x,
|
|
339
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
|
|
338
340
|
|
|
339
341
|
x, _ = unpack(x, mem_ps, 'b * d')
|
|
340
342
|
|
titans_pytorch/titans.py
CHANGED
|
@@ -425,11 +425,7 @@ class NeuralMemory(Module):
|
|
|
425
425
|
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
426
426
|
|
|
427
427
|
padding = next_seq_len - curtailed_seq_len
|
|
428
|
-
|
|
429
|
-
needs_pad = padding > 0
|
|
430
|
-
|
|
431
|
-
if needs_pad:
|
|
432
|
-
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
428
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
433
429
|
|
|
434
430
|
# the parameters of the memory model stores the memories of the key / values
|
|
435
431
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -481,10 +477,7 @@ class NeuralMemory(Module):
|
|
|
481
477
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
|
482
478
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
|
483
479
|
|
|
484
|
-
|
|
485
|
-
values = values[:, :-padding]
|
|
486
|
-
|
|
487
|
-
return values
|
|
480
|
+
return values[:, :seq_len]
|
|
488
481
|
|
|
489
482
|
def forward(
|
|
490
483
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.44
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -83,6 +83,26 @@ retrieved = mem(seq)
|
|
|
83
83
|
assert seq.shape == retrieved.shape
|
|
84
84
|
```
|
|
85
85
|
|
|
86
|
+
A transformer with the `MAC` configuration can be used as
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
import torch
|
|
90
|
+
from titans_pytorch import MemoryAsContextTransformer
|
|
91
|
+
|
|
92
|
+
transformer = MemoryAsContextTransformer(
|
|
93
|
+
num_tokens = 256,
|
|
94
|
+
dim = 256,
|
|
95
|
+
depth = 2,
|
|
96
|
+
segment_len = 128, # local attention window size
|
|
97
|
+
num_persist_mem_tokens = 4,
|
|
98
|
+
num_longterm_mem_tokens = 16,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
token_ids = torch.randint(0, 256, (1, 1023))
|
|
102
|
+
|
|
103
|
+
logits = transformer(token_ids) # (1, 1023, 256)
|
|
104
|
+
```
|
|
105
|
+
|
|
86
106
|
## Experiments
|
|
87
107
|
|
|
88
108
|
```bash
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=mF8PYAjeAjLas1gkYybgzZX1AVK82A_ps_LY00ofYYs,9565
|
|
4
|
+
titans_pytorch/titans.py,sha256=qxQ8pZCz8GEDhKeJMEaeAEzH66GAGVBNaRdNam_-czg,15260
|
|
5
|
+
titans_pytorch-0.0.44.dist-info/METADATA,sha256=QzgJ6YqBfRMvCMUnMPCtiil0Q488hL9BAfHhMEmI2pA,4210
|
|
6
|
+
titans_pytorch-0.0.44.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.44.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.44.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=F6pV8BamKCsbJFVo5x2hw69vzfJNLy54SwIKIueMdp4,142
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=kSdfWGWwEk6d0lbb0WLVKQwdmG8LAzDg36QZm7aIio0,9451
|
|
4
|
-
titans_pytorch/titans.py,sha256=eA7D9aqfGbtmC2SgGAQnfEVYp5Uza9uebEyDpVpjNQc,15372
|
|
5
|
-
titans_pytorch-0.0.42.dist-info/METADATA,sha256=4lZBFMZPuQRDQGdTK-TWheytxEaQZQv7bdXO7MLBrwI,3744
|
|
6
|
-
titans_pytorch-0.0.42.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.42.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.42.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|