nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,724 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.func import grad, vmap
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class SelfModifyingTitansConfig:
14
+ dim: int
15
+ eta_scale: float = 1e-3
16
+ chunk_size_other: int = 1
17
+ chunk_size_memory: int | None = None
18
+ objective: str = "l2"
19
+ stopgrad_vhat: bool = True
20
+ use_rank1_precond: bool = True
21
+ use_alpha: bool = True
22
+ momentum: float = 0.0
23
+ qk_l2_norm: bool = True
24
+ adaptive_q: bool = False
25
+ use_skip: bool = True
26
+ local_conv_window: int | None = 4
27
+ eps: float = 1e-6
28
+
29
+ def __post_init__(self) -> None:
30
+ if self.dim <= 0:
31
+ raise ValueError("dim must be positive")
32
+ if self.eta_scale <= 0:
33
+ raise ValueError("eta_scale must be positive")
34
+ if self.chunk_size_other <= 0:
35
+ raise ValueError("chunk_size_other must be positive")
36
+ if self.chunk_size_memory is not None and self.chunk_size_memory <= 0:
37
+ raise ValueError("chunk_size_memory must be positive")
38
+ if self.objective not in {"l2", "dot"}:
39
+ raise ValueError("objective must be one of {'l2', 'dot'}")
40
+ if not (0.0 <= self.momentum < 1.0):
41
+ raise ValueError("momentum must be in [0, 1)")
42
+ if self.local_conv_window is not None and int(self.local_conv_window) <= 0:
43
+ raise ValueError("local_conv_window must be positive")
44
+ if self.chunk_size_memory is None:
45
+ object.__setattr__(self, "chunk_size_memory", int(self.chunk_size_other))
46
+
47
+
48
+ @dataclass
49
+ class ResidualMLPMemoryState:
50
+ w1: torch.Tensor
51
+ w2: torch.Tensor
52
+ w_skip: torch.Tensor | None = None
53
+ m_w1: torch.Tensor | None = None
54
+ m_w2: torch.Tensor | None = None
55
+ m_w_skip: torch.Tensor | None = None
56
+
57
+ def clone(self) -> "ResidualMLPMemoryState":
58
+ return ResidualMLPMemoryState(
59
+ w1=self.w1.detach().clone(),
60
+ w2=self.w2.detach().clone(),
61
+ w_skip=None if self.w_skip is None else self.w_skip.detach().clone(),
62
+ m_w1=None if self.m_w1 is None else self.m_w1.detach().clone(),
63
+ m_w2=None if self.m_w2 is None else self.m_w2.detach().clone(),
64
+ m_w_skip=None if self.m_w_skip is None else self.m_w_skip.detach().clone(),
65
+ )
66
+
67
+
68
+ @dataclass
69
+ class SelfModifyingTitansState:
70
+ """
71
+ Fast state for self-modifying Titans.
72
+
73
+ Each memory M_□ is a residual MLP (Eq. 91) whose initial parameters are meta-learned
74
+ (stored in the module) and cloned into this fast state per context.
75
+ """
76
+
77
+ k: ResidualMLPMemoryState
78
+ v: ResidualMLPMemoryState
79
+ q: ResidualMLPMemoryState
80
+ eta: ResidualMLPMemoryState
81
+ alpha: ResidualMLPMemoryState
82
+ memory: ResidualMLPMemoryState
83
+
84
+ def clone(self) -> "SelfModifyingTitansState":
85
+ return SelfModifyingTitansState(
86
+ k=self.k.clone(),
87
+ v=self.v.clone(),
88
+ q=self.q.clone(),
89
+ eta=self.eta.clone(),
90
+ alpha=self.alpha.clone(),
91
+ memory=self.memory.clone(),
92
+ )
93
+
94
+
95
+ class ResidualMLPMemory(nn.Module):
96
+ def __init__(
97
+ self,
98
+ *,
99
+ in_dim: int,
100
+ out_dim: int,
101
+ hidden_dim: int,
102
+ activation: Callable[[torch.Tensor], torch.Tensor],
103
+ use_skip: bool = True,
104
+ ) -> None:
105
+ super().__init__()
106
+ if in_dim <= 0 or out_dim <= 0 or hidden_dim <= 0:
107
+ raise ValueError("in_dim/out_dim/hidden_dim must be positive")
108
+ self.in_dim = int(in_dim)
109
+ self.out_dim = int(out_dim)
110
+ self.hidden_dim = int(hidden_dim)
111
+ self.activation = activation
112
+ self.use_skip = bool(use_skip)
113
+ self.w2 = nn.Linear(self.in_dim, self.hidden_dim, bias=False)
114
+ self.w1 = nn.Linear(self.hidden_dim, self.out_dim, bias=False)
115
+ self.w_skip: nn.Linear | None = None
116
+ if self.use_skip and self.in_dim != self.out_dim:
117
+ self.w_skip = nn.Linear(self.in_dim, self.out_dim, bias=False)
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
120
+ hidden = self.activation(self.w2(x))
121
+ out = self.w1(hidden)
122
+ if self.w_skip is not None:
123
+ return self.w_skip(x) + out
124
+ if out.shape[-1] == x.shape[-1]:
125
+ return x + out
126
+ return out
127
+
128
+
129
+ class SelfModifyingTitans(nn.Module):
130
+ """
131
+ Self-modifying Titans (Nested Learning paper, Eqs. 83–93), correctness-first.
132
+
133
+ - Multiple memories: M_k, M_v, M_q, M_eta, M_alpha, M_memory.
134
+ - Each memory is a 2-layer residual MLP (Eq. 91).
135
+ - Updates are performed on fast state using chunked DGD-like rule (Eq. 90/93).
136
+
137
+ Note: This implementation prioritizes semantic fidelity and testability over speed.
138
+ """
139
+
140
+ def __init__(self, config: SelfModifyingTitansConfig):
141
+ super().__init__()
142
+ self.config = config
143
+ dim = config.dim
144
+ hidden = dim
145
+ act = F.gelu
146
+ self.local_conv: nn.Conv1d | None = None
147
+ if config.local_conv_window is not None:
148
+ window = int(config.local_conv_window)
149
+ self.local_conv = nn.Conv1d(
150
+ dim,
151
+ dim,
152
+ kernel_size=window,
153
+ groups=dim,
154
+ padding=0,
155
+ bias=False,
156
+ )
157
+ self.w_q = nn.Linear(dim, dim, bias=False)
158
+ self.m_k = ResidualMLPMemory(
159
+ in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
160
+ )
161
+ self.m_v = ResidualMLPMemory(
162
+ in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
163
+ )
164
+ self.m_q = ResidualMLPMemory(
165
+ in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
166
+ )
167
+ self.m_eta = ResidualMLPMemory(
168
+ in_dim=dim, out_dim=1, hidden_dim=hidden, activation=act, use_skip=config.use_skip
169
+ )
170
+ self.m_alpha = ResidualMLPMemory(
171
+ in_dim=dim, out_dim=1, hidden_dim=hidden, activation=act, use_skip=config.use_skip
172
+ )
173
+ self.m_memory = ResidualMLPMemory(
174
+ in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip
175
+ )
176
+
177
+ def init_fast_state(self) -> SelfModifyingTitansState:
178
+ return SelfModifyingTitansState(
179
+ k=self._init_memory_state(self.m_k),
180
+ v=self._init_memory_state(self.m_v),
181
+ q=self._init_memory_state(self.m_q),
182
+ eta=self._init_memory_state(self.m_eta),
183
+ alpha=self._init_memory_state(self.m_alpha),
184
+ memory=self._init_memory_state(self.m_memory),
185
+ )
186
+
187
+ def apply_updates_inplace(
188
+ self,
189
+ x: torch.Tensor,
190
+ *,
191
+ chunk_size_other: int | None = None,
192
+ chunk_size_memory: int | None = None,
193
+ ) -> None:
194
+ """
195
+ Apply the self-modifying update rule to the *module parameters* in-place.
196
+
197
+ This is intended to be called in an explicit "update pass" under `torch.no_grad()`
198
+ (e.g., after an outer backward), so we avoid mixing differentiable reads with
199
+ in-place writes during the same autograd graph.
200
+ """
201
+ state = self.init_fast_state()
202
+ _out, updated = self.forward_with_updates(
203
+ x,
204
+ state,
205
+ chunk_size_other=chunk_size_other,
206
+ chunk_size_memory=chunk_size_memory,
207
+ )
208
+ self._load_state_mean_(updated)
209
+
210
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
211
+ x = self._apply_local_conv(x)
212
+ q = self.m_q(x) if self.config.adaptive_q else self.w_q(x)
213
+ if self.config.qk_l2_norm:
214
+ q = F.normalize(q, dim=-1, eps=self.config.eps)
215
+ return self.m_memory(q)
216
+
217
+ def forward_with_state(
218
+ self,
219
+ x: torch.Tensor,
220
+ state: SelfModifyingTitansState,
221
+ ) -> torch.Tensor:
222
+ if x.ndim != 3:
223
+ raise ValueError("Expected x to have shape (B, T, D)")
224
+ batch, _seq_len, dim = x.shape
225
+ if dim != self.config.dim:
226
+ raise ValueError(f"Expected dim={self.config.dim}, got {dim}")
227
+ state = self._ensure_batched_state(state, batch)
228
+ x = self._apply_local_conv(x)
229
+ q = (
230
+ self._memory_forward(x, state.q, meta=self.m_q)
231
+ if self.config.adaptive_q
232
+ else self.w_q(x)
233
+ )
234
+ if self.config.qk_l2_norm:
235
+ q = F.normalize(q, dim=-1, eps=self.config.eps)
236
+ return self._memory_forward(q, state.memory, meta=self.m_memory)
237
+
238
+ def forward_with_updates(
239
+ self,
240
+ x: torch.Tensor,
241
+ state: SelfModifyingTitansState,
242
+ *,
243
+ chunk_size_other: int | None = None,
244
+ chunk_size_memory: int | None = None,
245
+ ) -> tuple[torch.Tensor, SelfModifyingTitansState]:
246
+ if x.ndim != 3:
247
+ raise ValueError("Expected x to have shape (B, T, D)")
248
+ batch, seq_len, dim = x.shape
249
+ if dim != self.config.dim:
250
+ raise ValueError(f"Expected dim={self.config.dim}, got {dim}")
251
+ state = self._ensure_batched_state(state, batch)
252
+ x = self._apply_local_conv(x)
253
+ other_chunk = int(
254
+ self.config.chunk_size_other if chunk_size_other is None else chunk_size_other
255
+ )
256
+ memory_chunk_cfg = self.config.chunk_size_memory
257
+ if memory_chunk_cfg is None:
258
+ memory_chunk_cfg = self.config.chunk_size_other
259
+ memory_chunk = int(memory_chunk_cfg if chunk_size_memory is None else chunk_size_memory)
260
+ if other_chunk <= 0 or memory_chunk <= 0:
261
+ raise ValueError("chunk sizes must be positive")
262
+
263
+ outputs: list[torch.Tensor] = []
264
+ other_k: list[torch.Tensor] = []
265
+ other_v: list[torch.Tensor] = []
266
+ other_eta: list[torch.Tensor] = []
267
+ other_alpha: list[torch.Tensor] = []
268
+ memory_k: list[torch.Tensor] = []
269
+ memory_v: list[torch.Tensor] = []
270
+ memory_eta: list[torch.Tensor] = []
271
+ memory_alpha: list[torch.Tensor] = []
272
+
273
+ def _next_boundary(idx: int, *, chunk_size: int) -> int:
274
+ if chunk_size <= 0:
275
+ raise ValueError("chunk_size must be positive")
276
+ return min(((idx // chunk_size) + 1) * chunk_size, seq_len)
277
+
278
+ with torch.no_grad():
279
+ idx = 0
280
+ while idx < seq_len:
281
+ next_other = _next_boundary(idx, chunk_size=other_chunk)
282
+ next_memory = _next_boundary(idx, chunk_size=memory_chunk)
283
+ end = min(next_other, next_memory, seq_len)
284
+ x_chunk = x[:, idx:end, :]
285
+
286
+ k_chunk = self._memory_forward(x_chunk, state.k)
287
+ v_chunk = self._memory_forward(x_chunk, state.v)
288
+ q_chunk = (
289
+ self._memory_forward(x_chunk, state.q)
290
+ if self.config.adaptive_q
291
+ else self.w_q(x_chunk)
292
+ )
293
+ if self.config.qk_l2_norm:
294
+ k_chunk = F.normalize(k_chunk, dim=-1, eps=self.config.eps)
295
+ q_chunk = F.normalize(q_chunk, dim=-1, eps=self.config.eps)
296
+ eta_chunk = self._memory_forward(x_chunk, state.eta).squeeze(-1)
297
+ eta_chunk = F.softplus(eta_chunk) * self.config.eta_scale
298
+ if self.config.use_alpha:
299
+ alpha_chunk = self._memory_forward(x_chunk, state.alpha).squeeze(-1)
300
+ alpha_chunk = torch.sigmoid(alpha_chunk)
301
+ else:
302
+ alpha_chunk = torch.ones_like(eta_chunk)
303
+ o_chunk = self._memory_forward(q_chunk, state.memory)
304
+ outputs.append(o_chunk)
305
+
306
+ other_k.append(k_chunk)
307
+ other_v.append(v_chunk)
308
+ other_eta.append(eta_chunk)
309
+ other_alpha.append(alpha_chunk)
310
+ memory_k.append(k_chunk)
311
+ memory_v.append(v_chunk)
312
+ memory_eta.append(eta_chunk)
313
+ memory_alpha.append(alpha_chunk)
314
+
315
+ idx = end
316
+
317
+ if idx == next_other and other_k:
318
+ other_memories: tuple[str, ...] = ("k", "v", "eta")
319
+ if self.config.adaptive_q:
320
+ other_memories = (*other_memories, "q")
321
+ if self.config.use_alpha:
322
+ other_memories = (*other_memories, "alpha")
323
+ self._apply_chunk_update_seq(
324
+ state,
325
+ k_seq=torch.cat(other_k, dim=1),
326
+ v_seq=torch.cat(other_v, dim=1),
327
+ eta_seq=torch.cat(other_eta, dim=1),
328
+ alpha_seq=torch.cat(other_alpha, dim=1),
329
+ memories=other_memories,
330
+ )
331
+ other_k.clear()
332
+ other_v.clear()
333
+ other_eta.clear()
334
+ other_alpha.clear()
335
+
336
+ if idx == next_memory and memory_k:
337
+ self._apply_chunk_update_seq(
338
+ state,
339
+ k_seq=torch.cat(memory_k, dim=1),
340
+ v_seq=torch.cat(memory_v, dim=1),
341
+ eta_seq=torch.cat(memory_eta, dim=1),
342
+ alpha_seq=torch.cat(memory_alpha, dim=1),
343
+ memories=("memory",),
344
+ )
345
+ memory_k.clear()
346
+ memory_v.clear()
347
+ memory_eta.clear()
348
+ memory_alpha.clear()
349
+
350
+ if other_k:
351
+ other_memories = ("k", "v", "eta")
352
+ if self.config.adaptive_q:
353
+ other_memories = (*other_memories, "q")
354
+ if self.config.use_alpha:
355
+ other_memories = (*other_memories, "alpha")
356
+ self._apply_chunk_update_seq(
357
+ state,
358
+ k_seq=torch.cat(other_k, dim=1),
359
+ v_seq=torch.cat(other_v, dim=1),
360
+ eta_seq=torch.cat(other_eta, dim=1),
361
+ alpha_seq=torch.cat(other_alpha, dim=1),
362
+ memories=other_memories,
363
+ )
364
+ if memory_k:
365
+ self._apply_chunk_update_seq(
366
+ state,
367
+ k_seq=torch.cat(memory_k, dim=1),
368
+ v_seq=torch.cat(memory_v, dim=1),
369
+ eta_seq=torch.cat(memory_eta, dim=1),
370
+ alpha_seq=torch.cat(memory_alpha, dim=1),
371
+ memories=("memory",),
372
+ )
373
+
374
+ return torch.cat(outputs, dim=1), state
375
+
376
+ def _apply_local_conv(self, x: torch.Tensor) -> torch.Tensor:
377
+ if self.local_conv is None:
378
+ return x
379
+ if x.ndim != 3:
380
+ raise ValueError("Expected x to have shape (B, T, D)")
381
+ kernel = int(self.local_conv.kernel_size[0])
382
+ # Causal depthwise conv: only attends to past tokens.
383
+ x_t = x.transpose(1, 2)
384
+ x_t = F.pad(x_t, (kernel - 1, 0))
385
+ x_t = self.local_conv(x_t)
386
+ return x_t.transpose(1, 2)
387
+
388
+ def _load_state_mean_(self, state: SelfModifyingTitansState) -> None:
389
+ def _mean_weight(weight: torch.Tensor) -> torch.Tensor:
390
+ return weight.mean(dim=0) if weight.ndim == 3 else weight
391
+
392
+ def _copy(module: ResidualMLPMemory, mem: ResidualMLPMemoryState) -> None:
393
+ module.w1.weight.copy_(_mean_weight(mem.w1))
394
+ module.w2.weight.copy_(_mean_weight(mem.w2))
395
+ if module.w_skip is None:
396
+ return
397
+ if mem.w_skip is None:
398
+ raise RuntimeError("Expected w_skip state for projected residual memory")
399
+ module.w_skip.weight.copy_(_mean_weight(mem.w_skip))
400
+
401
+ with torch.no_grad():
402
+ _copy(self.m_k, state.k)
403
+ _copy(self.m_v, state.v)
404
+ _copy(self.m_eta, state.eta)
405
+ if self.config.use_alpha:
406
+ _copy(self.m_alpha, state.alpha)
407
+ _copy(self.m_memory, state.memory)
408
+ if self.config.adaptive_q:
409
+ _copy(self.m_q, state.q)
410
+
411
+ def _apply_chunk_update(
412
+ self,
413
+ state: SelfModifyingTitansState,
414
+ buffer: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
415
+ *,
416
+ memories: tuple[str, ...],
417
+ ) -> None:
418
+ if not buffer:
419
+ return
420
+ k_seq = torch.stack([item[0] for item in buffer], dim=1)
421
+ v_seq = torch.stack([item[1] for item in buffer], dim=1)
422
+ eta_seq = torch.stack([item[2] for item in buffer], dim=1)
423
+ alpha_seq = torch.stack([item[3] for item in buffer], dim=1)
424
+ self._apply_chunk_update_seq(
425
+ state,
426
+ k_seq=k_seq,
427
+ v_seq=v_seq,
428
+ eta_seq=eta_seq,
429
+ alpha_seq=alpha_seq,
430
+ memories=memories,
431
+ )
432
+
433
+ def _apply_chunk_update_seq(
434
+ self,
435
+ state: SelfModifyingTitansState,
436
+ *,
437
+ k_seq: torch.Tensor,
438
+ v_seq: torch.Tensor,
439
+ eta_seq: torch.Tensor,
440
+ alpha_seq: torch.Tensor,
441
+ memories: tuple[str, ...],
442
+ ) -> None:
443
+ steps = k_seq.size(1)
444
+ dim = self.config.dim
445
+ eye = (
446
+ torch.eye(dim, device=k_seq.device, dtype=k_seq.dtype)
447
+ .unsqueeze(0)
448
+ .expand(k_seq.size(0), -1, -1)
449
+ )
450
+
451
+ boundary: dict[str, ResidualMLPMemoryState] = {
452
+ name: getattr(state, name).clone() for name in memories
453
+ }
454
+ grads = {name: self._memory_grads_chunk(boundary[name], k_seq, v_seq) for name in memories}
455
+
456
+ for t in range(steps):
457
+ k_t = k_seq[:, t, :]
458
+ eta_t = eta_seq[:, t]
459
+ alpha_t = alpha_seq[:, t]
460
+ kk = torch.einsum("bi,bj->bij", k_t, k_t)
461
+ precond = alpha_t[:, None, None] * eye - eta_t[:, None, None] * kk
462
+ for name in memories:
463
+ fast = getattr(state, name)
464
+ g1, g2, gskip = grads[name]
465
+ self._apply_param_update(
466
+ fast,
467
+ (
468
+ g1[:, t, ...],
469
+ g2[:, t, ...],
470
+ None if gskip is None else gskip[:, t, ...],
471
+ ),
472
+ eta_t,
473
+ alpha_t,
474
+ precond,
475
+ )
476
+
477
+ def _memory_grads(
478
+ self,
479
+ frozen: ResidualMLPMemoryState,
480
+ k_t: torch.Tensor,
481
+ v_t: torch.Tensor,
482
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
483
+ with torch.enable_grad():
484
+ w1 = frozen.w1.detach().requires_grad_(True)
485
+ w2 = frozen.w2.detach().requires_grad_(True)
486
+ w_skip = None
487
+ if frozen.w_skip is not None:
488
+ w_skip = frozen.w_skip.detach().requires_grad_(True)
489
+
490
+ pred = self._memory_forward(k_t, ResidualMLPMemoryState(w1=w1, w2=w2, w_skip=w_skip))
491
+ vhat = self._memory_forward(v_t, ResidualMLPMemoryState(w1=w1, w2=w2, w_skip=w_skip))
492
+ if self.config.stopgrad_vhat:
493
+ vhat = vhat.detach()
494
+
495
+ if self.config.objective == "dot":
496
+ loss = -(pred * vhat).sum(dim=-1)
497
+ else:
498
+ loss = (pred - vhat).pow(2).sum(dim=-1)
499
+ loss_scalar = loss.sum()
500
+
501
+ grads = torch.autograd.grad(
502
+ loss_scalar,
503
+ (w1, w2, w_skip) if w_skip is not None else (w1, w2),
504
+ retain_graph=False,
505
+ create_graph=False,
506
+ allow_unused=False,
507
+ )
508
+ if w_skip is None:
509
+ g1, g2 = grads
510
+ return g1, g2, None
511
+ g1, g2, gskip = grads
512
+ return g1, g2, gskip
513
+
514
+ def _memory_grads_chunk(
515
+ self,
516
+ frozen: ResidualMLPMemoryState,
517
+ k_seq: torch.Tensor,
518
+ v_seq: torch.Tensor,
519
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
520
+ """
521
+ Compute per-token gradients for an entire chunk in parallel (paper §8.2).
522
+
523
+ Returns gradients with leading shape (B, T, ...).
524
+ """
525
+ w1 = frozen.w1.detach()
526
+ w2 = frozen.w2.detach()
527
+ w_skip = None if frozen.w_skip is None else frozen.w_skip.detach()
528
+
529
+ k_tokens = k_seq.transpose(0, 1)
530
+ v_tokens = v_seq.transpose(0, 1)
531
+
532
+ if w_skip is None:
533
+
534
+ def loss_fn_noskip(
535
+ w1_t: torch.Tensor,
536
+ w2_t: torch.Tensor,
537
+ k_t: torch.Tensor,
538
+ v_t: torch.Tensor,
539
+ ) -> torch.Tensor:
540
+ mem = ResidualMLPMemoryState(w1=w1_t, w2=w2_t)
541
+ pred = self._memory_forward(k_t, mem)
542
+ vhat = self._memory_forward(v_t, mem)
543
+ if self.config.stopgrad_vhat:
544
+ vhat = vhat.detach()
545
+ if self.config.objective == "dot":
546
+ loss = -(pred * vhat).sum(dim=-1)
547
+ else:
548
+ loss = (pred - vhat).pow(2).sum(dim=-1)
549
+ return loss.sum()
550
+
551
+ grad_fn = grad(loss_fn_noskip, argnums=(0, 1))
552
+ g1_tokens, g2_tokens = vmap(grad_fn, in_dims=(None, None, 0, 0))(
553
+ w1,
554
+ w2,
555
+ k_tokens,
556
+ v_tokens,
557
+ )
558
+ return g1_tokens.transpose(0, 1), g2_tokens.transpose(0, 1), None
559
+
560
+ def loss_fn_skip(
561
+ w1_t: torch.Tensor,
562
+ w2_t: torch.Tensor,
563
+ w_skip_t: torch.Tensor,
564
+ k_t: torch.Tensor,
565
+ v_t: torch.Tensor,
566
+ ) -> torch.Tensor:
567
+ mem = ResidualMLPMemoryState(w1=w1_t, w2=w2_t, w_skip=w_skip_t)
568
+ pred = self._memory_forward(k_t, mem)
569
+ vhat = self._memory_forward(v_t, mem)
570
+ if self.config.stopgrad_vhat:
571
+ vhat = vhat.detach()
572
+ if self.config.objective == "dot":
573
+ loss = -(pred * vhat).sum(dim=-1)
574
+ else:
575
+ loss = (pred - vhat).pow(2).sum(dim=-1)
576
+ return loss.sum()
577
+
578
+ grad_fn = grad(loss_fn_skip, argnums=(0, 1, 2))
579
+ g1_tokens, g2_tokens, gskip_tokens = vmap(
580
+ grad_fn,
581
+ in_dims=(None, None, None, 0, 0),
582
+ )(w1, w2, w_skip, k_tokens, v_tokens)
583
+ return (
584
+ g1_tokens.transpose(0, 1),
585
+ g2_tokens.transpose(0, 1),
586
+ gskip_tokens.transpose(0, 1),
587
+ )
588
+
589
+ def _apply_param_update(
590
+ self,
591
+ fast: ResidualMLPMemoryState,
592
+ grads: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None],
593
+ eta_t: torch.Tensor,
594
+ alpha_t: torch.Tensor,
595
+ precond: torch.Tensor,
596
+ ) -> None:
597
+ g1, g2, gskip = grads
598
+ g1 = self._apply_momentum(fast, "m_w1", g1)
599
+ g2 = self._apply_momentum(fast, "m_w2", g2)
600
+ if self.config.use_rank1_precond:
601
+ fast.w2 = torch.matmul(fast.w2, precond) - eta_t[:, None, None] * g2
602
+ else:
603
+ fast.w2 = alpha_t[:, None, None] * fast.w2 - eta_t[:, None, None] * g2
604
+ fast.w1 = alpha_t[:, None, None] * fast.w1 - eta_t[:, None, None] * g1
605
+
606
+ if fast.w_skip is None:
607
+ return
608
+ if gskip is None:
609
+ raise RuntimeError("Expected w_skip grad to be present")
610
+ gskip = self._apply_momentum(fast, "m_w_skip", gskip)
611
+ if self.config.use_rank1_precond:
612
+ fast.w_skip = torch.matmul(fast.w_skip, precond) - eta_t[:, None, None] * gskip
613
+ else:
614
+ fast.w_skip = alpha_t[:, None, None] * fast.w_skip - eta_t[:, None, None] * gskip
615
+
616
+ def _apply_momentum(
617
+ self,
618
+ fast: ResidualMLPMemoryState,
619
+ attr_name: str,
620
+ grad: torch.Tensor,
621
+ ) -> torch.Tensor:
622
+ beta = float(self.config.momentum)
623
+ if beta <= 0.0:
624
+ return grad
625
+ buf = getattr(fast, attr_name)
626
+ if buf is None:
627
+ buf = torch.zeros_like(grad)
628
+ buf = beta * buf + grad
629
+ setattr(fast, attr_name, buf)
630
+ return buf
631
+
632
+ def _init_memory_state(self, module: ResidualMLPMemory) -> ResidualMLPMemoryState:
633
+ skip = None if module.w_skip is None else module.w_skip.weight.detach().clone()
634
+ return ResidualMLPMemoryState(
635
+ w1=module.w1.weight.detach().clone(),
636
+ w2=module.w2.weight.detach().clone(),
637
+ w_skip=skip,
638
+ )
639
+
640
+ def _ensure_batched_state(
641
+ self, state: SelfModifyingTitansState, batch: int
642
+ ) -> SelfModifyingTitansState:
643
+ if state.k.w1.ndim == 2:
644
+ return SelfModifyingTitansState(
645
+ k=self._expand_memory_state(state.k, batch),
646
+ v=self._expand_memory_state(state.v, batch),
647
+ q=self._expand_memory_state(state.q, batch),
648
+ eta=self._expand_memory_state(state.eta, batch),
649
+ alpha=self._expand_memory_state(state.alpha, batch),
650
+ memory=self._expand_memory_state(state.memory, batch),
651
+ )
652
+ if state.k.w1.ndim != 3:
653
+ raise ValueError("SelfModifyingTitansState weights must be 2D or 3D tensors")
654
+ if state.k.w1.size(0) != batch:
655
+ raise ValueError(
656
+ f"State batch mismatch: expected batch={batch}, got {state.k.w1.size(0)}"
657
+ )
658
+ return state
659
+
660
+ def _expand_memory_state(
661
+ self, mem: ResidualMLPMemoryState, batch: int
662
+ ) -> ResidualMLPMemoryState:
663
+ def _expand(t: torch.Tensor) -> torch.Tensor:
664
+ return t.detach().clone().unsqueeze(0).repeat(batch, 1, 1)
665
+
666
+ def _expand_opt(t: torch.Tensor | None) -> torch.Tensor | None:
667
+ return None if t is None else _expand(t)
668
+
669
+ return ResidualMLPMemoryState(
670
+ w1=_expand(mem.w1),
671
+ w2=_expand(mem.w2),
672
+ w_skip=_expand_opt(mem.w_skip),
673
+ m_w1=_expand_opt(mem.m_w1),
674
+ m_w2=_expand_opt(mem.m_w2),
675
+ m_w_skip=_expand_opt(mem.m_w_skip),
676
+ )
677
+
678
+ def _memory_forward(
679
+ self,
680
+ x: torch.Tensor,
681
+ mem: ResidualMLPMemoryState,
682
+ *,
683
+ meta: ResidualMLPMemory | None = None,
684
+ ) -> torch.Tensor:
685
+ if meta is None:
686
+ w2 = mem.w2
687
+ w1 = mem.w1
688
+ w_skip = mem.w_skip
689
+ else:
690
+ w2 = self._straight_through_meta(mem.w2, meta.w2.weight)
691
+ w1 = self._straight_through_meta(mem.w1, meta.w1.weight)
692
+ w_skip = None
693
+ if mem.w_skip is not None:
694
+ if meta.w_skip is None:
695
+ raise RuntimeError("Expected meta w_skip for projected residual memory")
696
+ w_skip = self._straight_through_meta(mem.w_skip, meta.w_skip.weight)
697
+ if x.ndim == 2:
698
+ x_seq = x.unsqueeze(1)
699
+ squeeze = True
700
+ else:
701
+ x_seq = x
702
+ squeeze = False
703
+ w2_t = w2.transpose(-1, -2)
704
+ hidden = torch.matmul(x_seq, w2_t)
705
+ hidden = F.gelu(hidden)
706
+ w1_t = w1.transpose(-1, -2)
707
+ out = torch.matmul(hidden, w1_t)
708
+ if w_skip is not None:
709
+ w_skip_t = w_skip.transpose(-1, -2)
710
+ out = out + torch.matmul(x_seq, w_skip_t)
711
+ elif out.size(-1) == x_seq.size(-1):
712
+ out = out + x_seq
713
+ if squeeze:
714
+ return out.squeeze(1)
715
+ return out
716
+
717
+ @staticmethod
718
+ def _straight_through_meta(fast: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:
719
+ if meta.ndim > fast.ndim:
720
+ raise ValueError("meta tensor must have <= fast tensor rank")
721
+ expanded = meta
722
+ while expanded.ndim < fast.ndim:
723
+ expanded = expanded.unsqueeze(0)
724
+ return fast + (expanded - expanded.detach())