titans-pytorch 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/PKG-INFO +2 -1
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/pyproject.toml +2 -1
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/titans_pytorch/__init__.py +2 -1
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/titans_pytorch/titans.py +42 -1
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/train_mac.py +5 -3
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/.gitignore +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/LICENSE +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/README.md +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/data/README.md +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/fig1.png +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/fig2.png +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/tests/test_titans.py +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.5 → titans_pytorch-0.1.6}/titans_pytorch/mac_transformer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
|
|
|
46
46
|
Requires-Dist: tqdm
|
|
47
47
|
Requires-Dist: x-transformers
|
|
48
48
|
Provides-Extra: examples
|
|
49
|
+
Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
|
|
49
50
|
Requires-Dist: wandb; extra == 'examples'
|
|
50
51
|
Provides-Extra: test
|
|
51
52
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.6"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -45,6 +45,7 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
|
|
|
45
45
|
[project.optional-dependencies]
|
|
46
46
|
|
|
47
47
|
examples = [
|
|
48
|
+
"adam-atan2-pytorch>=0.1.18",
|
|
48
49
|
"wandb"
|
|
49
50
|
]
|
|
50
51
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
|
3
|
+
|
|
3
4
|
import math
|
|
4
5
|
from functools import partial
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
|
-
from torch import nn, Tensor
|
|
8
|
+
from torch import nn, cat, Tensor
|
|
8
9
|
import torch.nn.functional as F
|
|
9
10
|
from torch.nn import Linear, Module, Parameter, ParameterList
|
|
10
11
|
from torch.func import functional_call, vmap, grad
|
|
@@ -154,6 +155,46 @@ class MemoryMLP(Module):
|
|
|
154
155
|
|
|
155
156
|
return x
|
|
156
157
|
|
|
158
|
+
# memory mlp, but with gated residual + final projection
|
|
159
|
+
|
|
160
|
+
class GatedResidualMemoryMLP(Module):
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
dim,
|
|
164
|
+
depth
|
|
165
|
+
):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.depth = depth
|
|
168
|
+
|
|
169
|
+
self.weights = ParameterList([
|
|
170
|
+
ParameterList([
|
|
171
|
+
Parameter(torch.randn(dim, dim)),
|
|
172
|
+
Parameter(torch.randn(dim * 2, dim)),
|
|
173
|
+
]) for _ in range(depth)
|
|
174
|
+
])
|
|
175
|
+
|
|
176
|
+
self.final_proj = Parameter(torch.randn(dim, dim))
|
|
177
|
+
|
|
178
|
+
for param in self.parameters():
|
|
179
|
+
nn.init.xavier_uniform_(param)
|
|
180
|
+
|
|
181
|
+
def forward(
|
|
182
|
+
self,
|
|
183
|
+
x
|
|
184
|
+
):
|
|
185
|
+
for weight, to_gates in self.weights:
|
|
186
|
+
res = x
|
|
187
|
+
|
|
188
|
+
x = x @ weight
|
|
189
|
+
x = F.silu(x)
|
|
190
|
+
|
|
191
|
+
# gated residual
|
|
192
|
+
|
|
193
|
+
gates = cat((x, res), dim = -1) @ to_gates
|
|
194
|
+
x = res.lerp(x, gates.sigmoid())
|
|
195
|
+
|
|
196
|
+
return x @ self.final_proj
|
|
197
|
+
|
|
157
198
|
# memory mlp with factorized weights
|
|
158
199
|
# so can tradeoff capacity for smaller chunk sizes
|
|
159
200
|
|
|
@@ -5,10 +5,10 @@ import numpy as np
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch import nn, Tensor
|
|
8
|
-
from torch.optim import Adam
|
|
9
8
|
from torch.nn import functional as F
|
|
10
9
|
from torch.utils.data import DataLoader, Dataset
|
|
11
10
|
|
|
11
|
+
from adam_atan2_pytorch import AdoptAtan2
|
|
12
12
|
from titans_pytorch import MemoryAsContextTransformer
|
|
13
13
|
|
|
14
14
|
# constants
|
|
@@ -34,6 +34,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = True
|
|
|
34
34
|
WINDOW_SIZE = 32
|
|
35
35
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
|
36
36
|
SLIDING_WINDOWS = True
|
|
37
|
+
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
|
37
38
|
KV_RECON_LOSS_WEIGHT = 0.
|
|
38
39
|
LEARNED_MEM_MODEL_WEIGHTS = True
|
|
39
40
|
|
|
@@ -86,6 +87,7 @@ model = MemoryAsContextTransformer(
|
|
|
86
87
|
neural_memory_kwargs = dict(
|
|
87
88
|
dim_head = 64,
|
|
88
89
|
heads = 4,
|
|
90
|
+
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
|
89
91
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
|
90
92
|
learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
|
|
91
93
|
default_model_kwargs = dict(
|
|
@@ -122,11 +124,11 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
|
|
122
124
|
|
|
123
125
|
# optimizer
|
|
124
126
|
|
|
125
|
-
optim =
|
|
127
|
+
optim = AdoptAtan2(model.parameters(), lr = LEARNING_RATE)
|
|
126
128
|
|
|
127
129
|
# training
|
|
128
130
|
|
|
129
|
-
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
131
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
130
132
|
model.train()
|
|
131
133
|
|
|
132
134
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|