titans-pytorch 0.0.53__py3-none-any.whl → 0.0.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -0
- titans_pytorch/titans.py +9 -1
- {titans_pytorch-0.0.53.dist-info → titans_pytorch-0.0.54.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.54.dist-info/RECORD +8 -0
- titans_pytorch-0.0.53.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.53.dist-info → titans_pytorch-0.0.54.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.53.dist-info → titans_pytorch-0.0.54.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
|
@@ -123,9 +123,12 @@ class MemoryMLP(Module):
|
|
|
123
123
|
class MemoryAttention(Module):
|
|
124
124
|
def __init__(
|
|
125
125
|
self,
|
|
126
|
-
dim
|
|
126
|
+
dim,
|
|
127
|
+
scale = 8.
|
|
127
128
|
):
|
|
128
129
|
super().__init__()
|
|
130
|
+
self.scale = scale
|
|
131
|
+
|
|
129
132
|
self.weights = nn.ParameterList([
|
|
130
133
|
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
131
134
|
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
@@ -143,6 +146,7 @@ class MemoryAttention(Module):
|
|
|
143
146
|
|
|
144
147
|
attn_out = F.scaled_dot_product_attention(
|
|
145
148
|
q, k, v,
|
|
149
|
+
scale = self.scale,
|
|
146
150
|
is_causal = True
|
|
147
151
|
)
|
|
148
152
|
|
|
@@ -174,6 +178,7 @@ class NeuralMemory(Module):
|
|
|
174
178
|
default_step_transform_max_lr = 1e-2,
|
|
175
179
|
pre_rmsnorm = True,
|
|
176
180
|
post_rmsnorm = True,
|
|
181
|
+
learned_mem_model_weights = True,
|
|
177
182
|
max_grad_norm: float | None = None,
|
|
178
183
|
use_accelerated_scan = False,
|
|
179
184
|
activation: Module | None = None,
|
|
@@ -212,6 +217,9 @@ class NeuralMemory(Module):
|
|
|
212
217
|
if not exists(model):
|
|
213
218
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
|
214
219
|
|
|
220
|
+
if not learned_mem_model_weights:
|
|
221
|
+
model.requires_grad_(False)
|
|
222
|
+
|
|
215
223
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
|
216
224
|
|
|
217
225
|
# the memory is the weights of the model
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=GdUAYq6MDRGY0l2ESBH_kM01AEzVztiKmWfblSKxBEM,13212
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.54.dist-info/METADATA,sha256=bxbC3NBO4Sjii7DpFPcmNsO9M1kX76vj947H2DeUceg,4457
|
|
6
|
+
titans_pytorch-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.54.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=GdUAYq6MDRGY0l2ESBH_kM01AEzVztiKmWfblSKxBEM,13212
|
|
4
|
-
titans_pytorch/titans.py,sha256=ALJHJOkxgMZr-zFOZQB8Y7KJr-6FcSw8Jefyp4ElXho,15917
|
|
5
|
-
titans_pytorch-0.0.53.dist-info/METADATA,sha256=o7Shd67cH6O8YWJQcTcTwuDPPHwiVRY2hNIPxu2DEHA,4457
|
|
6
|
-
titans_pytorch-0.0.53.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.53.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.53.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|