titans-memory 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_memory-0.1.0/PKG-INFO +264 -0
- titans_memory-0.1.0/README.md +226 -0
- titans_memory-0.1.0/pyproject.toml +66 -0
- titans_memory-0.1.0/setup.cfg +4 -0
- titans_memory-0.1.0/setup.py +8 -0
- titans_memory-0.1.0/tests/test_memory.py +85 -0
- titans_memory-0.1.0/tests/test_models.py +151 -0
- titans_memory-0.1.0/tests/test_scan.py +65 -0
- titans_memory-0.1.0/titans/__init__.py +38 -0
- titans_memory-0.1.0/titans/memory/__init__.py +4 -0
- titans_memory-0.1.0/titans/memory/neural_memory.py +351 -0
- titans_memory-0.1.0/titans/memory/persistent_memory.py +92 -0
- titans_memory-0.1.0/titans/models/__init__.py +6 -0
- titans_memory-0.1.0/titans/models/lmm.py +210 -0
- titans_memory-0.1.0/titans/models/mac.py +260 -0
- titans_memory-0.1.0/titans/models/mag.py +246 -0
- titans_memory-0.1.0/titans/models/mal.py +227 -0
- titans_memory-0.1.0/titans/ops/__init__.py +4 -0
- titans_memory-0.1.0/titans/ops/attention.py +107 -0
- titans_memory-0.1.0/titans/ops/scan.py +47 -0
- titans_memory-0.1.0/titans/utils/__init__.py +12 -0
- titans_memory-0.1.0/titans/utils/config.py +135 -0
- titans_memory-0.1.0/titans/utils/factory.py +57 -0
- titans_memory-0.1.0/titans/utils/training.py +77 -0
- titans_memory-0.1.0/titans_memory.egg-info/PKG-INFO +264 -0
- titans_memory-0.1.0/titans_memory.egg-info/SOURCES.txt +27 -0
- titans_memory-0.1.0/titans_memory.egg-info/dependency_links.txt +1 -0
- titans_memory-0.1.0/titans_memory.egg-info/requires.txt +19 -0
- titans_memory-0.1.0/titans_memory.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: titans-memory
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)
|
|
5
|
+
Author: Neuranox, Implementation of arXiv:2501.00663
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/Neuranox/titans-memory
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/Neuranox/titans-memory/issues
|
|
9
|
+
Project-URL: Paper, https://arxiv.org/abs/2501.00663
|
|
10
|
+
Keywords: deep-learning,transformers,long-context,memory,titans,neural-memory
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
Requires-Dist: torch>=2.1.0
|
|
23
|
+
Provides-Extra: dev
|
|
24
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
25
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
26
|
+
Provides-Extra: train
|
|
27
|
+
Requires-Dist: datasets>=2.16; extra == "train"
|
|
28
|
+
Requires-Dist: transformers>=4.38; extra == "train"
|
|
29
|
+
Requires-Dist: tqdm; extra == "train"
|
|
30
|
+
Requires-Dist: tensorboard; extra == "train"
|
|
31
|
+
Provides-Extra: all
|
|
32
|
+
Requires-Dist: datasets>=2.16; extra == "all"
|
|
33
|
+
Requires-Dist: transformers>=4.38; extra == "all"
|
|
34
|
+
Requires-Dist: tqdm; extra == "all"
|
|
35
|
+
Requires-Dist: tensorboard; extra == "all"
|
|
36
|
+
Requires-Dist: pytest>=7.0; extra == "all"
|
|
37
|
+
Requires-Dist: pytest-cov; extra == "all"
|
|
38
|
+
|
|
39
|
+
# Titans: Learning to Memorize at Test Time
|
|
40
|
+
|
|
41
|
+
[](https://python.org)
|
|
42
|
+
[](https://pytorch.org)
|
|
43
|
+
[](https://github.com/Neuranox/titans-memory)
|
|
44
|
+
[](LICENSE)
|
|
45
|
+
[](https://arxiv.org/abs/2501.00663)
|
|
46
|
+
|
|
47
|
+
A clean, highly-optimized PyTorch implementation of the **Titans** architecture from:
|
|
48
|
+
|
|
49
|
+
> **Titans: Learning to Memorize at Test Time**
|
|
50
|
+
> Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
|
|
51
|
+
> [arXiv:2501.00663](https://arxiv.org/abs/2501.00663)
|
|
52
|
+
|
|
53
|
+
<p align="center">
|
|
54
|
+
<img src="assets/image.png" alt="Titans Architecture Overview" width="80%">
|
|
55
|
+
</p>
|
|
56
|
+
|
|
57
|
+
---
|
|
58
|
+
|
|
59
|
+
## What's Inside
|
|
60
|
+
|
|
61
|
+
| Module | Description |
|
|
62
|
+
|---|---|
|
|
63
|
+
| `NeuralMemory` | Deep MLP that learns to memorize via gradient descent with **momentum** + **weight-decay forgetting** (§3) |
|
|
64
|
+
| `PersistentMemory` | Learnable task-knowledge tokens prepended to every sequence (§3.3) |
|
|
65
|
+
| `TitansMAC` | **Memory as a Context** — retrieves long-term memory as prefix to attention window (§4.1) |
|
|
66
|
+
| `TitansMAG` | **Memory as a Gate** — SWA ⊗ NeuralMemory gated branch (§4.2) |
|
|
67
|
+
| `TitansMAL` | **Memory as a Layer** — sequential LMM → SWA stack (§4.3) |
|
|
68
|
+
| `TitansLMM` | **Standalone LMM** — neural memory without attention (§4.3) |
|
|
69
|
+
|
|
70
|
+
---
|
|
71
|
+
|
|
72
|
+
## Installation
|
|
73
|
+
|
|
74
|
+
```bash
|
|
75
|
+
# Install directly from GitHub
|
|
76
|
+
pip install git+https://github.com/Neuranox/titans-memory.git
|
|
77
|
+
|
|
78
|
+
# Or clone and install locally (editable — recommended for development)
|
|
79
|
+
git clone https://github.com/Neuranox/titans-memory.git
|
|
80
|
+
cd titans-memory
|
|
81
|
+
pip install -e .
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
## Quick Start
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
import torch
|
|
90
|
+
from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
|
|
91
|
+
from titans.utils import TitansConfig, build_model, count_parameters
|
|
92
|
+
|
|
93
|
+
# ── Build from config ──────────────────────────────────────────────────
|
|
94
|
+
cfg = TitansConfig.small(variant="MAC") # ~170 M params
|
|
95
|
+
cfg.vocab_size = 32_000
|
|
96
|
+
model = build_model(cfg)
|
|
97
|
+
print(f"Parameters: {count_parameters(model):,}")
|
|
98
|
+
|
|
99
|
+
# ── Forward pass ───────────────────────────────────────────────────────
|
|
100
|
+
input_ids = torch.randint(0, 32_000, (2, 512))
|
|
101
|
+
labels = input_ids.clone()
|
|
102
|
+
|
|
103
|
+
out = model(input_ids, labels=labels)
|
|
104
|
+
print(out["logits"].shape) # (2, 512, 32000)
|
|
105
|
+
print(out["loss"].item())
|
|
106
|
+
|
|
107
|
+
# ── Generation ─────────────────────────────────────────────────────────
|
|
108
|
+
prompt = torch.randint(0, 32_000, (1, 8))
|
|
109
|
+
generated = model.generate(prompt, max_new_tokens=50, top_k=50)
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
---
|
|
113
|
+
|
|
114
|
+
## All Four Variants
|
|
115
|
+
|
|
116
|
+
```python
|
|
117
|
+
VOCAB = 32_000
|
|
118
|
+
D = 512
|
|
119
|
+
|
|
120
|
+
models = {
|
|
121
|
+
"LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
|
|
122
|
+
"MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
|
|
123
|
+
"MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
|
|
124
|
+
"MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
|
|
125
|
+
}
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
---
|
|
129
|
+
|
|
130
|
+
## TitansConfig — Paper-Scale Presets
|
|
131
|
+
|
|
132
|
+
```python
|
|
133
|
+
from titans.utils import TitansConfig, build_model
|
|
134
|
+
|
|
135
|
+
cfg = TitansConfig.tiny(variant="MAC") # ~30 M — quick experiments
|
|
136
|
+
cfg = TitansConfig.small(variant="MAC") # ~170 M — paper Table 1
|
|
137
|
+
cfg = TitansConfig.medium(variant="MAC") # ~340 M — paper Table 1
|
|
138
|
+
cfg = TitansConfig.large(variant="MAC") # ~760 M — paper Table 1
|
|
139
|
+
|
|
140
|
+
# JSON save / load
|
|
141
|
+
cfg.to_json("config.json")
|
|
142
|
+
cfg = TitansConfig.from_json("config.json")
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
---
|
|
146
|
+
|
|
147
|
+
## Training
|
|
148
|
+
|
|
149
|
+
```python
|
|
150
|
+
from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup
|
|
151
|
+
|
|
152
|
+
optim = build_optimizer(model, lr=4e-4, weight_decay=0.1) # AdamW, no wd on bias/norm
|
|
153
|
+
sched = get_cosine_schedule_with_warmup(optim,
|
|
154
|
+
warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)
|
|
155
|
+
|
|
156
|
+
for batch in dataloader:
|
|
157
|
+
out = model(batch["input_ids"], labels=batch["labels"])
|
|
158
|
+
out["loss"].backward()
|
|
159
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
160
|
+
optim.step(); sched.step(); optim.zero_grad()
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
See `examples/02_training_loop.py` for a complete runnable example.
|
|
164
|
+
|
|
165
|
+
---
|
|
166
|
+
|
|
167
|
+
## Architecture Overview
|
|
168
|
+
|
|
169
|
+
<p align="center">
|
|
170
|
+
<img src="assets/image1.png" alt="Titans Detail" width="80%">
|
|
171
|
+
</p>
|
|
172
|
+
|
|
173
|
+
```
|
|
174
|
+
Titans (MAC) — Memory as a Context
|
|
175
|
+
───────────────────────────────────────────────────────────
|
|
176
|
+
For each segment S^(t):
|
|
177
|
+
h_t = M*_{t-1}(q_t) ← retrieve long-term memory
|
|
178
|
+
S̃^(t) = [P || h_t || S^(t)] ← augment with persistent + history
|
|
179
|
+
y_t = Attention(S̃^(t)) ← full causal attention over window
|
|
180
|
+
M_t = M_{t-1}.update(y_t) ← write: gradient descent w/ momentum
|
|
181
|
+
o_t = y_t ⊗ M*_t(y_t) ← gated output
|
|
182
|
+
|
|
183
|
+
Titans (MAG) — Memory as a Gate
|
|
184
|
+
───────────────────────────────────────────────────────────
|
|
185
|
+
x̃ = [P || x]
|
|
186
|
+
y = SW-Attn(x̃) ← precise short-term memory (sliding window)
|
|
187
|
+
o = y ⊗ M(x̃) ← gated with neural long-term memory
|
|
188
|
+
|
|
189
|
+
Titans (MAL) — Memory as a Layer
|
|
190
|
+
───────────────────────────────────────────────────────────
|
|
191
|
+
x̃ = [P || x]
|
|
192
|
+
y = M(x̃) ← memory compresses context
|
|
193
|
+
o = SW-Attn(y) ← attention refines compressed representation
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
---
|
|
197
|
+
|
|
198
|
+
## Neural Memory — Key Equations
|
|
199
|
+
|
|
200
|
+
| Component | Equation | Description |
|
|
201
|
+
|---|---|---|
|
|
202
|
+
| Momentary surprise | `∇ℓ(M_{t-1}; x_t)` | How unexpected is `x_t`? |
|
|
203
|
+
| Surprise with momentum | `S_t = η_t S_{t-1} − θ_t ∇ℓ` | Eq. 10 — carries information flow |
|
|
204
|
+
| Forgetting gate | `M_t = (1−α_t) M_{t-1} + S_t` | Eq. 13 — weight-decay style |
|
|
205
|
+
| Retrieval | `y_t = M*(q_t)` | Eq. 15 — inference, no update |
|
|
206
|
+
|
|
207
|
+
---
|
|
208
|
+
|
|
209
|
+
## Running Tests
|
|
210
|
+
|
|
211
|
+
```bash
|
|
212
|
+
cd "F:\Titan Model"
|
|
213
|
+
pip install -e .[dev]
|
|
214
|
+
pytest
|
|
215
|
+
```
|
|
216
|
+
|
|
217
|
+
---
|
|
218
|
+
|
|
219
|
+
## Project Structure
|
|
220
|
+
|
|
221
|
+
```
|
|
222
|
+
Titan Model/
|
|
223
|
+
├── titans/
|
|
224
|
+
│ ├── __init__.py ← public API
|
|
225
|
+
│ ├── memory/
|
|
226
|
+
│ │ ├── neural_memory.py ← NeuralMemory (LMM core)
|
|
227
|
+
│ │ └── persistent_memory.py
|
|
228
|
+
│ ├── models/
|
|
229
|
+
│ │ ├── lmm.py ← TitansLMM
|
|
230
|
+
│ │ ├── mac.py ← TitansMAC
|
|
231
|
+
│ │ ├── mag.py ← TitansMAG
|
|
232
|
+
│ │ └── mal.py ← TitansMAL
|
|
233
|
+
│ ├── ops/
|
|
234
|
+
│ │ ├── scan.py ← parallel associative scan
|
|
235
|
+
│ │ └── attention.py ← causal + sliding-window attention
|
|
236
|
+
│ └── utils/
|
|
237
|
+
│ ├── config.py ← TitansConfig dataclass
|
|
238
|
+
│ ├── factory.py ← build_model()
|
|
239
|
+
│ └── training.py ← optimizer + LR schedule helpers
|
|
240
|
+
├── tests/
|
|
241
|
+
│ ├── test_scan.py
|
|
242
|
+
│ ├── test_memory.py
|
|
243
|
+
│ └── test_models.py
|
|
244
|
+
├── examples/
|
|
245
|
+
│ ├── 01_quickstart.py
|
|
246
|
+
│ ├── 02_training_loop.py
|
|
247
|
+
│ └── 03_memory_standalone.py
|
|
248
|
+
├── pyproject.toml
|
|
249
|
+
├── setup.py
|
|
250
|
+
└── README.md
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
---
|
|
254
|
+
|
|
255
|
+
## Citation
|
|
256
|
+
|
|
257
|
+
```bibtex
|
|
258
|
+
@article{behrouz2024titans,
|
|
259
|
+
title = {Titans: Learning to Memorize at Test Time},
|
|
260
|
+
author = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
|
|
261
|
+
journal = {arXiv preprint arXiv:2501.00663},
|
|
262
|
+
year = {2024}
|
|
263
|
+
}
|
|
264
|
+
```
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Titans: Learning to Memorize at Test Time
|
|
2
|
+
|
|
3
|
+
[](https://python.org)
|
|
4
|
+
[](https://pytorch.org)
|
|
5
|
+
[](https://github.com/Neuranox/titans-memory)
|
|
6
|
+
[](LICENSE)
|
|
7
|
+
[](https://arxiv.org/abs/2501.00663)
|
|
8
|
+
|
|
9
|
+
A clean, highly-optimized PyTorch implementation of the **Titans** architecture from:
|
|
10
|
+
|
|
11
|
+
> **Titans: Learning to Memorize at Test Time**
|
|
12
|
+
> Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
|
|
13
|
+
> [arXiv:2501.00663](https://arxiv.org/abs/2501.00663)
|
|
14
|
+
|
|
15
|
+
<p align="center">
|
|
16
|
+
<img src="assets/image.png" alt="Titans Architecture Overview" width="80%">
|
|
17
|
+
</p>
|
|
18
|
+
|
|
19
|
+
---
|
|
20
|
+
|
|
21
|
+
## What's Inside
|
|
22
|
+
|
|
23
|
+
| Module | Description |
|
|
24
|
+
|---|---|
|
|
25
|
+
| `NeuralMemory` | Deep MLP that learns to memorize via gradient descent with **momentum** + **weight-decay forgetting** (§3) |
|
|
26
|
+
| `PersistentMemory` | Learnable task-knowledge tokens prepended to every sequence (§3.3) |
|
|
27
|
+
| `TitansMAC` | **Memory as a Context** — retrieves long-term memory as prefix to attention window (§4.1) |
|
|
28
|
+
| `TitansMAG` | **Memory as a Gate** — SWA ⊗ NeuralMemory gated branch (§4.2) |
|
|
29
|
+
| `TitansMAL` | **Memory as a Layer** — sequential LMM → SWA stack (§4.3) |
|
|
30
|
+
| `TitansLMM` | **Standalone LMM** — neural memory without attention (§4.3) |
|
|
31
|
+
|
|
32
|
+
---
|
|
33
|
+
|
|
34
|
+
## Installation
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
# Install directly from GitHub
|
|
38
|
+
pip install git+https://github.com/Neuranox/titans-memory.git
|
|
39
|
+
|
|
40
|
+
# Or clone and install locally (editable — recommended for development)
|
|
41
|
+
git clone https://github.com/Neuranox/titans-memory.git
|
|
42
|
+
cd titans-memory
|
|
43
|
+
pip install -e .
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
## Quick Start
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
import torch
|
|
52
|
+
from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
|
|
53
|
+
from titans.utils import TitansConfig, build_model, count_parameters
|
|
54
|
+
|
|
55
|
+
# ── Build from config ──────────────────────────────────────────────────
|
|
56
|
+
cfg = TitansConfig.small(variant="MAC") # ~170 M params
|
|
57
|
+
cfg.vocab_size = 32_000
|
|
58
|
+
model = build_model(cfg)
|
|
59
|
+
print(f"Parameters: {count_parameters(model):,}")
|
|
60
|
+
|
|
61
|
+
# ── Forward pass ───────────────────────────────────────────────────────
|
|
62
|
+
input_ids = torch.randint(0, 32_000, (2, 512))
|
|
63
|
+
labels = input_ids.clone()
|
|
64
|
+
|
|
65
|
+
out = model(input_ids, labels=labels)
|
|
66
|
+
print(out["logits"].shape) # (2, 512, 32000)
|
|
67
|
+
print(out["loss"].item())
|
|
68
|
+
|
|
69
|
+
# ── Generation ─────────────────────────────────────────────────────────
|
|
70
|
+
prompt = torch.randint(0, 32_000, (1, 8))
|
|
71
|
+
generated = model.generate(prompt, max_new_tokens=50, top_k=50)
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
---
|
|
75
|
+
|
|
76
|
+
## All Four Variants
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
VOCAB = 32_000
|
|
80
|
+
D = 512
|
|
81
|
+
|
|
82
|
+
models = {
|
|
83
|
+
"LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
|
|
84
|
+
"MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
|
|
85
|
+
"MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
|
|
86
|
+
"MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
|
|
87
|
+
}
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
---
|
|
91
|
+
|
|
92
|
+
## TitansConfig — Paper-Scale Presets
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
from titans.utils import TitansConfig, build_model
|
|
96
|
+
|
|
97
|
+
cfg = TitansConfig.tiny(variant="MAC") # ~30 M — quick experiments
|
|
98
|
+
cfg = TitansConfig.small(variant="MAC") # ~170 M — paper Table 1
|
|
99
|
+
cfg = TitansConfig.medium(variant="MAC") # ~340 M — paper Table 1
|
|
100
|
+
cfg = TitansConfig.large(variant="MAC") # ~760 M — paper Table 1
|
|
101
|
+
|
|
102
|
+
# JSON save / load
|
|
103
|
+
cfg.to_json("config.json")
|
|
104
|
+
cfg = TitansConfig.from_json("config.json")
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
---
|
|
108
|
+
|
|
109
|
+
## Training
|
|
110
|
+
|
|
111
|
+
```python
|
|
112
|
+
from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup
|
|
113
|
+
|
|
114
|
+
optim = build_optimizer(model, lr=4e-4, weight_decay=0.1) # AdamW, no wd on bias/norm
|
|
115
|
+
sched = get_cosine_schedule_with_warmup(optim,
|
|
116
|
+
warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)
|
|
117
|
+
|
|
118
|
+
for batch in dataloader:
|
|
119
|
+
out = model(batch["input_ids"], labels=batch["labels"])
|
|
120
|
+
out["loss"].backward()
|
|
121
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
122
|
+
optim.step(); sched.step(); optim.zero_grad()
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
See `examples/02_training_loop.py` for a complete runnable example.
|
|
126
|
+
|
|
127
|
+
---
|
|
128
|
+
|
|
129
|
+
## Architecture Overview
|
|
130
|
+
|
|
131
|
+
<p align="center">
|
|
132
|
+
<img src="assets/image1.png" alt="Titans Detail" width="80%">
|
|
133
|
+
</p>
|
|
134
|
+
|
|
135
|
+
```
|
|
136
|
+
Titans (MAC) — Memory as a Context
|
|
137
|
+
───────────────────────────────────────────────────────────
|
|
138
|
+
For each segment S^(t):
|
|
139
|
+
h_t = M*_{t-1}(q_t) ← retrieve long-term memory
|
|
140
|
+
S̃^(t) = [P || h_t || S^(t)] ← augment with persistent + history
|
|
141
|
+
y_t = Attention(S̃^(t)) ← full causal attention over window
|
|
142
|
+
M_t = M_{t-1}.update(y_t) ← write: gradient descent w/ momentum
|
|
143
|
+
o_t = y_t ⊗ M*_t(y_t) ← gated output
|
|
144
|
+
|
|
145
|
+
Titans (MAG) — Memory as a Gate
|
|
146
|
+
───────────────────────────────────────────────────────────
|
|
147
|
+
x̃ = [P || x]
|
|
148
|
+
y = SW-Attn(x̃) ← precise short-term memory (sliding window)
|
|
149
|
+
o = y ⊗ M(x̃) ← gated with neural long-term memory
|
|
150
|
+
|
|
151
|
+
Titans (MAL) — Memory as a Layer
|
|
152
|
+
───────────────────────────────────────────────────────────
|
|
153
|
+
x̃ = [P || x]
|
|
154
|
+
y = M(x̃) ← memory compresses context
|
|
155
|
+
o = SW-Attn(y) ← attention refines compressed representation
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
---
|
|
159
|
+
|
|
160
|
+
## Neural Memory — Key Equations
|
|
161
|
+
|
|
162
|
+
| Component | Equation | Description |
|
|
163
|
+
|---|---|---|
|
|
164
|
+
| Momentary surprise | `∇ℓ(M_{t-1}; x_t)` | How unexpected is `x_t`? |
|
|
165
|
+
| Surprise with momentum | `S_t = η_t S_{t-1} − θ_t ∇ℓ` | Eq. 10 — carries information flow |
|
|
166
|
+
| Forgetting gate | `M_t = (1−α_t) M_{t-1} + S_t` | Eq. 13 — weight-decay style |
|
|
167
|
+
| Retrieval | `y_t = M*(q_t)` | Eq. 15 — inference, no update |
|
|
168
|
+
|
|
169
|
+
---
|
|
170
|
+
|
|
171
|
+
## Running Tests
|
|
172
|
+
|
|
173
|
+
```bash
|
|
174
|
+
cd "F:\Titan Model"
|
|
175
|
+
pip install -e .[dev]
|
|
176
|
+
pytest
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
---
|
|
180
|
+
|
|
181
|
+
## Project Structure
|
|
182
|
+
|
|
183
|
+
```
|
|
184
|
+
Titan Model/
|
|
185
|
+
├── titans/
|
|
186
|
+
│ ├── __init__.py ← public API
|
|
187
|
+
│ ├── memory/
|
|
188
|
+
│ │ ├── neural_memory.py ← NeuralMemory (LMM core)
|
|
189
|
+
│ │ └── persistent_memory.py
|
|
190
|
+
│ ├── models/
|
|
191
|
+
│ │ ├── lmm.py ← TitansLMM
|
|
192
|
+
│ │ ├── mac.py ← TitansMAC
|
|
193
|
+
│ │ ├── mag.py ← TitansMAG
|
|
194
|
+
│ │ └── mal.py ← TitansMAL
|
|
195
|
+
│ ├── ops/
|
|
196
|
+
│ │ ├── scan.py ← parallel associative scan
|
|
197
|
+
│ │ └── attention.py ← causal + sliding-window attention
|
|
198
|
+
│ └── utils/
|
|
199
|
+
│ ├── config.py ← TitansConfig dataclass
|
|
200
|
+
│ ├── factory.py ← build_model()
|
|
201
|
+
│ └── training.py ← optimizer + LR schedule helpers
|
|
202
|
+
├── tests/
|
|
203
|
+
│ ├── test_scan.py
|
|
204
|
+
│ ├── test_memory.py
|
|
205
|
+
│ └── test_models.py
|
|
206
|
+
├── examples/
|
|
207
|
+
│ ├── 01_quickstart.py
|
|
208
|
+
│ ├── 02_training_loop.py
|
|
209
|
+
│ └── 03_memory_standalone.py
|
|
210
|
+
├── pyproject.toml
|
|
211
|
+
├── setup.py
|
|
212
|
+
└── README.md
|
|
213
|
+
```
|
|
214
|
+
|
|
215
|
+
---
|
|
216
|
+
|
|
217
|
+
## Citation
|
|
218
|
+
|
|
219
|
+
```bibtex
|
|
220
|
+
@article{behrouz2024titans,
|
|
221
|
+
title = {Titans: Learning to Memorize at Test Time},
|
|
222
|
+
author = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
|
|
223
|
+
journal = {arXiv preprint arXiv:2501.00663},
|
|
224
|
+
year = {2024}
|
|
225
|
+
}
|
|
226
|
+
```
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "titans-memory"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = { text = "MIT" }
|
|
11
|
+
requires-python = ">=3.9"
|
|
12
|
+
keywords = ["deep-learning", "transformers", "long-context", "memory", "titans", "neural-memory"]
|
|
13
|
+
|
|
14
|
+
authors = [
|
|
15
|
+
{ name = "Neuranox" },
|
|
16
|
+
{ name = "Implementation of arXiv:2501.00663" },
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
classifiers = [
|
|
20
|
+
"Development Status :: 3 - Alpha",
|
|
21
|
+
"Intended Audience :: Science/Research",
|
|
22
|
+
"License :: OSI Approved :: MIT License",
|
|
23
|
+
"Programming Language :: Python :: 3",
|
|
24
|
+
"Programming Language :: Python :: 3.9",
|
|
25
|
+
"Programming Language :: Python :: 3.10",
|
|
26
|
+
"Programming Language :: Python :: 3.11",
|
|
27
|
+
"Programming Language :: Python :: 3.12",
|
|
28
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
dependencies = [
|
|
32
|
+
"torch>=2.1.0",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[project.optional-dependencies]
|
|
36
|
+
dev = [
|
|
37
|
+
"pytest>=7.0",
|
|
38
|
+
"pytest-cov",
|
|
39
|
+
]
|
|
40
|
+
train = [
|
|
41
|
+
"datasets>=2.16",
|
|
42
|
+
"transformers>=4.38",
|
|
43
|
+
"tqdm",
|
|
44
|
+
"tensorboard",
|
|
45
|
+
]
|
|
46
|
+
all = [
|
|
47
|
+
"datasets>=2.16",
|
|
48
|
+
"transformers>=4.38",
|
|
49
|
+
"tqdm",
|
|
50
|
+
"tensorboard",
|
|
51
|
+
"pytest>=7.0",
|
|
52
|
+
"pytest-cov",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
[project.urls]
|
|
56
|
+
"Homepage" = "https://github.com/Neuranox/titans-memory"
|
|
57
|
+
"Bug Tracker" = "https://github.com/Neuranox/titans-memory/issues"
|
|
58
|
+
"Paper" = "https://arxiv.org/abs/2501.00663"
|
|
59
|
+
|
|
60
|
+
[tool.setuptools.packages.find]
|
|
61
|
+
where = ["."]
|
|
62
|
+
include = ["titans*"]
|
|
63
|
+
|
|
64
|
+
[tool.pytest.ini_options]
|
|
65
|
+
testpaths = ["tests"]
|
|
66
|
+
addopts = "-v --tb=short"
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for NeuralMemory and PersistentMemory modules.
|
|
3
|
+
"""
|
|
4
|
+
import pytest
|
|
5
|
+
import torch
|
|
6
|
+
from titans.memory import NeuralMemory, PersistentMemory
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestNeuralMemory:
|
|
10
|
+
|
|
11
|
+
@pytest.fixture
|
|
12
|
+
def mem(self):
|
|
13
|
+
return NeuralMemory(d_model=32, n_layers=2, chunk_size=8)
|
|
14
|
+
|
|
15
|
+
def test_forward_shape(self, mem):
|
|
16
|
+
x = torch.randn(2, 16, 32)
|
|
17
|
+
out = mem(x)
|
|
18
|
+
assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
|
|
19
|
+
|
|
20
|
+
def test_retrieve_shape(self, mem):
|
|
21
|
+
q = torch.randn(2, 8, 32)
|
|
22
|
+
out = mem.retrieve(q)
|
|
23
|
+
assert out.shape == q.shape
|
|
24
|
+
|
|
25
|
+
@pytest.mark.parametrize("lm", [1, 2, 3])
|
|
26
|
+
def test_memory_depth(self, lm):
|
|
27
|
+
m = NeuralMemory(d_model=16, n_layers=lm, chunk_size=4)
|
|
28
|
+
x = torch.randn(1, 8, 16)
|
|
29
|
+
out = m(x)
|
|
30
|
+
assert out.shape == x.shape
|
|
31
|
+
|
|
32
|
+
def test_no_momentum(self):
|
|
33
|
+
m = NeuralMemory(d_model=16, n_layers=1, chunk_size=4, use_momentum=False)
|
|
34
|
+
x = torch.randn(1, 8, 16)
|
|
35
|
+
out = m(x)
|
|
36
|
+
assert out.shape == x.shape
|
|
37
|
+
|
|
38
|
+
def test_no_decay(self):
|
|
39
|
+
m = NeuralMemory(d_model=16, n_layers=1, chunk_size=4, use_decay=False)
|
|
40
|
+
x = torch.randn(1, 8, 16)
|
|
41
|
+
out = m(x)
|
|
42
|
+
assert out.shape == x.shape
|
|
43
|
+
|
|
44
|
+
def test_gradients_flow(self, mem):
|
|
45
|
+
x = torch.randn(1, 8, 32, requires_grad=True)
|
|
46
|
+
out = mem(x)
|
|
47
|
+
out.sum().backward()
|
|
48
|
+
assert x.grad is not None
|
|
49
|
+
assert not torch.isnan(x.grad).any()
|
|
50
|
+
|
|
51
|
+
def test_chunk_size_larger_than_seq(self):
|
|
52
|
+
m = NeuralMemory(d_model=16, n_layers=1, chunk_size=128)
|
|
53
|
+
x = torch.randn(1, 10, 16) # T < chunk_size
|
|
54
|
+
out = m(x)
|
|
55
|
+
assert out.shape == x.shape
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TestPersistentMemory:
|
|
59
|
+
|
|
60
|
+
def test_prepend_shape(self):
|
|
61
|
+
pm = PersistentMemory(n_tokens=8, d_model=32)
|
|
62
|
+
x = torch.randn(2, 16, 32)
|
|
63
|
+
out = pm(x)
|
|
64
|
+
assert out.shape == (2, 24, 32)
|
|
65
|
+
|
|
66
|
+
def test_strip_shape(self):
|
|
67
|
+
pm = PersistentMemory(n_tokens=8, d_model=32)
|
|
68
|
+
x = torch.randn(2, 16, 32)
|
|
69
|
+
aug = pm(x)
|
|
70
|
+
stripped = pm.strip(aug)
|
|
71
|
+
assert stripped.shape == x.shape
|
|
72
|
+
|
|
73
|
+
def test_freeze_unfreeze(self):
|
|
74
|
+
pm = PersistentMemory(n_tokens=4, d_model=8)
|
|
75
|
+
pm.freeze()
|
|
76
|
+
assert not pm.P.requires_grad
|
|
77
|
+
pm.unfreeze()
|
|
78
|
+
assert pm.P.requires_grad
|
|
79
|
+
|
|
80
|
+
def test_gradient_through_persistent(self):
|
|
81
|
+
pm = PersistentMemory(n_tokens=4, d_model=8)
|
|
82
|
+
x = torch.randn(1, 4, 8)
|
|
83
|
+
aug = pm(x)
|
|
84
|
+
aug.sum().backward()
|
|
85
|
+
assert pm.P.grad is not None
|