titans-pytorch 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,11 +3,33 @@ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ from einops import rearrange
7
+
6
8
  # functions
7
9
 
8
10
  def l2norm(t):
9
11
  return F.normalize(t, dim = -1)
10
12
 
13
+ # norms
14
+
15
+ class LayerNorm(Module):
16
+ def __init__(
17
+ self,
18
+ dim
19
+ ):
20
+ super().__init__()
21
+
22
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
23
+ self.gamma = Parameter(torch.zeros(dim))
24
+
25
+ def forward(self, x):
26
+ gamma = self.gamma
27
+
28
+ if gamma.ndim == 2:
29
+ gamma = rearrange(gamma, 'b d -> b 1 d')
30
+
31
+ return self.ln(x) * (gamma + 1.)
32
+
11
33
  # memory mlp proposed in TTT
12
34
 
13
35
  class MemoryMLP(Module):
@@ -19,6 +41,8 @@ class MemoryMLP(Module):
19
41
  super().__init__()
20
42
  self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
21
43
 
44
+ self.ln = LayerNorm(dim)
45
+
22
46
  for weight in self.weights:
23
47
  nn.init.xavier_uniform_(weight)
24
48
 
@@ -36,7 +60,7 @@ class MemoryMLP(Module):
36
60
 
37
61
  x = x @ weight
38
62
 
39
- return x + residual
63
+ return self.ln(x) + residual
40
64
 
41
65
  # memory mlp, but with gated residual + final projection
42
66
 
@@ -60,6 +84,8 @@ class GatedResidualMemoryMLP(Module):
60
84
 
61
85
  self.final_proj = Parameter(torch.randn(dim, dim))
62
86
 
87
+ self.ln = LayerNorm(dim)
88
+
63
89
  for param in self.parameters():
64
90
  nn.init.xavier_uniform_(param)
65
91
 
@@ -83,7 +109,7 @@ class GatedResidualMemoryMLP(Module):
83
109
 
84
110
  out = x @ self.final_proj
85
111
 
86
- return out + residual
112
+ return self.ln(out) + residual
87
113
 
88
114
  # memory mlp with factorized weights
89
115
  # so can tradeoff capacity for smaller chunk sizes
@@ -103,6 +129,8 @@ class FactorizedMemoryMLP(Module):
103
129
  ]) for _ in range(depth)
104
130
  ])
105
131
 
132
+ self.ln = LayerNorm(dim)
133
+
106
134
  for weight1, weight2 in self.weights:
107
135
  nn.init.xavier_uniform_(weight1)
108
136
  nn.init.xavier_uniform_(weight2)
@@ -121,7 +149,7 @@ class FactorizedMemoryMLP(Module):
121
149
 
122
150
  x = x @ weight1 @ weight2
123
151
 
124
- return x + residual
152
+ return self.ln(x) + residual
125
153
 
126
154
  # improvised attention as memory module
127
155
 
@@ -144,6 +172,8 @@ class MemoryAttention(Module):
144
172
  nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
145
173
  ])
146
174
 
175
+ self.ln = LayerNorm(dim)
176
+
147
177
  for weight in self.weights:
148
178
  nn.init.xavier_uniform_(weight)
149
179
 
@@ -168,4 +198,4 @@ class MemoryAttention(Module):
168
198
  h = F.gelu(x @ ffw1)
169
199
  ff_out = h @ ffw2
170
200
 
171
- return attn_out + ff_out + residual
201
+ return self.ln(attn_out + ff_out) + residual
@@ -742,7 +742,7 @@ class NeuralMemory(Module):
742
742
 
743
743
  # retrieve
744
744
 
745
- retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
745
+ retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
746
746
 
747
747
  # next state tuple
748
748
 
@@ -801,10 +801,9 @@ class NeuralMemory(Module):
801
801
 
802
802
  updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
803
803
 
804
-
805
804
  retrieved = self.retrieve_memories(
806
805
  seq,
807
- mem_model_weights,
806
+ updates,
808
807
  chunk_size = chunk_size,
809
808
  prev_layer_updates = prev_layer_updates
810
809
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=YVbKl7DYKFWUgCawDTxXIEgJAcl7nq5OaZytmovIl8Q,24899
6
+ titans_pytorch-0.2.9.dist-info/METADATA,sha256=fSFt54zXLKB5gRhLTJd9551O0pF2qcYNlR7039yJiD0,6811
7
+ titans_pytorch-0.2.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.9.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
4
- titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
5
- titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
6
- titans_pytorch-0.2.7.dist-info/METADATA,sha256=ndFb28pAe8xWmNU6oncV8VJDDPImo3aCuBv0d0JylIs,6811
7
- titans_pytorch-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.7.dist-info/RECORD,,