nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Nested Learning (HOPE) reproduction package."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
from .levels import LevelClock, LevelSpec # noqa: F401
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
__version__ = version("nested-learning")
|
|
9
|
+
except PackageNotFoundError: # pragma: no cover - editable/local source tree
|
|
10
|
+
__version__ = "0.2.0"
|
|
11
|
+
|
|
12
|
+
__all__ = ["LevelClock", "LevelSpec", "__version__"]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AssocMemory(nn.Module):
|
|
10
|
+
"""Base class for associative memories with explicit update hooks."""
|
|
11
|
+
|
|
12
|
+
def forward(self, query: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
@torch.no_grad()
|
|
16
|
+
def update(
|
|
17
|
+
self, *, key: torch.Tensor, value: torch.Tensor, error_signal: torch.Tensor | None = None
|
|
18
|
+
) -> None:
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SupportsReset(Protocol):
|
|
23
|
+
def reset_state(self) -> None: ...
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from .fast_state import AttentionKVCache
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class AttentionConfig:
|
|
14
|
+
dim: int
|
|
15
|
+
heads: int
|
|
16
|
+
dropout: float = 0.0
|
|
17
|
+
use_flash: bool = True
|
|
18
|
+
causal: bool = True
|
|
19
|
+
qk_l2_norm: bool = False
|
|
20
|
+
qk_norm_eps: float = 1e-6
|
|
21
|
+
local_conv_window: int | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SelfAttention(nn.Module):
|
|
25
|
+
def __init__(self, config: AttentionConfig):
|
|
26
|
+
super().__init__()
|
|
27
|
+
if config.dim % config.heads != 0:
|
|
28
|
+
msg = f"dim must be divisible by heads (got dim={config.dim}, heads={config.heads})"
|
|
29
|
+
raise ValueError(msg)
|
|
30
|
+
self.config = config
|
|
31
|
+
self.heads = config.heads
|
|
32
|
+
self.head_dim = config.dim // config.heads
|
|
33
|
+
self.qkv = nn.Linear(config.dim, config.dim * 3, bias=False)
|
|
34
|
+
self.out_proj = nn.Linear(config.dim, config.dim, bias=False)
|
|
35
|
+
self.resid_dropout = nn.Dropout(config.dropout)
|
|
36
|
+
self.norm = nn.LayerNorm(config.dim)
|
|
37
|
+
self.local_conv: nn.Conv1d | None = None
|
|
38
|
+
if config.local_conv_window is not None:
|
|
39
|
+
window = int(config.local_conv_window)
|
|
40
|
+
if window <= 0:
|
|
41
|
+
raise ValueError("local_conv_window must be positive")
|
|
42
|
+
self.local_conv = nn.Conv1d(
|
|
43
|
+
config.dim,
|
|
44
|
+
config.dim,
|
|
45
|
+
kernel_size=window,
|
|
46
|
+
groups=config.dim,
|
|
47
|
+
padding=0,
|
|
48
|
+
bias=False,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def forward( # type: ignore[override]
|
|
52
|
+
self,
|
|
53
|
+
x: torch.Tensor,
|
|
54
|
+
*,
|
|
55
|
+
kv_cache: AttentionKVCache | None = None,
|
|
56
|
+
return_kv_cache: bool = False,
|
|
57
|
+
) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:
|
|
58
|
+
residual = x
|
|
59
|
+
attn_inp = x
|
|
60
|
+
if kv_cache is not None and self.local_conv is not None:
|
|
61
|
+
raise RuntimeError(
|
|
62
|
+
"kv_cache with local_conv_window is not supported in this implementation."
|
|
63
|
+
)
|
|
64
|
+
if self.local_conv is not None:
|
|
65
|
+
kernel = self.local_conv.kernel_size[0]
|
|
66
|
+
attn_inp = attn_inp.transpose(1, 2)
|
|
67
|
+
# Causal depthwise conv: only attends to past tokens.
|
|
68
|
+
attn_inp = F.pad(attn_inp, (kernel - 1, 0))
|
|
69
|
+
attn_inp = self.local_conv(attn_inp).transpose(1, 2)
|
|
70
|
+
q, k, v = self._compute_qkv(attn_inp)
|
|
71
|
+
past_len = 0
|
|
72
|
+
k_all = k
|
|
73
|
+
v_all = v
|
|
74
|
+
if kv_cache is not None:
|
|
75
|
+
if kv_cache.key.size(0) != k.size(0):
|
|
76
|
+
raise ValueError("kv_cache batch dimension must match input batch dimension")
|
|
77
|
+
if kv_cache.key.size(1) != k.size(1) or kv_cache.key.size(-1) != k.size(-1):
|
|
78
|
+
raise ValueError("kv_cache shape is incompatible with attention heads/head_dim")
|
|
79
|
+
past_len = int(kv_cache.key.size(2))
|
|
80
|
+
k_all = torch.cat([kv_cache.key, k], dim=2)
|
|
81
|
+
v_all = torch.cat([kv_cache.value, v], dim=2)
|
|
82
|
+
attn_output = self._scaled_dot_product_attn(q, k_all, v_all, past_len=past_len)
|
|
83
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), -1)
|
|
84
|
+
attn_output = self.out_proj(attn_output)
|
|
85
|
+
attn_output = self.resid_dropout(attn_output)
|
|
86
|
+
out = self.norm(residual + attn_output)
|
|
87
|
+
if return_kv_cache:
|
|
88
|
+
return out, AttentionKVCache(key=k_all.detach(), value=v_all.detach())
|
|
89
|
+
return out
|
|
90
|
+
|
|
91
|
+
def _compute_qkv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
92
|
+
qkv = self.qkv(x)
|
|
93
|
+
q, k, v = qkv.chunk(3, dim=-1)
|
|
94
|
+
shape = (x.size(0), x.size(1), self.heads, self.head_dim)
|
|
95
|
+
q = q.view(*shape).transpose(1, 2)
|
|
96
|
+
k = k.view(*shape).transpose(1, 2)
|
|
97
|
+
v = v.view(*shape).transpose(1, 2)
|
|
98
|
+
if self.config.qk_l2_norm:
|
|
99
|
+
q = F.normalize(q, dim=-1, eps=self.config.qk_norm_eps)
|
|
100
|
+
k = F.normalize(k, dim=-1, eps=self.config.qk_norm_eps)
|
|
101
|
+
return q, k, v
|
|
102
|
+
|
|
103
|
+
def _scaled_dot_product_attn(
|
|
104
|
+
self,
|
|
105
|
+
q: torch.Tensor,
|
|
106
|
+
k: torch.Tensor,
|
|
107
|
+
v: torch.Tensor,
|
|
108
|
+
*,
|
|
109
|
+
past_len: int = 0,
|
|
110
|
+
) -> torch.Tensor:
|
|
111
|
+
dropout_p = self.config.dropout if self.training else 0.0
|
|
112
|
+
attn_mask = None
|
|
113
|
+
if self.config.causal and past_len > 0:
|
|
114
|
+
query_len = int(q.size(-2))
|
|
115
|
+
key_len = int(k.size(-2))
|
|
116
|
+
key_positions = torch.arange(key_len, device=q.device)
|
|
117
|
+
query_positions = past_len + torch.arange(query_len, device=q.device)
|
|
118
|
+
attn_mask = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
|
|
119
|
+
is_causal = self.config.causal and attn_mask is None
|
|
120
|
+
device_type = q.device.type
|
|
121
|
+
if (
|
|
122
|
+
device_type == "cuda"
|
|
123
|
+
and torch.cuda.is_available()
|
|
124
|
+
and hasattr(torch.backends, "cuda")
|
|
125
|
+
and hasattr(torch.backends.cuda, "sdp_kernel")
|
|
126
|
+
):
|
|
127
|
+
with torch.backends.cuda.sdp_kernel( # type: ignore[attr-defined]
|
|
128
|
+
enable_flash=self.config.use_flash,
|
|
129
|
+
enable_mem_efficient=True,
|
|
130
|
+
enable_math=not self.config.use_flash,
|
|
131
|
+
):
|
|
132
|
+
return F.scaled_dot_product_attention(
|
|
133
|
+
q,
|
|
134
|
+
k,
|
|
135
|
+
v,
|
|
136
|
+
attn_mask=attn_mask,
|
|
137
|
+
dropout_p=dropout_p,
|
|
138
|
+
is_causal=is_causal,
|
|
139
|
+
)
|
|
140
|
+
return F.scaled_dot_product_attention(
|
|
141
|
+
q,
|
|
142
|
+
k,
|
|
143
|
+
v,
|
|
144
|
+
attn_mask=attn_mask,
|
|
145
|
+
dropout_p=dropout_p,
|
|
146
|
+
is_causal=is_causal,
|
|
147
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import platform
|
|
4
|
+
import sys
|
|
5
|
+
from dataclasses import asdict, dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class RuntimeCapabilities:
|
|
13
|
+
python_version: str
|
|
14
|
+
platform: str
|
|
15
|
+
machine: str
|
|
16
|
+
torch_version: str
|
|
17
|
+
cuda_available: bool
|
|
18
|
+
cuda_device_count: int
|
|
19
|
+
cuda_devices: list[str] = field(default_factory=list)
|
|
20
|
+
mps_available: bool = False
|
|
21
|
+
mps_built: bool = False
|
|
22
|
+
distributed_available: bool = False
|
|
23
|
+
compile_available: bool = False
|
|
24
|
+
sdpa_flash_available: bool = False
|
|
25
|
+
sdpa_mem_efficient_available: bool = False
|
|
26
|
+
sdpa_math_available: bool = True
|
|
27
|
+
bf16_supported: bool = False
|
|
28
|
+
fp16_supported: bool = False
|
|
29
|
+
default_device: str = "cpu"
|
|
30
|
+
warnings: list[str] = field(default_factory=list)
|
|
31
|
+
|
|
32
|
+
def to_dict(self) -> dict[str, Any]:
|
|
33
|
+
return asdict(self)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def collect_runtime_capabilities() -> RuntimeCapabilities:
|
|
37
|
+
cuda_available = bool(torch.cuda.is_available())
|
|
38
|
+
cuda_device_count = int(torch.cuda.device_count() if cuda_available else 0)
|
|
39
|
+
cuda_devices: list[str] = []
|
|
40
|
+
warnings: list[str] = []
|
|
41
|
+
|
|
42
|
+
if cuda_available:
|
|
43
|
+
for idx in range(cuda_device_count):
|
|
44
|
+
try:
|
|
45
|
+
name = torch.cuda.get_device_name(idx)
|
|
46
|
+
cuda_devices.append(f"cuda:{idx} {name}")
|
|
47
|
+
except Exception as err: # pragma: no cover - backend specific
|
|
48
|
+
warnings.append(f"failed to query cuda:{idx}: {err}")
|
|
49
|
+
|
|
50
|
+
mps_backend = getattr(torch.backends, "mps", None)
|
|
51
|
+
mps_available = bool(mps_backend and mps_backend.is_available())
|
|
52
|
+
mps_built = bool(mps_backend and mps_backend.is_built())
|
|
53
|
+
|
|
54
|
+
distributed_available = bool(torch.distributed.is_available())
|
|
55
|
+
compile_available = bool(hasattr(torch, "compile"))
|
|
56
|
+
|
|
57
|
+
flash_enabled = False
|
|
58
|
+
mem_eff_enabled = False
|
|
59
|
+
math_enabled = True
|
|
60
|
+
if hasattr(torch.backends, "cuda") and torch.backends.cuda.is_built():
|
|
61
|
+
try:
|
|
62
|
+
flash_enabled = bool(torch.backends.cuda.flash_sdp_enabled())
|
|
63
|
+
mem_eff_enabled = bool(torch.backends.cuda.mem_efficient_sdp_enabled())
|
|
64
|
+
math_enabled = bool(torch.backends.cuda.math_sdp_enabled())
|
|
65
|
+
except Exception as err: # pragma: no cover - backend specific
|
|
66
|
+
warnings.append(f"failed to query SDPA backend flags: {err}")
|
|
67
|
+
|
|
68
|
+
bf16_supported = False
|
|
69
|
+
fp16_supported = False
|
|
70
|
+
if cuda_available:
|
|
71
|
+
try:
|
|
72
|
+
bf16_supported = bool(torch.cuda.is_bf16_supported())
|
|
73
|
+
fp16_supported = True
|
|
74
|
+
except Exception as err: # pragma: no cover
|
|
75
|
+
warnings.append(f"failed to query CUDA dtype support: {err}")
|
|
76
|
+
elif mps_available:
|
|
77
|
+
fp16_supported = True
|
|
78
|
+
|
|
79
|
+
default_device = "cpu"
|
|
80
|
+
if cuda_available:
|
|
81
|
+
default_device = "cuda:0"
|
|
82
|
+
elif mps_available:
|
|
83
|
+
default_device = "mps"
|
|
84
|
+
|
|
85
|
+
return RuntimeCapabilities(
|
|
86
|
+
python_version=sys.version.split()[0],
|
|
87
|
+
platform=platform.platform(),
|
|
88
|
+
machine=platform.machine(),
|
|
89
|
+
torch_version=torch.__version__,
|
|
90
|
+
cuda_available=cuda_available,
|
|
91
|
+
cuda_device_count=cuda_device_count,
|
|
92
|
+
cuda_devices=cuda_devices,
|
|
93
|
+
mps_available=mps_available,
|
|
94
|
+
mps_built=mps_built,
|
|
95
|
+
distributed_available=distributed_available,
|
|
96
|
+
compile_available=compile_available,
|
|
97
|
+
sdpa_flash_available=flash_enabled,
|
|
98
|
+
sdpa_mem_efficient_available=mem_eff_enabled,
|
|
99
|
+
sdpa_math_available=math_enabled,
|
|
100
|
+
bf16_supported=bf16_supported,
|
|
101
|
+
fp16_supported=fp16_supported,
|
|
102
|
+
default_device=default_device,
|
|
103
|
+
warnings=warnings,
|
|
104
|
+
)
|
nested_learning/cli.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import typer
|
|
9
|
+
from omegaconf import OmegaConf
|
|
10
|
+
|
|
11
|
+
from .capabilities import collect_runtime_capabilities
|
|
12
|
+
from .config_utils import compose_config
|
|
13
|
+
from .device import resolve_device
|
|
14
|
+
from .training import build_model_from_cfg
|
|
15
|
+
|
|
16
|
+
app = typer.Typer(
|
|
17
|
+
add_completion=False,
|
|
18
|
+
no_args_is_help=True,
|
|
19
|
+
help="Nested Learning CLI (training, diagnostics, and smoke checks).",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _resolve_cli_device(device: str) -> torch.device:
|
|
24
|
+
lowered = device.strip().lower()
|
|
25
|
+
if lowered == "auto":
|
|
26
|
+
caps = collect_runtime_capabilities()
|
|
27
|
+
return resolve_device(caps.default_device)
|
|
28
|
+
return resolve_device(device)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@app.command("doctor")
|
|
32
|
+
def doctor(
|
|
33
|
+
as_json: Annotated[
|
|
34
|
+
bool,
|
|
35
|
+
typer.Option("--json", help="Emit machine-readable JSON only."),
|
|
36
|
+
] = False,
|
|
37
|
+
output: Annotated[
|
|
38
|
+
Path | None,
|
|
39
|
+
typer.Option(
|
|
40
|
+
"--output",
|
|
41
|
+
"-o",
|
|
42
|
+
help="Optional path for writing doctor output JSON.",
|
|
43
|
+
dir_okay=False,
|
|
44
|
+
writable=True,
|
|
45
|
+
),
|
|
46
|
+
] = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Inspect runtime capabilities for backend/device compatibility."""
|
|
49
|
+
payload = collect_runtime_capabilities().to_dict()
|
|
50
|
+
rendered = json.dumps(payload, indent=2, sort_keys=True)
|
|
51
|
+
if output is not None:
|
|
52
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
output.write_text(rendered + "\n", encoding="utf-8")
|
|
54
|
+
if as_json:
|
|
55
|
+
typer.echo(rendered)
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
typer.echo("Runtime Doctor")
|
|
59
|
+
typer.echo(f"python: {payload['python_version']}")
|
|
60
|
+
typer.echo(f"platform: {payload['platform']} ({payload['machine']})")
|
|
61
|
+
typer.echo(f"torch: {payload['torch_version']}")
|
|
62
|
+
typer.echo(f"default_device: {payload['default_device']}")
|
|
63
|
+
typer.echo(
|
|
64
|
+
"cuda_available: {available} ({count} device(s))".format(
|
|
65
|
+
available=payload["cuda_available"],
|
|
66
|
+
count=payload["cuda_device_count"],
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
for name in payload["cuda_devices"]:
|
|
70
|
+
typer.echo(f" - {name}")
|
|
71
|
+
typer.echo(f"mps_available: {payload['mps_available']} (built={payload['mps_built']})")
|
|
72
|
+
typer.echo(f"distributed_available: {payload['distributed_available']}")
|
|
73
|
+
typer.echo(f"compile_available: {payload['compile_available']}")
|
|
74
|
+
typer.echo(
|
|
75
|
+
"sdpa backends: flash={flash} mem_efficient={mem} math={math}".format(
|
|
76
|
+
flash=payload["sdpa_flash_available"],
|
|
77
|
+
mem=payload["sdpa_mem_efficient_available"],
|
|
78
|
+
math=payload["sdpa_math_available"],
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
typer.echo(f"dtype support: bf16={payload['bf16_supported']} fp16={payload['fp16_supported']}")
|
|
82
|
+
if payload["warnings"]:
|
|
83
|
+
typer.echo("warnings:")
|
|
84
|
+
for warning in payload["warnings"]:
|
|
85
|
+
typer.echo(f" - {warning}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@app.command("smoke")
|
|
89
|
+
def smoke(
|
|
90
|
+
config_name: Annotated[
|
|
91
|
+
str,
|
|
92
|
+
typer.Option("--config-name", "-c", help="Hydra config name (e.g. pilot, hope/mid)."),
|
|
93
|
+
] = "pilot_smoke",
|
|
94
|
+
override: Annotated[
|
|
95
|
+
list[str] | None,
|
|
96
|
+
typer.Option(
|
|
97
|
+
"--override",
|
|
98
|
+
"-O",
|
|
99
|
+
help="Hydra override(s), may be passed multiple times.",
|
|
100
|
+
),
|
|
101
|
+
] = None,
|
|
102
|
+
config_dir: Annotated[
|
|
103
|
+
Path | None,
|
|
104
|
+
typer.Option(
|
|
105
|
+
"--config-dir",
|
|
106
|
+
help="Optional explicit config directory.",
|
|
107
|
+
exists=True,
|
|
108
|
+
file_okay=False,
|
|
109
|
+
dir_okay=True,
|
|
110
|
+
readable=True,
|
|
111
|
+
),
|
|
112
|
+
] = None,
|
|
113
|
+
device: Annotated[
|
|
114
|
+
str,
|
|
115
|
+
typer.Option(
|
|
116
|
+
"--device",
|
|
117
|
+
help="Device string for smoke pass (cpu, cuda:0, mps, auto).",
|
|
118
|
+
),
|
|
119
|
+
] = "cpu",
|
|
120
|
+
batch_size: Annotated[
|
|
121
|
+
int,
|
|
122
|
+
typer.Option("--batch-size", help="Synthetic smoke batch size."),
|
|
123
|
+
] = 1,
|
|
124
|
+
seq_len: Annotated[
|
|
125
|
+
int,
|
|
126
|
+
typer.Option("--seq-len", help="Synthetic smoke sequence length."),
|
|
127
|
+
] = 32,
|
|
128
|
+
) -> None:
|
|
129
|
+
"""Run a lightweight forward-pass smoke test with composed config."""
|
|
130
|
+
cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)
|
|
131
|
+
model_cfg = cfg.model
|
|
132
|
+
torch_device = _resolve_cli_device(device)
|
|
133
|
+
model = build_model_from_cfg(model_cfg).to(torch_device)
|
|
134
|
+
model.eval()
|
|
135
|
+
|
|
136
|
+
vocab_size = int(model_cfg.vocab_size)
|
|
137
|
+
tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch_device)
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
outputs = model(tokens)
|
|
140
|
+
if isinstance(outputs, tuple):
|
|
141
|
+
logits = outputs[0]
|
|
142
|
+
else:
|
|
143
|
+
logits = outputs
|
|
144
|
+
typer.echo(
|
|
145
|
+
json.dumps(
|
|
146
|
+
{
|
|
147
|
+
"status": "ok",
|
|
148
|
+
"config_name": config_name,
|
|
149
|
+
"device": str(torch_device),
|
|
150
|
+
"batch_size": batch_size,
|
|
151
|
+
"seq_len": seq_len,
|
|
152
|
+
"logits_shape": list(logits.shape),
|
|
153
|
+
"dtype": str(logits.dtype),
|
|
154
|
+
},
|
|
155
|
+
sort_keys=True,
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@app.command("train")
|
|
161
|
+
def train(
|
|
162
|
+
config_name: Annotated[
|
|
163
|
+
str,
|
|
164
|
+
typer.Option("--config-name", "-c", help="Hydra config name for training."),
|
|
165
|
+
] = "pilot",
|
|
166
|
+
override: Annotated[
|
|
167
|
+
list[str] | None,
|
|
168
|
+
typer.Option("--override", "-O", help="Hydra override(s), may be passed multiple times."),
|
|
169
|
+
] = None,
|
|
170
|
+
config_dir: Annotated[
|
|
171
|
+
Path | None,
|
|
172
|
+
typer.Option(
|
|
173
|
+
"--config-dir",
|
|
174
|
+
help="Optional explicit config directory.",
|
|
175
|
+
exists=True,
|
|
176
|
+
file_okay=False,
|
|
177
|
+
dir_okay=True,
|
|
178
|
+
readable=True,
|
|
179
|
+
),
|
|
180
|
+
] = None,
|
|
181
|
+
device: Annotated[
|
|
182
|
+
str | None,
|
|
183
|
+
typer.Option(
|
|
184
|
+
"--device",
|
|
185
|
+
help="Override cfg.train.device (e.g. cpu, cuda:1, auto).",
|
|
186
|
+
),
|
|
187
|
+
] = None,
|
|
188
|
+
dry_run: Annotated[
|
|
189
|
+
bool,
|
|
190
|
+
typer.Option("--dry-run", help="Print resolved config and exit."),
|
|
191
|
+
] = False,
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Launch a local (single-process) training loop."""
|
|
194
|
+
from .training import run_training_loop
|
|
195
|
+
|
|
196
|
+
cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)
|
|
197
|
+
if device is not None:
|
|
198
|
+
cfg.train.device = device
|
|
199
|
+
if dry_run:
|
|
200
|
+
typer.echo(OmegaConf.to_yaml(cfg))
|
|
201
|
+
return
|
|
202
|
+
train_device = _resolve_cli_device(str(cfg.train.device))
|
|
203
|
+
run_training_loop(cfg, device=train_device, distributed=False, dist_ctx=None)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@app.command("audit")
|
|
207
|
+
def audit(
|
|
208
|
+
config_name: Annotated[
|
|
209
|
+
str,
|
|
210
|
+
typer.Option("--config-name", "-c", help="Hydra config name to audit."),
|
|
211
|
+
] = "pilot_paper_faithful",
|
|
212
|
+
override: Annotated[
|
|
213
|
+
list[str] | None,
|
|
214
|
+
typer.Option("--override", "-O", help="Hydra override(s), may be passed multiple times."),
|
|
215
|
+
] = None,
|
|
216
|
+
config_dir: Annotated[
|
|
217
|
+
Path | None,
|
|
218
|
+
typer.Option(
|
|
219
|
+
"--config-dir",
|
|
220
|
+
help="Optional explicit config directory.",
|
|
221
|
+
exists=True,
|
|
222
|
+
file_okay=False,
|
|
223
|
+
dir_okay=True,
|
|
224
|
+
readable=True,
|
|
225
|
+
),
|
|
226
|
+
] = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Run static architecture checks on a composed config."""
|
|
229
|
+
cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)
|
|
230
|
+
model = build_model_from_cfg(cfg.model)
|
|
231
|
+
has_embed = hasattr(model, "embed")
|
|
232
|
+
has_lm_head = hasattr(model, "lm_head")
|
|
233
|
+
tied_weights = False
|
|
234
|
+
if has_embed and has_lm_head:
|
|
235
|
+
embed = getattr(model, "embed")
|
|
236
|
+
lm_head = getattr(model, "lm_head")
|
|
237
|
+
tied_weights = bool(embed.weight.data_ptr() == lm_head.weight.data_ptr())
|
|
238
|
+
|
|
239
|
+
report = {
|
|
240
|
+
"status": "ok",
|
|
241
|
+
"config_name": config_name,
|
|
242
|
+
"model_type": str(cfg.model.get("type", "hope")),
|
|
243
|
+
"block_variant": str(cfg.model.get("block_variant", "hope_hybrid")),
|
|
244
|
+
"surprise_metric": str(cfg.model.get("surprise_metric", "l2")),
|
|
245
|
+
"surprise_threshold": cfg.model.get("surprise_threshold"),
|
|
246
|
+
"teach_scale": float(cfg.model.get("teach_scale", 1.0)),
|
|
247
|
+
"teach_clip": float(cfg.model.get("teach_clip", 0.0)),
|
|
248
|
+
"freeze_backbone": bool(cfg.model.get("freeze_backbone", False)),
|
|
249
|
+
"has_embed": has_embed,
|
|
250
|
+
"has_lm_head": has_lm_head,
|
|
251
|
+
"lm_tied_to_embedding": tied_weights,
|
|
252
|
+
}
|
|
253
|
+
typer.echo(json.dumps(report, sort_keys=True))
|
nested_learning/cms.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from .levels import LevelSpec, ensure_level_specs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CMSBlock(nn.Module):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
dim: int,
|
|
15
|
+
hidden_multiplier: int = 4,
|
|
16
|
+
activation: str = "gelu",
|
|
17
|
+
grad_clip: float = 1.0,
|
|
18
|
+
use_layernorm: bool = True,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
hidden = dim * hidden_multiplier
|
|
22
|
+
act: nn.Module
|
|
23
|
+
if activation == "relu":
|
|
24
|
+
act = nn.ReLU()
|
|
25
|
+
elif activation == "silu":
|
|
26
|
+
act = nn.SiLU()
|
|
27
|
+
else:
|
|
28
|
+
act = nn.GELU()
|
|
29
|
+
norm: nn.Module = nn.LayerNorm(dim) if use_layernorm else nn.Identity()
|
|
30
|
+
self.net = nn.Sequential(
|
|
31
|
+
norm,
|
|
32
|
+
nn.Linear(dim, hidden),
|
|
33
|
+
act,
|
|
34
|
+
nn.Linear(hidden, dim),
|
|
35
|
+
)
|
|
36
|
+
self.grad_clip = grad_clip
|
|
37
|
+
|
|
38
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
|
39
|
+
delta = self.net(x)
|
|
40
|
+
if self.training and self.grad_clip > 0:
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
norm = delta.norm(dim=-1, keepdim=True)
|
|
43
|
+
scale = torch.clamp(norm / self.grad_clip, min=1.0)
|
|
44
|
+
delta = delta / scale
|
|
45
|
+
return x + delta
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CMS(nn.Module):
|
|
49
|
+
"""Continuum Memory System with multi-frequency updates."""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
*,
|
|
54
|
+
dim: int,
|
|
55
|
+
levels: Sequence[LevelSpec],
|
|
56
|
+
hidden_multiplier: int = 4,
|
|
57
|
+
activation: str = "gelu",
|
|
58
|
+
use_layernorm: bool = True,
|
|
59
|
+
) -> None:
|
|
60
|
+
super().__init__()
|
|
61
|
+
ordered = ensure_level_specs(levels)
|
|
62
|
+
self.level_specs: Sequence[LevelSpec] = tuple(ordered)
|
|
63
|
+
self.blocks = nn.ModuleDict(
|
|
64
|
+
{
|
|
65
|
+
spec.name: CMSBlock(
|
|
66
|
+
dim,
|
|
67
|
+
hidden_multiplier=hidden_multiplier,
|
|
68
|
+
activation=activation,
|
|
69
|
+
grad_clip=1.0,
|
|
70
|
+
use_layernorm=use_layernorm,
|
|
71
|
+
)
|
|
72
|
+
for spec in self.level_specs
|
|
73
|
+
}
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
x: torch.Tensor,
|
|
79
|
+
*,
|
|
80
|
+
return_intermediates: bool = False,
|
|
81
|
+
) -> torch.Tensor | tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
|
82
|
+
current = x
|
|
83
|
+
inputs: Dict[str, torch.Tensor] = {}
|
|
84
|
+
outputs: Dict[str, torch.Tensor] = {}
|
|
85
|
+
for spec in self.level_specs:
|
|
86
|
+
block = self.blocks[spec.name]
|
|
87
|
+
inputs[spec.name] = current
|
|
88
|
+
current = block(current)
|
|
89
|
+
outputs[spec.name] = current
|
|
90
|
+
if return_intermediates:
|
|
91
|
+
return current, inputs, outputs
|
|
92
|
+
return current
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from importlib.resources import as_file, files
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterator
|
|
7
|
+
|
|
8
|
+
from hydra import compose, initialize_config_dir
|
|
9
|
+
from hydra.core.global_hydra import GlobalHydra
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
|
|
12
|
+
from .training import unwrap_config
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def find_repo_root(start: Path | None = None) -> Path | None:
|
|
16
|
+
cursor = (start or Path.cwd()).resolve()
|
|
17
|
+
for candidate in (cursor, *cursor.parents):
|
|
18
|
+
if (candidate / ".git").exists() and (candidate / "configs").exists():
|
|
19
|
+
return candidate
|
|
20
|
+
return None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@contextmanager
|
|
24
|
+
def resolved_config_dir(config_dir: Path | None = None) -> Iterator[Path]:
|
|
25
|
+
if config_dir is not None:
|
|
26
|
+
yield config_dir.resolve()
|
|
27
|
+
return
|
|
28
|
+
|
|
29
|
+
module_path = Path(__file__).resolve()
|
|
30
|
+
repo_config_dir = module_path.parents[2] / "configs"
|
|
31
|
+
if repo_config_dir.exists():
|
|
32
|
+
yield repo_config_dir
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
package_configs = files("nested_learning").joinpath("configs")
|
|
36
|
+
with as_file(package_configs) as pkg_dir:
|
|
37
|
+
yield Path(pkg_dir)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def compose_config(
|
|
41
|
+
config_name: str,
|
|
42
|
+
*,
|
|
43
|
+
overrides: list[str] | None = None,
|
|
44
|
+
config_dir: Path | None = None,
|
|
45
|
+
) -> DictConfig:
|
|
46
|
+
with resolved_config_dir(config_dir) as cfg_dir:
|
|
47
|
+
GlobalHydra.instance().clear()
|
|
48
|
+
with initialize_config_dir(version_base=None, config_dir=str(cfg_dir)):
|
|
49
|
+
cfg = compose(config_name=config_name, overrides=overrides or [])
|
|
50
|
+
return unwrap_config(cfg)
|