titans-pytorch 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +28 -6
- {titans_pytorch-0.0.4.dist-info → titans_pytorch-0.0.6.dist-info}/METADATA +6 -2
- titans_pytorch-0.0.6.dist-info/RECORD +7 -0
- titans_pytorch-0.0.4.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.4.dist-info → titans_pytorch-0.0.6.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.4.dist-info → titans_pytorch-0.0.6.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
import torch
|
|
@@ -11,7 +12,8 @@ from tensordict import TensorDict
|
|
|
11
12
|
|
|
12
13
|
from titans_pytorch.associative_scan import (
|
|
13
14
|
associative_scan,
|
|
14
|
-
binary_operator
|
|
15
|
+
binary_operator,
|
|
16
|
+
pad_at_dim
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
import einx
|
|
@@ -41,6 +43,9 @@ def default(v, d):
|
|
|
41
43
|
def round_down_multiple(seq, mult):
|
|
42
44
|
return seq // mult * mult
|
|
43
45
|
|
|
46
|
+
def round_up_multiple(seq, mult):
|
|
47
|
+
return math.ceil(seq / mult) * mult
|
|
48
|
+
|
|
44
49
|
def pack_one_with_inverse(t, pattern):
|
|
45
50
|
packed, packed_shape = pack([t], pattern)
|
|
46
51
|
|
|
@@ -159,7 +164,7 @@ class NeuralMemory(Module):
|
|
|
159
164
|
# curtail sequence by multiple of the chunk size
|
|
160
165
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
161
166
|
|
|
162
|
-
seq_len = seq.shape[-2]
|
|
167
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
|
163
168
|
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
164
169
|
|
|
165
170
|
seq = seq[:, :round_down_seq_len]
|
|
@@ -229,14 +234,26 @@ class NeuralMemory(Module):
|
|
|
229
234
|
|
|
230
235
|
next_state = (curr_weights + last_update, next_momentum)
|
|
231
236
|
|
|
232
|
-
return updates, next_state, aux_store_loss.mean()
|
|
237
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
|
233
238
|
|
|
234
239
|
def retrieve_memories(
|
|
235
240
|
self,
|
|
236
241
|
seq,
|
|
237
242
|
past_weights: dict[str, Tensor] | None = None,
|
|
238
243
|
):
|
|
239
|
-
|
|
244
|
+
chunk_size = self.chunk_size
|
|
245
|
+
batch, seq_len = seq.shape[:2]
|
|
246
|
+
|
|
247
|
+
assert seq_len >= chunk_size
|
|
248
|
+
|
|
249
|
+
seq = seq[:, (chunk_size - 1):]
|
|
250
|
+
curtailed_seq_len = seq.shape[-2]
|
|
251
|
+
|
|
252
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
253
|
+
|
|
254
|
+
padding = next_seq_len - curtailed_seq_len
|
|
255
|
+
|
|
256
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
240
257
|
|
|
241
258
|
# the parameters of the memory model stores the memories of the key / values
|
|
242
259
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -256,7 +273,7 @@ class NeuralMemory(Module):
|
|
|
256
273
|
# fetch values from memory model
|
|
257
274
|
|
|
258
275
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
|
259
|
-
queries = rearrange(queries, 'b n d -> (b n)
|
|
276
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
|
260
277
|
|
|
261
278
|
# forward functional call
|
|
262
279
|
|
|
@@ -264,7 +281,12 @@ class NeuralMemory(Module):
|
|
|
264
281
|
|
|
265
282
|
# reconstitute batch dimension
|
|
266
283
|
|
|
267
|
-
values = rearrange(values, '(b n)
|
|
284
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
|
285
|
+
|
|
286
|
+
# restore
|
|
287
|
+
|
|
288
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
|
289
|
+
values = values[:, :-padding]
|
|
268
290
|
|
|
269
291
|
return values
|
|
270
292
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -39,6 +39,7 @@ Requires-Dist: einx>=0.3.0
|
|
|
39
39
|
Requires-Dist: tensordict>=0.6.2
|
|
40
40
|
Requires-Dist: torch>=2.3
|
|
41
41
|
Provides-Extra: examples
|
|
42
|
+
Requires-Dist: local-attention>=1.9.15; extra == 'examples'
|
|
42
43
|
Provides-Extra: test
|
|
43
44
|
Requires-Dist: pytest; extra == 'test'
|
|
44
45
|
Description-Content-Type: text/markdown
|
|
@@ -65,7 +66,10 @@ from titans_pytorch import NeuralMemory
|
|
|
65
66
|
|
|
66
67
|
x = torch.randn(2, 64, 32)
|
|
67
68
|
|
|
68
|
-
mem = NeuralMemory(
|
|
69
|
+
mem = NeuralMemory(
|
|
70
|
+
dim = 32,
|
|
71
|
+
chunk_size = 2
|
|
72
|
+
)
|
|
69
73
|
|
|
70
74
|
out = mem(x)
|
|
71
75
|
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/titans.py,sha256=S8J8B9o7Rlnj2hU3FZgpn28GTmis3ZbenLqjB_uny54,9470
|
|
4
|
+
titans_pytorch-0.0.6.dist-info/METADATA,sha256=t4HXD6sZT7_pgcwD8TBY6ojYHUHiZ05J6t19wRKtHNc,3092
|
|
5
|
+
titans_pytorch-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
titans_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
titans_pytorch-0.0.6.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/titans.py,sha256=Cue4Q3OCRPh-lUF99x-1LXmjIecbWOj8bDz8-xP-Rt0,8719
|
|
4
|
-
titans_pytorch-0.0.4.dist-info/METADATA,sha256=IDXO4RWPda9jJak-7_Y0lEW2ADEao79XQqxcYVaLWxI,3000
|
|
5
|
-
titans_pytorch-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
titans_pytorch-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
titans_pytorch-0.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|