titans-pytorch 0.0.4__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/PKG-INFO +5 -2
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/README.md +4 -1
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/pyproject.toml +1 -1
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/titans_pytorch/titans.py +26 -4
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/.gitignore +0 -0
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/LICENSE +0 -0
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/fig1.png +0 -0
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/fig2.png +0 -0
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.4 → titans_pytorch-0.0.5}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -65,7 +65,10 @@ from titans_pytorch import NeuralMemory
|
|
|
65
65
|
|
|
66
66
|
x = torch.randn(2, 64, 32)
|
|
67
67
|
|
|
68
|
-
mem = NeuralMemory(
|
|
68
|
+
mem = NeuralMemory(
|
|
69
|
+
dim = 32,
|
|
70
|
+
chunk_size = 2
|
|
71
|
+
)
|
|
69
72
|
|
|
70
73
|
out = mem(x)
|
|
71
74
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
import torch
|
|
@@ -11,7 +12,8 @@ from tensordict import TensorDict
|
|
|
11
12
|
|
|
12
13
|
from titans_pytorch.associative_scan import (
|
|
13
14
|
associative_scan,
|
|
14
|
-
binary_operator
|
|
15
|
+
binary_operator,
|
|
16
|
+
pad_at_dim
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
import einx
|
|
@@ -41,6 +43,9 @@ def default(v, d):
|
|
|
41
43
|
def round_down_multiple(seq, mult):
|
|
42
44
|
return seq // mult * mult
|
|
43
45
|
|
|
46
|
+
def round_up_multiple(seq, mult):
|
|
47
|
+
return math.ceil(seq / mult) * mult
|
|
48
|
+
|
|
44
49
|
def pack_one_with_inverse(t, pattern):
|
|
45
50
|
packed, packed_shape = pack([t], pattern)
|
|
46
51
|
|
|
@@ -236,7 +241,19 @@ class NeuralMemory(Module):
|
|
|
236
241
|
seq,
|
|
237
242
|
past_weights: dict[str, Tensor] | None = None,
|
|
238
243
|
):
|
|
239
|
-
|
|
244
|
+
chunk_size = self.chunk_size
|
|
245
|
+
batch, seq_len = seq.shape[:2]
|
|
246
|
+
|
|
247
|
+
assert seq_len >= chunk_size
|
|
248
|
+
|
|
249
|
+
seq = seq[:, (chunk_size - 1):]
|
|
250
|
+
curtailed_seq_len = seq.shape[-2]
|
|
251
|
+
|
|
252
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
253
|
+
|
|
254
|
+
padding = next_seq_len - curtailed_seq_len
|
|
255
|
+
|
|
256
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
240
257
|
|
|
241
258
|
# the parameters of the memory model stores the memories of the key / values
|
|
242
259
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -256,7 +273,7 @@ class NeuralMemory(Module):
|
|
|
256
273
|
# fetch values from memory model
|
|
257
274
|
|
|
258
275
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
|
259
|
-
queries = rearrange(queries, 'b n d -> (b n)
|
|
276
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
|
260
277
|
|
|
261
278
|
# forward functional call
|
|
262
279
|
|
|
@@ -264,7 +281,12 @@ class NeuralMemory(Module):
|
|
|
264
281
|
|
|
265
282
|
# reconstitute batch dimension
|
|
266
283
|
|
|
267
|
-
values = rearrange(values, '(b n)
|
|
284
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
|
285
|
+
|
|
286
|
+
# restore
|
|
287
|
+
|
|
288
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
|
289
|
+
values = values[:, :-padding]
|
|
268
290
|
|
|
269
291
|
return values
|
|
270
292
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|