titans-pytorch 0.0.4__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -39,6 +39,7 @@ Requires-Dist: einx>=0.3.0
39
39
  Requires-Dist: tensordict>=0.6.2
40
40
  Requires-Dist: torch>=2.3
41
41
  Provides-Extra: examples
42
+ Requires-Dist: local-attention>=1.9.15; extra == 'examples'
42
43
  Provides-Extra: test
43
44
  Requires-Dist: pytest; extra == 'test'
44
45
  Description-Content-Type: text/markdown
@@ -65,7 +66,10 @@ from titans_pytorch import NeuralMemory
65
66
 
66
67
  x = torch.randn(2, 64, 32)
67
68
 
68
- mem = NeuralMemory(32)
69
+ mem = NeuralMemory(
70
+ dim = 32,
71
+ chunk_size = 2
72
+ )
69
73
 
70
74
  out = mem(x)
71
75
 
@@ -20,7 +20,10 @@ from titans_pytorch import NeuralMemory
20
20
 
21
21
  x = torch.randn(2, 64, 32)
22
22
 
23
- mem = NeuralMemory(32)
23
+ mem = NeuralMemory(
24
+ dim = 32,
25
+ chunk_size = 2
26
+ )
24
27
 
25
28
  out = mem(x)
26
29
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.4"
3
+ version = "0.0.6"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -36,7 +36,9 @@ Homepage = "https://pypi.org/project/titans-pytorch/"
36
36
  Repository = "https://github.com/lucidrains/titans-pytorch"
37
37
 
38
38
  [project.optional-dependencies]
39
- examples = []
39
+ examples = [
40
+ "local-attention>=1.9.15"
41
+ ]
40
42
  test = [
41
43
  "pytest"
42
44
  ]
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from functools import partial
3
4
 
4
5
  import torch
@@ -11,7 +12,8 @@ from tensordict import TensorDict
11
12
 
12
13
  from titans_pytorch.associative_scan import (
13
14
  associative_scan,
14
- binary_operator
15
+ binary_operator,
16
+ pad_at_dim
15
17
  )
16
18
 
17
19
  import einx
@@ -41,6 +43,9 @@ def default(v, d):
41
43
  def round_down_multiple(seq, mult):
42
44
  return seq // mult * mult
43
45
 
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
44
49
  def pack_one_with_inverse(t, pattern):
45
50
  packed, packed_shape = pack([t], pattern)
46
51
 
@@ -159,7 +164,7 @@ class NeuralMemory(Module):
159
164
  # curtail sequence by multiple of the chunk size
160
165
  # only a complete chunk of the sequence provides the memory for the next chunk
161
166
 
162
- seq_len = seq.shape[-2]
167
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
163
168
  round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
164
169
 
165
170
  seq = seq[:, :round_down_seq_len]
@@ -229,14 +234,26 @@ class NeuralMemory(Module):
229
234
 
230
235
  next_state = (curr_weights + last_update, next_momentum)
231
236
 
232
- return updates, next_state, aux_store_loss.mean()
237
+ return updates, next_state, aux_store_loss.mean() / chunk_size
233
238
 
234
239
  def retrieve_memories(
235
240
  self,
236
241
  seq,
237
242
  past_weights: dict[str, Tensor] | None = None,
238
243
  ):
239
- batch = seq.shape[0]
244
+ chunk_size = self.chunk_size
245
+ batch, seq_len = seq.shape[:2]
246
+
247
+ assert seq_len >= chunk_size
248
+
249
+ seq = seq[:, (chunk_size - 1):]
250
+ curtailed_seq_len = seq.shape[-2]
251
+
252
+ next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
253
+
254
+ padding = next_seq_len - curtailed_seq_len
255
+
256
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
240
257
 
241
258
  # the parameters of the memory model stores the memories of the key / values
242
259
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
@@ -256,7 +273,7 @@ class NeuralMemory(Module):
256
273
  # fetch values from memory model
257
274
 
258
275
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
259
- queries = rearrange(queries, 'b n d -> (b n) 1 d')
276
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
260
277
 
261
278
  # forward functional call
262
279
 
@@ -264,7 +281,12 @@ class NeuralMemory(Module):
264
281
 
265
282
  # reconstitute batch dimension
266
283
 
267
- values = rearrange(values, '(b n) 1 d -> b n d', b = batch)
284
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
285
+
286
+ # restore
287
+
288
+ values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
289
+ values = values[:, :-padding]
268
290
 
269
291
  return values
270
292
 
File without changes
File without changes
File without changes