titans-pytorch 0.4.3__tar.gz → 0.4.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.3
3
+ Version: 0.4.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch (wip)
59
+ ## Titans - Pytorch
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  <img src="./fig1.png" width="400px"></img>
4
4
 
5
- ## Titans - Pytorch (wip)
5
+ ## Titans - Pytorch
6
6
 
7
7
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
8
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.4.3"
3
+ version = "0.4.5"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -405,3 +405,23 @@ def test_assoc_scan(
405
405
  assert second_half.shape == inputs2.shape
406
406
 
407
407
  assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
408
+
409
+ def test_mem_state_detach():
410
+ from titans_pytorch.neural_memory import mem_state_detach
411
+
412
+ mem = NeuralMemory(
413
+ dim = 384,
414
+ chunk_size = 2,
415
+ qk_rmsnorm = True,
416
+ dim_head = 64,
417
+ heads = 4,
418
+ )
419
+
420
+ seq = torch.randn(4, 64, 384)
421
+
422
+ state = None
423
+
424
+ for _ in range(2):
425
+ parallel_retrieved, state = mem(seq, state = state)
426
+ state = mem_state_detach(state)
427
+ parallel_retrieved.sum().backward()
@@ -1,5 +1,7 @@
1
1
  from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
+ NeuralMemState,
4
+ mem_state_detach
3
5
  )
4
6
 
5
7
  from titans_pytorch.memory_models import (
@@ -7,10 +7,11 @@ from itertools import zip_longest
7
7
  from collections import namedtuple
8
8
 
9
9
  import torch
10
- from torch import nn, stack, cat, tensor, Tensor
10
+ from torch import nn, stack, cat, is_tensor, tensor, Tensor
11
11
  import torch.nn.functional as F
12
12
  from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
13
13
  from torch.func import functional_call, vmap, grad
14
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
14
15
 
15
16
  from tensordict import TensorDict
16
17
 
@@ -40,6 +41,8 @@ u - key / value updates - allowing a token to emit multiple key / values
40
41
 
41
42
  LinearNoBias = partial(Linear, bias = False)
42
43
 
44
+ # neural mem state related
45
+
43
46
  NeuralMemState = namedtuple('NeuralMemState', [
44
47
  'seq_index',
45
48
  'weights',
@@ -48,6 +51,13 @@ NeuralMemState = namedtuple('NeuralMemState', [
48
51
  'updates',
49
52
  ])
50
53
 
54
+ def mem_state_detach(
55
+ state: NeuralMemState
56
+ ):
57
+ assert isinstance(state, NeuralMemState)
58
+ state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
59
+ return NeuralMemState(*state)
60
+
51
61
  # functions
52
62
 
53
63
  def exists(v):
@@ -854,6 +864,7 @@ class NeuralMemory(Module):
854
864
  seq,
855
865
  store_seq = None,
856
866
  state: NeuralMemState | None = None,
867
+ detach_mem_state = False,
857
868
  prev_weights = None,
858
869
  store_mask: Tensor | None = None,
859
870
  return_surprises = False
@@ -1003,6 +1014,11 @@ class NeuralMemory(Module):
1003
1014
  updates
1004
1015
  )
1005
1016
 
1017
+ # maybe detach
1018
+
1019
+ if detach_mem_state:
1020
+ next_neural_mem_state = mem_state_detach(next_neural_mem_state)
1021
+
1006
1022
  # returning
1007
1023
 
1008
1024
  if not return_surprises:
File without changes
File without changes
File without changes