titans-pytorch 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from functools import partial
3
4
 
4
5
  import torch
@@ -11,7 +12,8 @@ from tensordict import TensorDict
11
12
 
12
13
  from titans_pytorch.associative_scan import (
13
14
  associative_scan,
14
- binary_operator
15
+ binary_operator,
16
+ pad_at_dim
15
17
  )
16
18
 
17
19
  import einx
@@ -41,6 +43,9 @@ def default(v, d):
41
43
  def round_down_multiple(seq, mult):
42
44
  return seq // mult * mult
43
45
 
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
44
49
  def pack_one_with_inverse(t, pattern):
45
50
  packed, packed_shape = pack([t], pattern)
46
51
 
@@ -236,7 +241,19 @@ class NeuralMemory(Module):
236
241
  seq,
237
242
  past_weights: dict[str, Tensor] | None = None,
238
243
  ):
239
- batch = seq.shape[0]
244
+ chunk_size = self.chunk_size
245
+ batch, seq_len = seq.shape[:2]
246
+
247
+ assert seq_len >= chunk_size
248
+
249
+ seq = seq[:, (chunk_size - 1):]
250
+ curtailed_seq_len = seq.shape[-2]
251
+
252
+ next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
253
+
254
+ padding = next_seq_len - curtailed_seq_len
255
+
256
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
240
257
 
241
258
  # the parameters of the memory model stores the memories of the key / values
242
259
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -256,7 +273,7 @@ class NeuralMemory(Module):
256
273
  # fetch values from memory model
257
274
 
258
275
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
259
- queries = rearrange(queries, 'b n d -> (b n) 1 d')
276
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
260
277
 
261
278
  # forward functional call
262
279
 
@@ -264,7 +281,12 @@ class NeuralMemory(Module):
264
281
 
265
282
  # reconstitute batch dimension
266
283
 
267
- values = rearrange(values, '(b n) 1 d -> b n d', b = batch)
284
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
285
+
286
+ # restore
287
+
288
+ values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
289
+ values = values[:, :-padding]
268
290
 
269
291
  return values
270
292
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -65,7 +65,10 @@ from titans_pytorch import NeuralMemory
65
65
 
66
66
  x = torch.randn(2, 64, 32)
67
67
 
68
- mem = NeuralMemory(32)
68
+ mem = NeuralMemory(
69
+ dim = 32,
70
+ chunk_size = 2
71
+ )
69
72
 
70
73
  out = mem(x)
71
74
 
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=3Mewuysj0g7iAlfjdqMlJhn9-pKJuOerB1frQmQYXuc,9428
4
+ titans_pytorch-0.0.5.dist-info/METADATA,sha256=f1DgCKZz9nqNfZOrqbOpyn-yEx2v5M5zgGIW0Zeu84I,3032
5
+ titans_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.5.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=Cue4Q3OCRPh-lUF99x-1LXmjIecbWOj8bDz8-xP-Rt0,8719
4
- titans_pytorch-0.0.4.dist-info/METADATA,sha256=IDXO4RWPda9jJak-7_Y0lEW2ADEao79XQqxcYVaLWxI,3000
5
- titans_pytorch-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.4.dist-info/RECORD,,