titans-pytorch 0.0.61__py3-none-any.whl → 0.0.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -1
- titans_pytorch/titans.py +50 -6
- {titans_pytorch-0.0.61.dist-info → titans_pytorch-0.0.62.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.62.dist-info/RECORD +8 -0
- titans_pytorch-0.0.61.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.61.dist-info → titans_pytorch-0.0.62.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.61.dist-info → titans_pytorch-0.0.62.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
|
@@ -6,7 +6,7 @@ from functools import partial
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch import nn, Tensor
|
|
8
8
|
import torch.nn.functional as F
|
|
9
|
-
from torch.nn import Linear, Module
|
|
9
|
+
from torch.nn import Linear, Module, Parameter, ParameterList
|
|
10
10
|
from torch.func import functional_call, vmap, grad
|
|
11
11
|
|
|
12
12
|
from tensordict import TensorDict
|
|
@@ -88,7 +88,7 @@ class MultiheadRMSNorm(Module):
|
|
|
88
88
|
def __init__(self, dim, heads):
|
|
89
89
|
super().__init__()
|
|
90
90
|
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
|
|
91
|
-
self.gamma =
|
|
91
|
+
self.gamma = Parameter(torch.zeros(heads, 1, dim))
|
|
92
92
|
|
|
93
93
|
def forward(self, x):
|
|
94
94
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
@@ -102,7 +102,10 @@ class MemoryMLP(Module):
|
|
|
102
102
|
depth
|
|
103
103
|
):
|
|
104
104
|
super().__init__()
|
|
105
|
-
self.weights =
|
|
105
|
+
self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
|
106
|
+
|
|
107
|
+
for weight in self.weights:
|
|
108
|
+
nn.init.xavier_uniform_(weight)
|
|
106
109
|
|
|
107
110
|
def forward(
|
|
108
111
|
self,
|
|
@@ -118,13 +121,50 @@ class MemoryMLP(Module):
|
|
|
118
121
|
|
|
119
122
|
return x
|
|
120
123
|
|
|
124
|
+
# memory mlp with factorized weights
|
|
125
|
+
# so can tradeoff capacity for smaller chunk sizes
|
|
126
|
+
|
|
127
|
+
class FactorizedMemoryMLP(Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
dim,
|
|
131
|
+
depth,
|
|
132
|
+
k = 32
|
|
133
|
+
):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.weights = ParameterList([
|
|
136
|
+
ParameterList([
|
|
137
|
+
Parameter(torch.randn(dim, k)),
|
|
138
|
+
Parameter(torch.randn(k, dim)),
|
|
139
|
+
]) for _ in range(depth)
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
for weight1, weight2 in self.weights:
|
|
143
|
+
nn.init.xavier_uniform_(weight1)
|
|
144
|
+
nn.init.xavier_uniform_(weight2)
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
x
|
|
149
|
+
):
|
|
150
|
+
for ind, (weight1, weight2) in enumerate(self.weights):
|
|
151
|
+
is_first = ind == 0
|
|
152
|
+
|
|
153
|
+
if not is_first:
|
|
154
|
+
x = F.silu(x)
|
|
155
|
+
|
|
156
|
+
x = x @ weight1 @ weight2
|
|
157
|
+
|
|
158
|
+
return x
|
|
159
|
+
|
|
121
160
|
# improvised attention as memory module
|
|
122
161
|
|
|
123
162
|
class MemoryAttention(Module):
|
|
124
163
|
def __init__(
|
|
125
164
|
self,
|
|
126
165
|
dim,
|
|
127
|
-
scale = 8
|
|
166
|
+
scale = 8.,
|
|
167
|
+
expansion_factor = 2.
|
|
128
168
|
):
|
|
129
169
|
super().__init__()
|
|
130
170
|
self.scale = scale
|
|
@@ -133,10 +173,13 @@ class MemoryAttention(Module):
|
|
|
133
173
|
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
134
174
|
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
135
175
|
nn.Parameter(torch.randn(dim, dim)), # values
|
|
136
|
-
nn.Parameter(torch.randn(dim, dim *
|
|
137
|
-
nn.Parameter(torch.randn(dim *
|
|
176
|
+
nn.Parameter(torch.randn(dim, dim * expansion_factor)), # ff w1
|
|
177
|
+
nn.Parameter(torch.randn(dim * expansion_factor, dim)), # ff w2
|
|
138
178
|
])
|
|
139
179
|
|
|
180
|
+
for weight in self.weights:
|
|
181
|
+
nn.init.xavier_uniform_(weight)
|
|
182
|
+
|
|
140
183
|
def forward(self, x):
|
|
141
184
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
142
185
|
|
|
@@ -536,6 +579,7 @@ class NeuralMemory(Module):
|
|
|
536
579
|
|
|
537
580
|
past_weights, _ = past_state
|
|
538
581
|
|
|
582
|
+
|
|
539
583
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
540
584
|
|
|
541
585
|
if not return_aux_kv_loss:
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
|
|
4
|
+
titans_pytorch/titans.py,sha256=95J6UL44lOrdZSXdm7p36xC9tDeSmRBwdjig9T82PzI,17452
|
|
5
|
+
titans_pytorch-0.0.62.dist-info/METADATA,sha256=08Blaa9Ehyv09rSA9uWguxbhKpbrd7C53Ya13E1VbpU,4457
|
|
6
|
+
titans_pytorch-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.62.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.62.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
|
|
4
|
-
titans_pytorch/titans.py,sha256=5wuAoDULbgXTM8Nbq8bXrW3Fd2nsn22kpERRfJOwZiU,16367
|
|
5
|
-
titans_pytorch-0.0.61.dist-info/METADATA,sha256=Cfhqnse_9nnFNqVGo9p_kxO_LVawwv4uuZOx4anqhf0,4457
|
|
6
|
-
titans_pytorch-0.0.61.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.61.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.61.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|