nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,42 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ # Explicit paper-defined variant (avoid inheriting repo default `hope_hybrid`).
7
+ block_variant: hope_attention
8
+ # Paper-faithful: treat "surprise" as the (scaled) teach signal itself, without threshold gating.
9
+ surprise_threshold: null
10
+ # Paper updates on the last (possibly partial) chunk; enable flush for non-multiple seq lengths.
11
+ cms_flush_partial_at_end: true
12
+ # Paper: q is non-adaptive and uses a fixed projection.
13
+ self_mod_adaptive_q: false
14
+ # Paper: local causal conv in the HOPE self-mod module.
15
+ self_mod_local_conv_window: 4
16
+
17
+ data:
18
+ # Paper-faithful semantics: CMS/TITAN fast state is per-context; this repo currently treats
19
+ # each *batch* as a single shared context when batch_size>1.
20
+ batch_size: 1
21
+
22
+ train:
23
+ algorithm_mode: two_pass_stopgrad_updates
24
+ # Keep this explicit (instead of inherited) so paper-faithful behavior is visible in one file.
25
+ online_updates: true
26
+ # Paper: re-initialize fast memories per context (sequence).
27
+ use_fast_state: true
28
+ strict_streaming_contract: true
29
+ # Use explicit boundary-token supervision (no overlap approximation).
30
+ online_boundary_targets: true
31
+ # Carry attention state across chunks during online updates.
32
+ online_carry_attention_cache: true
33
+ # Fail fast if DDP would silently disable paper-critical features.
34
+ fail_if_paper_faithful_disabled: true
35
+
36
+ optim:
37
+ # Ensure meta-learning updates include memory module initial states (paper §8.2).
38
+ param_policy: all
39
+
40
+ logging:
41
+ run_name: pilot-paper-faithful
42
+ path: logs/pilot_paper_faithful_metrics.json
@@ -0,0 +1,18 @@
1
+ defaults:
2
+ - /pilot_paper_faithful
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ # Chunk update cadence (paper §8.2): other memories update more often than M_memory.
8
+ self_mod_chunk_size: 8
9
+ self_mod_chunk_size_memory: 64
10
+ self_mod_use_skip: false
11
+
12
+ train:
13
+ checkpoint:
14
+ dir: artifacts/checkpoints/pilot_selfmod_paper_faithful
15
+
16
+ logging:
17
+ run_name: pilot-selfmod-paper-faithful
18
+ path: logs/pilot_selfmod_paper_faithful_metrics.json
@@ -0,0 +1,80 @@
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ output_subdir: null
5
+ job:
6
+ chdir: false
7
+
8
+ model:
9
+ vocab_size: 32000
10
+ dim: 128
11
+ num_layers: 2
12
+ heads: 4
13
+ titan_level:
14
+ name: titan
15
+ update_period: 8
16
+ optimizer_key: titan_opt
17
+ cms_levels:
18
+ - name: cms_fast
19
+ update_period: 1
20
+ optimizer_key: cms_opt
21
+ - name: cms_mid
22
+ update_period: 4
23
+ optimizer_key: cms_opt
24
+ - name: cms_slow
25
+ update_period: 16
26
+ optimizer_key: cms_opt
27
+ optimizers:
28
+ titan_opt:
29
+ type: deep_momentum
30
+ lr: 1.0e-3
31
+ params:
32
+ beta: 0.9
33
+ beta2: 0.999
34
+ cms_opt:
35
+ type: deep_momentum
36
+ lr: 5.0e-4
37
+ params:
38
+ beta: 0.9
39
+ beta2: 0.999
40
+
41
+ data:
42
+ source: synthetic
43
+ vocab_size: 32000
44
+ seq_len: 64
45
+ dataset_size: 1024
46
+ batch_size: 4
47
+ num_workers: 0
48
+
49
+ train:
50
+ strict_streaming_contract: false
51
+ online_updates: true
52
+ online_chunk_size: 0
53
+ online_boundary_targets: false
54
+ online_carry_attention_cache: false
55
+ per_layer_teach_signal: true
56
+ steps: 10
57
+ log_interval: 1
58
+ device: "cpu"
59
+ seed: 1234
60
+ deterministic: true
61
+ mixed_precision:
62
+ enabled: false
63
+ dtype: bf16
64
+ compile:
65
+ enable: false
66
+ checkpoint:
67
+ enable: true
68
+ dir: artifacts/checkpoints/pilot_smoke
69
+ save_interval: 10
70
+ save_last: true
71
+
72
+ optim:
73
+ type: adamw
74
+ lr: 3.0e-4
75
+ fused: false
76
+
77
+ logging:
78
+ enabled: true
79
+ backend: json
80
+ path: logs/pilot_smoke.json
@@ -0,0 +1,105 @@
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ output_subdir: null
5
+ job:
6
+ chdir: false
7
+ model:
8
+ vocab_size: 32000
9
+ dim: 384
10
+ num_layers: 8
11
+ heads: 6
12
+ teach_scale: 0.1
13
+ teach_clip: 5.0
14
+ self_mod_lr: 0.001
15
+ teach_schedule:
16
+ warmup_steps: 2000
17
+ decay_start: 120000
18
+ decay_duration: 20000
19
+ titan_level:
20
+ name: titan
21
+ update_period: 8
22
+ optimizer_key: titan_opt
23
+ cms_levels:
24
+ - name: cms_fast
25
+ update_period: 8
26
+ optimizer_key: cms_opt
27
+ - name: cms_mid
28
+ update_period: 32
29
+ optimizer_key: cms_opt
30
+ - name: cms_slow
31
+ update_period: 128
32
+ optimizer_key: cms_opt
33
+ - name: cms_ultra
34
+ update_period: 512
35
+ optimizer_key: cms_opt
36
+ optimizers:
37
+ titan_opt:
38
+ type: deep_momentum
39
+ lr: 0.0006
40
+ params:
41
+ beta: 0.9
42
+ beta2: 0.999
43
+ cms_opt:
44
+ type: deep_momentum
45
+ lr: 0.0003
46
+ params:
47
+ beta: 0.9
48
+ beta2: 0.999
49
+ cms_hidden_multiplier: 2
50
+ data:
51
+ source: mixture
52
+ seq_len: 1024
53
+ batch_size: 2
54
+ num_workers: 2
55
+ mixture:
56
+ samples_per_epoch: 65536
57
+ seed: 1337
58
+ sources:
59
+ - name: refinedweb
60
+ shards_dir: data/shards/refinedweb_filtered
61
+ weight: 0.4
62
+ - name: wikipedia
63
+ shards_dir: data/shards/wikipedia_filtered
64
+ weight: 0.2
65
+ - name: c4
66
+ shards_dir: data/shards/c4_filtered
67
+ weight: 0.15
68
+ - name: redpajama
69
+ shards_dir: data/shards/redpajama_filtered
70
+ weight: 0.15
71
+ - name: code
72
+ shards_dir: data/shards/code_filtered
73
+ weight: 0.1
74
+ train:
75
+ online_updates: true
76
+ online_chunk_size: 0
77
+ per_layer_teach_signal: true
78
+ steps: 5000
79
+ log_interval: 25
80
+ device: cuda:1
81
+ seed: 1337
82
+ deterministic: false
83
+ mixed_precision:
84
+ enabled: true
85
+ dtype: bf16
86
+ compile:
87
+ enable: false
88
+ mode: max-autotune
89
+ checkpoint:
90
+ enable: true
91
+ dir: artifacts/checkpoints/pilot_cms_sparse
92
+ save_interval: 1000
93
+ save_last: true
94
+ resume_path: null
95
+ resume_tag: null
96
+ optim:
97
+ type: adamw
98
+ lr: 0.00025
99
+ fused: auto
100
+ logging:
101
+ enabled: true
102
+ backend: json
103
+ path: logs/pilot_cms_sparse_metrics.json
104
+ project: nested-learning
105
+ run_name: pilot-cms-sparse
@@ -0,0 +1,49 @@
1
+ model:
2
+ vocab_size: 32000
3
+ dim: 512
4
+ num_layers: 12
5
+ heads: 8
6
+ teach_scale: 0.10
7
+ teach_clip: 5.0
8
+ surprise_threshold: 0.02
9
+ freeze_backbone: false
10
+ qk_l2_norm: true
11
+ local_conv_window: 4
12
+ block_variant: hope_attention
13
+ teach_schedule:
14
+ warmup_steps: 2000
15
+ decay_start: 120000
16
+ decay_duration: 20000
17
+ titan_level:
18
+ name: titan
19
+ update_period: 8
20
+ optimizer_key: titan_opt
21
+ cms_levels:
22
+ - name: cms_fast
23
+ update_period: 1
24
+ optimizer_key: cms_opt
25
+ - name: cms_mid
26
+ update_period: 4
27
+ optimizer_key: cms_opt
28
+ - name: cms_slow
29
+ update_period: 32
30
+ optimizer_key: cms_opt
31
+ - name: cms_ultra
32
+ update_period: 128
33
+ optimizer_key: cms_opt
34
+ optimizers:
35
+ titan_opt:
36
+ type: deep_momentum
37
+ lr: 6.0e-4
38
+ params:
39
+ beta: 0.9
40
+ beta2: 0.999
41
+ variant: nl_l2_precond
42
+ cms_opt:
43
+ type: deep_momentum
44
+ lr: 3.0e-4
45
+ params:
46
+ beta: 0.9
47
+ beta2: 0.999
48
+ variant: nl_l2_precond
49
+
@@ -0,0 +1,49 @@
1
+ model:
2
+ vocab_size: 32000
3
+ dim: 512
4
+ num_layers: 12
5
+ heads: 8
6
+ teach_scale: 0.10
7
+ teach_clip: 5.0
8
+ surprise_threshold: 0.02
9
+ freeze_backbone: false
10
+ qk_l2_norm: true
11
+ local_conv_window: 4
12
+ block_variant: transformer
13
+ teach_schedule:
14
+ warmup_steps: 2000
15
+ decay_start: 120000
16
+ decay_duration: 20000
17
+ titan_level:
18
+ name: titan
19
+ update_period: 8
20
+ optimizer_key: titan_opt
21
+ cms_levels:
22
+ - name: cms_fast
23
+ update_period: 1
24
+ optimizer_key: cms_opt
25
+ - name: cms_mid
26
+ update_period: 4
27
+ optimizer_key: cms_opt
28
+ - name: cms_slow
29
+ update_period: 32
30
+ optimizer_key: cms_opt
31
+ - name: cms_ultra
32
+ update_period: 128
33
+ optimizer_key: cms_opt
34
+ optimizers:
35
+ titan_opt:
36
+ type: deep_momentum
37
+ lr: 6.0e-4
38
+ params:
39
+ beta: 0.9
40
+ beta2: 0.999
41
+ variant: nl_l2_precond
42
+ cms_opt:
43
+ type: deep_momentum
44
+ lr: 3.0e-4
45
+ params:
46
+ beta: 0.9
47
+ beta2: 0.999
48
+ variant: nl_l2_precond
49
+
@@ -0,0 +1,136 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable, List, Sequence
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class ClassificationExample:
9
+ text: str
10
+ label: str
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class LoadedClassificationDataset:
15
+ name: str
16
+ split: str
17
+ examples: List[ClassificationExample]
18
+ label_names: List[str]
19
+
20
+
21
+ def load_hf_classification_dataset(
22
+ dataset: str,
23
+ *,
24
+ split: str,
25
+ text_field: str,
26
+ label_field: str,
27
+ name: str | None = None,
28
+ max_samples: int | None = None,
29
+ ) -> LoadedClassificationDataset:
30
+ """
31
+ Load a HuggingFace `datasets` text classification dataset into a simple in-memory format.
32
+
33
+ This is used by the Phase 4 continual-learning harness (CLINC/Banking/DBpedia).
34
+ """
35
+ try:
36
+ from datasets import load_dataset # type: ignore[import-not-found]
37
+ except Exception as exc: # pragma: no cover
38
+ raise RuntimeError(
39
+ "`datasets` dependency is required for continual classification."
40
+ ) from exc
41
+
42
+ ds = load_dataset(dataset, name=name, split=split)
43
+ features = getattr(ds, "features", None)
44
+ label_names: List[str] = []
45
+ if features is not None and label_field in features:
46
+ feature = features[label_field]
47
+ if getattr(feature, "names", None) is not None:
48
+ label_names = list(feature.names)
49
+
50
+ examples: List[ClassificationExample] = []
51
+ count = 0
52
+ for row in ds:
53
+ if max_samples is not None and count >= max_samples:
54
+ break
55
+ text = str(row[text_field])
56
+ raw_label = row[label_field]
57
+ if isinstance(raw_label, int) and label_names:
58
+ label = label_names[raw_label]
59
+ else:
60
+ label = str(raw_label)
61
+ examples.append(ClassificationExample(text=text, label=label))
62
+ count += 1
63
+
64
+ if not label_names:
65
+ label_names = sorted({ex.label for ex in examples})
66
+
67
+ return LoadedClassificationDataset(
68
+ name=dataset if name is None else f"{dataset}:{name}",
69
+ split=split,
70
+ examples=examples,
71
+ label_names=label_names,
72
+ )
73
+
74
+
75
+ def load_clinc_oos(
76
+ *,
77
+ split: str = "test",
78
+ max_samples: int | None = None,
79
+ ) -> LoadedClassificationDataset:
80
+ # HF dataset: "clinc_oos" with fields {"text", "intent"}.
81
+ return load_hf_classification_dataset(
82
+ "clinc_oos",
83
+ split=split,
84
+ text_field="text",
85
+ label_field="intent",
86
+ max_samples=max_samples,
87
+ )
88
+
89
+
90
+ def load_banking77(
91
+ *,
92
+ split: str = "test",
93
+ max_samples: int | None = None,
94
+ ) -> LoadedClassificationDataset:
95
+ # HF dataset: "banking77" with fields {"text", "label"}.
96
+ return load_hf_classification_dataset(
97
+ "banking77",
98
+ split=split,
99
+ text_field="text",
100
+ label_field="label",
101
+ max_samples=max_samples,
102
+ )
103
+
104
+
105
+ def load_dbpedia14(
106
+ *,
107
+ split: str = "test",
108
+ max_samples: int | None = None,
109
+ ) -> LoadedClassificationDataset:
110
+ # HF dataset: "dbpedia_14" with fields {"content", "label"}.
111
+ return load_hf_classification_dataset(
112
+ "dbpedia_14",
113
+ split=split,
114
+ text_field="content",
115
+ label_field="label",
116
+ max_samples=max_samples,
117
+ )
118
+
119
+
120
+ def unique_labels(examples: Iterable[ClassificationExample]) -> List[str]:
121
+ seen = set()
122
+ ordered: List[str] = []
123
+ for ex in examples:
124
+ if ex.label in seen:
125
+ continue
126
+ seen.add(ex.label)
127
+ ordered.append(ex.label)
128
+ return ordered
129
+
130
+
131
+ def filter_examples_by_labels(
132
+ examples: Sequence[ClassificationExample],
133
+ *,
134
+ allowed: set[str],
135
+ ) -> List[ClassificationExample]:
136
+ return [ex for ex in examples if ex.label in allowed]