titans-pytorch 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
+ NeuralMemState,
4
+ mem_state_detach
3
5
  )
4
6
 
5
7
  from titans_pytorch.memory_models import (
@@ -7,10 +7,11 @@ from itertools import zip_longest
7
7
  from collections import namedtuple
8
8
 
9
9
  import torch
10
- from torch import nn, stack, cat, tensor, Tensor
10
+ from torch import nn, stack, cat, is_tensor, tensor, Tensor
11
11
  import torch.nn.functional as F
12
12
  from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
13
13
  from torch.func import functional_call, vmap, grad
14
+ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
14
15
 
15
16
  from tensordict import TensorDict
16
17
 
@@ -40,6 +41,8 @@ u - key / value updates - allowing a token to emit multiple key / values
40
41
 
41
42
  LinearNoBias = partial(Linear, bias = False)
42
43
 
44
+ # neural mem state related
45
+
43
46
  NeuralMemState = namedtuple('NeuralMemState', [
44
47
  'seq_index',
45
48
  'weights',
@@ -48,6 +51,13 @@ NeuralMemState = namedtuple('NeuralMemState', [
48
51
  'updates',
49
52
  ])
50
53
 
54
+ def mem_state_detach(
55
+ state: NeuralMemState
56
+ ):
57
+ assert isinstance(state, NeuralMemState)
58
+ state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
59
+ return NeuralMemState(*state)
60
+
51
61
  # functions
52
62
 
53
63
  def exists(v):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch (wip)
59
+ ## Titans - Pytorch
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=eDC8rl8S241zp8pxKrIEi_6nFm6CEoaH9K4hnDfgzu8,33145
6
+ titans_pytorch-0.4.4.dist-info/METADATA,sha256=CWciTl1VeOvwyL_lqr0JsdmDDIjXG85N8ykwd3w2TxQ,6810
7
+ titans_pytorch-0.4.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.4.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=HdBaRGURJ84Qy-a6PdfeQoc5ZzY7H0c5YHUASaSVu1A,32824
6
- titans_pytorch-0.4.3.dist-info/METADATA,sha256=SIq5KS2xehsUAwuFpRSFNdnLbgamWUMLN5xj4MJGRe0,6816
7
- titans_pytorch-0.4.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.4.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.4.3.dist-info/RECORD,,