titans-pytorch 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,6 +30,25 @@ class LayerNorm(Module):
30
30
 
31
31
  return self.ln(x) * (gamma + 1.)
32
32
 
33
+ # norm + residual wrapper, as used in original TTT paper
34
+ # but could be removed
35
+
36
+ class ResidualNorm(Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ model: Module
41
+ ):
42
+ super().__init__()
43
+ self.norm = LayerNorm(dim)
44
+ self.model = model
45
+
46
+ def forward(self, x):
47
+
48
+ out = self.model(x)
49
+
50
+ return self.norm(out) + x
51
+
33
52
  # memory mlp proposed in TTT
34
53
 
35
54
  class MemoryMLP(Module):
@@ -45,8 +64,6 @@ class MemoryMLP(Module):
45
64
 
46
65
  self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
47
66
 
48
- self.ln = LayerNorm(dim)
49
-
50
67
  for weight in self.weights:
51
68
  nn.init.xavier_uniform_(weight)
52
69
 
@@ -54,8 +71,6 @@ class MemoryMLP(Module):
54
71
  self,
55
72
  x
56
73
  ):
57
- residual = x
58
-
59
74
  for ind, weight in enumerate(self.weights):
60
75
  is_first = ind == 0
61
76
 
@@ -64,7 +79,7 @@ class MemoryMLP(Module):
64
79
 
65
80
  x = x @ weight
66
81
 
67
- return self.ln(x) + residual
82
+ return x
68
83
 
69
84
  # memory mlp, but with gated residual + final projection
70
85
 
@@ -97,7 +112,6 @@ class GatedResidualMemoryMLP(Module):
97
112
  self,
98
113
  x
99
114
  ):
100
- residual = x
101
115
 
102
116
  for weight1, weight2, to_gates in self.weights:
103
117
  res = x
@@ -111,9 +125,7 @@ class GatedResidualMemoryMLP(Module):
111
125
  gates = cat((branch_out, res), dim = -1) @ to_gates
112
126
  x = res.lerp(branch_out, gates.sigmoid())
113
127
 
114
- out = x @ self.final_proj
115
-
116
- return self.ln(out) + residual
128
+ return x @ self.final_proj
117
129
 
118
130
  # memory mlp with factorized weights
119
131
  # so can tradeoff capacity for smaller chunk sizes
@@ -143,7 +155,6 @@ class FactorizedMemoryMLP(Module):
143
155
  self,
144
156
  x
145
157
  ):
146
- residual = x
147
158
 
148
159
  for ind, (weight1, weight2) in enumerate(self.weights):
149
160
  is_first = ind == 0
@@ -153,7 +164,7 @@ class FactorizedMemoryMLP(Module):
153
164
 
154
165
  x = x @ weight1 @ weight2
155
166
 
156
- return self.ln(x) + residual
167
+ return x
157
168
 
158
169
  # improvised attention as memory module
159
170
 
@@ -182,7 +193,6 @@ class MemoryAttention(Module):
182
193
  nn.init.xavier_uniform_(weight)
183
194
 
184
195
  def forward(self, x):
185
- residual = x
186
196
 
187
197
  wq, wk, wv, ffw1, ffw2 = self.weights
188
198
 
@@ -202,4 +212,4 @@ class MemoryAttention(Module):
202
212
  h = F.gelu(x @ ffw1)
203
213
  ff_out = h @ ffw2
204
214
 
205
- return self.ln(attn_out + ff_out) + residual
215
+ return attn_out + ff_out
@@ -16,7 +16,8 @@ from tensordict import TensorDict
16
16
  from titans_pytorch.associative_scan import AssocScan
17
17
 
18
18
  from titans_pytorch.memory_models import(
19
- MemoryMLP
19
+ MemoryMLP,
20
+ ResidualNorm
20
21
  )
21
22
 
22
23
  import einx
@@ -234,6 +235,7 @@ class NeuralMemory(Module):
234
235
  init_decay_bias = None,
235
236
  accept_weight_residual = False,
236
237
  gated_transition = False,
238
+ mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
237
239
  default_model_kwargs: dict = dict(
238
240
  depth = 2,
239
241
  expansion_factor = 4.
@@ -304,6 +306,9 @@ class NeuralMemory(Module):
304
306
 
305
307
  # the memory is the weights of the model
306
308
 
309
+ if mem_model_norm_add_residual:
310
+ model = ResidualNorm(dim = dim_head, model = model)
311
+
307
312
  self.memory_model = model
308
313
 
309
314
  mem_model_params = dict(model.named_parameters())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.4
3
+ Version: 0.3.5
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=2fma9u0NQmDabgbpG6CLDGBRYzX99yIDQCSYIB0etkU,4989
5
+ titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
+ titans_pytorch-0.3.5.dist-info/METADATA,sha256=R6EL4q-zgW7DV5OyLzqz5XP2IvLNpJkaBylwH8GsyII,6815
7
+ titans_pytorch-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.5.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
5
- titans_pytorch/neural_memory.py,sha256=9eyeEvYsP5OFlwLDRyVut99uVYGvXAElFPabVoZnGJw,27063
6
- titans_pytorch-0.3.4.dist-info/METADATA,sha256=2ZD_DovSYkVejsTWHq7_IOTN-Je0of1f-HOiojaQBhQ,6815
7
- titans_pytorch-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.4.dist-info/RECORD,,