titans-pytorch 0.2.7__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.2.7"
3
+ version = "0.2.8"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -3,11 +3,33 @@ from torch import nn, cat
3
3
  import torch.nn.functional as F
4
4
  from torch.nn import Module, ModuleList, Parameter, ParameterList
5
5
 
6
+ from einops import rearrange
7
+
6
8
  # functions
7
9
 
8
10
  def l2norm(t):
9
11
  return F.normalize(t, dim = -1)
10
12
 
13
+ # norms
14
+
15
+ class LayerNorm(Module):
16
+ def __init__(
17
+ self,
18
+ dim
19
+ ):
20
+ super().__init__()
21
+
22
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
23
+ self.gamma = Parameter(torch.zeros(dim))
24
+
25
+ def forward(self, x):
26
+ gamma = self.gamma
27
+
28
+ if gamma.ndim == 2:
29
+ gamma = rearrange(gamma, 'b d -> b 1 d')
30
+
31
+ return self.ln(x) * (gamma + 1.)
32
+
11
33
  # memory mlp proposed in TTT
12
34
 
13
35
  class MemoryMLP(Module):
@@ -19,6 +41,8 @@ class MemoryMLP(Module):
19
41
  super().__init__()
20
42
  self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
21
43
 
44
+ self.ln = LayerNorm(dim)
45
+
22
46
  for weight in self.weights:
23
47
  nn.init.xavier_uniform_(weight)
24
48
 
@@ -36,7 +60,7 @@ class MemoryMLP(Module):
36
60
 
37
61
  x = x @ weight
38
62
 
39
- return x + residual
63
+ return self.ln(x) + residual
40
64
 
41
65
  # memory mlp, but with gated residual + final projection
42
66
 
@@ -60,6 +84,8 @@ class GatedResidualMemoryMLP(Module):
60
84
 
61
85
  self.final_proj = Parameter(torch.randn(dim, dim))
62
86
 
87
+ self.ln = LayerNorm(dim)
88
+
63
89
  for param in self.parameters():
64
90
  nn.init.xavier_uniform_(param)
65
91
 
@@ -83,7 +109,7 @@ class GatedResidualMemoryMLP(Module):
83
109
 
84
110
  out = x @ self.final_proj
85
111
 
86
- return out + residual
112
+ return self.ln(out) + residual
87
113
 
88
114
  # memory mlp with factorized weights
89
115
  # so can tradeoff capacity for smaller chunk sizes
@@ -103,6 +129,8 @@ class FactorizedMemoryMLP(Module):
103
129
  ]) for _ in range(depth)
104
130
  ])
105
131
 
132
+ self.ln = LayerNorm(dim)
133
+
106
134
  for weight1, weight2 in self.weights:
107
135
  nn.init.xavier_uniform_(weight1)
108
136
  nn.init.xavier_uniform_(weight2)
@@ -121,7 +149,7 @@ class FactorizedMemoryMLP(Module):
121
149
 
122
150
  x = x @ weight1 @ weight2
123
151
 
124
- return x + residual
152
+ return self.ln(x) + residual
125
153
 
126
154
  # improvised attention as memory module
127
155
 
@@ -144,6 +172,8 @@ class MemoryAttention(Module):
144
172
  nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
145
173
  ])
146
174
 
175
+ self.ln = LayerNorm(dim)
176
+
147
177
  for weight in self.weights:
148
178
  nn.init.xavier_uniform_(weight)
149
179
 
@@ -168,4 +198,4 @@ class MemoryAttention(Module):
168
198
  h = F.gelu(x @ ffw1)
169
199
  ff_out = h @ ffw2
170
200
 
171
- return attn_out + ff_out + residual
201
+ return self.ln(attn_out + ff_out) + residual
File without changes
File without changes
File without changes
File without changes