titans-pytorch 0.0.52__py3-none-any.whl → 0.0.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -0
- titans_pytorch/mac_transformer.py +2 -1
- titans_pytorch/titans.py +12 -4
- {titans_pytorch-0.0.52.dist-info → titans_pytorch-0.0.54.dist-info}/METADATA +1 -2
- titans_pytorch-0.0.54.dist-info/RECORD +8 -0
- titans_pytorch-0.0.52.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.52.dist-info → titans_pytorch-0.0.54.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.52.dist-info → titans_pytorch-0.0.54.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
from math import ceil
|
|
3
4
|
from functools import partial
|
|
4
5
|
|
|
@@ -32,7 +33,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
|
32
33
|
|
|
33
34
|
# einstein notation related
|
|
34
35
|
|
|
35
|
-
from einops import
|
|
36
|
+
from einops import repeat, rearrange, pack, unpack
|
|
36
37
|
from einops.layers.torch import Rearrange
|
|
37
38
|
|
|
38
39
|
# b - batch
|
titans_pytorch/titans.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
import math
|
|
3
4
|
from functools import partial
|
|
4
5
|
|
|
@@ -16,7 +17,6 @@ from titans_pytorch.associative_scan import (
|
|
|
16
17
|
pad_at_dim
|
|
17
18
|
)
|
|
18
19
|
|
|
19
|
-
import einx
|
|
20
20
|
from einops import rearrange, repeat, pack, unpack
|
|
21
21
|
from einops.layers.torch import Rearrange, Reduce
|
|
22
22
|
|
|
@@ -123,9 +123,12 @@ class MemoryMLP(Module):
|
|
|
123
123
|
class MemoryAttention(Module):
|
|
124
124
|
def __init__(
|
|
125
125
|
self,
|
|
126
|
-
dim
|
|
126
|
+
dim,
|
|
127
|
+
scale = 8.
|
|
127
128
|
):
|
|
128
129
|
super().__init__()
|
|
130
|
+
self.scale = scale
|
|
131
|
+
|
|
129
132
|
self.weights = nn.ParameterList([
|
|
130
133
|
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
131
134
|
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
@@ -143,6 +146,7 @@ class MemoryAttention(Module):
|
|
|
143
146
|
|
|
144
147
|
attn_out = F.scaled_dot_product_attention(
|
|
145
148
|
q, k, v,
|
|
149
|
+
scale = self.scale,
|
|
146
150
|
is_causal = True
|
|
147
151
|
)
|
|
148
152
|
|
|
@@ -174,6 +178,7 @@ class NeuralMemory(Module):
|
|
|
174
178
|
default_step_transform_max_lr = 1e-2,
|
|
175
179
|
pre_rmsnorm = True,
|
|
176
180
|
post_rmsnorm = True,
|
|
181
|
+
learned_mem_model_weights = True,
|
|
177
182
|
max_grad_norm: float | None = None,
|
|
178
183
|
use_accelerated_scan = False,
|
|
179
184
|
activation: Module | None = None,
|
|
@@ -212,6 +217,9 @@ class NeuralMemory(Module):
|
|
|
212
217
|
if not exists(model):
|
|
213
218
|
model = MemoryMLP(dim_head, **default_model_kwargs)
|
|
214
219
|
|
|
220
|
+
if not learned_mem_model_weights:
|
|
221
|
+
model.requires_grad_(False)
|
|
222
|
+
|
|
215
223
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
|
216
224
|
|
|
217
225
|
# the memory is the weights of the model
|
|
@@ -338,9 +346,9 @@ class NeuralMemory(Module):
|
|
|
338
346
|
|
|
339
347
|
# take care of chunking
|
|
340
348
|
|
|
341
|
-
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c =
|
|
349
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
|
|
342
350
|
|
|
343
|
-
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c =
|
|
351
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
|
344
352
|
|
|
345
353
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
346
354
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.54
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -37,7 +37,6 @@ Requires-Python: >=3.9
|
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.5
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
|
-
Requires-Dist: einx>=0.3.0
|
|
41
40
|
Requires-Dist: hyper-connections>=0.1.8
|
|
42
41
|
Requires-Dist: ninja
|
|
43
42
|
Requires-Dist: rotary-embedding-torch
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=GdUAYq6MDRGY0l2ESBH_kM01AEzVztiKmWfblSKxBEM,13212
|
|
4
|
+
titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
|
|
5
|
+
titans_pytorch-0.0.54.dist-info/METADATA,sha256=bxbC3NBO4Sjii7DpFPcmNsO9M1kX76vj947H2DeUceg,4457
|
|
6
|
+
titans_pytorch-0.0.54.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.54.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.54.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=khfjpbsy-uT9NIG3dZLsLOG_XSEi7EqcyfbPr7EQc2Q,13192
|
|
4
|
-
titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
|
|
5
|
-
titans_pytorch-0.0.52.dist-info/METADATA,sha256=coC9ExIuNvmab0BktSE1NwUgxRaBUV7h_cTHeoJkRJo,4484
|
|
6
|
-
titans_pytorch-0.0.52.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.52.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.52.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|