titans-pytorch 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -488,7 +488,7 @@ class MemoryAsContextTransformer(Module):
488
488
  neural_memory_model: Module | None = None,
489
489
  neural_memory_kwargs: dict = dict(),
490
490
  neural_memory_layers: tuple[int, ...] | None = None,
491
- aux_kv_recon_loss_weight = 0.,
491
+ aux_kv_recon_loss_weight = 1.,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
494
  weight_tie_memory_model = False,
@@ -1,8 +1,15 @@
1
1
  import torch
2
- from torch import nn
2
+ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ # functions
7
+
8
+ def l2norm(t):
9
+ return F.normalize(t, dim = -1)
10
+
11
+ # memory mlp proposed in TTT
12
+
6
13
  class MemoryMLP(Module):
7
14
  def __init__(
8
15
  self,
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
19
26
  self,
20
27
  x
21
28
  ):
29
+ residual = x
30
+
22
31
  for ind, weight in enumerate(self.weights):
23
32
  is_first = ind == 0
24
33
 
25
34
  if not is_first:
26
- x = F.silu(x)
35
+ x = F.gelu(x)
27
36
 
28
37
  x = x @ weight
29
38
 
30
- return x
39
+ return x + residual
31
40
 
32
41
  # memory mlp, but with gated residual + final projection
33
42
 
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
36
45
  self,
37
46
  dim,
38
47
  depth,
39
- expansion_factor = 2.
48
+ expansion_factor = 4.
40
49
  ):
41
50
  super().__init__()
42
51
  dim_hidden = int(dim * expansion_factor)
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
58
67
  self,
59
68
  x
60
69
  ):
70
+ residual = x
71
+
61
72
  for weight1, weight2, to_gates in self.weights:
62
73
  res = x
63
74
 
64
75
  hidden = x @ weight1
65
- hidden = F.silu(hidden)
76
+ hidden = F.gelu(hidden)
66
77
  branch_out = hidden @ weight2
67
78
 
68
79
  # gated residual
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
70
81
  gates = cat((branch_out, res), dim = -1) @ to_gates
71
82
  x = res.lerp(branch_out, gates.sigmoid())
72
83
 
73
- return x @ self.final_proj
84
+ out = x @ self.final_proj
85
+
86
+ return out + residual
74
87
 
75
88
  # memory mlp with factorized weights
76
89
  # so can tradeoff capacity for smaller chunk sizes
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
98
111
  self,
99
112
  x
100
113
  ):
114
+ residual = x
115
+
101
116
  for ind, (weight1, weight2) in enumerate(self.weights):
102
117
  is_first = ind == 0
103
118
 
104
119
  if not is_first:
105
- x = F.silu(x)
120
+ x = F.gelu(x)
106
121
 
107
122
  x = x @ weight1 @ weight2
108
123
 
109
- return x
124
+ return x + residual
110
125
 
111
126
  # improvised attention as memory module
112
127
 
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
133
148
  nn.init.xavier_uniform_(weight)
134
149
 
135
150
  def forward(self, x):
151
+ residual = x
152
+
136
153
  wq, wk, wv, ffw1, ffw2 = self.weights
137
154
 
138
- q = F.normalize(x @ wq, dim = -1)
139
- k = F.normalize(x @ wk, dim = -1)
155
+ q = l2norm(x @ wq)
156
+ k = l2norm(x @ wk)
140
157
  v = x @ wv
141
158
 
142
159
  attn_out = F.scaled_dot_product_attention(
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
145
162
  is_causal = True
146
163
  )
147
164
 
148
- x = x + attn_out
165
+ # parallel attention + feedforward block
166
+ # as in PaLM + Gpt-J
149
167
 
150
- h = F.silu(x @ ffw1)
151
- out = h @ ffw2
168
+ h = F.gelu(x @ ffw1)
169
+ ff_out = h @ ffw2
152
170
 
153
- return out
171
+ return attn_out + ff_out + residual
@@ -304,13 +304,26 @@ class NeuralMemory(Module):
304
304
  nn.Sigmoid()
305
305
  ) if heads > 1 else None
306
306
 
307
- # memory mlp
307
+ # memory model
308
308
 
309
309
  if not exists(model):
310
310
  model = MemoryMLP(dim_head, **default_model_kwargs)
311
311
 
312
+ # validate memory model
313
+
312
314
  assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
313
315
 
316
+ test_shape = (3, 2, dim_head)
317
+
318
+ with torch.no_grad():
319
+ try:
320
+ test_input = torch.randn(test_shape)
321
+ mem_model_output = model(test_input)
322
+ except:
323
+ raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
324
+
325
+ assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
326
+
314
327
  # the memory is the weights of the model
315
328
 
316
329
  self.memory_model = model
@@ -772,10 +785,14 @@ class NeuralMemory(Module):
772
785
 
773
786
  # retrieve
774
787
 
775
- if exists(prev_layer_updates):
776
- prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
788
+ retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
789
+
790
+ if retrieve_chunk_size != 1:
791
+ if exists(prev_layer_updates):
792
+ prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
793
+
794
+ updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
777
795
 
778
- updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
779
796
 
780
797
  retrieved = self.retrieve_memories(
781
798
  seq,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch
59
+ ## Titans - Pytorch (wip)
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
+ titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
5
+ titans_pytorch/neural_memory.py,sha256=UNST32JulrDw3_dPljSFU3ZCLofDH-KBoFxx8j1Oii4,24733
6
+ titans_pytorch-0.2.6.dist-info/METADATA,sha256=ZsjIkdKM2zanf74Q9EOqhCreVIQsdRNnOEIRBTISwAI,6825
7
+ titans_pytorch-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.6.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=g-Rx8zwTUbMv-XBYWPe9abFVVSUFLxOn_yVQ-wWvG5M,26039
4
- titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
- titans_pytorch/neural_memory.py,sha256=3ykFukUDp3dW1QwDmS3jZ2wFysiZE2ippcOoMFall34,24143
6
- titans_pytorch-0.2.4.dist-info/METADATA,sha256=2yY3d58zPQ1uyvnTX4Dml7a2dd2jRu3TR5NhBpPNmdY,6819
7
- titans_pytorch-0.2.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.4.dist-info/RECORD,,