titans-pytorch 0.1.6__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.6
3
+ Version: 0.1.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.6"
3
+ version = "0.1.7"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -161,14 +161,16 @@ class GatedResidualMemoryMLP(Module):
161
161
  def __init__(
162
162
  self,
163
163
  dim,
164
- depth
164
+ depth,
165
+ expansion_factor = 2.
165
166
  ):
166
167
  super().__init__()
167
- self.depth = depth
168
+ dim_hidden = int(dim * expansion_factor)
168
169
 
169
170
  self.weights = ParameterList([
170
171
  ParameterList([
171
- Parameter(torch.randn(dim, dim)),
172
+ Parameter(torch.randn(dim, dim_hidden)),
173
+ Parameter(torch.randn(dim_hidden, dim)),
172
174
  Parameter(torch.randn(dim * 2, dim)),
173
175
  ]) for _ in range(depth)
174
176
  ])
@@ -182,16 +184,17 @@ class GatedResidualMemoryMLP(Module):
182
184
  self,
183
185
  x
184
186
  ):
185
- for weight, to_gates in self.weights:
187
+ for weight1, weight2, to_gates in self.weights:
186
188
  res = x
187
189
 
188
- x = x @ weight
189
- x = F.silu(x)
190
+ hidden = x @ weight1
191
+ hidden = F.silu(hidden)
192
+ branch_out = hidden @ weight2
190
193
 
191
194
  # gated residual
192
195
 
193
- gates = cat((x, res), dim = -1) @ to_gates
194
- x = res.lerp(x, gates.sigmoid())
196
+ gates = cat((branch_out, res), dim = -1) @ to_gates
197
+ x = res.lerp(branch_out, gates.sigmoid())
195
198
 
196
199
  return x @ self.final_proj
197
200
 
File without changes
File without changes
File without changes
File without changes