nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.func import grad, vmap
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class SelfModifyingTitansConfig:
|
|
14
|
+
dim: int
|
|
15
|
+
eta_scale: float = 1e-3
|
|
16
|
+
chunk_size_other: int = 1
|
|
17
|
+
chunk_size_memory: int | None = None
|
|
18
|
+
objective: str = "l2"
|
|
19
|
+
stopgrad_vhat: bool = True
|
|
20
|
+
use_rank1_precond: bool = True
|
|
21
|
+
use_alpha: bool = True
|
|
22
|
+
momentum: float = 0.0
|
|
23
|
+
qk_l2_norm: bool = True
|
|
24
|
+
adaptive_q: bool = False
|
|
25
|
+
use_skip: bool = True
|
|
26
|
+
local_conv_window: int | None = 4
|
|
27
|
+
eps: float = 1e-6
|
|
28
|
+
|
|
29
|
+
def __post_init__(self) -> None:
|
|
30
|
+
if self.dim <= 0:
|
|
31
|
+
raise ValueError("dim must be positive")
|
|
32
|
+
if self.eta_scale <= 0:
|
|
33
|
+
raise ValueError("eta_scale must be positive")
|
|
34
|
+
if self.chunk_size_other <= 0:
|
|
35
|
+
raise ValueError("chunk_size_other must be positive")
|
|
36
|
+
if self.chunk_size_memory is not None and self.chunk_size_memory <= 0:
|
|
37
|
+
raise ValueError("chunk_size_memory must be positive")
|
|
38
|
+
if self.objective not in {"l2", "dot"}:
|
|
39
|
+
raise ValueError("objective must be one of {'l2', 'dot'}")
|
|
40
|
+
if not (0.0 <= self.momentum < 1.0):
|
|
41
|
+
raise ValueError("momentum must be in [0, 1)")
|
|
42
|
+
if self.local_conv_window is not None and int(self.local_conv_window) <= 0:
|
|
43
|
+
raise ValueError("local_conv_window must be positive")
|
|
44
|
+
if self.chunk_size_memory is None:
|
|
45
|
+
object.__setattr__(self, "chunk_size_memory", int(self.chunk_size_other))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ResidualMLPMemoryState:
|
|
50
|
+
w1: torch.Tensor
|
|
51
|
+
w2: torch.Tensor
|
|
52
|
+
w_skip: torch.Tensor | None = None
|
|
53
|
+
m_w1: torch.Tensor | None = None
|
|
54
|
+
m_w2: torch.Tensor | None = None
|
|
55
|
+
m_w_skip: torch.Tensor | None = None
|
|
56
|
+
|
|
57
|
+
def clone(self) -> "ResidualMLPMemoryState":
|
|
58
|
+
return ResidualMLPMemoryState(
|
|
59
|
+
w1=self.w1.detach().clone(),
|
|
60
|
+
w2=self.w2.detach().clone(),
|
|
61
|
+
w_skip=None if self.w_skip is None else self.w_skip.detach().clone(),
|
|
62
|
+
m_w1=None if self.m_w1 is None else self.m_w1.detach().clone(),
|
|
63
|
+
m_w2=None if self.m_w2 is None else self.m_w2.detach().clone(),
|
|
64
|
+
m_w_skip=None if self.m_w_skip is None else self.m_w_skip.detach().clone(),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class SelfModifyingTitansState:
|
|
70
|
+
"""
|
|
71
|
+
Fast state for self-modifying Titans.
|
|
72
|
+
|
|
73
|
+
Each memory M_□ is a residual MLP (Eq. 91) whose initial parameters are meta-learned
|
|
74
|
+
(stored in the module) and cloned into this fast state per context.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
k: ResidualMLPMemoryState
|
|
78
|
+
v: ResidualMLPMemoryState
|
|
79
|
+
q: ResidualMLPMemoryState
|
|
80
|
+
eta: ResidualMLPMemoryState
|
|
81
|
+
alpha: ResidualMLPMemoryState
|
|
82
|
+
memory: ResidualMLPMemoryState
|
|
83
|
+
|
|
84
|
+
def clone(self) -> "SelfModifyingTitansState":
|
|
85
|
+
return SelfModifyingTitansState(
|
|
86
|
+
k=self.k.clone(),
|
|
87
|
+
v=self.v.clone(),
|
|
88
|
+
q=self.q.clone(),
|
|
89
|
+
eta=self.eta.clone(),
|
|
90
|
+
alpha=self.alpha.clone(),
|
|
91
|
+
memory=self.memory.clone(),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ResidualMLPMemory(nn.Module):
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
*,
|
|
99
|
+
in_dim: int,
|
|
100
|
+
out_dim: int,
|
|
101
|
+
hidden_dim: int,
|
|
102
|
+
activation: Callable[[torch.Tensor], torch.Tensor],
|
|
103
|
+
use_skip: bool = True,
|
|
104
|
+
) -> None:
|
|
105
|
+
super().__init__()
|
|
106
|
+
if in_dim <= 0 or out_dim <= 0 or hidden_dim <= 0:
|
|
107
|
+
raise ValueError("in_dim/out_dim/hidden_dim must be positive")
|
|
108
|
+
self.in_dim = int(in_dim)
|
|
109
|
+
self.out_dim = int(out_dim)
|
|
110
|
+
self.hidden_dim = int(hidden_dim)
|
|
111
|
+
self.activation = activation
|
|
112
|
+
self.use_skip = bool(use_skip)
|
|
113
|
+
self.w2 = nn.Linear(self.in_dim, self.hidden_dim, bias=False)
|
|
114
|
+
self.w1 = nn.Linear(self.hidden_dim, self.out_dim, bias=False)
|
|
115
|
+
self.w_skip: nn.Linear | None = None
|
|
116
|
+
if self.use_skip and self.in_dim != self.out_dim:
|
|
117
|
+
self.w_skip = nn.Linear(self.in_dim, self.out_dim, bias=False)
|
|
118
|
+
|
|
119
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
|
120
|
+
hidden = self.activation(self.w2(x))
|
|
121
|
+
out = self.w1(hidden)
|
|
122
|
+
if self.w_skip is not None:
|
|
123
|
+
return self.w_skip(x) + out
|
|
124
|
+
if out.shape[-1] == x.shape[-1]:
|
|
125
|
+
return x + out
|
|
126
|
+
return out
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SelfModifyingTitans(nn.Module):
|
|
130
|
+
"""
|
|
131
|
+
Self-modifying Titans (Nested Learning paper, Eqs. 83–93), correctness-first.
|
|
132
|
+
|
|
133
|
+
- Multiple memories: M_k, M_v, M_q, M_eta, M_alpha, M_memory.
|
|
134
|
+
- Each memory is a 2-layer residual MLP (Eq. 91).
|
|
135
|
+
- Updates are performed on fast state using chunked DGD-like rule (Eq. 90/93).
|
|
136
|
+
|
|
137
|
+
Note: This implementation prioritizes semantic fidelity and testability over speed.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(self, config: SelfModifyingTitansConfig):
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.config = config
|
|
143
|
+
dim = config.dim
|
|
144
|
+
hidden = dim
|
|
145
|
+
act = F.gelu
|
|
146
|
+
self.local_conv: nn.Conv1d | None = None
|
|
147
|
+
if config.local_conv_window is not None:
|
|
148
|
+
window = int(config.local_conv_window)
|
|
149
|
+
self.local_conv = nn.Conv1d(
|
|
150
|
+
dim,
|
|
151
|
+
dim,
|
|
152
|
+
kernel_size=window,
|
|
153
|
+
groups=dim,
|
|
154
|
+
padding=0,
|
|
155
|
+
bias=False,
|
|
156
|
+
)
|
|
157
|
+
self.w_q = nn.Linear(dim, dim, bias=False)
|
|
158
|
+
self.m_k = ResidualMLPMemory(
|
|
159
|
+
in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
|
|
160
|
+
)
|
|
161
|
+
self.m_v = ResidualMLPMemory(
|
|
162
|
+
in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
|
|
163
|
+
)
|
|
164
|
+
self.m_q = ResidualMLPMemory(
|
|
165
|
+
in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
|
|
166
|
+
)
|
|
167
|
+
self.m_eta = ResidualMLPMemory(
|
|
168
|
+
in_dim=dim, out_dim=1, hidden_dim=hidden, activation=act, use_skip=config.use_skip
|
|
169
|
+
)
|
|
170
|
+
self.m_alpha = ResidualMLPMemory(
|
|
171
|
+
in_dim=dim, out_dim=1, hidden_dim=hidden, activation=act, use_skip=config.use_skip
|
|
172
|
+
)
|
|
173
|
+
self.m_memory = ResidualMLPMemory(
|
|
174
|
+
in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def init_fast_state(self) -> SelfModifyingTitansState:
|
|
178
|
+
return SelfModifyingTitansState(
|
|
179
|
+
k=self._init_memory_state(self.m_k),
|
|
180
|
+
v=self._init_memory_state(self.m_v),
|
|
181
|
+
q=self._init_memory_state(self.m_q),
|
|
182
|
+
eta=self._init_memory_state(self.m_eta),
|
|
183
|
+
alpha=self._init_memory_state(self.m_alpha),
|
|
184
|
+
memory=self._init_memory_state(self.m_memory),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def apply_updates_inplace(
|
|
188
|
+
self,
|
|
189
|
+
x: torch.Tensor,
|
|
190
|
+
*,
|
|
191
|
+
chunk_size_other: int | None = None,
|
|
192
|
+
chunk_size_memory: int | None = None,
|
|
193
|
+
) -> None:
|
|
194
|
+
"""
|
|
195
|
+
Apply the self-modifying update rule to the *module parameters* in-place.
|
|
196
|
+
|
|
197
|
+
This is intended to be called in an explicit "update pass" under `torch.no_grad()`
|
|
198
|
+
(e.g., after an outer backward), so we avoid mixing differentiable reads with
|
|
199
|
+
in-place writes during the same autograd graph.
|
|
200
|
+
"""
|
|
201
|
+
state = self.init_fast_state()
|
|
202
|
+
_out, updated = self.forward_with_updates(
|
|
203
|
+
x,
|
|
204
|
+
state,
|
|
205
|
+
chunk_size_other=chunk_size_other,
|
|
206
|
+
chunk_size_memory=chunk_size_memory,
|
|
207
|
+
)
|
|
208
|
+
self._load_state_mean_(updated)
|
|
209
|
+
|
|
210
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
|
211
|
+
x = self._apply_local_conv(x)
|
|
212
|
+
q = self.m_q(x) if self.config.adaptive_q else self.w_q(x)
|
|
213
|
+
if self.config.qk_l2_norm:
|
|
214
|
+
q = F.normalize(q, dim=-1, eps=self.config.eps)
|
|
215
|
+
return self.m_memory(q)
|
|
216
|
+
|
|
217
|
+
def forward_with_state(
|
|
218
|
+
self,
|
|
219
|
+
x: torch.Tensor,
|
|
220
|
+
state: SelfModifyingTitansState,
|
|
221
|
+
) -> torch.Tensor:
|
|
222
|
+
if x.ndim != 3:
|
|
223
|
+
raise ValueError("Expected x to have shape (B, T, D)")
|
|
224
|
+
batch, _seq_len, dim = x.shape
|
|
225
|
+
if dim != self.config.dim:
|
|
226
|
+
raise ValueError(f"Expected dim={self.config.dim}, got {dim}")
|
|
227
|
+
state = self._ensure_batched_state(state, batch)
|
|
228
|
+
x = self._apply_local_conv(x)
|
|
229
|
+
q = (
|
|
230
|
+
self._memory_forward(x, state.q, meta=self.m_q)
|
|
231
|
+
if self.config.adaptive_q
|
|
232
|
+
else self.w_q(x)
|
|
233
|
+
)
|
|
234
|
+
if self.config.qk_l2_norm:
|
|
235
|
+
q = F.normalize(q, dim=-1, eps=self.config.eps)
|
|
236
|
+
return self._memory_forward(q, state.memory, meta=self.m_memory)
|
|
237
|
+
|
|
238
|
+
def forward_with_updates(
|
|
239
|
+
self,
|
|
240
|
+
x: torch.Tensor,
|
|
241
|
+
state: SelfModifyingTitansState,
|
|
242
|
+
*,
|
|
243
|
+
chunk_size_other: int | None = None,
|
|
244
|
+
chunk_size_memory: int | None = None,
|
|
245
|
+
) -> tuple[torch.Tensor, SelfModifyingTitansState]:
|
|
246
|
+
if x.ndim != 3:
|
|
247
|
+
raise ValueError("Expected x to have shape (B, T, D)")
|
|
248
|
+
batch, seq_len, dim = x.shape
|
|
249
|
+
if dim != self.config.dim:
|
|
250
|
+
raise ValueError(f"Expected dim={self.config.dim}, got {dim}")
|
|
251
|
+
state = self._ensure_batched_state(state, batch)
|
|
252
|
+
x = self._apply_local_conv(x)
|
|
253
|
+
other_chunk = int(
|
|
254
|
+
self.config.chunk_size_other if chunk_size_other is None else chunk_size_other
|
|
255
|
+
)
|
|
256
|
+
memory_chunk_cfg = self.config.chunk_size_memory
|
|
257
|
+
if memory_chunk_cfg is None:
|
|
258
|
+
memory_chunk_cfg = self.config.chunk_size_other
|
|
259
|
+
memory_chunk = int(memory_chunk_cfg if chunk_size_memory is None else chunk_size_memory)
|
|
260
|
+
if other_chunk <= 0 or memory_chunk <= 0:
|
|
261
|
+
raise ValueError("chunk sizes must be positive")
|
|
262
|
+
|
|
263
|
+
outputs: list[torch.Tensor] = []
|
|
264
|
+
other_k: list[torch.Tensor] = []
|
|
265
|
+
other_v: list[torch.Tensor] = []
|
|
266
|
+
other_eta: list[torch.Tensor] = []
|
|
267
|
+
other_alpha: list[torch.Tensor] = []
|
|
268
|
+
memory_k: list[torch.Tensor] = []
|
|
269
|
+
memory_v: list[torch.Tensor] = []
|
|
270
|
+
memory_eta: list[torch.Tensor] = []
|
|
271
|
+
memory_alpha: list[torch.Tensor] = []
|
|
272
|
+
|
|
273
|
+
def _next_boundary(idx: int, *, chunk_size: int) -> int:
|
|
274
|
+
if chunk_size <= 0:
|
|
275
|
+
raise ValueError("chunk_size must be positive")
|
|
276
|
+
return min(((idx // chunk_size) + 1) * chunk_size, seq_len)
|
|
277
|
+
|
|
278
|
+
with torch.no_grad():
|
|
279
|
+
idx = 0
|
|
280
|
+
while idx < seq_len:
|
|
281
|
+
next_other = _next_boundary(idx, chunk_size=other_chunk)
|
|
282
|
+
next_memory = _next_boundary(idx, chunk_size=memory_chunk)
|
|
283
|
+
end = min(next_other, next_memory, seq_len)
|
|
284
|
+
x_chunk = x[:, idx:end, :]
|
|
285
|
+
|
|
286
|
+
k_chunk = self._memory_forward(x_chunk, state.k)
|
|
287
|
+
v_chunk = self._memory_forward(x_chunk, state.v)
|
|
288
|
+
q_chunk = (
|
|
289
|
+
self._memory_forward(x_chunk, state.q)
|
|
290
|
+
if self.config.adaptive_q
|
|
291
|
+
else self.w_q(x_chunk)
|
|
292
|
+
)
|
|
293
|
+
if self.config.qk_l2_norm:
|
|
294
|
+
k_chunk = F.normalize(k_chunk, dim=-1, eps=self.config.eps)
|
|
295
|
+
q_chunk = F.normalize(q_chunk, dim=-1, eps=self.config.eps)
|
|
296
|
+
eta_chunk = self._memory_forward(x_chunk, state.eta).squeeze(-1)
|
|
297
|
+
eta_chunk = F.softplus(eta_chunk) * self.config.eta_scale
|
|
298
|
+
if self.config.use_alpha:
|
|
299
|
+
alpha_chunk = self._memory_forward(x_chunk, state.alpha).squeeze(-1)
|
|
300
|
+
alpha_chunk = torch.sigmoid(alpha_chunk)
|
|
301
|
+
else:
|
|
302
|
+
alpha_chunk = torch.ones_like(eta_chunk)
|
|
303
|
+
o_chunk = self._memory_forward(q_chunk, state.memory)
|
|
304
|
+
outputs.append(o_chunk)
|
|
305
|
+
|
|
306
|
+
other_k.append(k_chunk)
|
|
307
|
+
other_v.append(v_chunk)
|
|
308
|
+
other_eta.append(eta_chunk)
|
|
309
|
+
other_alpha.append(alpha_chunk)
|
|
310
|
+
memory_k.append(k_chunk)
|
|
311
|
+
memory_v.append(v_chunk)
|
|
312
|
+
memory_eta.append(eta_chunk)
|
|
313
|
+
memory_alpha.append(alpha_chunk)
|
|
314
|
+
|
|
315
|
+
idx = end
|
|
316
|
+
|
|
317
|
+
if idx == next_other and other_k:
|
|
318
|
+
other_memories: tuple[str, ...] = ("k", "v", "eta")
|
|
319
|
+
if self.config.adaptive_q:
|
|
320
|
+
other_memories = (*other_memories, "q")
|
|
321
|
+
if self.config.use_alpha:
|
|
322
|
+
other_memories = (*other_memories, "alpha")
|
|
323
|
+
self._apply_chunk_update_seq(
|
|
324
|
+
state,
|
|
325
|
+
k_seq=torch.cat(other_k, dim=1),
|
|
326
|
+
v_seq=torch.cat(other_v, dim=1),
|
|
327
|
+
eta_seq=torch.cat(other_eta, dim=1),
|
|
328
|
+
alpha_seq=torch.cat(other_alpha, dim=1),
|
|
329
|
+
memories=other_memories,
|
|
330
|
+
)
|
|
331
|
+
other_k.clear()
|
|
332
|
+
other_v.clear()
|
|
333
|
+
other_eta.clear()
|
|
334
|
+
other_alpha.clear()
|
|
335
|
+
|
|
336
|
+
if idx == next_memory and memory_k:
|
|
337
|
+
self._apply_chunk_update_seq(
|
|
338
|
+
state,
|
|
339
|
+
k_seq=torch.cat(memory_k, dim=1),
|
|
340
|
+
v_seq=torch.cat(memory_v, dim=1),
|
|
341
|
+
eta_seq=torch.cat(memory_eta, dim=1),
|
|
342
|
+
alpha_seq=torch.cat(memory_alpha, dim=1),
|
|
343
|
+
memories=("memory",),
|
|
344
|
+
)
|
|
345
|
+
memory_k.clear()
|
|
346
|
+
memory_v.clear()
|
|
347
|
+
memory_eta.clear()
|
|
348
|
+
memory_alpha.clear()
|
|
349
|
+
|
|
350
|
+
if other_k:
|
|
351
|
+
other_memories = ("k", "v", "eta")
|
|
352
|
+
if self.config.adaptive_q:
|
|
353
|
+
other_memories = (*other_memories, "q")
|
|
354
|
+
if self.config.use_alpha:
|
|
355
|
+
other_memories = (*other_memories, "alpha")
|
|
356
|
+
self._apply_chunk_update_seq(
|
|
357
|
+
state,
|
|
358
|
+
k_seq=torch.cat(other_k, dim=1),
|
|
359
|
+
v_seq=torch.cat(other_v, dim=1),
|
|
360
|
+
eta_seq=torch.cat(other_eta, dim=1),
|
|
361
|
+
alpha_seq=torch.cat(other_alpha, dim=1),
|
|
362
|
+
memories=other_memories,
|
|
363
|
+
)
|
|
364
|
+
if memory_k:
|
|
365
|
+
self._apply_chunk_update_seq(
|
|
366
|
+
state,
|
|
367
|
+
k_seq=torch.cat(memory_k, dim=1),
|
|
368
|
+
v_seq=torch.cat(memory_v, dim=1),
|
|
369
|
+
eta_seq=torch.cat(memory_eta, dim=1),
|
|
370
|
+
alpha_seq=torch.cat(memory_alpha, dim=1),
|
|
371
|
+
memories=("memory",),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return torch.cat(outputs, dim=1), state
|
|
375
|
+
|
|
376
|
+
def _apply_local_conv(self, x: torch.Tensor) -> torch.Tensor:
|
|
377
|
+
if self.local_conv is None:
|
|
378
|
+
return x
|
|
379
|
+
if x.ndim != 3:
|
|
380
|
+
raise ValueError("Expected x to have shape (B, T, D)")
|
|
381
|
+
kernel = int(self.local_conv.kernel_size[0])
|
|
382
|
+
# Causal depthwise conv: only attends to past tokens.
|
|
383
|
+
x_t = x.transpose(1, 2)
|
|
384
|
+
x_t = F.pad(x_t, (kernel - 1, 0))
|
|
385
|
+
x_t = self.local_conv(x_t)
|
|
386
|
+
return x_t.transpose(1, 2)
|
|
387
|
+
|
|
388
|
+
def _load_state_mean_(self, state: SelfModifyingTitansState) -> None:
|
|
389
|
+
def _mean_weight(weight: torch.Tensor) -> torch.Tensor:
|
|
390
|
+
return weight.mean(dim=0) if weight.ndim == 3 else weight
|
|
391
|
+
|
|
392
|
+
def _copy(module: ResidualMLPMemory, mem: ResidualMLPMemoryState) -> None:
|
|
393
|
+
module.w1.weight.copy_(_mean_weight(mem.w1))
|
|
394
|
+
module.w2.weight.copy_(_mean_weight(mem.w2))
|
|
395
|
+
if module.w_skip is None:
|
|
396
|
+
return
|
|
397
|
+
if mem.w_skip is None:
|
|
398
|
+
raise RuntimeError("Expected w_skip state for projected residual memory")
|
|
399
|
+
module.w_skip.weight.copy_(_mean_weight(mem.w_skip))
|
|
400
|
+
|
|
401
|
+
with torch.no_grad():
|
|
402
|
+
_copy(self.m_k, state.k)
|
|
403
|
+
_copy(self.m_v, state.v)
|
|
404
|
+
_copy(self.m_eta, state.eta)
|
|
405
|
+
if self.config.use_alpha:
|
|
406
|
+
_copy(self.m_alpha, state.alpha)
|
|
407
|
+
_copy(self.m_memory, state.memory)
|
|
408
|
+
if self.config.adaptive_q:
|
|
409
|
+
_copy(self.m_q, state.q)
|
|
410
|
+
|
|
411
|
+
def _apply_chunk_update(
|
|
412
|
+
self,
|
|
413
|
+
state: SelfModifyingTitansState,
|
|
414
|
+
buffer: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
|
415
|
+
*,
|
|
416
|
+
memories: tuple[str, ...],
|
|
417
|
+
) -> None:
|
|
418
|
+
if not buffer:
|
|
419
|
+
return
|
|
420
|
+
k_seq = torch.stack([item[0] for item in buffer], dim=1)
|
|
421
|
+
v_seq = torch.stack([item[1] for item in buffer], dim=1)
|
|
422
|
+
eta_seq = torch.stack([item[2] for item in buffer], dim=1)
|
|
423
|
+
alpha_seq = torch.stack([item[3] for item in buffer], dim=1)
|
|
424
|
+
self._apply_chunk_update_seq(
|
|
425
|
+
state,
|
|
426
|
+
k_seq=k_seq,
|
|
427
|
+
v_seq=v_seq,
|
|
428
|
+
eta_seq=eta_seq,
|
|
429
|
+
alpha_seq=alpha_seq,
|
|
430
|
+
memories=memories,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def _apply_chunk_update_seq(
|
|
434
|
+
self,
|
|
435
|
+
state: SelfModifyingTitansState,
|
|
436
|
+
*,
|
|
437
|
+
k_seq: torch.Tensor,
|
|
438
|
+
v_seq: torch.Tensor,
|
|
439
|
+
eta_seq: torch.Tensor,
|
|
440
|
+
alpha_seq: torch.Tensor,
|
|
441
|
+
memories: tuple[str, ...],
|
|
442
|
+
) -> None:
|
|
443
|
+
steps = k_seq.size(1)
|
|
444
|
+
dim = self.config.dim
|
|
445
|
+
eye = (
|
|
446
|
+
torch.eye(dim, device=k_seq.device, dtype=k_seq.dtype)
|
|
447
|
+
.unsqueeze(0)
|
|
448
|
+
.expand(k_seq.size(0), -1, -1)
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
boundary: dict[str, ResidualMLPMemoryState] = {
|
|
452
|
+
name: getattr(state, name).clone() for name in memories
|
|
453
|
+
}
|
|
454
|
+
grads = {name: self._memory_grads_chunk(boundary[name], k_seq, v_seq) for name in memories}
|
|
455
|
+
|
|
456
|
+
for t in range(steps):
|
|
457
|
+
k_t = k_seq[:, t, :]
|
|
458
|
+
eta_t = eta_seq[:, t]
|
|
459
|
+
alpha_t = alpha_seq[:, t]
|
|
460
|
+
kk = torch.einsum("bi,bj->bij", k_t, k_t)
|
|
461
|
+
precond = alpha_t[:, None, None] * eye - eta_t[:, None, None] * kk
|
|
462
|
+
for name in memories:
|
|
463
|
+
fast = getattr(state, name)
|
|
464
|
+
g1, g2, gskip = grads[name]
|
|
465
|
+
self._apply_param_update(
|
|
466
|
+
fast,
|
|
467
|
+
(
|
|
468
|
+
g1[:, t, ...],
|
|
469
|
+
g2[:, t, ...],
|
|
470
|
+
None if gskip is None else gskip[:, t, ...],
|
|
471
|
+
),
|
|
472
|
+
eta_t,
|
|
473
|
+
alpha_t,
|
|
474
|
+
precond,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
def _memory_grads(
|
|
478
|
+
self,
|
|
479
|
+
frozen: ResidualMLPMemoryState,
|
|
480
|
+
k_t: torch.Tensor,
|
|
481
|
+
v_t: torch.Tensor,
|
|
482
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
|
483
|
+
with torch.enable_grad():
|
|
484
|
+
w1 = frozen.w1.detach().requires_grad_(True)
|
|
485
|
+
w2 = frozen.w2.detach().requires_grad_(True)
|
|
486
|
+
w_skip = None
|
|
487
|
+
if frozen.w_skip is not None:
|
|
488
|
+
w_skip = frozen.w_skip.detach().requires_grad_(True)
|
|
489
|
+
|
|
490
|
+
pred = self._memory_forward(k_t, ResidualMLPMemoryState(w1=w1, w2=w2, w_skip=w_skip))
|
|
491
|
+
vhat = self._memory_forward(v_t, ResidualMLPMemoryState(w1=w1, w2=w2, w_skip=w_skip))
|
|
492
|
+
if self.config.stopgrad_vhat:
|
|
493
|
+
vhat = vhat.detach()
|
|
494
|
+
|
|
495
|
+
if self.config.objective == "dot":
|
|
496
|
+
loss = -(pred * vhat).sum(dim=-1)
|
|
497
|
+
else:
|
|
498
|
+
loss = (pred - vhat).pow(2).sum(dim=-1)
|
|
499
|
+
loss_scalar = loss.sum()
|
|
500
|
+
|
|
501
|
+
grads = torch.autograd.grad(
|
|
502
|
+
loss_scalar,
|
|
503
|
+
(w1, w2, w_skip) if w_skip is not None else (w1, w2),
|
|
504
|
+
retain_graph=False,
|
|
505
|
+
create_graph=False,
|
|
506
|
+
allow_unused=False,
|
|
507
|
+
)
|
|
508
|
+
if w_skip is None:
|
|
509
|
+
g1, g2 = grads
|
|
510
|
+
return g1, g2, None
|
|
511
|
+
g1, g2, gskip = grads
|
|
512
|
+
return g1, g2, gskip
|
|
513
|
+
|
|
514
|
+
def _memory_grads_chunk(
|
|
515
|
+
self,
|
|
516
|
+
frozen: ResidualMLPMemoryState,
|
|
517
|
+
k_seq: torch.Tensor,
|
|
518
|
+
v_seq: torch.Tensor,
|
|
519
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
|
520
|
+
"""
|
|
521
|
+
Compute per-token gradients for an entire chunk in parallel (paper §8.2).
|
|
522
|
+
|
|
523
|
+
Returns gradients with leading shape (B, T, ...).
|
|
524
|
+
"""
|
|
525
|
+
w1 = frozen.w1.detach()
|
|
526
|
+
w2 = frozen.w2.detach()
|
|
527
|
+
w_skip = None if frozen.w_skip is None else frozen.w_skip.detach()
|
|
528
|
+
|
|
529
|
+
k_tokens = k_seq.transpose(0, 1)
|
|
530
|
+
v_tokens = v_seq.transpose(0, 1)
|
|
531
|
+
|
|
532
|
+
if w_skip is None:
|
|
533
|
+
|
|
534
|
+
def loss_fn_noskip(
|
|
535
|
+
w1_t: torch.Tensor,
|
|
536
|
+
w2_t: torch.Tensor,
|
|
537
|
+
k_t: torch.Tensor,
|
|
538
|
+
v_t: torch.Tensor,
|
|
539
|
+
) -> torch.Tensor:
|
|
540
|
+
mem = ResidualMLPMemoryState(w1=w1_t, w2=w2_t)
|
|
541
|
+
pred = self._memory_forward(k_t, mem)
|
|
542
|
+
vhat = self._memory_forward(v_t, mem)
|
|
543
|
+
if self.config.stopgrad_vhat:
|
|
544
|
+
vhat = vhat.detach()
|
|
545
|
+
if self.config.objective == "dot":
|
|
546
|
+
loss = -(pred * vhat).sum(dim=-1)
|
|
547
|
+
else:
|
|
548
|
+
loss = (pred - vhat).pow(2).sum(dim=-1)
|
|
549
|
+
return loss.sum()
|
|
550
|
+
|
|
551
|
+
grad_fn = grad(loss_fn_noskip, argnums=(0, 1))
|
|
552
|
+
g1_tokens, g2_tokens = vmap(grad_fn, in_dims=(None, None, 0, 0))(
|
|
553
|
+
w1,
|
|
554
|
+
w2,
|
|
555
|
+
k_tokens,
|
|
556
|
+
v_tokens,
|
|
557
|
+
)
|
|
558
|
+
return g1_tokens.transpose(0, 1), g2_tokens.transpose(0, 1), None
|
|
559
|
+
|
|
560
|
+
def loss_fn_skip(
|
|
561
|
+
w1_t: torch.Tensor,
|
|
562
|
+
w2_t: torch.Tensor,
|
|
563
|
+
w_skip_t: torch.Tensor,
|
|
564
|
+
k_t: torch.Tensor,
|
|
565
|
+
v_t: torch.Tensor,
|
|
566
|
+
) -> torch.Tensor:
|
|
567
|
+
mem = ResidualMLPMemoryState(w1=w1_t, w2=w2_t, w_skip=w_skip_t)
|
|
568
|
+
pred = self._memory_forward(k_t, mem)
|
|
569
|
+
vhat = self._memory_forward(v_t, mem)
|
|
570
|
+
if self.config.stopgrad_vhat:
|
|
571
|
+
vhat = vhat.detach()
|
|
572
|
+
if self.config.objective == "dot":
|
|
573
|
+
loss = -(pred * vhat).sum(dim=-1)
|
|
574
|
+
else:
|
|
575
|
+
loss = (pred - vhat).pow(2).sum(dim=-1)
|
|
576
|
+
return loss.sum()
|
|
577
|
+
|
|
578
|
+
grad_fn = grad(loss_fn_skip, argnums=(0, 1, 2))
|
|
579
|
+
g1_tokens, g2_tokens, gskip_tokens = vmap(
|
|
580
|
+
grad_fn,
|
|
581
|
+
in_dims=(None, None, None, 0, 0),
|
|
582
|
+
)(w1, w2, w_skip, k_tokens, v_tokens)
|
|
583
|
+
return (
|
|
584
|
+
g1_tokens.transpose(0, 1),
|
|
585
|
+
g2_tokens.transpose(0, 1),
|
|
586
|
+
gskip_tokens.transpose(0, 1),
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
def _apply_param_update(
|
|
590
|
+
self,
|
|
591
|
+
fast: ResidualMLPMemoryState,
|
|
592
|
+
grads: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None],
|
|
593
|
+
eta_t: torch.Tensor,
|
|
594
|
+
alpha_t: torch.Tensor,
|
|
595
|
+
precond: torch.Tensor,
|
|
596
|
+
) -> None:
|
|
597
|
+
g1, g2, gskip = grads
|
|
598
|
+
g1 = self._apply_momentum(fast, "m_w1", g1)
|
|
599
|
+
g2 = self._apply_momentum(fast, "m_w2", g2)
|
|
600
|
+
if self.config.use_rank1_precond:
|
|
601
|
+
fast.w2 = torch.matmul(fast.w2, precond) - eta_t[:, None, None] * g2
|
|
602
|
+
else:
|
|
603
|
+
fast.w2 = alpha_t[:, None, None] * fast.w2 - eta_t[:, None, None] * g2
|
|
604
|
+
fast.w1 = alpha_t[:, None, None] * fast.w1 - eta_t[:, None, None] * g1
|
|
605
|
+
|
|
606
|
+
if fast.w_skip is None:
|
|
607
|
+
return
|
|
608
|
+
if gskip is None:
|
|
609
|
+
raise RuntimeError("Expected w_skip grad to be present")
|
|
610
|
+
gskip = self._apply_momentum(fast, "m_w_skip", gskip)
|
|
611
|
+
if self.config.use_rank1_precond:
|
|
612
|
+
fast.w_skip = torch.matmul(fast.w_skip, precond) - eta_t[:, None, None] * gskip
|
|
613
|
+
else:
|
|
614
|
+
fast.w_skip = alpha_t[:, None, None] * fast.w_skip - eta_t[:, None, None] * gskip
|
|
615
|
+
|
|
616
|
+
def _apply_momentum(
|
|
617
|
+
self,
|
|
618
|
+
fast: ResidualMLPMemoryState,
|
|
619
|
+
attr_name: str,
|
|
620
|
+
grad: torch.Tensor,
|
|
621
|
+
) -> torch.Tensor:
|
|
622
|
+
beta = float(self.config.momentum)
|
|
623
|
+
if beta <= 0.0:
|
|
624
|
+
return grad
|
|
625
|
+
buf = getattr(fast, attr_name)
|
|
626
|
+
if buf is None:
|
|
627
|
+
buf = torch.zeros_like(grad)
|
|
628
|
+
buf = beta * buf + grad
|
|
629
|
+
setattr(fast, attr_name, buf)
|
|
630
|
+
return buf
|
|
631
|
+
|
|
632
|
+
def _init_memory_state(self, module: ResidualMLPMemory) -> ResidualMLPMemoryState:
|
|
633
|
+
skip = None if module.w_skip is None else module.w_skip.weight.detach().clone()
|
|
634
|
+
return ResidualMLPMemoryState(
|
|
635
|
+
w1=module.w1.weight.detach().clone(),
|
|
636
|
+
w2=module.w2.weight.detach().clone(),
|
|
637
|
+
w_skip=skip,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
def _ensure_batched_state(
|
|
641
|
+
self, state: SelfModifyingTitansState, batch: int
|
|
642
|
+
) -> SelfModifyingTitansState:
|
|
643
|
+
if state.k.w1.ndim == 2:
|
|
644
|
+
return SelfModifyingTitansState(
|
|
645
|
+
k=self._expand_memory_state(state.k, batch),
|
|
646
|
+
v=self._expand_memory_state(state.v, batch),
|
|
647
|
+
q=self._expand_memory_state(state.q, batch),
|
|
648
|
+
eta=self._expand_memory_state(state.eta, batch),
|
|
649
|
+
alpha=self._expand_memory_state(state.alpha, batch),
|
|
650
|
+
memory=self._expand_memory_state(state.memory, batch),
|
|
651
|
+
)
|
|
652
|
+
if state.k.w1.ndim != 3:
|
|
653
|
+
raise ValueError("SelfModifyingTitansState weights must be 2D or 3D tensors")
|
|
654
|
+
if state.k.w1.size(0) != batch:
|
|
655
|
+
raise ValueError(
|
|
656
|
+
f"State batch mismatch: expected batch={batch}, got {state.k.w1.size(0)}"
|
|
657
|
+
)
|
|
658
|
+
return state
|
|
659
|
+
|
|
660
|
+
def _expand_memory_state(
|
|
661
|
+
self, mem: ResidualMLPMemoryState, batch: int
|
|
662
|
+
) -> ResidualMLPMemoryState:
|
|
663
|
+
def _expand(t: torch.Tensor) -> torch.Tensor:
|
|
664
|
+
return t.detach().clone().unsqueeze(0).repeat(batch, 1, 1)
|
|
665
|
+
|
|
666
|
+
def _expand_opt(t: torch.Tensor | None) -> torch.Tensor | None:
|
|
667
|
+
return None if t is None else _expand(t)
|
|
668
|
+
|
|
669
|
+
return ResidualMLPMemoryState(
|
|
670
|
+
w1=_expand(mem.w1),
|
|
671
|
+
w2=_expand(mem.w2),
|
|
672
|
+
w_skip=_expand_opt(mem.w_skip),
|
|
673
|
+
m_w1=_expand_opt(mem.m_w1),
|
|
674
|
+
m_w2=_expand_opt(mem.m_w2),
|
|
675
|
+
m_w_skip=_expand_opt(mem.m_w_skip),
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
def _memory_forward(
|
|
679
|
+
self,
|
|
680
|
+
x: torch.Tensor,
|
|
681
|
+
mem: ResidualMLPMemoryState,
|
|
682
|
+
*,
|
|
683
|
+
meta: ResidualMLPMemory | None = None,
|
|
684
|
+
) -> torch.Tensor:
|
|
685
|
+
if meta is None:
|
|
686
|
+
w2 = mem.w2
|
|
687
|
+
w1 = mem.w1
|
|
688
|
+
w_skip = mem.w_skip
|
|
689
|
+
else:
|
|
690
|
+
w2 = self._straight_through_meta(mem.w2, meta.w2.weight)
|
|
691
|
+
w1 = self._straight_through_meta(mem.w1, meta.w1.weight)
|
|
692
|
+
w_skip = None
|
|
693
|
+
if mem.w_skip is not None:
|
|
694
|
+
if meta.w_skip is None:
|
|
695
|
+
raise RuntimeError("Expected meta w_skip for projected residual memory")
|
|
696
|
+
w_skip = self._straight_through_meta(mem.w_skip, meta.w_skip.weight)
|
|
697
|
+
if x.ndim == 2:
|
|
698
|
+
x_seq = x.unsqueeze(1)
|
|
699
|
+
squeeze = True
|
|
700
|
+
else:
|
|
701
|
+
x_seq = x
|
|
702
|
+
squeeze = False
|
|
703
|
+
w2_t = w2.transpose(-1, -2)
|
|
704
|
+
hidden = torch.matmul(x_seq, w2_t)
|
|
705
|
+
hidden = F.gelu(hidden)
|
|
706
|
+
w1_t = w1.transpose(-1, -2)
|
|
707
|
+
out = torch.matmul(hidden, w1_t)
|
|
708
|
+
if w_skip is not None:
|
|
709
|
+
w_skip_t = w_skip.transpose(-1, -2)
|
|
710
|
+
out = out + torch.matmul(x_seq, w_skip_t)
|
|
711
|
+
elif out.size(-1) == x_seq.size(-1):
|
|
712
|
+
out = out + x_seq
|
|
713
|
+
if squeeze:
|
|
714
|
+
return out.squeeze(1)
|
|
715
|
+
return out
|
|
716
|
+
|
|
717
|
+
@staticmethod
|
|
718
|
+
def _straight_through_meta(fast: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:
|
|
719
|
+
if meta.ndim > fast.ndim:
|
|
720
|
+
raise ValueError("meta tensor must have <= fast tensor rank")
|
|
721
|
+
expanded = meta
|
|
722
|
+
while expanded.ndim < fast.ndim:
|
|
723
|
+
expanded = expanded.unsqueeze(0)
|
|
724
|
+
return fast + (expanded - expanded.detach())
|