titans-pytorch 0.0.35__tar.gz → 0.0.36__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/PKG-INFO +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/pyproject.toml +1 -1
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/titans_pytorch/mac_transformer.py +10 -2
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/train_mac.py +11 -7
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/.gitignore +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/LICENSE +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/README.md +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/data/README.md +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/fig1.png +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/fig2.png +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/requirements.txt +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.35 → titans_pytorch-0.0.36}/train.py +0 -0
@@ -185,7 +185,9 @@ class MemoryAsContextTransformer(Module):
|
|
185
185
|
# long term mem tokens
|
186
186
|
|
187
187
|
self.segment_len = segment_len
|
188
|
+
|
188
189
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
190
|
+
has_longterm_mems = num_longterm_mem_tokens > 0
|
189
191
|
|
190
192
|
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
191
193
|
|
@@ -197,7 +199,11 @@ class MemoryAsContextTransformer(Module):
|
|
197
199
|
self.neural_mem_layers = ModuleList([])
|
198
200
|
|
199
201
|
layers = tuple(range(1, depth + 1))
|
200
|
-
|
202
|
+
|
203
|
+
if not exists(neural_memory_layers):
|
204
|
+
neural_memory_layers = layers if has_longterm_mems else ()
|
205
|
+
|
206
|
+
assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
|
201
207
|
|
202
208
|
for layer in layers:
|
203
209
|
|
@@ -205,7 +211,9 @@ class MemoryAsContextTransformer(Module):
|
|
205
211
|
|
206
212
|
mem = None
|
207
213
|
|
208
|
-
if
|
214
|
+
if layer in neural_memory_layers:
|
215
|
+
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
|
216
|
+
|
209
217
|
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
210
218
|
mem = init_hyper_conn(dim = dim, branch = mem)
|
211
219
|
|
@@ -25,10 +25,12 @@ SEQ_LEN = 512
|
|
25
25
|
|
26
26
|
PROJECT_NAME = 'titans-mac-transformer'
|
27
27
|
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
28
|
-
GLOBAL_LAYERS = (2, 4)
|
29
28
|
NEURAL_MEMORY_DEPTH = 2
|
30
|
-
|
31
|
-
|
29
|
+
NUM_PERSIST_MEM = 4
|
30
|
+
NUM_LONGTERM_MEM = 4
|
31
|
+
NEURAL_MEM_LAYERS = (2, 4, 6)
|
32
|
+
WINDOW_SIZE = 32
|
33
|
+
RUN_NAME = 'mac - 4 longterm mems, layers (2, 4, 6)'
|
32
34
|
|
33
35
|
# wandb experiment tracker
|
34
36
|
|
@@ -57,12 +59,14 @@ model = MemoryAsContextTransformer(
|
|
57
59
|
dim = 384,
|
58
60
|
depth = 8,
|
59
61
|
segment_len = WINDOW_SIZE,
|
60
|
-
num_persist_mem_tokens =
|
61
|
-
num_longterm_mem_tokens =
|
62
|
-
neural_memory_layers =
|
62
|
+
num_persist_mem_tokens = NUM_PERSIST_MEM,
|
63
|
+
num_longterm_mem_tokens = NUM_LONGTERM_MEM,
|
64
|
+
neural_memory_layers = NEURAL_MEM_LAYERS,
|
63
65
|
neural_memory_kwargs = dict(
|
64
66
|
default_mlp_kwargs = dict(
|
65
|
-
depth = NEURAL_MEMORY_DEPTH
|
67
|
+
depth = NEURAL_MEMORY_DEPTH,
|
68
|
+
dim_head = 64,
|
69
|
+
heads = 4
|
66
70
|
)
|
67
71
|
)
|
68
72
|
).cuda()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|