titans-memory 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,264 @@
1
+ Metadata-Version: 2.4
2
+ Name: titans-memory
3
+ Version: 0.1.0
4
+ Summary: PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)
5
+ Author: Neuranox, Implementation of arXiv:2501.00663
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/Neuranox/titans-memory
8
+ Project-URL: Bug Tracker, https://github.com/Neuranox/titans-memory/issues
9
+ Project-URL: Paper, https://arxiv.org/abs/2501.00663
10
+ Keywords: deep-learning,transformers,long-context,memory,titans,neural-memory
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.9
21
+ Description-Content-Type: text/markdown
22
+ Requires-Dist: torch>=2.1.0
23
+ Provides-Extra: dev
24
+ Requires-Dist: pytest>=7.0; extra == "dev"
25
+ Requires-Dist: pytest-cov; extra == "dev"
26
+ Provides-Extra: train
27
+ Requires-Dist: datasets>=2.16; extra == "train"
28
+ Requires-Dist: transformers>=4.38; extra == "train"
29
+ Requires-Dist: tqdm; extra == "train"
30
+ Requires-Dist: tensorboard; extra == "train"
31
+ Provides-Extra: all
32
+ Requires-Dist: datasets>=2.16; extra == "all"
33
+ Requires-Dist: transformers>=4.38; extra == "all"
34
+ Requires-Dist: tqdm; extra == "all"
35
+ Requires-Dist: tensorboard; extra == "all"
36
+ Requires-Dist: pytest>=7.0; extra == "all"
37
+ Requires-Dist: pytest-cov; extra == "all"
38
+
39
+ # Titans: Learning to Memorize at Test Time
40
+
41
+ [![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://python.org)
42
+ [![PyTorch](https://img.shields.io/badge/pytorch-2.1%2B-orange)](https://pytorch.org)
43
+ [![GitHub release](https://img.shields.io/github/v/release/Neuranox/titans-memory)](https://github.com/Neuranox/titans-memory)
44
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green)](LICENSE)
45
+ [![arXiv](https://img.shields.io/badge/arXiv-2501.00663-red)](https://arxiv.org/abs/2501.00663)
46
+
47
+ A clean, highly-optimized PyTorch implementation of the **Titans** architecture from:
48
+
49
+ > **Titans: Learning to Memorize at Test Time**
50
+ > Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
51
+ > [arXiv:2501.00663](https://arxiv.org/abs/2501.00663)
52
+
53
+ <p align="center">
54
+ <img src="assets/image.png" alt="Titans Architecture Overview" width="80%">
55
+ </p>
56
+
57
+ ---
58
+
59
+ ## What's Inside
60
+
61
+ | Module | Description |
62
+ |---|---|
63
+ | `NeuralMemory` | Deep MLP that learns to memorize via gradient descent with **momentum** + **weight-decay forgetting** (§3) |
64
+ | `PersistentMemory` | Learnable task-knowledge tokens prepended to every sequence (§3.3) |
65
+ | `TitansMAC` | **Memory as a Context** — retrieves long-term memory as prefix to attention window (§4.1) |
66
+ | `TitansMAG` | **Memory as a Gate** — SWA ⊗ NeuralMemory gated branch (§4.2) |
67
+ | `TitansMAL` | **Memory as a Layer** — sequential LMM → SWA stack (§4.3) |
68
+ | `TitansLMM` | **Standalone LMM** — neural memory without attention (§4.3) |
69
+
70
+ ---
71
+
72
+ ## Installation
73
+
74
+ ```bash
75
+ # Install directly from GitHub
76
+ pip install git+https://github.com/Neuranox/titans-memory.git
77
+
78
+ # Or clone and install locally (editable — recommended for development)
79
+ git clone https://github.com/Neuranox/titans-memory.git
80
+ cd titans-memory
81
+ pip install -e .
82
+ ```
83
+
84
+ ---
85
+
86
+ ## Quick Start
87
+
88
+ ```python
89
+ import torch
90
+ from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
91
+ from titans.utils import TitansConfig, build_model, count_parameters
92
+
93
+ # ── Build from config ──────────────────────────────────────────────────
94
+ cfg = TitansConfig.small(variant="MAC") # ~170 M params
95
+ cfg.vocab_size = 32_000
96
+ model = build_model(cfg)
97
+ print(f"Parameters: {count_parameters(model):,}")
98
+
99
+ # ── Forward pass ───────────────────────────────────────────────────────
100
+ input_ids = torch.randint(0, 32_000, (2, 512))
101
+ labels = input_ids.clone()
102
+
103
+ out = model(input_ids, labels=labels)
104
+ print(out["logits"].shape) # (2, 512, 32000)
105
+ print(out["loss"].item())
106
+
107
+ # ── Generation ─────────────────────────────────────────────────────────
108
+ prompt = torch.randint(0, 32_000, (1, 8))
109
+ generated = model.generate(prompt, max_new_tokens=50, top_k=50)
110
+ ```
111
+
112
+ ---
113
+
114
+ ## All Four Variants
115
+
116
+ ```python
117
+ VOCAB = 32_000
118
+ D = 512
119
+
120
+ models = {
121
+ "LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
122
+ "MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
123
+ "MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
124
+ "MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
125
+ }
126
+ ```
127
+
128
+ ---
129
+
130
+ ## TitansConfig — Paper-Scale Presets
131
+
132
+ ```python
133
+ from titans.utils import TitansConfig, build_model
134
+
135
+ cfg = TitansConfig.tiny(variant="MAC") # ~30 M — quick experiments
136
+ cfg = TitansConfig.small(variant="MAC") # ~170 M — paper Table 1
137
+ cfg = TitansConfig.medium(variant="MAC") # ~340 M — paper Table 1
138
+ cfg = TitansConfig.large(variant="MAC") # ~760 M — paper Table 1
139
+
140
+ # JSON save / load
141
+ cfg.to_json("config.json")
142
+ cfg = TitansConfig.from_json("config.json")
143
+ ```
144
+
145
+ ---
146
+
147
+ ## Training
148
+
149
+ ```python
150
+ from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup
151
+
152
+ optim = build_optimizer(model, lr=4e-4, weight_decay=0.1) # AdamW, no wd on bias/norm
153
+ sched = get_cosine_schedule_with_warmup(optim,
154
+ warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)
155
+
156
+ for batch in dataloader:
157
+ out = model(batch["input_ids"], labels=batch["labels"])
158
+ out["loss"].backward()
159
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
160
+ optim.step(); sched.step(); optim.zero_grad()
161
+ ```
162
+
163
+ See `examples/02_training_loop.py` for a complete runnable example.
164
+
165
+ ---
166
+
167
+ ## Architecture Overview
168
+
169
+ <p align="center">
170
+ <img src="assets/image1.png" alt="Titans Detail" width="80%">
171
+ </p>
172
+
173
+ ```
174
+ Titans (MAC) — Memory as a Context
175
+ ───────────────────────────────────────────────────────────
176
+ For each segment S^(t):
177
+ h_t = M*_{t-1}(q_t) ← retrieve long-term memory
178
+ S̃^(t) = [P || h_t || S^(t)] ← augment with persistent + history
179
+ y_t = Attention(S̃^(t)) ← full causal attention over window
180
+ M_t = M_{t-1}.update(y_t) ← write: gradient descent w/ momentum
181
+ o_t = y_t ⊗ M*_t(y_t) ← gated output
182
+
183
+ Titans (MAG) — Memory as a Gate
184
+ ───────────────────────────────────────────────────────────
185
+ x̃ = [P || x]
186
+ y = SW-Attn(x̃) ← precise short-term memory (sliding window)
187
+ o = y ⊗ M(x̃) ← gated with neural long-term memory
188
+
189
+ Titans (MAL) — Memory as a Layer
190
+ ───────────────────────────────────────────────────────────
191
+ x̃ = [P || x]
192
+ y = M(x̃) ← memory compresses context
193
+ o = SW-Attn(y) ← attention refines compressed representation
194
+ ```
195
+
196
+ ---
197
+
198
+ ## Neural Memory — Key Equations
199
+
200
+ | Component | Equation | Description |
201
+ |---|---|---|
202
+ | Momentary surprise | `∇ℓ(M_{t-1}; x_t)` | How unexpected is `x_t`? |
203
+ | Surprise with momentum | `S_t = η_t S_{t-1} − θ_t ∇ℓ` | Eq. 10 — carries information flow |
204
+ | Forgetting gate | `M_t = (1−α_t) M_{t-1} + S_t` | Eq. 13 — weight-decay style |
205
+ | Retrieval | `y_t = M*(q_t)` | Eq. 15 — inference, no update |
206
+
207
+ ---
208
+
209
+ ## Running Tests
210
+
211
+ ```bash
212
+ cd "F:\Titan Model"
213
+ pip install -e .[dev]
214
+ pytest
215
+ ```
216
+
217
+ ---
218
+
219
+ ## Project Structure
220
+
221
+ ```
222
+ Titan Model/
223
+ ├── titans/
224
+ │ ├── __init__.py ← public API
225
+ │ ├── memory/
226
+ │ │ ├── neural_memory.py ← NeuralMemory (LMM core)
227
+ │ │ └── persistent_memory.py
228
+ │ ├── models/
229
+ │ │ ├── lmm.py ← TitansLMM
230
+ │ │ ├── mac.py ← TitansMAC
231
+ │ │ ├── mag.py ← TitansMAG
232
+ │ │ └── mal.py ← TitansMAL
233
+ │ ├── ops/
234
+ │ │ ├── scan.py ← parallel associative scan
235
+ │ │ └── attention.py ← causal + sliding-window attention
236
+ │ └── utils/
237
+ │ ├── config.py ← TitansConfig dataclass
238
+ │ ├── factory.py ← build_model()
239
+ │ └── training.py ← optimizer + LR schedule helpers
240
+ ├── tests/
241
+ │ ├── test_scan.py
242
+ │ ├── test_memory.py
243
+ │ └── test_models.py
244
+ ├── examples/
245
+ │ ├── 01_quickstart.py
246
+ │ ├── 02_training_loop.py
247
+ │ └── 03_memory_standalone.py
248
+ ├── pyproject.toml
249
+ ├── setup.py
250
+ └── README.md
251
+ ```
252
+
253
+ ---
254
+
255
+ ## Citation
256
+
257
+ ```bibtex
258
+ @article{behrouz2024titans,
259
+ title = {Titans: Learning to Memorize at Test Time},
260
+ author = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
261
+ journal = {arXiv preprint arXiv:2501.00663},
262
+ year = {2024}
263
+ }
264
+ ```
@@ -0,0 +1,226 @@
1
+ # Titans: Learning to Memorize at Test Time
2
+
3
+ [![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://python.org)
4
+ [![PyTorch](https://img.shields.io/badge/pytorch-2.1%2B-orange)](https://pytorch.org)
5
+ [![GitHub release](https://img.shields.io/github/v/release/Neuranox/titans-memory)](https://github.com/Neuranox/titans-memory)
6
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green)](LICENSE)
7
+ [![arXiv](https://img.shields.io/badge/arXiv-2501.00663-red)](https://arxiv.org/abs/2501.00663)
8
+
9
+ A clean, highly-optimized PyTorch implementation of the **Titans** architecture from:
10
+
11
+ > **Titans: Learning to Memorize at Test Time**
12
+ > Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
13
+ > [arXiv:2501.00663](https://arxiv.org/abs/2501.00663)
14
+
15
+ <p align="center">
16
+ <img src="assets/image.png" alt="Titans Architecture Overview" width="80%">
17
+ </p>
18
+
19
+ ---
20
+
21
+ ## What's Inside
22
+
23
+ | Module | Description |
24
+ |---|---|
25
+ | `NeuralMemory` | Deep MLP that learns to memorize via gradient descent with **momentum** + **weight-decay forgetting** (§3) |
26
+ | `PersistentMemory` | Learnable task-knowledge tokens prepended to every sequence (§3.3) |
27
+ | `TitansMAC` | **Memory as a Context** — retrieves long-term memory as prefix to attention window (§4.1) |
28
+ | `TitansMAG` | **Memory as a Gate** — SWA ⊗ NeuralMemory gated branch (§4.2) |
29
+ | `TitansMAL` | **Memory as a Layer** — sequential LMM → SWA stack (§4.3) |
30
+ | `TitansLMM` | **Standalone LMM** — neural memory without attention (§4.3) |
31
+
32
+ ---
33
+
34
+ ## Installation
35
+
36
+ ```bash
37
+ # Install directly from GitHub
38
+ pip install git+https://github.com/Neuranox/titans-memory.git
39
+
40
+ # Or clone and install locally (editable — recommended for development)
41
+ git clone https://github.com/Neuranox/titans-memory.git
42
+ cd titans-memory
43
+ pip install -e .
44
+ ```
45
+
46
+ ---
47
+
48
+ ## Quick Start
49
+
50
+ ```python
51
+ import torch
52
+ from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
53
+ from titans.utils import TitansConfig, build_model, count_parameters
54
+
55
+ # ── Build from config ──────────────────────────────────────────────────
56
+ cfg = TitansConfig.small(variant="MAC") # ~170 M params
57
+ cfg.vocab_size = 32_000
58
+ model = build_model(cfg)
59
+ print(f"Parameters: {count_parameters(model):,}")
60
+
61
+ # ── Forward pass ───────────────────────────────────────────────────────
62
+ input_ids = torch.randint(0, 32_000, (2, 512))
63
+ labels = input_ids.clone()
64
+
65
+ out = model(input_ids, labels=labels)
66
+ print(out["logits"].shape) # (2, 512, 32000)
67
+ print(out["loss"].item())
68
+
69
+ # ── Generation ─────────────────────────────────────────────────────────
70
+ prompt = torch.randint(0, 32_000, (1, 8))
71
+ generated = model.generate(prompt, max_new_tokens=50, top_k=50)
72
+ ```
73
+
74
+ ---
75
+
76
+ ## All Four Variants
77
+
78
+ ```python
79
+ VOCAB = 32_000
80
+ D = 512
81
+
82
+ models = {
83
+ "LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
84
+ "MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
85
+ "MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
86
+ "MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
87
+ }
88
+ ```
89
+
90
+ ---
91
+
92
+ ## TitansConfig — Paper-Scale Presets
93
+
94
+ ```python
95
+ from titans.utils import TitansConfig, build_model
96
+
97
+ cfg = TitansConfig.tiny(variant="MAC") # ~30 M — quick experiments
98
+ cfg = TitansConfig.small(variant="MAC") # ~170 M — paper Table 1
99
+ cfg = TitansConfig.medium(variant="MAC") # ~340 M — paper Table 1
100
+ cfg = TitansConfig.large(variant="MAC") # ~760 M — paper Table 1
101
+
102
+ # JSON save / load
103
+ cfg.to_json("config.json")
104
+ cfg = TitansConfig.from_json("config.json")
105
+ ```
106
+
107
+ ---
108
+
109
+ ## Training
110
+
111
+ ```python
112
+ from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup
113
+
114
+ optim = build_optimizer(model, lr=4e-4, weight_decay=0.1) # AdamW, no wd on bias/norm
115
+ sched = get_cosine_schedule_with_warmup(optim,
116
+ warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)
117
+
118
+ for batch in dataloader:
119
+ out = model(batch["input_ids"], labels=batch["labels"])
120
+ out["loss"].backward()
121
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
122
+ optim.step(); sched.step(); optim.zero_grad()
123
+ ```
124
+
125
+ See `examples/02_training_loop.py` for a complete runnable example.
126
+
127
+ ---
128
+
129
+ ## Architecture Overview
130
+
131
+ <p align="center">
132
+ <img src="assets/image1.png" alt="Titans Detail" width="80%">
133
+ </p>
134
+
135
+ ```
136
+ Titans (MAC) — Memory as a Context
137
+ ───────────────────────────────────────────────────────────
138
+ For each segment S^(t):
139
+ h_t = M*_{t-1}(q_t) ← retrieve long-term memory
140
+ S̃^(t) = [P || h_t || S^(t)] ← augment with persistent + history
141
+ y_t = Attention(S̃^(t)) ← full causal attention over window
142
+ M_t = M_{t-1}.update(y_t) ← write: gradient descent w/ momentum
143
+ o_t = y_t ⊗ M*_t(y_t) ← gated output
144
+
145
+ Titans (MAG) — Memory as a Gate
146
+ ───────────────────────────────────────────────────────────
147
+ x̃ = [P || x]
148
+ y = SW-Attn(x̃) ← precise short-term memory (sliding window)
149
+ o = y ⊗ M(x̃) ← gated with neural long-term memory
150
+
151
+ Titans (MAL) — Memory as a Layer
152
+ ───────────────────────────────────────────────────────────
153
+ x̃ = [P || x]
154
+ y = M(x̃) ← memory compresses context
155
+ o = SW-Attn(y) ← attention refines compressed representation
156
+ ```
157
+
158
+ ---
159
+
160
+ ## Neural Memory — Key Equations
161
+
162
+ | Component | Equation | Description |
163
+ |---|---|---|
164
+ | Momentary surprise | `∇ℓ(M_{t-1}; x_t)` | How unexpected is `x_t`? |
165
+ | Surprise with momentum | `S_t = η_t S_{t-1} − θ_t ∇ℓ` | Eq. 10 — carries information flow |
166
+ | Forgetting gate | `M_t = (1−α_t) M_{t-1} + S_t` | Eq. 13 — weight-decay style |
167
+ | Retrieval | `y_t = M*(q_t)` | Eq. 15 — inference, no update |
168
+
169
+ ---
170
+
171
+ ## Running Tests
172
+
173
+ ```bash
174
+ cd "F:\Titan Model"
175
+ pip install -e .[dev]
176
+ pytest
177
+ ```
178
+
179
+ ---
180
+
181
+ ## Project Structure
182
+
183
+ ```
184
+ Titan Model/
185
+ ├── titans/
186
+ │ ├── __init__.py ← public API
187
+ │ ├── memory/
188
+ │ │ ├── neural_memory.py ← NeuralMemory (LMM core)
189
+ │ │ └── persistent_memory.py
190
+ │ ├── models/
191
+ │ │ ├── lmm.py ← TitansLMM
192
+ │ │ ├── mac.py ← TitansMAC
193
+ │ │ ├── mag.py ← TitansMAG
194
+ │ │ └── mal.py ← TitansMAL
195
+ │ ├── ops/
196
+ │ │ ├── scan.py ← parallel associative scan
197
+ │ │ └── attention.py ← causal + sliding-window attention
198
+ │ └── utils/
199
+ │ ├── config.py ← TitansConfig dataclass
200
+ │ ├── factory.py ← build_model()
201
+ │ └── training.py ← optimizer + LR schedule helpers
202
+ ├── tests/
203
+ │ ├── test_scan.py
204
+ │ ├── test_memory.py
205
+ │ └── test_models.py
206
+ ├── examples/
207
+ │ ├── 01_quickstart.py
208
+ │ ├── 02_training_loop.py
209
+ │ └── 03_memory_standalone.py
210
+ ├── pyproject.toml
211
+ ├── setup.py
212
+ └── README.md
213
+ ```
214
+
215
+ ---
216
+
217
+ ## Citation
218
+
219
+ ```bibtex
220
+ @article{behrouz2024titans,
221
+ title = {Titans: Learning to Memorize at Test Time},
222
+ author = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
223
+ journal = {arXiv preprint arXiv:2501.00663},
224
+ year = {2024}
225
+ }
226
+ ```
@@ -0,0 +1,66 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "titans-memory"
7
+ version = "0.1.0"
8
+ description = "PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)"
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ requires-python = ">=3.9"
12
+ keywords = ["deep-learning", "transformers", "long-context", "memory", "titans", "neural-memory"]
13
+
14
+ authors = [
15
+ { name = "Neuranox" },
16
+ { name = "Implementation of arXiv:2501.00663" },
17
+ ]
18
+
19
+ classifiers = [
20
+ "Development Status :: 3 - Alpha",
21
+ "Intended Audience :: Science/Research",
22
+ "License :: OSI Approved :: MIT License",
23
+ "Programming Language :: Python :: 3",
24
+ "Programming Language :: Python :: 3.9",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Programming Language :: Python :: 3.12",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ ]
30
+
31
+ dependencies = [
32
+ "torch>=2.1.0",
33
+ ]
34
+
35
+ [project.optional-dependencies]
36
+ dev = [
37
+ "pytest>=7.0",
38
+ "pytest-cov",
39
+ ]
40
+ train = [
41
+ "datasets>=2.16",
42
+ "transformers>=4.38",
43
+ "tqdm",
44
+ "tensorboard",
45
+ ]
46
+ all = [
47
+ "datasets>=2.16",
48
+ "transformers>=4.38",
49
+ "tqdm",
50
+ "tensorboard",
51
+ "pytest>=7.0",
52
+ "pytest-cov",
53
+ ]
54
+
55
+ [project.urls]
56
+ "Homepage" = "https://github.com/Neuranox/titans-memory"
57
+ "Bug Tracker" = "https://github.com/Neuranox/titans-memory/issues"
58
+ "Paper" = "https://arxiv.org/abs/2501.00663"
59
+
60
+ [tool.setuptools.packages.find]
61
+ where = ["."]
62
+ include = ["titans*"]
63
+
64
+ [tool.pytest.ini_options]
65
+ testpaths = ["tests"]
66
+ addopts = "-v --tb=short"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,8 @@
1
+ """
2
+ Fallback setup.py for environments that don't support PEP 517 build.
3
+ All configuration is in pyproject.toml.
4
+ """
5
+ from setuptools import setup
6
+
7
+ if __name__ == "__main__":
8
+ setup()
@@ -0,0 +1,85 @@
1
+ """
2
+ Tests for NeuralMemory and PersistentMemory modules.
3
+ """
4
+ import pytest
5
+ import torch
6
+ from titans.memory import NeuralMemory, PersistentMemory
7
+
8
+
9
+ class TestNeuralMemory:
10
+
11
+ @pytest.fixture
12
+ def mem(self):
13
+ return NeuralMemory(d_model=32, n_layers=2, chunk_size=8)
14
+
15
+ def test_forward_shape(self, mem):
16
+ x = torch.randn(2, 16, 32)
17
+ out = mem(x)
18
+ assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}"
19
+
20
+ def test_retrieve_shape(self, mem):
21
+ q = torch.randn(2, 8, 32)
22
+ out = mem.retrieve(q)
23
+ assert out.shape == q.shape
24
+
25
+ @pytest.mark.parametrize("lm", [1, 2, 3])
26
+ def test_memory_depth(self, lm):
27
+ m = NeuralMemory(d_model=16, n_layers=lm, chunk_size=4)
28
+ x = torch.randn(1, 8, 16)
29
+ out = m(x)
30
+ assert out.shape == x.shape
31
+
32
+ def test_no_momentum(self):
33
+ m = NeuralMemory(d_model=16, n_layers=1, chunk_size=4, use_momentum=False)
34
+ x = torch.randn(1, 8, 16)
35
+ out = m(x)
36
+ assert out.shape == x.shape
37
+
38
+ def test_no_decay(self):
39
+ m = NeuralMemory(d_model=16, n_layers=1, chunk_size=4, use_decay=False)
40
+ x = torch.randn(1, 8, 16)
41
+ out = m(x)
42
+ assert out.shape == x.shape
43
+
44
+ def test_gradients_flow(self, mem):
45
+ x = torch.randn(1, 8, 32, requires_grad=True)
46
+ out = mem(x)
47
+ out.sum().backward()
48
+ assert x.grad is not None
49
+ assert not torch.isnan(x.grad).any()
50
+
51
+ def test_chunk_size_larger_than_seq(self):
52
+ m = NeuralMemory(d_model=16, n_layers=1, chunk_size=128)
53
+ x = torch.randn(1, 10, 16) # T < chunk_size
54
+ out = m(x)
55
+ assert out.shape == x.shape
56
+
57
+
58
+ class TestPersistentMemory:
59
+
60
+ def test_prepend_shape(self):
61
+ pm = PersistentMemory(n_tokens=8, d_model=32)
62
+ x = torch.randn(2, 16, 32)
63
+ out = pm(x)
64
+ assert out.shape == (2, 24, 32)
65
+
66
+ def test_strip_shape(self):
67
+ pm = PersistentMemory(n_tokens=8, d_model=32)
68
+ x = torch.randn(2, 16, 32)
69
+ aug = pm(x)
70
+ stripped = pm.strip(aug)
71
+ assert stripped.shape == x.shape
72
+
73
+ def test_freeze_unfreeze(self):
74
+ pm = PersistentMemory(n_tokens=4, d_model=8)
75
+ pm.freeze()
76
+ assert not pm.P.requires_grad
77
+ pm.unfreeze()
78
+ assert pm.P.requires_grad
79
+
80
+ def test_gradient_through_persistent(self):
81
+ pm = PersistentMemory(n_tokens=4, d_model=8)
82
+ x = torch.randn(1, 4, 8)
83
+ aug = pm(x)
84
+ aug.sum().backward()
85
+ assert pm.P.grad is not None