nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,12 @@
1
+ """Nested Learning (HOPE) reproduction package."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ from .levels import LevelClock, LevelSpec # noqa: F401
6
+
7
+ try:
8
+ __version__ = version("nested-learning")
9
+ except PackageNotFoundError: # pragma: no cover - editable/local source tree
10
+ __version__ = "0.2.0"
11
+
12
+ __all__ = ["LevelClock", "LevelSpec", "__version__"]
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from .cli import app
4
+
5
+
6
+ def main() -> None:
7
+ app()
8
+
9
+
10
+ if __name__ == "__main__":
11
+ main()
12
+
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Protocol
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class AssocMemory(nn.Module):
10
+ """Base class for associative memories with explicit update hooks."""
11
+
12
+ def forward(self, query: torch.Tensor) -> torch.Tensor: # type: ignore[override]
13
+ raise NotImplementedError
14
+
15
+ @torch.no_grad()
16
+ def update(
17
+ self, *, key: torch.Tensor, value: torch.Tensor, error_signal: torch.Tensor | None = None
18
+ ) -> None:
19
+ raise NotImplementedError
20
+
21
+
22
+ class SupportsReset(Protocol):
23
+ def reset_state(self) -> None: ...
@@ -0,0 +1,147 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .fast_state import AttentionKVCache
10
+
11
+
12
+ @dataclass
13
+ class AttentionConfig:
14
+ dim: int
15
+ heads: int
16
+ dropout: float = 0.0
17
+ use_flash: bool = True
18
+ causal: bool = True
19
+ qk_l2_norm: bool = False
20
+ qk_norm_eps: float = 1e-6
21
+ local_conv_window: int | None = None
22
+
23
+
24
+ class SelfAttention(nn.Module):
25
+ def __init__(self, config: AttentionConfig):
26
+ super().__init__()
27
+ if config.dim % config.heads != 0:
28
+ msg = f"dim must be divisible by heads (got dim={config.dim}, heads={config.heads})"
29
+ raise ValueError(msg)
30
+ self.config = config
31
+ self.heads = config.heads
32
+ self.head_dim = config.dim // config.heads
33
+ self.qkv = nn.Linear(config.dim, config.dim * 3, bias=False)
34
+ self.out_proj = nn.Linear(config.dim, config.dim, bias=False)
35
+ self.resid_dropout = nn.Dropout(config.dropout)
36
+ self.norm = nn.LayerNorm(config.dim)
37
+ self.local_conv: nn.Conv1d | None = None
38
+ if config.local_conv_window is not None:
39
+ window = int(config.local_conv_window)
40
+ if window <= 0:
41
+ raise ValueError("local_conv_window must be positive")
42
+ self.local_conv = nn.Conv1d(
43
+ config.dim,
44
+ config.dim,
45
+ kernel_size=window,
46
+ groups=config.dim,
47
+ padding=0,
48
+ bias=False,
49
+ )
50
+
51
+ def forward( # type: ignore[override]
52
+ self,
53
+ x: torch.Tensor,
54
+ *,
55
+ kv_cache: AttentionKVCache | None = None,
56
+ return_kv_cache: bool = False,
57
+ ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:
58
+ residual = x
59
+ attn_inp = x
60
+ if kv_cache is not None and self.local_conv is not None:
61
+ raise RuntimeError(
62
+ "kv_cache with local_conv_window is not supported in this implementation."
63
+ )
64
+ if self.local_conv is not None:
65
+ kernel = self.local_conv.kernel_size[0]
66
+ attn_inp = attn_inp.transpose(1, 2)
67
+ # Causal depthwise conv: only attends to past tokens.
68
+ attn_inp = F.pad(attn_inp, (kernel - 1, 0))
69
+ attn_inp = self.local_conv(attn_inp).transpose(1, 2)
70
+ q, k, v = self._compute_qkv(attn_inp)
71
+ past_len = 0
72
+ k_all = k
73
+ v_all = v
74
+ if kv_cache is not None:
75
+ if kv_cache.key.size(0) != k.size(0):
76
+ raise ValueError("kv_cache batch dimension must match input batch dimension")
77
+ if kv_cache.key.size(1) != k.size(1) or kv_cache.key.size(-1) != k.size(-1):
78
+ raise ValueError("kv_cache shape is incompatible with attention heads/head_dim")
79
+ past_len = int(kv_cache.key.size(2))
80
+ k_all = torch.cat([kv_cache.key, k], dim=2)
81
+ v_all = torch.cat([kv_cache.value, v], dim=2)
82
+ attn_output = self._scaled_dot_product_attn(q, k_all, v_all, past_len=past_len)
83
+ attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), -1)
84
+ attn_output = self.out_proj(attn_output)
85
+ attn_output = self.resid_dropout(attn_output)
86
+ out = self.norm(residual + attn_output)
87
+ if return_kv_cache:
88
+ return out, AttentionKVCache(key=k_all.detach(), value=v_all.detach())
89
+ return out
90
+
91
+ def _compute_qkv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
92
+ qkv = self.qkv(x)
93
+ q, k, v = qkv.chunk(3, dim=-1)
94
+ shape = (x.size(0), x.size(1), self.heads, self.head_dim)
95
+ q = q.view(*shape).transpose(1, 2)
96
+ k = k.view(*shape).transpose(1, 2)
97
+ v = v.view(*shape).transpose(1, 2)
98
+ if self.config.qk_l2_norm:
99
+ q = F.normalize(q, dim=-1, eps=self.config.qk_norm_eps)
100
+ k = F.normalize(k, dim=-1, eps=self.config.qk_norm_eps)
101
+ return q, k, v
102
+
103
+ def _scaled_dot_product_attn(
104
+ self,
105
+ q: torch.Tensor,
106
+ k: torch.Tensor,
107
+ v: torch.Tensor,
108
+ *,
109
+ past_len: int = 0,
110
+ ) -> torch.Tensor:
111
+ dropout_p = self.config.dropout if self.training else 0.0
112
+ attn_mask = None
113
+ if self.config.causal and past_len > 0:
114
+ query_len = int(q.size(-2))
115
+ key_len = int(k.size(-2))
116
+ key_positions = torch.arange(key_len, device=q.device)
117
+ query_positions = past_len + torch.arange(query_len, device=q.device)
118
+ attn_mask = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
119
+ is_causal = self.config.causal and attn_mask is None
120
+ device_type = q.device.type
121
+ if (
122
+ device_type == "cuda"
123
+ and torch.cuda.is_available()
124
+ and hasattr(torch.backends, "cuda")
125
+ and hasattr(torch.backends.cuda, "sdp_kernel")
126
+ ):
127
+ with torch.backends.cuda.sdp_kernel( # type: ignore[attr-defined]
128
+ enable_flash=self.config.use_flash,
129
+ enable_mem_efficient=True,
130
+ enable_math=not self.config.use_flash,
131
+ ):
132
+ return F.scaled_dot_product_attention(
133
+ q,
134
+ k,
135
+ v,
136
+ attn_mask=attn_mask,
137
+ dropout_p=dropout_p,
138
+ is_causal=is_causal,
139
+ )
140
+ return F.scaled_dot_product_attention(
141
+ q,
142
+ k,
143
+ v,
144
+ attn_mask=attn_mask,
145
+ dropout_p=dropout_p,
146
+ is_causal=is_causal,
147
+ )
@@ -0,0 +1,104 @@
1
+ from __future__ import annotations
2
+
3
+ import platform
4
+ import sys
5
+ from dataclasses import asdict, dataclass, field
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass
12
+ class RuntimeCapabilities:
13
+ python_version: str
14
+ platform: str
15
+ machine: str
16
+ torch_version: str
17
+ cuda_available: bool
18
+ cuda_device_count: int
19
+ cuda_devices: list[str] = field(default_factory=list)
20
+ mps_available: bool = False
21
+ mps_built: bool = False
22
+ distributed_available: bool = False
23
+ compile_available: bool = False
24
+ sdpa_flash_available: bool = False
25
+ sdpa_mem_efficient_available: bool = False
26
+ sdpa_math_available: bool = True
27
+ bf16_supported: bool = False
28
+ fp16_supported: bool = False
29
+ default_device: str = "cpu"
30
+ warnings: list[str] = field(default_factory=list)
31
+
32
+ def to_dict(self) -> dict[str, Any]:
33
+ return asdict(self)
34
+
35
+
36
+ def collect_runtime_capabilities() -> RuntimeCapabilities:
37
+ cuda_available = bool(torch.cuda.is_available())
38
+ cuda_device_count = int(torch.cuda.device_count() if cuda_available else 0)
39
+ cuda_devices: list[str] = []
40
+ warnings: list[str] = []
41
+
42
+ if cuda_available:
43
+ for idx in range(cuda_device_count):
44
+ try:
45
+ name = torch.cuda.get_device_name(idx)
46
+ cuda_devices.append(f"cuda:{idx} {name}")
47
+ except Exception as err: # pragma: no cover - backend specific
48
+ warnings.append(f"failed to query cuda:{idx}: {err}")
49
+
50
+ mps_backend = getattr(torch.backends, "mps", None)
51
+ mps_available = bool(mps_backend and mps_backend.is_available())
52
+ mps_built = bool(mps_backend and mps_backend.is_built())
53
+
54
+ distributed_available = bool(torch.distributed.is_available())
55
+ compile_available = bool(hasattr(torch, "compile"))
56
+
57
+ flash_enabled = False
58
+ mem_eff_enabled = False
59
+ math_enabled = True
60
+ if hasattr(torch.backends, "cuda") and torch.backends.cuda.is_built():
61
+ try:
62
+ flash_enabled = bool(torch.backends.cuda.flash_sdp_enabled())
63
+ mem_eff_enabled = bool(torch.backends.cuda.mem_efficient_sdp_enabled())
64
+ math_enabled = bool(torch.backends.cuda.math_sdp_enabled())
65
+ except Exception as err: # pragma: no cover - backend specific
66
+ warnings.append(f"failed to query SDPA backend flags: {err}")
67
+
68
+ bf16_supported = False
69
+ fp16_supported = False
70
+ if cuda_available:
71
+ try:
72
+ bf16_supported = bool(torch.cuda.is_bf16_supported())
73
+ fp16_supported = True
74
+ except Exception as err: # pragma: no cover
75
+ warnings.append(f"failed to query CUDA dtype support: {err}")
76
+ elif mps_available:
77
+ fp16_supported = True
78
+
79
+ default_device = "cpu"
80
+ if cuda_available:
81
+ default_device = "cuda:0"
82
+ elif mps_available:
83
+ default_device = "mps"
84
+
85
+ return RuntimeCapabilities(
86
+ python_version=sys.version.split()[0],
87
+ platform=platform.platform(),
88
+ machine=platform.machine(),
89
+ torch_version=torch.__version__,
90
+ cuda_available=cuda_available,
91
+ cuda_device_count=cuda_device_count,
92
+ cuda_devices=cuda_devices,
93
+ mps_available=mps_available,
94
+ mps_built=mps_built,
95
+ distributed_available=distributed_available,
96
+ compile_available=compile_available,
97
+ sdpa_flash_available=flash_enabled,
98
+ sdpa_mem_efficient_available=mem_eff_enabled,
99
+ sdpa_math_available=math_enabled,
100
+ bf16_supported=bf16_supported,
101
+ fp16_supported=fp16_supported,
102
+ default_device=default_device,
103
+ warnings=warnings,
104
+ )
nested_learning/cli.py ADDED
@@ -0,0 +1,253 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Annotated
6
+
7
+ import torch
8
+ import typer
9
+ from omegaconf import OmegaConf
10
+
11
+ from .capabilities import collect_runtime_capabilities
12
+ from .config_utils import compose_config
13
+ from .device import resolve_device
14
+ from .training import build_model_from_cfg
15
+
16
+ app = typer.Typer(
17
+ add_completion=False,
18
+ no_args_is_help=True,
19
+ help="Nested Learning CLI (training, diagnostics, and smoke checks).",
20
+ )
21
+
22
+
23
+ def _resolve_cli_device(device: str) -> torch.device:
24
+ lowered = device.strip().lower()
25
+ if lowered == "auto":
26
+ caps = collect_runtime_capabilities()
27
+ return resolve_device(caps.default_device)
28
+ return resolve_device(device)
29
+
30
+
31
+ @app.command("doctor")
32
+ def doctor(
33
+ as_json: Annotated[
34
+ bool,
35
+ typer.Option("--json", help="Emit machine-readable JSON only."),
36
+ ] = False,
37
+ output: Annotated[
38
+ Path | None,
39
+ typer.Option(
40
+ "--output",
41
+ "-o",
42
+ help="Optional path for writing doctor output JSON.",
43
+ dir_okay=False,
44
+ writable=True,
45
+ ),
46
+ ] = None,
47
+ ) -> None:
48
+ """Inspect runtime capabilities for backend/device compatibility."""
49
+ payload = collect_runtime_capabilities().to_dict()
50
+ rendered = json.dumps(payload, indent=2, sort_keys=True)
51
+ if output is not None:
52
+ output.parent.mkdir(parents=True, exist_ok=True)
53
+ output.write_text(rendered + "\n", encoding="utf-8")
54
+ if as_json:
55
+ typer.echo(rendered)
56
+ return
57
+
58
+ typer.echo("Runtime Doctor")
59
+ typer.echo(f"python: {payload['python_version']}")
60
+ typer.echo(f"platform: {payload['platform']} ({payload['machine']})")
61
+ typer.echo(f"torch: {payload['torch_version']}")
62
+ typer.echo(f"default_device: {payload['default_device']}")
63
+ typer.echo(
64
+ "cuda_available: {available} ({count} device(s))".format(
65
+ available=payload["cuda_available"],
66
+ count=payload["cuda_device_count"],
67
+ )
68
+ )
69
+ for name in payload["cuda_devices"]:
70
+ typer.echo(f" - {name}")
71
+ typer.echo(f"mps_available: {payload['mps_available']} (built={payload['mps_built']})")
72
+ typer.echo(f"distributed_available: {payload['distributed_available']}")
73
+ typer.echo(f"compile_available: {payload['compile_available']}")
74
+ typer.echo(
75
+ "sdpa backends: flash={flash} mem_efficient={mem} math={math}".format(
76
+ flash=payload["sdpa_flash_available"],
77
+ mem=payload["sdpa_mem_efficient_available"],
78
+ math=payload["sdpa_math_available"],
79
+ )
80
+ )
81
+ typer.echo(f"dtype support: bf16={payload['bf16_supported']} fp16={payload['fp16_supported']}")
82
+ if payload["warnings"]:
83
+ typer.echo("warnings:")
84
+ for warning in payload["warnings"]:
85
+ typer.echo(f" - {warning}")
86
+
87
+
88
+ @app.command("smoke")
89
+ def smoke(
90
+ config_name: Annotated[
91
+ str,
92
+ typer.Option("--config-name", "-c", help="Hydra config name (e.g. pilot, hope/mid)."),
93
+ ] = "pilot_smoke",
94
+ override: Annotated[
95
+ list[str] | None,
96
+ typer.Option(
97
+ "--override",
98
+ "-O",
99
+ help="Hydra override(s), may be passed multiple times.",
100
+ ),
101
+ ] = None,
102
+ config_dir: Annotated[
103
+ Path | None,
104
+ typer.Option(
105
+ "--config-dir",
106
+ help="Optional explicit config directory.",
107
+ exists=True,
108
+ file_okay=False,
109
+ dir_okay=True,
110
+ readable=True,
111
+ ),
112
+ ] = None,
113
+ device: Annotated[
114
+ str,
115
+ typer.Option(
116
+ "--device",
117
+ help="Device string for smoke pass (cpu, cuda:0, mps, auto).",
118
+ ),
119
+ ] = "cpu",
120
+ batch_size: Annotated[
121
+ int,
122
+ typer.Option("--batch-size", help="Synthetic smoke batch size."),
123
+ ] = 1,
124
+ seq_len: Annotated[
125
+ int,
126
+ typer.Option("--seq-len", help="Synthetic smoke sequence length."),
127
+ ] = 32,
128
+ ) -> None:
129
+ """Run a lightweight forward-pass smoke test with composed config."""
130
+ cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)
131
+ model_cfg = cfg.model
132
+ torch_device = _resolve_cli_device(device)
133
+ model = build_model_from_cfg(model_cfg).to(torch_device)
134
+ model.eval()
135
+
136
+ vocab_size = int(model_cfg.vocab_size)
137
+ tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch_device)
138
+ with torch.no_grad():
139
+ outputs = model(tokens)
140
+ if isinstance(outputs, tuple):
141
+ logits = outputs[0]
142
+ else:
143
+ logits = outputs
144
+ typer.echo(
145
+ json.dumps(
146
+ {
147
+ "status": "ok",
148
+ "config_name": config_name,
149
+ "device": str(torch_device),
150
+ "batch_size": batch_size,
151
+ "seq_len": seq_len,
152
+ "logits_shape": list(logits.shape),
153
+ "dtype": str(logits.dtype),
154
+ },
155
+ sort_keys=True,
156
+ )
157
+ )
158
+
159
+
160
+ @app.command("train")
161
+ def train(
162
+ config_name: Annotated[
163
+ str,
164
+ typer.Option("--config-name", "-c", help="Hydra config name for training."),
165
+ ] = "pilot",
166
+ override: Annotated[
167
+ list[str] | None,
168
+ typer.Option("--override", "-O", help="Hydra override(s), may be passed multiple times."),
169
+ ] = None,
170
+ config_dir: Annotated[
171
+ Path | None,
172
+ typer.Option(
173
+ "--config-dir",
174
+ help="Optional explicit config directory.",
175
+ exists=True,
176
+ file_okay=False,
177
+ dir_okay=True,
178
+ readable=True,
179
+ ),
180
+ ] = None,
181
+ device: Annotated[
182
+ str | None,
183
+ typer.Option(
184
+ "--device",
185
+ help="Override cfg.train.device (e.g. cpu, cuda:1, auto).",
186
+ ),
187
+ ] = None,
188
+ dry_run: Annotated[
189
+ bool,
190
+ typer.Option("--dry-run", help="Print resolved config and exit."),
191
+ ] = False,
192
+ ) -> None:
193
+ """Launch a local (single-process) training loop."""
194
+ from .training import run_training_loop
195
+
196
+ cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)
197
+ if device is not None:
198
+ cfg.train.device = device
199
+ if dry_run:
200
+ typer.echo(OmegaConf.to_yaml(cfg))
201
+ return
202
+ train_device = _resolve_cli_device(str(cfg.train.device))
203
+ run_training_loop(cfg, device=train_device, distributed=False, dist_ctx=None)
204
+
205
+
206
+ @app.command("audit")
207
+ def audit(
208
+ config_name: Annotated[
209
+ str,
210
+ typer.Option("--config-name", "-c", help="Hydra config name to audit."),
211
+ ] = "pilot_paper_faithful",
212
+ override: Annotated[
213
+ list[str] | None,
214
+ typer.Option("--override", "-O", help="Hydra override(s), may be passed multiple times."),
215
+ ] = None,
216
+ config_dir: Annotated[
217
+ Path | None,
218
+ typer.Option(
219
+ "--config-dir",
220
+ help="Optional explicit config directory.",
221
+ exists=True,
222
+ file_okay=False,
223
+ dir_okay=True,
224
+ readable=True,
225
+ ),
226
+ ] = None,
227
+ ) -> None:
228
+ """Run static architecture checks on a composed config."""
229
+ cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)
230
+ model = build_model_from_cfg(cfg.model)
231
+ has_embed = hasattr(model, "embed")
232
+ has_lm_head = hasattr(model, "lm_head")
233
+ tied_weights = False
234
+ if has_embed and has_lm_head:
235
+ embed = getattr(model, "embed")
236
+ lm_head = getattr(model, "lm_head")
237
+ tied_weights = bool(embed.weight.data_ptr() == lm_head.weight.data_ptr())
238
+
239
+ report = {
240
+ "status": "ok",
241
+ "config_name": config_name,
242
+ "model_type": str(cfg.model.get("type", "hope")),
243
+ "block_variant": str(cfg.model.get("block_variant", "hope_hybrid")),
244
+ "surprise_metric": str(cfg.model.get("surprise_metric", "l2")),
245
+ "surprise_threshold": cfg.model.get("surprise_threshold"),
246
+ "teach_scale": float(cfg.model.get("teach_scale", 1.0)),
247
+ "teach_clip": float(cfg.model.get("teach_clip", 0.0)),
248
+ "freeze_backbone": bool(cfg.model.get("freeze_backbone", False)),
249
+ "has_embed": has_embed,
250
+ "has_lm_head": has_lm_head,
251
+ "lm_tied_to_embedding": tied_weights,
252
+ }
253
+ typer.echo(json.dumps(report, sort_keys=True))
nested_learning/cms.py ADDED
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Sequence
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .levels import LevelSpec, ensure_level_specs
9
+
10
+
11
+ class CMSBlock(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ hidden_multiplier: int = 4,
16
+ activation: str = "gelu",
17
+ grad_clip: float = 1.0,
18
+ use_layernorm: bool = True,
19
+ ):
20
+ super().__init__()
21
+ hidden = dim * hidden_multiplier
22
+ act: nn.Module
23
+ if activation == "relu":
24
+ act = nn.ReLU()
25
+ elif activation == "silu":
26
+ act = nn.SiLU()
27
+ else:
28
+ act = nn.GELU()
29
+ norm: nn.Module = nn.LayerNorm(dim) if use_layernorm else nn.Identity()
30
+ self.net = nn.Sequential(
31
+ norm,
32
+ nn.Linear(dim, hidden),
33
+ act,
34
+ nn.Linear(hidden, dim),
35
+ )
36
+ self.grad_clip = grad_clip
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
39
+ delta = self.net(x)
40
+ if self.training and self.grad_clip > 0:
41
+ with torch.no_grad():
42
+ norm = delta.norm(dim=-1, keepdim=True)
43
+ scale = torch.clamp(norm / self.grad_clip, min=1.0)
44
+ delta = delta / scale
45
+ return x + delta
46
+
47
+
48
+ class CMS(nn.Module):
49
+ """Continuum Memory System with multi-frequency updates."""
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ dim: int,
55
+ levels: Sequence[LevelSpec],
56
+ hidden_multiplier: int = 4,
57
+ activation: str = "gelu",
58
+ use_layernorm: bool = True,
59
+ ) -> None:
60
+ super().__init__()
61
+ ordered = ensure_level_specs(levels)
62
+ self.level_specs: Sequence[LevelSpec] = tuple(ordered)
63
+ self.blocks = nn.ModuleDict(
64
+ {
65
+ spec.name: CMSBlock(
66
+ dim,
67
+ hidden_multiplier=hidden_multiplier,
68
+ activation=activation,
69
+ grad_clip=1.0,
70
+ use_layernorm=use_layernorm,
71
+ )
72
+ for spec in self.level_specs
73
+ }
74
+ )
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ *,
80
+ return_intermediates: bool = False,
81
+ ) -> torch.Tensor | tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
82
+ current = x
83
+ inputs: Dict[str, torch.Tensor] = {}
84
+ outputs: Dict[str, torch.Tensor] = {}
85
+ for spec in self.level_specs:
86
+ block = self.blocks[spec.name]
87
+ inputs[spec.name] = current
88
+ current = block(current)
89
+ outputs[spec.name] = current
90
+ if return_intermediates:
91
+ return current, inputs, outputs
92
+ return current
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ from importlib.resources import as_file, files
5
+ from pathlib import Path
6
+ from typing import Iterator
7
+
8
+ from hydra import compose, initialize_config_dir
9
+ from hydra.core.global_hydra import GlobalHydra
10
+ from omegaconf import DictConfig
11
+
12
+ from .training import unwrap_config
13
+
14
+
15
+ def find_repo_root(start: Path | None = None) -> Path | None:
16
+ cursor = (start or Path.cwd()).resolve()
17
+ for candidate in (cursor, *cursor.parents):
18
+ if (candidate / ".git").exists() and (candidate / "configs").exists():
19
+ return candidate
20
+ return None
21
+
22
+
23
+ @contextmanager
24
+ def resolved_config_dir(config_dir: Path | None = None) -> Iterator[Path]:
25
+ if config_dir is not None:
26
+ yield config_dir.resolve()
27
+ return
28
+
29
+ module_path = Path(__file__).resolve()
30
+ repo_config_dir = module_path.parents[2] / "configs"
31
+ if repo_config_dir.exists():
32
+ yield repo_config_dir
33
+ return
34
+
35
+ package_configs = files("nested_learning").joinpath("configs")
36
+ with as_file(package_configs) as pkg_dir:
37
+ yield Path(pkg_dir)
38
+
39
+
40
+ def compose_config(
41
+ config_name: str,
42
+ *,
43
+ overrides: list[str] | None = None,
44
+ config_dir: Path | None = None,
45
+ ) -> DictConfig:
46
+ with resolved_config_dir(config_dir) as cfg_dir:
47
+ GlobalHydra.instance().clear()
48
+ with initialize_config_dir(version_base=None, config_dir=str(cfg_dir)):
49
+ cfg = compose(config_name=config_name, overrides=overrides or [])
50
+ return unwrap_config(cfg)