nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1973 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, Sequence, Set
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from ..backbones import AttentionConfig, SelfAttention
11
+ from ..cms import CMS
12
+ from ..fast_state import AttentionKVCache, BlockFastState, CMSChunkBuffer
13
+ from ..functional import (
14
+ call_with_deltas,
15
+ call_with_params,
16
+ grads_to_dict,
17
+ params_with_deltas,
18
+ require_grad_params,
19
+ )
20
+ from ..levels import LevelSpec
21
+ from ..optim.manager import LevelConfig, LevelOptimizerManager
22
+ from ..titan.memory import TitanMemory, TitanMemoryConfig
23
+ from ..titan.self_modifying import SelfModifyingTitans, SelfModifyingTitansConfig
24
+ from .self_mod import SelfModifier
25
+
26
+
27
+ def _chunk_loss(
28
+ prediction: torch.Tensor,
29
+ delta_target: torch.Tensor,
30
+ mask_f: torch.Tensor,
31
+ *,
32
+ reduction: str,
33
+ differentiable_target: bool = False,
34
+ ) -> torch.Tensor:
35
+ if differentiable_target:
36
+ target = prediction.detach() - delta_target
37
+ else:
38
+ target = (prediction.detach() - delta_target).detach()
39
+ diff_sq = (prediction - target).pow(2)
40
+ masked = diff_sq * mask_f
41
+ if reduction == "mean":
42
+ return masked.sum() / mask_f.sum().clamp(min=1.0)
43
+ if reduction == "sum":
44
+ return masked.sum()
45
+ raise ValueError(f"Unsupported cms_chunk_reduction={reduction}")
46
+
47
+
48
+ def _min_update_period(levels: Sequence[LevelSpec]) -> int:
49
+ periods = [int(spec.update_period) for spec in levels if int(spec.update_period) > 0]
50
+ return min(periods) if periods else 1
51
+
52
+
53
+ @dataclass
54
+ class _CmsBuffer:
55
+ inputs: list[torch.Tensor]
56
+ teach: list[torch.Tensor]
57
+ active: list[torch.Tensor]
58
+ count: int = 0
59
+
60
+
61
+ def _clear_buffer(buffer: _CmsBuffer | CMSChunkBuffer) -> None:
62
+ buffer.inputs.clear()
63
+ buffer.teach.clear()
64
+ buffer.active.clear()
65
+ buffer.count = 0
66
+
67
+
68
+ def _fast_state_buffers(
69
+ fast_state: BlockFastState, levels: Sequence[LevelSpec]
70
+ ) -> dict[str, CMSChunkBuffer]:
71
+ buffers = fast_state.cms_online_buffers
72
+ for spec in levels:
73
+ if spec.name not in buffers:
74
+ buffers[spec.name] = CMSChunkBuffer()
75
+ return buffers
76
+
77
+
78
+ def _pop_buffer_chunk(
79
+ buffer: _CmsBuffer | CMSChunkBuffer,
80
+ count: int,
81
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
82
+ if count <= 0:
83
+ raise ValueError("count must be positive")
84
+ result_inputs: list[torch.Tensor] = []
85
+ result_teach: list[torch.Tensor] = []
86
+ result_active: list[torch.Tensor] = []
87
+ remaining = count
88
+ while remaining > 0:
89
+ first = buffer.inputs[0]
90
+ chunk_len = first.size(1)
91
+ take = min(remaining, chunk_len)
92
+ src_inputs = buffer.inputs[0]
93
+ src_teach = buffer.teach[0]
94
+ src_active = buffer.active[0]
95
+ result_inputs.append(src_inputs[:, :take])
96
+ result_teach.append(src_teach[:, :take])
97
+ result_active.append(src_active[:, :take])
98
+ if take == chunk_len:
99
+ buffer.inputs.pop(0)
100
+ buffer.teach.pop(0)
101
+ buffer.active.pop(0)
102
+ else:
103
+ buffer.inputs[0] = src_inputs[:, take:]
104
+ buffer.teach[0] = src_teach[:, take:]
105
+ buffer.active[0] = src_active[:, take:]
106
+ remaining -= take
107
+ return (
108
+ torch.cat(result_inputs, dim=1),
109
+ torch.cat(result_teach, dim=1),
110
+ torch.cat(result_active, dim=1),
111
+ )
112
+
113
+
114
+ @dataclass
115
+ class HOPEBlockConfig:
116
+ dim: int
117
+ heads: int
118
+ titan_level: LevelSpec
119
+ cms_levels: Sequence[LevelSpec]
120
+ titan_hidden_multiplier: int = 4
121
+ cms_hidden_multiplier: int = 4
122
+ cms_use_layernorm: bool = True
123
+ activation: str = "gelu"
124
+ qk_l2_norm: bool = False
125
+ local_conv_window: int | None = None
126
+ self_mod_hidden: int = 4
127
+ self_mod_lr: float = 1e-3
128
+ cms_chunk_reduction: str = "sum"
129
+ cms_online_updates: bool = True
130
+ cms_flush_partial_at_end: bool = False
131
+ optimizer_configs: Dict[str, dict] = field(default_factory=dict)
132
+
133
+
134
+ @dataclass
135
+ class HOPEAttentionBlockConfig:
136
+ dim: int
137
+ heads: int
138
+ cms_levels: Sequence[LevelSpec]
139
+ cms_hidden_multiplier: int = 4
140
+ cms_use_layernorm: bool = True
141
+ activation: str = "gelu"
142
+ qk_l2_norm: bool = False
143
+ local_conv_window: int | None = None
144
+ self_mod_lr: float = 1e-3
145
+ cms_chunk_reduction: str = "sum"
146
+ cms_online_updates: bool = True
147
+ cms_flush_partial_at_end: bool = False
148
+ optimizer_configs: Dict[str, dict] = field(default_factory=dict)
149
+
150
+
151
+ class HOPEAttentionBlock(nn.Module):
152
+ """
153
+ Paper-defined HOPE-Attention variant: softmax attention followed by CMS.
154
+
155
+ Reference: Nested Learning paper, HOPE-Attention note under Eqs. 94–97.
156
+ """
157
+
158
+ def __init__(self, config: HOPEAttentionBlockConfig):
159
+ super().__init__()
160
+ self.config = config
161
+ self.last_update_stats: Dict[str, Dict[str, float]] = {}
162
+ self.surprise_threshold: float | None = None
163
+ self.surprise_metric: str = "l2"
164
+ self.allowed_levels: Set[str] | None = None
165
+ self.attn = SelfAttention(
166
+ AttentionConfig(
167
+ dim=config.dim,
168
+ heads=config.heads,
169
+ qk_l2_norm=config.qk_l2_norm,
170
+ local_conv_window=config.local_conv_window,
171
+ )
172
+ )
173
+ self.cms = CMS(
174
+ dim=config.dim,
175
+ levels=config.cms_levels,
176
+ hidden_multiplier=config.cms_hidden_multiplier,
177
+ activation=config.activation,
178
+ use_layernorm=config.cms_use_layernorm,
179
+ )
180
+ level_config = LevelConfig(
181
+ specs=config.cms_levels,
182
+ optimizer_configs=config.optimizer_configs,
183
+ default_lr=config.self_mod_lr,
184
+ )
185
+ self.level_manager = LevelOptimizerManager(level_config)
186
+
187
+ def forward(
188
+ self,
189
+ x: torch.Tensor,
190
+ *,
191
+ teach_signal: torch.Tensor | None = None,
192
+ surprise_value: float | None = None,
193
+ fast_state: BlockFastState | None = None,
194
+ finalize_updates: bool = True,
195
+ attention_cache: AttentionKVCache | None = None,
196
+ return_attention_cache: bool = False,
197
+ differentiable_updates: bool = False,
198
+ ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:
199
+ next_attn_cache: AttentionKVCache | None = None
200
+ if return_attention_cache:
201
+ attn_out, next_attn_cache = self.attn(
202
+ x,
203
+ kv_cache=attention_cache,
204
+ return_kv_cache=True,
205
+ )
206
+ else:
207
+ attn_out = self.attn(x, kv_cache=attention_cache)
208
+ if fast_state is None:
209
+ if teach_signal is not None and self.config.cms_online_updates:
210
+ cms_out = self._cms_forward_online(
211
+ attn_out,
212
+ teach_signal,
213
+ surprise_value,
214
+ finalize_updates=finalize_updates,
215
+ )
216
+ else:
217
+ cms_result = self.cms(attn_out, return_intermediates=True)
218
+ cms_out, cms_inputs, cms_outputs = cms_result
219
+ if teach_signal is not None:
220
+ self._update_cms(cms_inputs, cms_outputs, teach_signal, surprise_value)
221
+ self.level_manager.tick()
222
+ return cms_out
223
+ if teach_signal is not None and self.config.cms_online_updates:
224
+ cms_out = self._cms_forward_online_fast(
225
+ attn_out,
226
+ fast_state,
227
+ teach_signal,
228
+ surprise_value,
229
+ finalize_updates=finalize_updates,
230
+ differentiable_updates=differentiable_updates,
231
+ )
232
+ else:
233
+ cms_out, cms_inputs = self._cms_forward_fast(attn_out, fast_state)
234
+ if teach_signal is not None:
235
+ self._update_cms_fast(
236
+ fast_state,
237
+ cms_inputs,
238
+ teach_signal,
239
+ surprise_value,
240
+ differentiable_updates=differentiable_updates,
241
+ )
242
+ fast_state.level_manager.tick()
243
+ if return_attention_cache:
244
+ assert next_attn_cache is not None
245
+ return cms_out, next_attn_cache
246
+ return cms_out
247
+
248
+ def set_surprise_threshold(self, threshold: float | None) -> None:
249
+ self.surprise_threshold = threshold
250
+
251
+ def set_surprise_metric(self, metric: str) -> None:
252
+ self.surprise_metric = str(metric).strip().lower()
253
+
254
+ def set_allowed_levels(self, allowed: Set[str] | None) -> None:
255
+ self.allowed_levels = allowed.copy() if allowed is not None else None
256
+
257
+ def pop_update_stats(self) -> Dict[str, Dict[str, float]]:
258
+ stats = self.last_update_stats
259
+ self.last_update_stats = {}
260
+ return stats
261
+
262
+ def _cms_forward_fast(
263
+ self,
264
+ x: torch.Tensor,
265
+ fast_state: BlockFastState,
266
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
267
+ current = x
268
+ inputs: dict[str, torch.Tensor] = {}
269
+ for spec in self.config.cms_levels:
270
+ level_name = spec.name
271
+ inputs[level_name] = current
272
+ params = fast_state.cms_params[level_name]
273
+ current = call_with_deltas(self.cms.blocks[level_name], params, current)
274
+ return current, inputs
275
+
276
+ def _cms_forward_online(
277
+ self,
278
+ x: torch.Tensor,
279
+ teach_signal: torch.Tensor,
280
+ surprise_value: float | None,
281
+ *,
282
+ finalize_updates: bool = True,
283
+ ) -> torch.Tensor:
284
+ seq_len = x.shape[1]
285
+ base_chunk = _min_update_period(self.config.cms_levels)
286
+ active_mask = teach_signal.detach().abs().sum(dim=-1) > 0
287
+ outputs: list[torch.Tensor] = []
288
+ stats: dict[str, Dict[str, float]] = {}
289
+ buffers: dict[str, _CmsBuffer] = {}
290
+ for spec in self.config.cms_levels:
291
+ buffers[spec.name] = _CmsBuffer(inputs=[], teach=[], active=[], count=0)
292
+ stats[spec.name] = {
293
+ "grad_norm": 0.0,
294
+ "chunk_tokens": 0.0,
295
+ "gate_hit": 0.0,
296
+ "gate_hits": 0.0,
297
+ "updates_applied": 0.0,
298
+ "tokens_flushed": 0.0,
299
+ "pending_tokens": 0.0,
300
+ }
301
+
302
+ for start in range(0, seq_len, base_chunk):
303
+ end = min(start + base_chunk, seq_len)
304
+ chunk_in = x[:, start:end, :]
305
+ chunk_teach = teach_signal[:, start:end, :]
306
+ chunk_active = active_mask[:, start:end]
307
+
308
+ current = chunk_in
309
+ level_inputs: dict[str, torch.Tensor] = {}
310
+ for spec in self.config.cms_levels:
311
+ level_name = spec.name
312
+ level_inputs[level_name] = current
313
+ current = self.cms.blocks[level_name](current)
314
+ outputs.append(current)
315
+
316
+ for spec in self.config.cms_levels:
317
+ level_name = spec.name
318
+ buffer = buffers[level_name]
319
+ buffer.inputs.append(level_inputs[level_name].detach())
320
+ buffer.teach.append(chunk_teach)
321
+ buffer.active.append(chunk_active)
322
+ buffer.count += end - start
323
+ update_period = int(spec.update_period)
324
+ while update_period > 0 and buffer.count >= update_period:
325
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(
326
+ buffer, update_period
327
+ )
328
+ buffer.count -= update_period
329
+ magnitude = self._update_cms_chunk(
330
+ level_name,
331
+ chunk_inputs,
332
+ chunk_teach,
333
+ chunk_active,
334
+ surprise_value,
335
+ )
336
+ if magnitude > 0:
337
+ stats[level_name]["grad_norm"] += magnitude
338
+ stats[level_name]["chunk_tokens"] += float(update_period)
339
+ stats[level_name]["gate_hit"] += 1.0
340
+ stats[level_name]["gate_hits"] += 1.0
341
+ stats[level_name]["updates_applied"] += 1.0
342
+ if self.config.cms_flush_partial_at_end and finalize_updates:
343
+ for spec in self.config.cms_levels:
344
+ level_name = spec.name
345
+ buffer = buffers[level_name]
346
+ remaining = int(buffer.count)
347
+ if remaining <= 0:
348
+ continue
349
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)
350
+ buffer.count -= remaining
351
+ if not bool(chunk_active.any()):
352
+ continue
353
+ magnitude = self._update_cms_chunk(
354
+ level_name,
355
+ chunk_inputs,
356
+ chunk_teach,
357
+ chunk_active,
358
+ surprise_value,
359
+ )
360
+ if magnitude > 0:
361
+ stats[level_name]["grad_norm"] += magnitude
362
+ stats[level_name]["chunk_tokens"] += float(remaining)
363
+ stats[level_name]["gate_hit"] += 1.0
364
+ stats[level_name]["gate_hits"] += 1.0
365
+ stats[level_name]["updates_applied"] += 1.0
366
+ stats[level_name]["tokens_flushed"] += float(remaining)
367
+ for spec in self.config.cms_levels:
368
+ stats[spec.name]["pending_tokens"] = float(buffers[spec.name].count)
369
+ for level_name, payload in stats.items():
370
+ if (
371
+ payload["updates_applied"] <= 0
372
+ and payload["pending_tokens"] <= 0
373
+ and payload["tokens_flushed"] <= 0
374
+ ):
375
+ continue
376
+ if surprise_value is not None:
377
+ payload["surprise_value"] = surprise_value
378
+ self.last_update_stats[f"cms.{level_name}"] = payload
379
+ return torch.cat(outputs, dim=1)
380
+
381
+ def _cms_forward_online_fast(
382
+ self,
383
+ x: torch.Tensor,
384
+ fast_state: BlockFastState,
385
+ teach_signal: torch.Tensor,
386
+ surprise_value: float | None,
387
+ *,
388
+ finalize_updates: bool = True,
389
+ differentiable_updates: bool = False,
390
+ ) -> torch.Tensor:
391
+ seq_len = x.shape[1]
392
+ base_chunk = _min_update_period(self.config.cms_levels)
393
+ active_mask = teach_signal.detach().abs().sum(dim=-1) > 0
394
+ outputs: list[torch.Tensor] = []
395
+ stats: dict[str, Dict[str, float]] = {}
396
+ buffers = _fast_state_buffers(fast_state, self.config.cms_levels)
397
+ for spec in self.config.cms_levels:
398
+ stats[spec.name] = {
399
+ "grad_norm": 0.0,
400
+ "chunk_tokens": 0.0,
401
+ "gate_hit": 0.0,
402
+ "gate_hits": 0.0,
403
+ "updates_applied": 0.0,
404
+ "tokens_flushed": 0.0,
405
+ "pending_tokens": 0.0,
406
+ }
407
+
408
+ for start in range(0, seq_len, base_chunk):
409
+ end = min(start + base_chunk, seq_len)
410
+ chunk_in = x[:, start:end, :]
411
+ chunk_teach = teach_signal[:, start:end, :]
412
+ chunk_active = active_mask[:, start:end]
413
+
414
+ current = chunk_in
415
+ level_inputs: dict[str, torch.Tensor] = {}
416
+ for spec in self.config.cms_levels:
417
+ level_name = spec.name
418
+ level_inputs[level_name] = current
419
+ params = fast_state.cms_params[level_name]
420
+ current = call_with_deltas(self.cms.blocks[level_name], params, current)
421
+ outputs.append(current)
422
+
423
+ for spec in self.config.cms_levels:
424
+ level_name = spec.name
425
+ buffer = buffers[level_name]
426
+ if differentiable_updates:
427
+ buffer.inputs.append(level_inputs[level_name])
428
+ else:
429
+ buffer.inputs.append(level_inputs[level_name].detach())
430
+ buffer.teach.append(chunk_teach)
431
+ buffer.active.append(chunk_active)
432
+ buffer.count += end - start
433
+ update_period = int(spec.update_period)
434
+ while update_period > 0 and buffer.count >= update_period:
435
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(
436
+ buffer, update_period
437
+ )
438
+ buffer.count -= update_period
439
+ magnitude = self._update_cms_chunk_fast(
440
+ fast_state,
441
+ level_name,
442
+ chunk_inputs,
443
+ chunk_teach,
444
+ chunk_active,
445
+ surprise_value,
446
+ differentiable_updates=differentiable_updates,
447
+ )
448
+ if magnitude > 0:
449
+ stats[level_name]["grad_norm"] += magnitude
450
+ stats[level_name]["chunk_tokens"] += float(update_period)
451
+ stats[level_name]["gate_hit"] += 1.0
452
+ stats[level_name]["gate_hits"] += 1.0
453
+ stats[level_name]["updates_applied"] += 1.0
454
+ if finalize_updates:
455
+ if self.config.cms_flush_partial_at_end:
456
+ for spec in self.config.cms_levels:
457
+ level_name = spec.name
458
+ buffer = buffers[level_name]
459
+ remaining = int(buffer.count)
460
+ if remaining <= 0:
461
+ continue
462
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)
463
+ buffer.count -= remaining
464
+ if not bool(chunk_active.any()):
465
+ continue
466
+ magnitude = self._update_cms_chunk_fast(
467
+ fast_state,
468
+ level_name,
469
+ chunk_inputs,
470
+ chunk_teach,
471
+ chunk_active,
472
+ surprise_value,
473
+ differentiable_updates=differentiable_updates,
474
+ )
475
+ if magnitude > 0:
476
+ stats[level_name]["grad_norm"] += magnitude
477
+ stats[level_name]["chunk_tokens"] += float(remaining)
478
+ stats[level_name]["gate_hit"] += 1.0
479
+ stats[level_name]["gate_hits"] += 1.0
480
+ stats[level_name]["updates_applied"] += 1.0
481
+ stats[level_name]["tokens_flushed"] += float(remaining)
482
+ for spec in self.config.cms_levels:
483
+ _clear_buffer(buffers[spec.name])
484
+ for spec in self.config.cms_levels:
485
+ stats[spec.name]["pending_tokens"] = float(buffers[spec.name].count)
486
+ for level_name, payload in stats.items():
487
+ if (
488
+ payload["updates_applied"] <= 0
489
+ and payload["pending_tokens"] <= 0
490
+ and payload["tokens_flushed"] <= 0
491
+ ):
492
+ continue
493
+ if surprise_value is not None:
494
+ payload["surprise_value"] = surprise_value
495
+ self.last_update_stats[f"cms.{level_name}"] = payload
496
+ return torch.cat(outputs, dim=1)
497
+
498
+ def _update_cms_fast(
499
+ self,
500
+ fast_state: BlockFastState,
501
+ cms_inputs: dict[str, torch.Tensor],
502
+ teach_signal: torch.Tensor,
503
+ surprise_value: float | None,
504
+ *,
505
+ differentiable_updates: bool = False,
506
+ ) -> None:
507
+ teach = teach_signal if differentiable_updates else teach_signal.detach()
508
+ active_mask = teach.abs().sum(dim=-1) > 0
509
+ for spec in self.config.cms_levels:
510
+ level_name = spec.name
511
+ if not self._is_level_allowed(level_name):
512
+ continue
513
+ if not self._passes_surprise(surprise_value):
514
+ self._record_gate(level_name, hit=False)
515
+ continue
516
+ inputs = cms_inputs[level_name]
517
+ seq_len = inputs.shape[1]
518
+ chunk_size = int(spec.update_period)
519
+ if chunk_size <= 0:
520
+ continue
521
+ total_norm = 0.0
522
+ update_events = 0
523
+ token_events = 0
524
+ for start in range(0, seq_len, chunk_size):
525
+ end = min(start + chunk_size, seq_len)
526
+ chunk_len = end - start
527
+ chunk_inputs = (
528
+ inputs[:, start:end, :]
529
+ if differentiable_updates
530
+ else inputs[:, start:end, :].detach()
531
+ )
532
+ chunk_teach = teach[:, start:end, :]
533
+ chunk_active = active_mask[:, start:end]
534
+ if not bool(chunk_active.any()):
535
+ continue
536
+ magnitude = self._update_cms_chunk_fast(
537
+ fast_state,
538
+ level_name,
539
+ chunk_inputs,
540
+ chunk_teach,
541
+ chunk_active,
542
+ surprise_value,
543
+ differentiable_updates=differentiable_updates,
544
+ )
545
+ if magnitude <= 0:
546
+ continue
547
+ total_norm += magnitude
548
+ token_events += chunk_len
549
+ update_events += 1
550
+ if update_events == 0:
551
+ continue
552
+ stats_payload: Dict[str, float] = {
553
+ "grad_norm": total_norm,
554
+ "chunk_tokens": float(token_events),
555
+ "gate_hit": float(update_events),
556
+ }
557
+ if surprise_value is not None:
558
+ stats_payload["surprise_value"] = surprise_value
559
+ self.last_update_stats[f"cms.{level_name}"] = stats_payload
560
+
561
+ def _is_level_allowed(self, level_name: str) -> bool:
562
+ if self.allowed_levels is None:
563
+ return True
564
+ return level_name in self.allowed_levels
565
+
566
+ def _passes_surprise(self, surprise_value: float | None) -> bool:
567
+ if self.surprise_threshold is None:
568
+ return True
569
+ if surprise_value is None:
570
+ return False
571
+ return surprise_value >= self.surprise_threshold
572
+
573
+ def _record_gate(self, level_name: str, *, hit: bool) -> None:
574
+ stats_key = f"gate.{level_name}"
575
+ self.last_update_stats.setdefault(stats_key, {})
576
+ self.last_update_stats[stats_key]["gate_hit"] = 1.0 if hit else 0.0
577
+
578
+ def _update_cms(
579
+ self,
580
+ cms_inputs: dict[str, torch.Tensor],
581
+ cms_outputs: dict[str, torch.Tensor],
582
+ teach_signal: torch.Tensor,
583
+ surprise_value: float | None,
584
+ ) -> None:
585
+ teach = teach_signal.detach()
586
+ active_mask = teach.abs().sum(dim=-1) > 0
587
+ for spec in self.config.cms_levels:
588
+ level_name = spec.name
589
+ if not self._is_level_allowed(level_name):
590
+ continue
591
+ if not self._passes_surprise(surprise_value):
592
+ self._record_gate(level_name, hit=False)
593
+ continue
594
+ inputs = cms_inputs[level_name]
595
+ seq_len = inputs.shape[1]
596
+ chunk_size = int(spec.update_period)
597
+ if chunk_size <= 0:
598
+ continue
599
+ total_norm = 0.0
600
+ update_events = 0
601
+ token_events = 0
602
+ for start in range(0, seq_len, chunk_size):
603
+ end = min(start + chunk_size, seq_len)
604
+ chunk_len = end - start
605
+ chunk_inputs = inputs[:, start:end, :].detach()
606
+ chunk_teach = teach[:, start:end, :]
607
+ chunk_active = active_mask[:, start:end]
608
+ if not bool(chunk_active.any()):
609
+ continue
610
+ magnitude = self._update_cms_chunk(
611
+ level_name,
612
+ chunk_inputs,
613
+ chunk_teach,
614
+ chunk_active,
615
+ surprise_value,
616
+ )
617
+ if magnitude <= 0:
618
+ continue
619
+ total_norm += magnitude
620
+ token_events += chunk_len
621
+ update_events += 1
622
+ if update_events == 0:
623
+ continue
624
+ stats_payload: Dict[str, float] = {
625
+ "grad_norm": total_norm,
626
+ "chunk_tokens": float(token_events),
627
+ "gate_hit": float(update_events),
628
+ }
629
+ if surprise_value is not None:
630
+ stats_payload["surprise_value"] = surprise_value
631
+ self.last_update_stats[f"cms.{level_name}"] = stats_payload
632
+
633
+ def _update_cms_chunk(
634
+ self,
635
+ level_name: str,
636
+ chunk_inputs: torch.Tensor,
637
+ chunk_teach: torch.Tensor,
638
+ chunk_active: torch.Tensor,
639
+ surprise_value: float | None,
640
+ ) -> float:
641
+ if not self._is_level_allowed(level_name):
642
+ return 0.0
643
+ if not self._passes_surprise(surprise_value):
644
+ self._record_gate(level_name, hit=False)
645
+ return 0.0
646
+ mask_f = chunk_active.unsqueeze(-1).float()
647
+ with torch.enable_grad():
648
+ prediction = self.cms.blocks[level_name](chunk_inputs)
649
+ loss = _chunk_loss(
650
+ prediction,
651
+ chunk_teach,
652
+ mask_f,
653
+ reduction=self.config.cms_chunk_reduction,
654
+ differentiable_target=False,
655
+ )
656
+ context_vec = chunk_inputs.mean(dim=(0, 1))
657
+ magnitude = self.level_manager.optimize(
658
+ level_name,
659
+ self.cms.blocks[level_name],
660
+ loss,
661
+ context=context_vec,
662
+ force=True,
663
+ )
664
+ self.level_manager.pop_last_metrics(level_name)
665
+ return magnitude
666
+
667
+ def _update_cms_chunk_fast(
668
+ self,
669
+ fast_state: BlockFastState,
670
+ level_name: str,
671
+ chunk_inputs: torch.Tensor,
672
+ chunk_teach: torch.Tensor,
673
+ chunk_active: torch.Tensor,
674
+ surprise_value: float | None,
675
+ *,
676
+ differentiable_updates: bool = False,
677
+ ) -> float:
678
+ if not self._is_level_allowed(level_name):
679
+ return 0.0
680
+ if not self._passes_surprise(surprise_value):
681
+ self._record_gate(level_name, hit=False)
682
+ return 0.0
683
+ mask_f = chunk_active.unsqueeze(-1).float()
684
+ base_params = fast_state.cms_params[level_name]
685
+ forward_params = params_with_deltas(self.cms.blocks[level_name], base_params)
686
+ params_req = require_grad_params(forward_params, detach=not differentiable_updates)
687
+ with torch.enable_grad():
688
+ prediction = call_with_params(self.cms.blocks[level_name], params_req, chunk_inputs)
689
+ loss = _chunk_loss(
690
+ prediction,
691
+ chunk_teach,
692
+ mask_f,
693
+ reduction=self.config.cms_chunk_reduction,
694
+ differentiable_target=differentiable_updates,
695
+ )
696
+ grads = torch.autograd.grad(
697
+ loss,
698
+ tuple(params_req.values()),
699
+ retain_graph=differentiable_updates,
700
+ create_graph=differentiable_updates,
701
+ allow_unused=True,
702
+ )
703
+ grads_dict = grads_to_dict(params_req, grads)
704
+ context_vec = chunk_inputs.mean(dim=(0, 1))
705
+ updated, magnitude = fast_state.level_manager.apply_grads(
706
+ level_name,
707
+ base_params,
708
+ grads_dict,
709
+ context=context_vec,
710
+ force=True,
711
+ differentiable=differentiable_updates,
712
+ )
713
+ fast_state.cms_params[level_name] = updated
714
+ fast_state.level_manager.pop_last_metrics(level_name)
715
+ return magnitude
716
+
717
+
718
+ @dataclass
719
+ class HOPESelfModBlockConfig:
720
+ dim: int
721
+ cms_levels: Sequence[LevelSpec]
722
+ cms_hidden_multiplier: int = 4
723
+ cms_use_layernorm: bool = True
724
+ activation: str = "gelu"
725
+ qk_l2_norm: bool = True
726
+ cms_flush_partial_at_end: bool = False
727
+ selfmod_adaptive_q: bool = False
728
+ selfmod_local_conv_window: int | None = 4
729
+ eta_scale: float = 1e-3
730
+ selfmod_chunk_size: int = 1
731
+ selfmod_chunk_size_memory: int | None = None
732
+ selfmod_objective: str = "l2"
733
+ selfmod_stopgrad_vhat: bool = True
734
+ selfmod_use_rank1_precond: bool = True
735
+ selfmod_use_alpha: bool = True
736
+ selfmod_use_skip: bool = True
737
+ selfmod_momentum: float = 0.0
738
+ selfmod_online_updates: bool = True
739
+ self_mod_lr: float = 1e-3
740
+ cms_chunk_reduction: str = "sum"
741
+ cms_online_updates: bool = True
742
+ optimizer_configs: Dict[str, dict] = field(default_factory=dict)
743
+
744
+
745
+ class HOPESelfModBlock(nn.Module):
746
+ """
747
+ Paper-defined HOPE block (Eqs. 94–97): self-modifying Titans followed by CMS.
748
+
749
+ Fast-state is required for in-context self-mod updates.
750
+ """
751
+
752
+ def __init__(self, config: HOPESelfModBlockConfig):
753
+ super().__init__()
754
+ self.config = config
755
+ self.last_update_stats: Dict[str, Dict[str, float]] = {}
756
+ self.surprise_threshold: float | None = None
757
+ self.surprise_metric: str = "l2"
758
+ self.allowed_levels: Set[str] | None = None
759
+ self.selfmod = SelfModifyingTitans(
760
+ SelfModifyingTitansConfig(
761
+ dim=config.dim,
762
+ eta_scale=config.eta_scale,
763
+ chunk_size_other=config.selfmod_chunk_size,
764
+ chunk_size_memory=config.selfmod_chunk_size_memory,
765
+ objective=config.selfmod_objective,
766
+ stopgrad_vhat=config.selfmod_stopgrad_vhat,
767
+ use_rank1_precond=config.selfmod_use_rank1_precond,
768
+ use_alpha=config.selfmod_use_alpha,
769
+ use_skip=config.selfmod_use_skip,
770
+ momentum=config.selfmod_momentum,
771
+ qk_l2_norm=config.qk_l2_norm,
772
+ adaptive_q=config.selfmod_adaptive_q,
773
+ local_conv_window=config.selfmod_local_conv_window,
774
+ )
775
+ )
776
+ self.cms = CMS(
777
+ dim=config.dim,
778
+ levels=config.cms_levels,
779
+ hidden_multiplier=config.cms_hidden_multiplier,
780
+ activation=config.activation,
781
+ use_layernorm=config.cms_use_layernorm,
782
+ )
783
+ level_config = LevelConfig(
784
+ specs=config.cms_levels,
785
+ optimizer_configs=config.optimizer_configs,
786
+ default_lr=config.self_mod_lr,
787
+ )
788
+ self.level_manager = LevelOptimizerManager(level_config)
789
+
790
+ def forward(
791
+ self,
792
+ x: torch.Tensor,
793
+ *,
794
+ teach_signal: torch.Tensor | None = None,
795
+ surprise_value: float | None = None,
796
+ fast_state: BlockFastState | None = None,
797
+ finalize_updates: bool = True,
798
+ attention_cache: AttentionKVCache | None = None,
799
+ return_attention_cache: bool = False,
800
+ differentiable_updates: bool = False,
801
+ ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache | None]:
802
+ _ = (attention_cache, differentiable_updates)
803
+ if fast_state is None:
804
+ # Differentiable read path (used for the outer loss).
805
+ o = self.selfmod(x)
806
+ # Explicit update pass (typically called under `torch.no_grad()` after backward).
807
+ if teach_signal is not None and self.config.selfmod_online_updates:
808
+ self.selfmod.apply_updates_inplace(x)
809
+ if teach_signal is not None and self.config.cms_online_updates:
810
+ cms_out = self._cms_forward_online(
811
+ o,
812
+ teach_signal,
813
+ surprise_value,
814
+ finalize_updates=finalize_updates,
815
+ )
816
+ else:
817
+ cms_out, cms_inputs, cms_outputs = self.cms(o, return_intermediates=True)
818
+ if teach_signal is not None:
819
+ self._update_cms(cms_inputs, cms_outputs, teach_signal, surprise_value)
820
+ self.level_manager.tick()
821
+ return cms_out
822
+
823
+ if fast_state.selfmod_state is None:
824
+ raise ValueError("fast_state.selfmod_state is required for hope_selfmod variant")
825
+ if self.config.selfmod_online_updates and teach_signal is not None:
826
+ o, updated = self.selfmod.forward_with_updates(x, fast_state.selfmod_state)
827
+ fast_state.selfmod_state = updated
828
+ else:
829
+ o = self.selfmod.forward_with_state(x, fast_state.selfmod_state)
830
+ if teach_signal is not None and self.config.cms_online_updates:
831
+ cms_out = self._cms_forward_online_fast(
832
+ o,
833
+ fast_state,
834
+ teach_signal,
835
+ surprise_value,
836
+ finalize_updates=finalize_updates,
837
+ )
838
+ else:
839
+ cms_out, cms_inputs = self._cms_forward_fast(o, fast_state)
840
+ if teach_signal is not None:
841
+ self._update_cms_fast(fast_state, cms_inputs, teach_signal, surprise_value)
842
+ fast_state.level_manager.tick()
843
+ if return_attention_cache:
844
+ return cms_out, None
845
+ return cms_out
846
+
847
+ def set_surprise_threshold(self, threshold: float | None) -> None:
848
+ self.surprise_threshold = threshold
849
+
850
+ def set_surprise_metric(self, metric: str) -> None:
851
+ self.surprise_metric = str(metric).strip().lower()
852
+
853
+ def set_allowed_levels(self, allowed: Set[str] | None) -> None:
854
+ self.allowed_levels = allowed.copy() if allowed is not None else None
855
+
856
+ def pop_update_stats(self) -> Dict[str, Dict[str, float]]:
857
+ stats = self.last_update_stats
858
+ self.last_update_stats = {}
859
+ return stats
860
+
861
+ def _cms_forward_fast(
862
+ self,
863
+ x: torch.Tensor,
864
+ fast_state: BlockFastState,
865
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
866
+ current = x
867
+ inputs: dict[str, torch.Tensor] = {}
868
+ for spec in self.config.cms_levels:
869
+ level_name = spec.name
870
+ inputs[level_name] = current
871
+ params = fast_state.cms_params[level_name]
872
+ current = call_with_deltas(self.cms.blocks[level_name], params, current)
873
+ return current, inputs
874
+
875
+ def _cms_forward_online(
876
+ self,
877
+ x: torch.Tensor,
878
+ teach_signal: torch.Tensor,
879
+ surprise_value: float | None,
880
+ *,
881
+ finalize_updates: bool = True,
882
+ ) -> torch.Tensor:
883
+ seq_len = x.shape[1]
884
+ base_chunk = _min_update_period(self.config.cms_levels)
885
+ active_mask = teach_signal.detach().abs().sum(dim=-1) > 0
886
+ outputs: list[torch.Tensor] = []
887
+ stats: dict[str, Dict[str, float]] = {}
888
+ buffers: dict[str, _CmsBuffer] = {}
889
+ for spec in self.config.cms_levels:
890
+ buffers[spec.name] = _CmsBuffer(inputs=[], teach=[], active=[], count=0)
891
+ stats[spec.name] = {
892
+ "grad_norm": 0.0,
893
+ "chunk_tokens": 0.0,
894
+ "gate_hit": 0.0,
895
+ "gate_hits": 0.0,
896
+ "updates_applied": 0.0,
897
+ "tokens_flushed": 0.0,
898
+ "pending_tokens": 0.0,
899
+ }
900
+
901
+ for start in range(0, seq_len, base_chunk):
902
+ end = min(start + base_chunk, seq_len)
903
+ chunk_in = x[:, start:end, :]
904
+ chunk_teach = teach_signal[:, start:end, :]
905
+ chunk_active = active_mask[:, start:end]
906
+
907
+ current = chunk_in
908
+ level_inputs: dict[str, torch.Tensor] = {}
909
+ for spec in self.config.cms_levels:
910
+ level_name = spec.name
911
+ level_inputs[level_name] = current
912
+ current = self.cms.blocks[level_name](current)
913
+ outputs.append(current)
914
+
915
+ for spec in self.config.cms_levels:
916
+ level_name = spec.name
917
+ buffer = buffers[level_name]
918
+ buffer.inputs.append(level_inputs[level_name].detach())
919
+ buffer.teach.append(chunk_teach)
920
+ buffer.active.append(chunk_active)
921
+ buffer.count += end - start
922
+ update_period = int(spec.update_period)
923
+ while update_period > 0 and buffer.count >= update_period:
924
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(
925
+ buffer, update_period
926
+ )
927
+ buffer.count -= update_period
928
+ magnitude = self._update_cms_chunk(
929
+ level_name,
930
+ chunk_inputs,
931
+ chunk_teach,
932
+ chunk_active,
933
+ surprise_value,
934
+ )
935
+ if magnitude > 0:
936
+ stats[level_name]["grad_norm"] += magnitude
937
+ stats[level_name]["chunk_tokens"] += float(update_period)
938
+ stats[level_name]["gate_hit"] += 1.0
939
+ stats[level_name]["gate_hits"] += 1.0
940
+ stats[level_name]["updates_applied"] += 1.0
941
+ if self.config.cms_flush_partial_at_end and finalize_updates:
942
+ for spec in self.config.cms_levels:
943
+ level_name = spec.name
944
+ buffer = buffers[level_name]
945
+ remaining = int(buffer.count)
946
+ if remaining <= 0:
947
+ continue
948
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)
949
+ buffer.count -= remaining
950
+ if not bool(chunk_active.any()):
951
+ continue
952
+ magnitude = self._update_cms_chunk(
953
+ level_name,
954
+ chunk_inputs,
955
+ chunk_teach,
956
+ chunk_active,
957
+ surprise_value,
958
+ )
959
+ if magnitude > 0:
960
+ stats[level_name]["grad_norm"] += magnitude
961
+ stats[level_name]["chunk_tokens"] += float(remaining)
962
+ stats[level_name]["gate_hit"] += 1.0
963
+ stats[level_name]["gate_hits"] += 1.0
964
+ stats[level_name]["updates_applied"] += 1.0
965
+ stats[level_name]["tokens_flushed"] += float(remaining)
966
+ for spec in self.config.cms_levels:
967
+ stats[spec.name]["pending_tokens"] = float(buffers[spec.name].count)
968
+ for level_name, payload in stats.items():
969
+ if (
970
+ payload["updates_applied"] <= 0
971
+ and payload["pending_tokens"] <= 0
972
+ and payload["tokens_flushed"] <= 0
973
+ ):
974
+ continue
975
+ if surprise_value is not None:
976
+ payload["surprise_value"] = surprise_value
977
+ self.last_update_stats[f"cms.{level_name}"] = payload
978
+ return torch.cat(outputs, dim=1)
979
+
980
+ def _cms_forward_online_fast(
981
+ self,
982
+ x: torch.Tensor,
983
+ fast_state: BlockFastState,
984
+ teach_signal: torch.Tensor,
985
+ surprise_value: float | None,
986
+ *,
987
+ finalize_updates: bool = True,
988
+ ) -> torch.Tensor:
989
+ seq_len = x.shape[1]
990
+ base_chunk = _min_update_period(self.config.cms_levels)
991
+ active_mask = teach_signal.detach().abs().sum(dim=-1) > 0
992
+ outputs: list[torch.Tensor] = []
993
+ stats: dict[str, Dict[str, float]] = {}
994
+ buffers = _fast_state_buffers(fast_state, self.config.cms_levels)
995
+ for spec in self.config.cms_levels:
996
+ stats[spec.name] = {
997
+ "grad_norm": 0.0,
998
+ "chunk_tokens": 0.0,
999
+ "gate_hit": 0.0,
1000
+ "gate_hits": 0.0,
1001
+ "updates_applied": 0.0,
1002
+ "tokens_flushed": 0.0,
1003
+ "pending_tokens": 0.0,
1004
+ }
1005
+
1006
+ for start in range(0, seq_len, base_chunk):
1007
+ end = min(start + base_chunk, seq_len)
1008
+ chunk_in = x[:, start:end, :]
1009
+ chunk_teach = teach_signal[:, start:end, :]
1010
+ chunk_active = active_mask[:, start:end]
1011
+
1012
+ current = chunk_in
1013
+ level_inputs: dict[str, torch.Tensor] = {}
1014
+ for spec in self.config.cms_levels:
1015
+ level_name = spec.name
1016
+ level_inputs[level_name] = current
1017
+ params = fast_state.cms_params[level_name]
1018
+ current = call_with_deltas(self.cms.blocks[level_name], params, current)
1019
+ outputs.append(current)
1020
+
1021
+ for spec in self.config.cms_levels:
1022
+ level_name = spec.name
1023
+ buffer = buffers[level_name]
1024
+ buffer.inputs.append(level_inputs[level_name].detach())
1025
+ buffer.teach.append(chunk_teach)
1026
+ buffer.active.append(chunk_active)
1027
+ buffer.count += end - start
1028
+ update_period = int(spec.update_period)
1029
+ while update_period > 0 and buffer.count >= update_period:
1030
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(
1031
+ buffer, update_period
1032
+ )
1033
+ buffer.count -= update_period
1034
+ magnitude = self._update_cms_chunk_fast(
1035
+ fast_state,
1036
+ level_name,
1037
+ chunk_inputs,
1038
+ chunk_teach,
1039
+ chunk_active,
1040
+ surprise_value,
1041
+ )
1042
+ if magnitude > 0:
1043
+ stats[level_name]["grad_norm"] += magnitude
1044
+ stats[level_name]["chunk_tokens"] += float(update_period)
1045
+ stats[level_name]["gate_hit"] += 1.0
1046
+ stats[level_name]["gate_hits"] += 1.0
1047
+ stats[level_name]["updates_applied"] += 1.0
1048
+ if finalize_updates:
1049
+ if self.config.cms_flush_partial_at_end:
1050
+ for spec in self.config.cms_levels:
1051
+ level_name = spec.name
1052
+ buffer = buffers[level_name]
1053
+ remaining = int(buffer.count)
1054
+ if remaining <= 0:
1055
+ continue
1056
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)
1057
+ buffer.count -= remaining
1058
+ if not bool(chunk_active.any()):
1059
+ continue
1060
+ magnitude = self._update_cms_chunk_fast(
1061
+ fast_state,
1062
+ level_name,
1063
+ chunk_inputs,
1064
+ chunk_teach,
1065
+ chunk_active,
1066
+ surprise_value,
1067
+ )
1068
+ if magnitude > 0:
1069
+ stats[level_name]["grad_norm"] += magnitude
1070
+ stats[level_name]["chunk_tokens"] += float(remaining)
1071
+ stats[level_name]["gate_hit"] += 1.0
1072
+ stats[level_name]["gate_hits"] += 1.0
1073
+ stats[level_name]["updates_applied"] += 1.0
1074
+ stats[level_name]["tokens_flushed"] += float(remaining)
1075
+ for spec in self.config.cms_levels:
1076
+ _clear_buffer(buffers[spec.name])
1077
+ for spec in self.config.cms_levels:
1078
+ stats[spec.name]["pending_tokens"] = float(buffers[spec.name].count)
1079
+ for level_name, payload in stats.items():
1080
+ if (
1081
+ payload["updates_applied"] <= 0
1082
+ and payload["pending_tokens"] <= 0
1083
+ and payload["tokens_flushed"] <= 0
1084
+ ):
1085
+ continue
1086
+ if surprise_value is not None:
1087
+ payload["surprise_value"] = surprise_value
1088
+ self.last_update_stats[f"cms.{level_name}"] = payload
1089
+ return torch.cat(outputs, dim=1)
1090
+
1091
+ def _is_level_allowed(self, level_name: str) -> bool:
1092
+ if self.allowed_levels is None:
1093
+ return True
1094
+ return level_name in self.allowed_levels
1095
+
1096
+ def _passes_surprise(self, surprise_value: float | None) -> bool:
1097
+ if self.surprise_threshold is None:
1098
+ return True
1099
+ if surprise_value is None:
1100
+ return False
1101
+ return surprise_value >= self.surprise_threshold
1102
+
1103
+ def _record_gate(self, level_name: str, *, hit: bool) -> None:
1104
+ stats_key = f"gate.{level_name}"
1105
+ self.last_update_stats.setdefault(stats_key, {})
1106
+ self.last_update_stats[stats_key]["gate_hit"] = 1.0 if hit else 0.0
1107
+
1108
+ def _update_cms(
1109
+ self,
1110
+ cms_inputs: dict[str, torch.Tensor],
1111
+ cms_outputs: dict[str, torch.Tensor],
1112
+ teach_signal: torch.Tensor,
1113
+ surprise_value: float | None,
1114
+ ) -> None:
1115
+ teach = teach_signal.detach()
1116
+ active_mask = teach.abs().sum(dim=-1) > 0
1117
+ for spec in self.config.cms_levels:
1118
+ level_name = spec.name
1119
+ if not self._is_level_allowed(level_name):
1120
+ continue
1121
+ if not self._passes_surprise(surprise_value):
1122
+ self._record_gate(level_name, hit=False)
1123
+ continue
1124
+ inputs = cms_inputs[level_name]
1125
+ seq_len = inputs.shape[1]
1126
+ chunk_size = int(spec.update_period)
1127
+ if chunk_size <= 0:
1128
+ continue
1129
+ total_norm = 0.0
1130
+ update_events = 0
1131
+ token_events = 0
1132
+ for start in range(0, seq_len, chunk_size):
1133
+ end = min(start + chunk_size, seq_len)
1134
+ chunk_len = end - start
1135
+ chunk_inputs = inputs[:, start:end, :].detach()
1136
+ chunk_teach = teach[:, start:end, :]
1137
+ chunk_active = active_mask[:, start:end]
1138
+ if not bool(chunk_active.any()):
1139
+ continue
1140
+ magnitude = self._update_cms_chunk(
1141
+ level_name,
1142
+ chunk_inputs,
1143
+ chunk_teach,
1144
+ chunk_active,
1145
+ surprise_value,
1146
+ )
1147
+ if magnitude <= 0:
1148
+ continue
1149
+ total_norm += magnitude
1150
+ token_events += chunk_len
1151
+ update_events += 1
1152
+ if update_events == 0:
1153
+ continue
1154
+ stats_payload: Dict[str, float] = {
1155
+ "grad_norm": total_norm,
1156
+ "chunk_tokens": float(token_events),
1157
+ "gate_hit": float(update_events),
1158
+ }
1159
+ if surprise_value is not None:
1160
+ stats_payload["surprise_value"] = surprise_value
1161
+ self.last_update_stats[f"cms.{level_name}"] = stats_payload
1162
+
1163
+ def _update_cms_fast(
1164
+ self,
1165
+ fast_state: BlockFastState,
1166
+ cms_inputs: dict[str, torch.Tensor],
1167
+ teach_signal: torch.Tensor,
1168
+ surprise_value: float | None,
1169
+ ) -> None:
1170
+ teach = teach_signal.detach()
1171
+ active_mask = teach.abs().sum(dim=-1) > 0
1172
+ for spec in self.config.cms_levels:
1173
+ level_name = spec.name
1174
+ if not self._is_level_allowed(level_name):
1175
+ continue
1176
+ if not self._passes_surprise(surprise_value):
1177
+ self._record_gate(level_name, hit=False)
1178
+ continue
1179
+ inputs = cms_inputs[level_name]
1180
+ seq_len = inputs.shape[1]
1181
+ chunk_size = int(spec.update_period)
1182
+ if chunk_size <= 0:
1183
+ continue
1184
+ total_norm = 0.0
1185
+ update_events = 0
1186
+ token_events = 0
1187
+ for start in range(0, seq_len, chunk_size):
1188
+ end = min(start + chunk_size, seq_len)
1189
+ chunk_len = end - start
1190
+ chunk_inputs = inputs[:, start:end, :].detach()
1191
+ chunk_teach = teach[:, start:end, :]
1192
+ chunk_active = active_mask[:, start:end]
1193
+ if not bool(chunk_active.any()):
1194
+ continue
1195
+ magnitude = self._update_cms_chunk_fast(
1196
+ fast_state,
1197
+ level_name,
1198
+ chunk_inputs,
1199
+ chunk_teach,
1200
+ chunk_active,
1201
+ surprise_value,
1202
+ )
1203
+ if magnitude <= 0:
1204
+ continue
1205
+ total_norm += magnitude
1206
+ token_events += chunk_len
1207
+ update_events += 1
1208
+ if update_events == 0:
1209
+ continue
1210
+ stats_payload: Dict[str, float] = {
1211
+ "grad_norm": total_norm,
1212
+ "chunk_tokens": float(token_events),
1213
+ "gate_hit": float(update_events),
1214
+ }
1215
+ if surprise_value is not None:
1216
+ stats_payload["surprise_value"] = surprise_value
1217
+ self.last_update_stats[f"cms.{level_name}"] = stats_payload
1218
+
1219
+ def _update_cms_chunk(
1220
+ self,
1221
+ level_name: str,
1222
+ chunk_inputs: torch.Tensor,
1223
+ chunk_teach: torch.Tensor,
1224
+ chunk_active: torch.Tensor,
1225
+ surprise_value: float | None,
1226
+ ) -> float:
1227
+ if not self._is_level_allowed(level_name):
1228
+ return 0.0
1229
+ if not self._passes_surprise(surprise_value):
1230
+ self._record_gate(level_name, hit=False)
1231
+ return 0.0
1232
+ mask_f = chunk_active.unsqueeze(-1).float()
1233
+ with torch.enable_grad():
1234
+ prediction = self.cms.blocks[level_name](chunk_inputs)
1235
+ loss = _chunk_loss(
1236
+ prediction,
1237
+ chunk_teach,
1238
+ mask_f,
1239
+ reduction=self.config.cms_chunk_reduction,
1240
+ )
1241
+ context_vec = chunk_inputs.mean(dim=(0, 1))
1242
+ magnitude = self.level_manager.optimize(
1243
+ level_name,
1244
+ self.cms.blocks[level_name],
1245
+ loss,
1246
+ context=context_vec,
1247
+ force=True,
1248
+ )
1249
+ self.level_manager.pop_last_metrics(level_name)
1250
+ return magnitude
1251
+
1252
+ def _update_cms_chunk_fast(
1253
+ self,
1254
+ fast_state: BlockFastState,
1255
+ level_name: str,
1256
+ chunk_inputs: torch.Tensor,
1257
+ chunk_teach: torch.Tensor,
1258
+ chunk_active: torch.Tensor,
1259
+ surprise_value: float | None,
1260
+ ) -> float:
1261
+ if not self._is_level_allowed(level_name):
1262
+ return 0.0
1263
+ if not self._passes_surprise(surprise_value):
1264
+ self._record_gate(level_name, hit=False)
1265
+ return 0.0
1266
+ mask_f = chunk_active.unsqueeze(-1).float()
1267
+ base_params = fast_state.cms_params[level_name]
1268
+ forward_params = params_with_deltas(self.cms.blocks[level_name], base_params)
1269
+ params_req = require_grad_params(forward_params)
1270
+ with torch.enable_grad():
1271
+ prediction = call_with_params(self.cms.blocks[level_name], params_req, chunk_inputs)
1272
+ loss = _chunk_loss(
1273
+ prediction,
1274
+ chunk_teach,
1275
+ mask_f,
1276
+ reduction=self.config.cms_chunk_reduction,
1277
+ )
1278
+ grads = torch.autograd.grad(
1279
+ loss,
1280
+ tuple(params_req.values()),
1281
+ retain_graph=False,
1282
+ allow_unused=True,
1283
+ )
1284
+ grads_dict = grads_to_dict(params_req, grads)
1285
+ context_vec = chunk_inputs.mean(dim=(0, 1))
1286
+ updated, magnitude = fast_state.level_manager.apply_grads(
1287
+ level_name,
1288
+ base_params,
1289
+ grads_dict,
1290
+ context=context_vec,
1291
+ force=True,
1292
+ )
1293
+ fast_state.cms_params[level_name] = updated
1294
+ fast_state.level_manager.pop_last_metrics(level_name)
1295
+ return magnitude
1296
+
1297
+
1298
+ class HOPEBlock(nn.Module):
1299
+ def __init__(self, config: HOPEBlockConfig):
1300
+ super().__init__()
1301
+ self.config = config
1302
+ self.last_update_stats: Dict[str, Dict[str, float]] = {}
1303
+ self.surprise_threshold: float | None = None
1304
+ self.surprise_metric: str = "l2"
1305
+ self.allowed_levels: Set[str] | None = None
1306
+ self.attn = SelfAttention(
1307
+ AttentionConfig(
1308
+ dim=config.dim,
1309
+ heads=config.heads,
1310
+ qk_l2_norm=config.qk_l2_norm,
1311
+ local_conv_window=config.local_conv_window,
1312
+ )
1313
+ )
1314
+ titan_config = TitanMemoryConfig(
1315
+ dim=config.dim,
1316
+ hidden_multiplier=config.titan_hidden_multiplier,
1317
+ activation=config.activation,
1318
+ )
1319
+ self.titan_memory = TitanMemory(titan_config)
1320
+ self.cms = CMS(
1321
+ dim=config.dim,
1322
+ levels=config.cms_levels,
1323
+ hidden_multiplier=config.cms_hidden_multiplier,
1324
+ activation=config.activation,
1325
+ use_layernorm=config.cms_use_layernorm,
1326
+ )
1327
+ self.self_modifier = SelfModifier(config.dim, hidden_multiplier=config.self_mod_hidden)
1328
+ self.dropout = nn.Dropout(0.0)
1329
+ specs = [config.titan_level, *config.cms_levels]
1330
+ level_config = LevelConfig(
1331
+ specs=specs,
1332
+ optimizer_configs=config.optimizer_configs,
1333
+ default_lr=config.self_mod_lr,
1334
+ )
1335
+ self.level_manager = LevelOptimizerManager(level_config)
1336
+
1337
+ def forward(
1338
+ self,
1339
+ x: torch.Tensor,
1340
+ *,
1341
+ teach_signal: torch.Tensor | None = None,
1342
+ surprise_value: float | None = None,
1343
+ fast_state: BlockFastState | None = None,
1344
+ finalize_updates: bool = True,
1345
+ attention_cache: AttentionKVCache | None = None,
1346
+ return_attention_cache: bool = False,
1347
+ differentiable_updates: bool = False,
1348
+ ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:
1349
+ _ = differentiable_updates
1350
+ next_attn_cache: AttentionKVCache | None = None
1351
+ if return_attention_cache:
1352
+ attn_out, next_attn_cache = self.attn(
1353
+ x,
1354
+ kv_cache=attention_cache,
1355
+ return_kv_cache=True,
1356
+ )
1357
+ else:
1358
+ attn_out = self.attn(x, kv_cache=attention_cache)
1359
+ if fast_state is None:
1360
+ mem_out = self.titan_memory(attn_out)
1361
+ combined = attn_out + mem_out
1362
+ if teach_signal is not None and self.config.cms_online_updates:
1363
+ cms_out = self._cms_forward_online(
1364
+ combined,
1365
+ teach_signal,
1366
+ surprise_value,
1367
+ finalize_updates=finalize_updates,
1368
+ )
1369
+ self._update_titan(attn_out, mem_out, teach_signal, surprise_value)
1370
+ else:
1371
+ cms_result = self.cms(combined, return_intermediates=True)
1372
+ cms_out, cms_inputs, cms_outputs = cms_result
1373
+ if teach_signal is not None:
1374
+ self._update_titan(attn_out, mem_out, teach_signal, surprise_value)
1375
+ self._update_cms(cms_inputs, cms_outputs, teach_signal, surprise_value)
1376
+ self.level_manager.tick()
1377
+ return cms_out
1378
+
1379
+ if fast_state.titan_params is None:
1380
+ raise ValueError("fast_state.titan_params is required for HOPEBlock fast-state forward")
1381
+ mem_out = call_with_deltas(self.titan_memory, fast_state.titan_params, attn_out)
1382
+ combined = attn_out + mem_out
1383
+ if teach_signal is not None and self.config.cms_online_updates:
1384
+ cms_out = self._cms_forward_online_fast(
1385
+ combined,
1386
+ fast_state,
1387
+ teach_signal,
1388
+ surprise_value,
1389
+ finalize_updates=finalize_updates,
1390
+ )
1391
+ self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)
1392
+ else:
1393
+ cms_out, cms_inputs = self._cms_forward_fast(combined, fast_state)
1394
+ if teach_signal is not None:
1395
+ self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)
1396
+ self._update_cms_fast(fast_state, cms_inputs, teach_signal, surprise_value)
1397
+ fast_state.level_manager.tick()
1398
+ if return_attention_cache:
1399
+ assert next_attn_cache is not None
1400
+ return cms_out, next_attn_cache
1401
+ return cms_out
1402
+
1403
+ def set_surprise_threshold(self, threshold: float | None) -> None:
1404
+ self.surprise_threshold = threshold
1405
+
1406
+ def set_surprise_metric(self, metric: str) -> None:
1407
+ self.surprise_metric = str(metric).strip().lower()
1408
+
1409
+ def set_allowed_levels(self, allowed: Set[str] | None) -> None:
1410
+ self.allowed_levels = allowed.copy() if allowed is not None else None
1411
+
1412
+ def _cms_forward_fast(
1413
+ self,
1414
+ x: torch.Tensor,
1415
+ fast_state: BlockFastState,
1416
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1417
+ current = x
1418
+ inputs: dict[str, torch.Tensor] = {}
1419
+ for spec in self.config.cms_levels:
1420
+ level_name = spec.name
1421
+ inputs[level_name] = current
1422
+ params = fast_state.cms_params[level_name]
1423
+ current = call_with_deltas(self.cms.blocks[level_name], params, current)
1424
+ return current, inputs
1425
+
1426
+
1427
+ def _cms_forward_online(
1428
+ self,
1429
+ x: torch.Tensor,
1430
+ teach_signal: torch.Tensor,
1431
+ surprise_value: float | None,
1432
+ *,
1433
+ finalize_updates: bool = True,
1434
+ ) -> torch.Tensor:
1435
+ seq_len = x.shape[1]
1436
+ base_chunk = _min_update_period(self.config.cms_levels)
1437
+ active_mask = teach_signal.detach().abs().sum(dim=-1) > 0
1438
+ outputs: list[torch.Tensor] = []
1439
+ stats: dict[str, Dict[str, float]] = {}
1440
+ buffers: dict[str, _CmsBuffer] = {}
1441
+ for spec in self.config.cms_levels:
1442
+ buffers[spec.name] = _CmsBuffer(inputs=[], teach=[], active=[], count=0)
1443
+ stats[spec.name] = {
1444
+ "grad_norm": 0.0,
1445
+ "chunk_tokens": 0.0,
1446
+ "gate_hit": 0.0,
1447
+ "gate_hits": 0.0,
1448
+ "updates_applied": 0.0,
1449
+ "tokens_flushed": 0.0,
1450
+ "pending_tokens": 0.0,
1451
+ }
1452
+
1453
+ for start in range(0, seq_len, base_chunk):
1454
+ end = min(start + base_chunk, seq_len)
1455
+ chunk_in = x[:, start:end, :]
1456
+ chunk_teach = teach_signal[:, start:end, :]
1457
+ chunk_active = active_mask[:, start:end]
1458
+
1459
+ current = chunk_in
1460
+ level_inputs: dict[str, torch.Tensor] = {}
1461
+ for spec in self.config.cms_levels:
1462
+ level_name = spec.name
1463
+ level_inputs[level_name] = current
1464
+ current = self.cms.blocks[level_name](current)
1465
+ outputs.append(current)
1466
+
1467
+ for spec in self.config.cms_levels:
1468
+ level_name = spec.name
1469
+ buffer = buffers[level_name]
1470
+ buffer.inputs.append(level_inputs[level_name].detach())
1471
+ buffer.teach.append(chunk_teach)
1472
+ buffer.active.append(chunk_active)
1473
+ buffer.count += end - start
1474
+ update_period = int(spec.update_period)
1475
+ while update_period > 0 and buffer.count >= update_period:
1476
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(
1477
+ buffer, update_period
1478
+ )
1479
+ buffer.count -= update_period
1480
+ magnitude = self._update_cms_chunk(
1481
+ level_name,
1482
+ chunk_inputs,
1483
+ chunk_teach,
1484
+ chunk_active,
1485
+ surprise_value,
1486
+ )
1487
+ if magnitude > 0:
1488
+ stats[level_name]["grad_norm"] += magnitude
1489
+ stats[level_name]["chunk_tokens"] += float(update_period)
1490
+ stats[level_name]["gate_hit"] += 1.0
1491
+ stats[level_name]["gate_hits"] += 1.0
1492
+ stats[level_name]["updates_applied"] += 1.0
1493
+ if self.config.cms_flush_partial_at_end and finalize_updates:
1494
+ for spec in self.config.cms_levels:
1495
+ level_name = spec.name
1496
+ buffer = buffers[level_name]
1497
+ remaining = int(buffer.count)
1498
+ if remaining <= 0:
1499
+ continue
1500
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)
1501
+ buffer.count -= remaining
1502
+ if not bool(chunk_active.any()):
1503
+ continue
1504
+ magnitude = self._update_cms_chunk(
1505
+ level_name,
1506
+ chunk_inputs,
1507
+ chunk_teach,
1508
+ chunk_active,
1509
+ surprise_value,
1510
+ )
1511
+ if magnitude > 0:
1512
+ stats[level_name]["grad_norm"] += magnitude
1513
+ stats[level_name]["chunk_tokens"] += float(remaining)
1514
+ stats[level_name]["gate_hit"] += 1.0
1515
+ stats[level_name]["gate_hits"] += 1.0
1516
+ stats[level_name]["updates_applied"] += 1.0
1517
+ stats[level_name]["tokens_flushed"] += float(remaining)
1518
+ for spec in self.config.cms_levels:
1519
+ stats[spec.name]["pending_tokens"] = float(buffers[spec.name].count)
1520
+ for level_name, payload in stats.items():
1521
+ if (
1522
+ payload["updates_applied"] <= 0
1523
+ and payload["pending_tokens"] <= 0
1524
+ and payload["tokens_flushed"] <= 0
1525
+ ):
1526
+ continue
1527
+ if surprise_value is not None:
1528
+ payload["surprise_value"] = surprise_value
1529
+ self.last_update_stats[f"cms.{level_name}"] = payload
1530
+ return torch.cat(outputs, dim=1)
1531
+
1532
+ def _cms_forward_online_fast(
1533
+ self,
1534
+ x: torch.Tensor,
1535
+ fast_state: BlockFastState,
1536
+ teach_signal: torch.Tensor,
1537
+ surprise_value: float | None,
1538
+ *,
1539
+ finalize_updates: bool = True,
1540
+ ) -> torch.Tensor:
1541
+ seq_len = x.shape[1]
1542
+ base_chunk = _min_update_period(self.config.cms_levels)
1543
+ active_mask = teach_signal.detach().abs().sum(dim=-1) > 0
1544
+ outputs: list[torch.Tensor] = []
1545
+ stats: dict[str, Dict[str, float]] = {}
1546
+ buffers = _fast_state_buffers(fast_state, self.config.cms_levels)
1547
+ for spec in self.config.cms_levels:
1548
+ stats[spec.name] = {
1549
+ "grad_norm": 0.0,
1550
+ "chunk_tokens": 0.0,
1551
+ "gate_hit": 0.0,
1552
+ "gate_hits": 0.0,
1553
+ "updates_applied": 0.0,
1554
+ "tokens_flushed": 0.0,
1555
+ "pending_tokens": 0.0,
1556
+ }
1557
+
1558
+ for start in range(0, seq_len, base_chunk):
1559
+ end = min(start + base_chunk, seq_len)
1560
+ chunk_in = x[:, start:end, :]
1561
+ chunk_teach = teach_signal[:, start:end, :]
1562
+ chunk_active = active_mask[:, start:end]
1563
+
1564
+ current = chunk_in
1565
+ level_inputs: dict[str, torch.Tensor] = {}
1566
+ for spec in self.config.cms_levels:
1567
+ level_name = spec.name
1568
+ level_inputs[level_name] = current
1569
+ params = fast_state.cms_params[level_name]
1570
+ current = call_with_deltas(self.cms.blocks[level_name], params, current)
1571
+ outputs.append(current)
1572
+
1573
+ for spec in self.config.cms_levels:
1574
+ level_name = spec.name
1575
+ buffer = buffers[level_name]
1576
+ buffer.inputs.append(level_inputs[level_name].detach())
1577
+ buffer.teach.append(chunk_teach)
1578
+ buffer.active.append(chunk_active)
1579
+ buffer.count += end - start
1580
+ update_period = int(spec.update_period)
1581
+ while update_period > 0 and buffer.count >= update_period:
1582
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(
1583
+ buffer, update_period
1584
+ )
1585
+ buffer.count -= update_period
1586
+ magnitude = self._update_cms_chunk_fast(
1587
+ fast_state,
1588
+ level_name,
1589
+ chunk_inputs,
1590
+ chunk_teach,
1591
+ chunk_active,
1592
+ surprise_value,
1593
+ )
1594
+ if magnitude > 0:
1595
+ stats[level_name]["grad_norm"] += magnitude
1596
+ stats[level_name]["chunk_tokens"] += float(update_period)
1597
+ stats[level_name]["gate_hit"] += 1.0
1598
+ stats[level_name]["gate_hits"] += 1.0
1599
+ stats[level_name]["updates_applied"] += 1.0
1600
+ if finalize_updates:
1601
+ if self.config.cms_flush_partial_at_end:
1602
+ for spec in self.config.cms_levels:
1603
+ level_name = spec.name
1604
+ buffer = buffers[level_name]
1605
+ remaining = int(buffer.count)
1606
+ if remaining <= 0:
1607
+ continue
1608
+ chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)
1609
+ buffer.count -= remaining
1610
+ if not bool(chunk_active.any()):
1611
+ continue
1612
+ magnitude = self._update_cms_chunk_fast(
1613
+ fast_state,
1614
+ level_name,
1615
+ chunk_inputs,
1616
+ chunk_teach,
1617
+ chunk_active,
1618
+ surprise_value,
1619
+ )
1620
+ if magnitude > 0:
1621
+ stats[level_name]["grad_norm"] += magnitude
1622
+ stats[level_name]["chunk_tokens"] += float(remaining)
1623
+ stats[level_name]["gate_hit"] += 1.0
1624
+ stats[level_name]["gate_hits"] += 1.0
1625
+ stats[level_name]["updates_applied"] += 1.0
1626
+ stats[level_name]["tokens_flushed"] += float(remaining)
1627
+ for spec in self.config.cms_levels:
1628
+ _clear_buffer(buffers[spec.name])
1629
+ for spec in self.config.cms_levels:
1630
+ stats[spec.name]["pending_tokens"] = float(buffers[spec.name].count)
1631
+ for level_name, payload in stats.items():
1632
+ if (
1633
+ payload["updates_applied"] <= 0
1634
+ and payload["pending_tokens"] <= 0
1635
+ and payload["tokens_flushed"] <= 0
1636
+ ):
1637
+ continue
1638
+ if surprise_value is not None:
1639
+ payload["surprise_value"] = surprise_value
1640
+ self.last_update_stats[f"cms.{level_name}"] = payload
1641
+ return torch.cat(outputs, dim=1)
1642
+ def _update_titan(
1643
+ self,
1644
+ attn_out: torch.Tensor,
1645
+ mem_out: torch.Tensor,
1646
+ teach_signal: torch.Tensor,
1647
+ surprise_value: float | None,
1648
+ ) -> None:
1649
+ level_name = self.config.titan_level.name
1650
+ if not self._is_level_allowed("titan"):
1651
+ return
1652
+ if not self.level_manager.should_update(level_name):
1653
+ return
1654
+ if not self._passes_surprise(surprise_value):
1655
+ self._record_gate(level_name, hit=False)
1656
+ return
1657
+ # Use full sequence for granular updates (Critique P1)
1658
+ # Note: We intentionally do not pool over dim=1 (sequence) here.
1659
+ # teach_signal is (B, T, D), attn_out is (B, T, D)
1660
+ modifier = self.self_modifier(
1661
+ key=attn_out.detach(),
1662
+ value=mem_out.detach(),
1663
+ error_signal=teach_signal.detach(),
1664
+ )
1665
+ context_vec = attn_out.detach().mean(dim=(0, 1))
1666
+
1667
+ with torch.enable_grad():
1668
+ query = attn_out.detach()
1669
+ target = (modifier - teach_signal.detach()).detach()
1670
+ base_params = {name: param for name, param in self.titan_memory.named_parameters()}
1671
+ params_req = require_grad_params(base_params)
1672
+ prediction = call_with_params(self.titan_memory, params_req, query)
1673
+ loss_terms = F.mse_loss(prediction, target, reduction="none")
1674
+ active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0
1675
+ mask = active.float()
1676
+ if self.surprise_threshold is not None and self.surprise_metric == "l2":
1677
+ norms = teach_signal.norm(dim=-1, keepdim=True)
1678
+ mask = mask * (norms >= self.surprise_threshold).float()
1679
+ loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)
1680
+
1681
+ grads = torch.autograd.grad(
1682
+ loss,
1683
+ tuple(params_req.values()),
1684
+ retain_graph=False,
1685
+ allow_unused=True,
1686
+ )
1687
+ grads_dict = grads_to_dict(params_req, grads)
1688
+ magnitude = self.level_manager.apply_module_grads(
1689
+ level_name,
1690
+ self.titan_memory,
1691
+ grads_dict,
1692
+ context=context_vec,
1693
+ force=True,
1694
+ )
1695
+ extra_metrics = self.level_manager.pop_last_metrics(level_name)
1696
+ stats = {"grad_norm": magnitude, "gate_hit": 1.0}
1697
+ if surprise_value is not None:
1698
+ stats["surprise_value"] = surprise_value
1699
+ stats.update(extra_metrics)
1700
+ self.last_update_stats[f"titan.{level_name}"] = stats
1701
+
1702
+ def _update_titan_fast(
1703
+ self,
1704
+ fast_state: BlockFastState,
1705
+ attn_out: torch.Tensor,
1706
+ mem_out: torch.Tensor,
1707
+ teach_signal: torch.Tensor,
1708
+ surprise_value: float | None,
1709
+ ) -> None:
1710
+ level_name = self.config.titan_level.name
1711
+ if not self._is_level_allowed("titan"):
1712
+ return
1713
+ if not fast_state.level_manager.should_update(level_name):
1714
+ return
1715
+ if not self._passes_surprise(surprise_value):
1716
+ self._record_gate(level_name, hit=False)
1717
+ return
1718
+ if fast_state.titan_params is None:
1719
+ return
1720
+ modifier = self.self_modifier(
1721
+ key=attn_out.detach(),
1722
+ value=mem_out.detach(),
1723
+ error_signal=teach_signal.detach(),
1724
+ )
1725
+ context_vec = attn_out.detach().mean(dim=(0, 1))
1726
+ base_params = fast_state.titan_params
1727
+ forward_params = params_with_deltas(self.titan_memory, base_params)
1728
+ params_req = require_grad_params(forward_params)
1729
+ with torch.enable_grad():
1730
+ query = attn_out.detach()
1731
+ target = (modifier - teach_signal.detach()).detach()
1732
+ prediction = call_with_params(self.titan_memory, params_req, query)
1733
+ loss_terms = F.mse_loss(prediction, target, reduction="none")
1734
+ active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0
1735
+ mask = active.float()
1736
+ if self.surprise_threshold is not None and self.surprise_metric == "l2":
1737
+ norms = teach_signal.norm(dim=-1, keepdim=True)
1738
+ mask = mask * (norms >= self.surprise_threshold).float()
1739
+ loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)
1740
+ grads = torch.autograd.grad(
1741
+ loss,
1742
+ tuple(params_req.values()),
1743
+ retain_graph=False,
1744
+ allow_unused=True,
1745
+ )
1746
+ grads_dict = grads_to_dict(params_req, grads)
1747
+ updated, magnitude = fast_state.level_manager.apply_grads(
1748
+ level_name,
1749
+ base_params,
1750
+ grads_dict,
1751
+ context=context_vec,
1752
+ force=False,
1753
+ )
1754
+ fast_state.titan_params = updated
1755
+ extra_metrics = fast_state.level_manager.pop_last_metrics(level_name)
1756
+ stats = {"grad_norm": magnitude, "gate_hit": 1.0}
1757
+ if surprise_value is not None:
1758
+ stats["surprise_value"] = surprise_value
1759
+ stats.update(extra_metrics)
1760
+ self.last_update_stats[f"titan.{level_name}"] = stats
1761
+
1762
+ def _update_cms(
1763
+ self,
1764
+ cms_inputs: dict[str, torch.Tensor],
1765
+ cms_outputs: dict[str, torch.Tensor],
1766
+ teach_signal: torch.Tensor,
1767
+ surprise_value: float | None,
1768
+ ) -> None:
1769
+ teach = teach_signal.detach()
1770
+ active_mask = teach.abs().sum(dim=-1) > 0
1771
+ for spec in self.config.cms_levels:
1772
+ level_name = spec.name
1773
+ if not self._is_level_allowed(level_name):
1774
+ continue
1775
+ if not self._passes_surprise(surprise_value):
1776
+ self._record_gate(level_name, hit=False)
1777
+ continue
1778
+ inputs = cms_inputs[level_name]
1779
+ seq_len = inputs.shape[1]
1780
+ chunk_size = int(spec.update_period)
1781
+ if chunk_size <= 0:
1782
+ continue
1783
+ total_norm = 0.0
1784
+ update_events = 0
1785
+ token_events = 0
1786
+ for start in range(0, seq_len, chunk_size):
1787
+ end = min(start + chunk_size, seq_len)
1788
+ chunk_len = end - start
1789
+ chunk_inputs = inputs[:, start:end, :].detach()
1790
+ chunk_teach = teach[:, start:end, :]
1791
+ chunk_active = active_mask[:, start:end]
1792
+ if not bool(chunk_active.any()):
1793
+ continue
1794
+ magnitude = self._update_cms_chunk(
1795
+ level_name,
1796
+ chunk_inputs,
1797
+ chunk_teach,
1798
+ chunk_active,
1799
+ surprise_value,
1800
+ )
1801
+ if magnitude <= 0:
1802
+ continue
1803
+ total_norm += magnitude
1804
+ token_events += chunk_len
1805
+ update_events += 1
1806
+ if update_events == 0:
1807
+ continue
1808
+ stats_payload: Dict[str, float] = {
1809
+ "grad_norm": total_norm,
1810
+ "chunk_tokens": float(token_events),
1811
+ "gate_hit": float(update_events),
1812
+ }
1813
+ if surprise_value is not None:
1814
+ stats_payload["surprise_value"] = surprise_value
1815
+ self.last_update_stats[f"cms.{level_name}"] = stats_payload
1816
+
1817
+ def _update_cms_fast(
1818
+ self,
1819
+ fast_state: BlockFastState,
1820
+ cms_inputs: dict[str, torch.Tensor],
1821
+ teach_signal: torch.Tensor,
1822
+ surprise_value: float | None,
1823
+ ) -> None:
1824
+ teach = teach_signal.detach()
1825
+ active_mask = teach.abs().sum(dim=-1) > 0
1826
+ for spec in self.config.cms_levels:
1827
+ level_name = spec.name
1828
+ if not self._is_level_allowed(level_name):
1829
+ continue
1830
+ if not self._passes_surprise(surprise_value):
1831
+ self._record_gate(level_name, hit=False)
1832
+ continue
1833
+ inputs = cms_inputs[level_name]
1834
+ seq_len = inputs.shape[1]
1835
+ chunk_size = int(spec.update_period)
1836
+ if chunk_size <= 0:
1837
+ continue
1838
+ total_norm = 0.0
1839
+ update_events = 0
1840
+ token_events = 0
1841
+ for start in range(0, seq_len, chunk_size):
1842
+ end = min(start + chunk_size, seq_len)
1843
+ chunk_len = end - start
1844
+ chunk_inputs = inputs[:, start:end, :].detach()
1845
+ chunk_teach = teach[:, start:end, :]
1846
+ chunk_active = active_mask[:, start:end]
1847
+ if not bool(chunk_active.any()):
1848
+ continue
1849
+ magnitude = self._update_cms_chunk_fast(
1850
+ fast_state,
1851
+ level_name,
1852
+ chunk_inputs,
1853
+ chunk_teach,
1854
+ chunk_active,
1855
+ surprise_value,
1856
+ )
1857
+ if magnitude <= 0:
1858
+ continue
1859
+ total_norm += magnitude
1860
+ token_events += chunk_len
1861
+ update_events += 1
1862
+ if update_events == 0:
1863
+ continue
1864
+ stats_payload: Dict[str, float] = {
1865
+ "grad_norm": total_norm,
1866
+ "chunk_tokens": float(token_events),
1867
+ "gate_hit": float(update_events),
1868
+ }
1869
+ if surprise_value is not None:
1870
+ stats_payload["surprise_value"] = surprise_value
1871
+ self.last_update_stats[f"cms.{level_name}"] = stats_payload
1872
+
1873
+ def _update_cms_chunk(
1874
+ self,
1875
+ level_name: str,
1876
+ chunk_inputs: torch.Tensor,
1877
+ chunk_teach: torch.Tensor,
1878
+ chunk_active: torch.Tensor,
1879
+ surprise_value: float | None,
1880
+ ) -> float:
1881
+ if not self._is_level_allowed(level_name):
1882
+ return 0.0
1883
+ if not self._passes_surprise(surprise_value):
1884
+ self._record_gate(level_name, hit=False)
1885
+ return 0.0
1886
+ mask_f = chunk_active.unsqueeze(-1).float()
1887
+ with torch.enable_grad():
1888
+ prediction = self.cms.blocks[level_name](chunk_inputs)
1889
+ loss = _chunk_loss(
1890
+ prediction,
1891
+ chunk_teach,
1892
+ mask_f,
1893
+ reduction=self.config.cms_chunk_reduction,
1894
+ )
1895
+ context_vec = chunk_inputs.mean(dim=(0, 1))
1896
+ magnitude = self.level_manager.optimize(
1897
+ level_name,
1898
+ self.cms.blocks[level_name],
1899
+ loss,
1900
+ context=context_vec,
1901
+ force=True,
1902
+ )
1903
+ self.level_manager.pop_last_metrics(level_name)
1904
+ return magnitude
1905
+
1906
+ def _update_cms_chunk_fast(
1907
+ self,
1908
+ fast_state: BlockFastState,
1909
+ level_name: str,
1910
+ chunk_inputs: torch.Tensor,
1911
+ chunk_teach: torch.Tensor,
1912
+ chunk_active: torch.Tensor,
1913
+ surprise_value: float | None,
1914
+ ) -> float:
1915
+ if not self._is_level_allowed(level_name):
1916
+ return 0.0
1917
+ if not self._passes_surprise(surprise_value):
1918
+ self._record_gate(level_name, hit=False)
1919
+ return 0.0
1920
+ mask_f = chunk_active.unsqueeze(-1).float()
1921
+ base_params = fast_state.cms_params[level_name]
1922
+ forward_params = params_with_deltas(self.cms.blocks[level_name], base_params)
1923
+ params_req = require_grad_params(forward_params)
1924
+ with torch.enable_grad():
1925
+ prediction = call_with_params(self.cms.blocks[level_name], params_req, chunk_inputs)
1926
+ loss = _chunk_loss(
1927
+ prediction,
1928
+ chunk_teach,
1929
+ mask_f,
1930
+ reduction=self.config.cms_chunk_reduction,
1931
+ )
1932
+ grads = torch.autograd.grad(
1933
+ loss,
1934
+ tuple(params_req.values()),
1935
+ retain_graph=False,
1936
+ allow_unused=True,
1937
+ )
1938
+ grads_dict = grads_to_dict(params_req, grads)
1939
+ context_vec = chunk_inputs.mean(dim=(0, 1))
1940
+ updated, magnitude = fast_state.level_manager.apply_grads(
1941
+ level_name,
1942
+ base_params,
1943
+ grads_dict,
1944
+ context=context_vec,
1945
+ force=True,
1946
+ )
1947
+ fast_state.cms_params[level_name] = updated
1948
+ fast_state.level_manager.pop_last_metrics(level_name)
1949
+ return magnitude
1950
+
1951
+ def pop_update_stats(self) -> Dict[str, Dict[str, float]]:
1952
+ stats = self.last_update_stats
1953
+ self.last_update_stats = {}
1954
+ return stats
1955
+
1956
+ def _passes_surprise(self, surprise_value: float | None) -> bool:
1957
+ if self.surprise_threshold is None:
1958
+ return True
1959
+ if surprise_value is None:
1960
+ return False
1961
+ return surprise_value >= self.surprise_threshold
1962
+
1963
+ def _is_level_allowed(self, level_name: str) -> bool:
1964
+ if self.allowed_levels is None:
1965
+ return True
1966
+ return level_name in self.allowed_levels or (
1967
+ level_name.startswith("titan") and "titan" in self.allowed_levels
1968
+ )
1969
+
1970
+ def _record_gate(self, level_name: str, *, hit: bool) -> None:
1971
+ stats_key = f"gate.{level_name}"
1972
+ self.last_update_stats.setdefault(stats_key, {})
1973
+ self.last_update_stats[stats_key]["gate_hit"] = 1.0 if hit else 0.0