titans-pytorch 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
+ NeuralMemState,
4
+ mem_state_detach
3
5
  )
4
6
 
5
7
  from titans_pytorch.memory_models import (
@@ -7,10 +7,11 @@ from itertools import zip_longest
7
7
  from collections import namedtuple
8
8
 
9
9
  import torch
10
- from torch import nn, stack, cat, tensor, Tensor
10
+ from torch import nn, stack, cat, is_tensor, tensor, Tensor
11
11
  import torch.nn.functional as F
12
12
  from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
13
13
  from torch.func import functional_call, vmap, grad
14
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
14
15
 
15
16
  from tensordict import TensorDict
16
17
 
@@ -40,6 +41,8 @@ u - key / value updates - allowing a token to emit multiple key / values
40
41
 
41
42
  LinearNoBias = partial(Linear, bias = False)
42
43
 
44
+ # neural mem state related
45
+
43
46
  NeuralMemState = namedtuple('NeuralMemState', [
44
47
  'seq_index',
45
48
  'weights',
@@ -48,6 +51,13 @@ NeuralMemState = namedtuple('NeuralMemState', [
48
51
  'updates',
49
52
  ])
50
53
 
54
+ def mem_state_detach(
55
+ state: NeuralMemState
56
+ ):
57
+ assert isinstance(state, NeuralMemState)
58
+ state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
59
+ return NeuralMemState(*state)
60
+
51
61
  # functions
52
62
 
53
63
  def exists(v):
@@ -940,7 +950,7 @@ class NeuralMemory(Module):
940
950
 
941
951
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
942
952
 
943
- surprises = None
953
+ surprises = (None, None)
944
954
  gate = None
945
955
 
946
956
  if exists(self.transition_gate):
@@ -967,7 +977,7 @@ class NeuralMemory(Module):
967
977
 
968
978
  updates = accum_updates(updates, next_updates)
969
979
 
970
- surprises = safe_cat((surprises, chunk_surprises), dim = -1)
980
+ surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises))
971
981
 
972
982
  if is_last and not update_after_final_store:
973
983
  continue
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.2
3
+ Version: 0.4.4
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch (wip)
59
+ ## Titans - Pytorch
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=eDC8rl8S241zp8pxKrIEi_6nFm6CEoaH9K4hnDfgzu8,33145
6
+ titans_pytorch-0.4.4.dist-info/METADATA,sha256=CWciTl1VeOvwyL_lqr0JsdmDDIjXG85N8ykwd3w2TxQ,6810
7
+ titans_pytorch-0.4.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.4.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=D7jzi2SjcVj89F3Ws-zyOp04mCg5sJuUFXC6GPRdiz8,32789
6
- titans_pytorch-0.4.2.dist-info/METADATA,sha256=HNJZM3kvMlnRLVN9i4hLecWSL93q0Fg7nqq8xz-BT2o,6816
7
- titans_pytorch-0.4.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.4.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.4.2.dist-info/RECORD,,