titans-pytorch 0.3.4__tar.gz → 0.3.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.4
3
+ Version: 0.3.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.4"
3
+ version = "0.3.6"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -30,6 +30,25 @@ class LayerNorm(Module):
30
30
 
31
31
  return self.ln(x) * (gamma + 1.)
32
32
 
33
+ # norm + residual wrapper, as used in original TTT paper
34
+ # but could be removed
35
+
36
+ class ResidualNorm(Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ model: Module
41
+ ):
42
+ super().__init__()
43
+ self.norm = LayerNorm(dim)
44
+ self.model = model
45
+
46
+ def forward(self, x):
47
+
48
+ out = self.model(x)
49
+
50
+ return self.norm(out) + x
51
+
33
52
  # memory mlp proposed in TTT
34
53
 
35
54
  class MemoryMLP(Module):
@@ -45,8 +64,6 @@ class MemoryMLP(Module):
45
64
 
46
65
  self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
47
66
 
48
- self.ln = LayerNorm(dim)
49
-
50
67
  for weight in self.weights:
51
68
  nn.init.xavier_uniform_(weight)
52
69
 
@@ -54,8 +71,6 @@ class MemoryMLP(Module):
54
71
  self,
55
72
  x
56
73
  ):
57
- residual = x
58
-
59
74
  for ind, weight in enumerate(self.weights):
60
75
  is_first = ind == 0
61
76
 
@@ -64,7 +79,7 @@ class MemoryMLP(Module):
64
79
 
65
80
  x = x @ weight
66
81
 
67
- return self.ln(x) + residual
82
+ return x
68
83
 
69
84
  # memory mlp, but with gated residual + final projection
70
85
 
@@ -88,8 +103,6 @@ class GatedResidualMemoryMLP(Module):
88
103
 
89
104
  self.final_proj = Parameter(torch.randn(dim, dim))
90
105
 
91
- self.ln = LayerNorm(dim)
92
-
93
106
  for param in self.parameters():
94
107
  nn.init.xavier_uniform_(param)
95
108
 
@@ -97,7 +110,6 @@ class GatedResidualMemoryMLP(Module):
97
110
  self,
98
111
  x
99
112
  ):
100
- residual = x
101
113
 
102
114
  for weight1, weight2, to_gates in self.weights:
103
115
  res = x
@@ -111,9 +123,7 @@ class GatedResidualMemoryMLP(Module):
111
123
  gates = cat((branch_out, res), dim = -1) @ to_gates
112
124
  x = res.lerp(branch_out, gates.sigmoid())
113
125
 
114
- out = x @ self.final_proj
115
-
116
- return self.ln(out) + residual
126
+ return x @ self.final_proj
117
127
 
118
128
  # memory mlp with factorized weights
119
129
  # so can tradeoff capacity for smaller chunk sizes
@@ -133,8 +143,6 @@ class FactorizedMemoryMLP(Module):
133
143
  ]) for _ in range(depth)
134
144
  ])
135
145
 
136
- self.ln = LayerNorm(dim)
137
-
138
146
  for weight1, weight2 in self.weights:
139
147
  nn.init.xavier_uniform_(weight1)
140
148
  nn.init.xavier_uniform_(weight2)
@@ -143,7 +151,6 @@ class FactorizedMemoryMLP(Module):
143
151
  self,
144
152
  x
145
153
  ):
146
- residual = x
147
154
 
148
155
  for ind, (weight1, weight2) in enumerate(self.weights):
149
156
  is_first = ind == 0
@@ -153,7 +160,7 @@ class FactorizedMemoryMLP(Module):
153
160
 
154
161
  x = x @ weight1 @ weight2
155
162
 
156
- return self.ln(x) + residual
163
+ return x
157
164
 
158
165
  # improvised attention as memory module
159
166
 
@@ -176,13 +183,10 @@ class MemoryAttention(Module):
176
183
  nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
177
184
  ])
178
185
 
179
- self.ln = LayerNorm(dim)
180
-
181
186
  for weight in self.weights:
182
187
  nn.init.xavier_uniform_(weight)
183
188
 
184
189
  def forward(self, x):
185
- residual = x
186
190
 
187
191
  wq, wk, wv, ffw1, ffw2 = self.weights
188
192
 
@@ -202,4 +206,4 @@ class MemoryAttention(Module):
202
206
  h = F.gelu(x @ ffw1)
203
207
  ff_out = h @ ffw2
204
208
 
205
- return self.ln(attn_out + ff_out) + residual
209
+ return attn_out + ff_out
@@ -16,7 +16,8 @@ from tensordict import TensorDict
16
16
  from titans_pytorch.associative_scan import AssocScan
17
17
 
18
18
  from titans_pytorch.memory_models import(
19
- MemoryMLP
19
+ MemoryMLP,
20
+ ResidualNorm
20
21
  )
21
22
 
22
23
  import einx
@@ -234,6 +235,7 @@ class NeuralMemory(Module):
234
235
  init_decay_bias = None,
235
236
  accept_weight_residual = False,
236
237
  gated_transition = False,
238
+ mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
237
239
  default_model_kwargs: dict = dict(
238
240
  depth = 2,
239
241
  expansion_factor = 4.
@@ -304,6 +306,9 @@ class NeuralMemory(Module):
304
306
 
305
307
  # the memory is the weights of the model
306
308
 
309
+ if mem_model_norm_add_residual:
310
+ model = ResidualNorm(dim = dim_head, model = model)
311
+
307
312
  self.memory_model = model
308
313
 
309
314
  mem_model_params = dict(model.named_parameters())
File without changes
File without changes
File without changes
File without changes