nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,46 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ dim: 384
7
+ num_layers: 8
8
+ heads: 6
9
+ titan_level:
10
+ name: titan
11
+ update_period: 8
12
+ optimizer_key: titan_opt
13
+ cms_hidden_multiplier: 2
14
+ cms_levels:
15
+ - name: cms_fast
16
+ update_period: 8
17
+ optimizer_key: cms_opt
18
+ - name: cms_mid
19
+ update_period: 32
20
+ optimizer_key: cms_opt
21
+ - name: cms_slow
22
+ update_period: 128
23
+ optimizer_key: cms_opt
24
+ - name: cms_ultra
25
+ update_period: 512
26
+ optimizer_key: cms_opt
27
+
28
+ data:
29
+ seq_len: 1024
30
+ batch_size: 2
31
+ num_workers: 2
32
+
33
+ train:
34
+ online_updates: true
35
+ online_chunk_size: 0
36
+ per_layer_teach_signal: true
37
+ steps: 5000
38
+ device: "cuda:1"
39
+ checkpoint:
40
+ dir: artifacts/checkpoints/pilot_cms_sparse
41
+ save_interval: 1000
42
+ log_interval: 25
43
+
44
+ logging:
45
+ path: logs/pilot_cms_sparse_metrics.json
46
+ run_name: pilot-cms-sparse
@@ -0,0 +1,24 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ self_mod_chunk_size: 8
8
+ self_mod_chunk_size_memory: 64
9
+
10
+ train:
11
+ online_updates: true
12
+ online_chunk_size: 0
13
+ per_layer_teach_signal: true
14
+ steps: 5000
15
+ device: "cuda:1"
16
+ checkpoint:
17
+ dir: artifacts/checkpoints/pilot_selfmod_chunked_8_64
18
+ save_interval: 1000
19
+
20
+ logging:
21
+ enabled: true
22
+ backend: json
23
+ path: logs/pilot_selfmod_chunked_8_64_metrics.json
24
+ run_name: pilot-selfmod-chunked-8-64
@@ -0,0 +1,23 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ self_mod_momentum: 0.0
8
+
9
+ train:
10
+ online_updates: true
11
+ online_chunk_size: 0
12
+ per_layer_teach_signal: true
13
+ steps: 5000
14
+ device: "cuda:1"
15
+ checkpoint:
16
+ dir: artifacts/checkpoints/pilot_selfmod_momentum_off
17
+ save_interval: 1000
18
+
19
+ logging:
20
+ enabled: true
21
+ backend: json
22
+ path: logs/pilot_selfmod_momentum_off_metrics.json
23
+ run_name: pilot-selfmod-momentum-off
@@ -0,0 +1,23 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ self_mod_momentum: 0.9
8
+
9
+ train:
10
+ online_updates: true
11
+ online_chunk_size: 0
12
+ per_layer_teach_signal: true
13
+ steps: 5000
14
+ device: "cuda:1"
15
+ checkpoint:
16
+ dir: artifacts/checkpoints/pilot_selfmod_momentum_on
17
+ save_interval: 1000
18
+
19
+ logging:
20
+ enabled: true
21
+ backend: json
22
+ path: logs/pilot_selfmod_momentum_on_metrics.json
23
+ run_name: pilot-selfmod-momentum-on
@@ -0,0 +1,23 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ self_mod_use_alpha: false
8
+
9
+ train:
10
+ online_updates: true
11
+ online_chunk_size: 0
12
+ per_layer_teach_signal: true
13
+ steps: 5000
14
+ device: "cuda:1"
15
+ checkpoint:
16
+ dir: artifacts/checkpoints/pilot_selfmod_no_alpha
17
+ save_interval: 1000
18
+
19
+ logging:
20
+ enabled: true
21
+ backend: json
22
+ path: logs/pilot_selfmod_no_alpha_metrics.json
23
+ run_name: pilot-selfmod-no-alpha
@@ -0,0 +1,23 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ cms_levels: []
8
+
9
+ train:
10
+ online_updates: true
11
+ online_chunk_size: 0
12
+ per_layer_teach_signal: true
13
+ steps: 5000
14
+ device: "cuda:1"
15
+ checkpoint:
16
+ dir: artifacts/checkpoints/pilot_selfmod_no_cms
17
+ save_interval: 1000
18
+
19
+ logging:
20
+ enabled: true
21
+ backend: json
22
+ path: logs/pilot_selfmod_no_cms_metrics.json
23
+ run_name: pilot-selfmod-no-cms
@@ -0,0 +1,23 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ self_mod_use_rank1_precond: false
8
+
9
+ train:
10
+ online_updates: true
11
+ online_chunk_size: 0
12
+ per_layer_teach_signal: true
13
+ steps: 5000
14
+ device: "cuda:1"
15
+ checkpoint:
16
+ dir: artifacts/checkpoints/pilot_selfmod_rank1_off
17
+ save_interval: 1000
18
+
19
+ logging:
20
+ enabled: true
21
+ backend: json
22
+ path: logs/pilot_selfmod_rank1_off_metrics.json
23
+ run_name: pilot-selfmod-rank1-off
@@ -0,0 +1,9 @@
1
+ segments:
2
+ - name: refinedweb_2018
3
+ shards_dir: data/shards/refinedweb_sample
4
+ - name: wikipedia_sample
5
+ shards_dir: data/shards/wikipedia_sample
6
+ - name: c4_sample
7
+ shards_dir: data/shards/c4_sample
8
+ - name: redpajama_sample
9
+ shards_dir: data/shards/redpajama_sample
@@ -0,0 +1,14 @@
1
+ name: fineweb_edu_longdoc_filtered_sample
2
+ tokenizer_output_dir: artifacts/tokenizer/fineweb_edu_longdoc
3
+ datasets:
4
+ - name: fineweb_edu_longdoc
5
+ dataset: text
6
+ split: train
7
+ text_column: text
8
+ data_files: data/filtered/fineweb_edu_longdoc_en_sample.txt
9
+ sample_limit: 5000
10
+ seq_len: 4096
11
+ sequences_per_shard: 1024
12
+ output_dir: data/shards/fineweb_edu_longdoc_sample
13
+ max_records: null
14
+
@@ -0,0 +1,14 @@
1
+ name: fineweb_edu_full
2
+ tokenizer_output_dir: artifacts/tokenizer/fineweb_edu
3
+ datasets:
4
+ - name: fineweb_edu
5
+ dataset: HuggingFaceFW/fineweb-edu
6
+ subset: sample-100BT
7
+ split: train
8
+ text_column: text
9
+ sample_limit: 100000
10
+ seq_len: 4096
11
+ sequences_per_shard: 1024
12
+ output_dir: data/shards/fineweb_edu_full
13
+ max_records: null
14
+
@@ -0,0 +1,14 @@
1
+ name: fineweb_edu_sample
2
+ tokenizer_output_dir: artifacts/tokenizer/fineweb_edu
3
+ datasets:
4
+ - name: fineweb_edu
5
+ dataset: HuggingFaceFW/fineweb-edu
6
+ subset: sample-10BT
7
+ split: train
8
+ text_column: text
9
+ sample_limit: 5000
10
+ seq_len: 2048
11
+ sequences_per_shard: 1024
12
+ output_dir: data/shards/fineweb_edu_sample
13
+ max_records: 10000
14
+
@@ -0,0 +1,48 @@
1
+ name: refinedweb_mix_v1
2
+ tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
3
+ datasets:
4
+ - name: refinedweb
5
+ dataset: text
6
+ split: train
7
+ text_column: text
8
+ data_files: data/filtered/refinedweb_en_full.txt
9
+ seq_len: 2048
10
+ sequences_per_shard: 2048
11
+ output_dir: data/shards/refinedweb
12
+ max_records: null
13
+ - name: books
14
+ dataset: text
15
+ split: train
16
+ text_column: text
17
+ data_files: data/filtered/wikipedia_en_full.txt
18
+ seq_len: 2048
19
+ sequences_per_shard: 2048
20
+ output_dir: data/shards/wikipedia
21
+ max_records: null
22
+ - name: c4
23
+ dataset: text
24
+ split: train
25
+ text_column: text
26
+ data_files: data/filtered/c4_en_full.txt
27
+ seq_len: 2048
28
+ sequences_per_shard: 2048
29
+ output_dir: data/shards/c4
30
+ max_records: null
31
+ - name: redpajama
32
+ dataset: text
33
+ split: train
34
+ text_column: text
35
+ data_files: data/filtered/redpajama_en_full.txt
36
+ seq_len: 2048
37
+ sequences_per_shard: 2048
38
+ output_dir: data/shards/redpajama
39
+ max_records: null
40
+ - name: code
41
+ dataset: text
42
+ split: train
43
+ text_column: text
44
+ data_files: data/filtered/code_en_full.txt
45
+ seq_len: 2048
46
+ sequences_per_shard: 2048
47
+ output_dir: data/shards/code
48
+ max_records: null
@@ -0,0 +1,48 @@
1
+ name: refinedweb_mix_filtered
2
+ tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
3
+ datasets:
4
+ - name: refinedweb
5
+ dataset: text
6
+ split: train
7
+ text_column: text
8
+ data_files: data/filtered/refinedweb_en_sample.txt
9
+ seq_len: 512
10
+ sequences_per_shard: 512
11
+ output_dir: data/shards/refinedweb_filtered
12
+ max_records: null
13
+ - name: wikipedia
14
+ dataset: text
15
+ split: train
16
+ text_column: text
17
+ data_files: data/filtered/wikipedia_en_sample.txt
18
+ seq_len: 512
19
+ sequences_per_shard: 512
20
+ output_dir: data/shards/wikipedia_filtered
21
+ max_records: null
22
+ - name: c4
23
+ dataset: text
24
+ split: train
25
+ text_column: text
26
+ data_files: data/filtered/c4_en_sample.txt
27
+ seq_len: 512
28
+ sequences_per_shard: 512
29
+ output_dir: data/shards/c4_filtered
30
+ max_records: null
31
+ - name: redpajama
32
+ dataset: text
33
+ split: train
34
+ text_column: text
35
+ data_files: data/filtered/redpajama_en_sample.txt
36
+ seq_len: 512
37
+ sequences_per_shard: 512
38
+ output_dir: data/shards/redpajama_filtered
39
+ max_records: null
40
+ - name: code
41
+ dataset: text
42
+ split: train
43
+ text_column: text
44
+ data_files: data/filtered/code_en_sample.txt
45
+ seq_len: 512
46
+ sequences_per_shard: 512
47
+ output_dir: data/shards/code_filtered
48
+ max_records: null
@@ -0,0 +1,48 @@
1
+ name: refinedweb_mix_full
2
+ tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
3
+ datasets:
4
+ - name: refinedweb
5
+ dataset: text
6
+ split: train
7
+ text_column: text
8
+ data_files: data/filtered/refinedweb_en_full.txt
9
+ seq_len: 2048
10
+ sequences_per_shard: 1024
11
+ output_dir: data/shards/refinedweb_full
12
+ max_records: null
13
+ - name: wikipedia
14
+ dataset: text
15
+ split: train
16
+ text_column: text
17
+ data_files: data/filtered/wikipedia_en_full.txt
18
+ seq_len: 2048
19
+ sequences_per_shard: 1024
20
+ output_dir: data/shards/wikipedia_full
21
+ max_records: null
22
+ - name: c4
23
+ dataset: text
24
+ split: train
25
+ text_column: text
26
+ data_files: data/filtered/c4_en_full.txt
27
+ seq_len: 2048
28
+ sequences_per_shard: 1024
29
+ output_dir: data/shards/c4_full
30
+ max_records: null
31
+ - name: redpajama
32
+ dataset: text
33
+ split: train
34
+ text_column: text
35
+ data_files: data/filtered/redpajama_en_full.txt
36
+ seq_len: 2048
37
+ sequences_per_shard: 1024
38
+ output_dir: data/shards/redpajama_full
39
+ max_records: null
40
+ - name: code
41
+ dataset: text
42
+ split: train
43
+ text_column: text
44
+ data_files: data/filtered/code_en_full.txt
45
+ seq_len: 2048
46
+ sequences_per_shard: 1024
47
+ output_dir: data/shards/code_full
48
+ max_records: null
@@ -0,0 +1,51 @@
1
+ name: refinedweb_mix_sample
2
+ tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
3
+ datasets:
4
+ - name: refinedweb
5
+ dataset: HuggingFaceFW/fineweb
6
+ subset: sample-10BT
7
+ split: train
8
+ text_column: text
9
+ sample_limit: 5000
10
+ seq_len: 512
11
+ sequences_per_shard: 512
12
+ output_dir: data/shards/refinedweb_sample
13
+ max_records: 10000
14
+ - name: books
15
+ dataset: wikimedia/wikipedia
16
+ subset: 20231101.en
17
+ split: train
18
+ text_column: text
19
+ sample_limit: 2000
20
+ seq_len: 512
21
+ sequences_per_shard: 512
22
+ output_dir: data/shards/wikipedia_sample
23
+ max_records: 5000
24
+ - name: c4
25
+ dataset: allenai/c4
26
+ subset: en
27
+ split: train
28
+ text_column: text
29
+ sample_limit: 2000
30
+ seq_len: 512
31
+ sequences_per_shard: 512
32
+ output_dir: data/shards/c4_sample
33
+ max_records: 4000
34
+ - name: redpajama
35
+ dataset: cerebras/SlimPajama-627B
36
+ split: train
37
+ text_column: text
38
+ sample_limit: 2000
39
+ seq_len: 512
40
+ sequences_per_shard: 512
41
+ output_dir: data/shards/redpajama_sample
42
+ max_records: 4000
43
+ - name: code
44
+ dataset: codeparrot/codeparrot-clean-train
45
+ split: train
46
+ text_column: content
47
+ sample_limit: 2000
48
+ seq_len: 512
49
+ sequences_per_shard: 512
50
+ output_dir: data/shards/code_sample
51
+ max_records: 4000
@@ -0,0 +1,25 @@
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_batch_size": 64,
6
+ "gradient_accumulation_steps": 1,
7
+ "zero_optimization": {
8
+ "stage": 3,
9
+ "reduce_bucket_size": 50000000,
10
+ "stage3_param_persistence_threshold": 100000,
11
+ "stage3_prefetch_bucket_size": 50000000
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": 0.0002,
17
+ "betas": [
18
+ 0.9,
19
+ 0.95
20
+ ],
21
+ "eps": 1e-08,
22
+ "weight_decay": 0.01
23
+ }
24
+ }
25
+ }
@@ -0,0 +1,118 @@
1
+ defaults:
2
+ - _self_
3
+
4
+ hydra:
5
+ run:
6
+ dir: .
7
+ output_subdir: null
8
+ job:
9
+ chdir: false
10
+
11
+ model:
12
+ vocab_size: 32000
13
+ dim: 1024
14
+ num_layers: 24
15
+ heads: 16
16
+ surprise_threshold: null
17
+ freeze_backbone: false
18
+ titan_level:
19
+ name: titan
20
+ update_period: 16
21
+ optimizer_key: titan_opt
22
+ cms_levels:
23
+ - name: cms_fast
24
+ update_period: 1
25
+ optimizer_key: cms_opt
26
+ - name: cms_mid
27
+ update_period: 4
28
+ optimizer_key: cms_opt
29
+ - name: cms_slow
30
+ update_period: 32
31
+ optimizer_key: cms_opt
32
+ - name: cms_ultra
33
+ update_period: 128
34
+ optimizer_key: cms_opt
35
+ optimizers:
36
+ titan_opt:
37
+ type: deep_momentum
38
+ lr: 8.0e-4
39
+ params:
40
+ beta: 0.9
41
+ beta2: 0.999
42
+ variant: nl_l2_precond
43
+ cms_opt:
44
+ type: deep_momentum
45
+ lr: 4.0e-4
46
+ params:
47
+ beta: 0.9
48
+ beta2: 0.999
49
+ variant: nl_l2_precond
50
+
51
+ data:
52
+ source: mixture
53
+ batch_size: 16
54
+ num_workers: 4
55
+ mixture:
56
+ samples_per_epoch: 8192
57
+ seed: 42
58
+ sources:
59
+ - name: refinedweb
60
+ shards_dir: data/shards/refinedweb_full
61
+ weight: 0.4
62
+ - name: wikipedia
63
+ shards_dir: data/shards/wikipedia_full
64
+ weight: 0.2
65
+ - name: c4
66
+ shards_dir: data/shards/c4_full
67
+ weight: 0.15
68
+ - name: redpajama
69
+ shards_dir: data/shards/redpajama_full
70
+ weight: 0.15
71
+ - name: code
72
+ shards_dir: data/shards/code_full
73
+ weight: 0.1
74
+
75
+ train:
76
+ strict_streaming_contract: false
77
+ online_updates: true
78
+ online_chunk_size: 0
79
+ online_boundary_targets: false
80
+ online_carry_attention_cache: false
81
+ per_layer_teach_signal: true
82
+ steps: 100
83
+ log_interval: 10
84
+ device: "cuda:1"
85
+ seed: 808
86
+ deterministic: false
87
+ step_offset: 0
88
+ mixed_precision:
89
+ enabled: true
90
+ dtype: bf16
91
+ compile:
92
+ enable: true
93
+ mode: max-autotune
94
+ fsdp:
95
+ auto_wrap_min_params: 2000000
96
+ cpu_offload: false
97
+ checkpoint:
98
+ enable: true
99
+ dir: checkpoints/mid
100
+ save_interval: 50
101
+ resume_path: null
102
+ resume_tag: null
103
+
104
+ optim:
105
+ type: muon
106
+ lr: 2.0e-4
107
+ weight_decay: 0.02
108
+ momentum: 0.95
109
+ betas:
110
+ - 0.9
111
+ - 0.999
112
+
113
+ logging:
114
+ enabled: false
115
+ backend: wandb
116
+ project: nested-learning
117
+ run_name: mid-${now:%Y%m%d%H%M%S}
118
+ path: logs/mid_metrics.json
@@ -0,0 +1,47 @@
1
+ defaults:
2
+ - mid
3
+ - _self_
4
+
5
+ model:
6
+ gradient_checkpointing: true
7
+
8
+ data:
9
+ batch_size: 8 # per-rank micro-batch for 2× RTX 6000 Ada
10
+ num_workers: 6
11
+
12
+ train:
13
+ strict_streaming_contract: false
14
+ online_updates: true
15
+ online_chunk_size: 0
16
+ online_boundary_targets: false
17
+ online_carry_attention_cache: false
18
+ per_layer_teach_signal: true
19
+ steps: 250000
20
+ log_interval: 20
21
+ device: "cuda"
22
+ mixed_precision:
23
+ enabled: true
24
+ dtype: bf16
25
+ compile:
26
+ enable: false
27
+ fsdp:
28
+ auto_wrap_min_params: 2000000
29
+ cpu_offload: false
30
+ checkpoint:
31
+ enable: true
32
+ dir: artifacts/checkpoints/mid_fsdp
33
+ save_interval: 1000
34
+ resume_path: null
35
+ resume_tag: null
36
+
37
+ optim:
38
+ type: muon
39
+ lr: 2.0e-4
40
+ weight_decay: 0.01
41
+
42
+ logging:
43
+ enabled: true
44
+ backend: wandb
45
+ project: nested-learning
46
+ run_name: hope-mid-fsdp-${now:%Y%m%d%H%M%S}
47
+ path: logs/mid_fsdp_metrics.json
@@ -0,0 +1,2 @@
1
+ defaults:
2
+ - /pilot
@@ -0,0 +1,9 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_attention
7
+ qk_l2_norm: true
8
+ local_conv_window: 4
9
+
@@ -0,0 +1,20 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: hope_selfmod
7
+ # Chunk update cadence (paper §8.2): other memories update more often than M_memory.
8
+ self_mod_chunk_size: 8
9
+ self_mod_chunk_size_memory: 64
10
+
11
+ train:
12
+ online_updates: true
13
+ online_chunk_size: 0
14
+ per_layer_teach_signal: true
15
+ checkpoint:
16
+ dir: artifacts/checkpoints/pilot_selfmod
17
+
18
+ logging:
19
+ run_name: pilot-selfmod
20
+ path: logs/pilot_selfmod_metrics.json
@@ -0,0 +1,9 @@
1
+ defaults:
2
+ - /pilot
3
+ - _self_
4
+
5
+ model:
6
+ block_variant: transformer
7
+ qk_l2_norm: true
8
+ local_conv_window: 4
9
+