nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,145 @@
1
+ defaults:
2
+ - _self_
3
+
4
+ hydra:
5
+ run:
6
+ dir: .
7
+ output_subdir: null
8
+ job:
9
+ chdir: false
10
+
11
+ model:
12
+ vocab_size: 32000
13
+ dim: 1536
14
+ num_layers: 32
15
+ heads: 24
16
+ surprise_threshold: null
17
+ freeze_backbone: false
18
+ titan_level:
19
+ name: titan
20
+ update_period: 32
21
+ optimizer_key: titan_opt
22
+ cms_levels:
23
+ - name: cms_fast
24
+ update_period: 1
25
+ optimizer_key: cms_fast_opt
26
+ - name: cms_mid
27
+ update_period: 4
28
+ optimizer_key: cms_mid_opt
29
+ - name: cms_slow
30
+ update_period: 32
31
+ optimizer_key: cms_slow_opt
32
+ - name: cms_ultra
33
+ update_period: 128
34
+ optimizer_key: cms_slow_opt
35
+ - name: cms_anchor
36
+ update_period: 512
37
+ optimizer_key: cms_anchor_opt
38
+ optimizers:
39
+ titan_opt:
40
+ type: deep_momentum
41
+ lr: 6.0e-4
42
+ params:
43
+ beta: 0.9
44
+ beta2: 0.999
45
+ variant: nl_l2_precond
46
+ cms_fast_opt:
47
+ type: deep_momentum
48
+ lr: 3.0e-4
49
+ params:
50
+ beta: 0.9
51
+ beta2: 0.999
52
+ variant: nl_l2_precond
53
+ cms_mid_opt:
54
+ type: deep_momentum
55
+ lr: 2.5e-4
56
+ params:
57
+ beta: 0.9
58
+ beta2: 0.999
59
+ variant: nl_l2_precond
60
+ cms_slow_opt:
61
+ type: deep_momentum
62
+ lr: 2.0e-4
63
+ params:
64
+ beta: 0.9
65
+ beta2: 0.999
66
+ variant: nl_l2_precond
67
+ cms_anchor_opt:
68
+ type: deep_momentum
69
+ lr: 1.5e-4
70
+ params:
71
+ beta: 0.9
72
+ beta2: 0.999
73
+ variant: nl_l2_precond
74
+
75
+ data:
76
+ source: mixture
77
+ batch_size: 32
78
+ num_workers: 8
79
+ mixture:
80
+ samples_per_epoch: 32768
81
+ seed: 123
82
+ sources:
83
+ - name: refinedweb
84
+ shards_dir: data/shards/refinedweb_filtered
85
+ weight: 0.35
86
+ - name: wikipedia
87
+ shards_dir: data/shards/wikipedia_filtered
88
+ weight: 0.2
89
+ - name: c4
90
+ shards_dir: data/shards/c4_filtered
91
+ weight: 0.15
92
+ - name: redpajama
93
+ shards_dir: data/shards/redpajama_filtered
94
+ weight: 0.2
95
+ - name: code
96
+ shards_dir: data/shards/code_filtered
97
+ weight: 0.1
98
+
99
+ train:
100
+ strict_streaming_contract: false
101
+ online_updates: true
102
+ online_chunk_size: 0
103
+ online_boundary_targets: false
104
+ online_carry_attention_cache: false
105
+ per_layer_teach_signal: true
106
+ steps: 200
107
+ log_interval: 10
108
+ device: "cuda:1"
109
+ seed: 9001
110
+ deterministic: false
111
+ step_offset: 0
112
+ mixed_precision:
113
+ enabled: true
114
+ dtype: bf16
115
+ compile:
116
+ enable: true
117
+ mode: max-autotune
118
+ fsdp:
119
+ auto_wrap_min_params: 2000000
120
+ cpu_offload: false
121
+ checkpoint:
122
+ enable: true
123
+ dir: checkpoints/target
124
+ save_interval: 100
125
+ resume_path: null
126
+ resume_tag: null
127
+
128
+ optim:
129
+ type: muon
130
+ lr: 1.5e-4
131
+ weight_decay: 0.02
132
+ momentum: 0.95
133
+ betas:
134
+ - 0.9
135
+ - 0.999
136
+
137
+ logging:
138
+ enabled: false
139
+ backend: wandb
140
+ project: nested-learning
141
+ run_name: target-${now:%Y%m%d%H%M%S}
142
+ path: logs/target_metrics.json
143
+
144
+ deepspeed:
145
+ config: configs/deepspeed/zero3.json
@@ -0,0 +1,47 @@
1
+ defaults:
2
+ - target
3
+ - _self_
4
+
5
+ model:
6
+ gradient_checkpointing: true
7
+
8
+ data:
9
+ batch_size: 4 # per-rank micro-batch
10
+ num_workers: 8
11
+
12
+ train:
13
+ strict_streaming_contract: false
14
+ online_updates: true
15
+ online_chunk_size: 0
16
+ online_boundary_targets: false
17
+ online_carry_attention_cache: false
18
+ per_layer_teach_signal: true
19
+ steps: 300000
20
+ log_interval: 20
21
+ device: "cuda"
22
+ mixed_precision:
23
+ enabled: true
24
+ dtype: bf16
25
+ compile:
26
+ enable: false
27
+ fsdp:
28
+ auto_wrap_min_params: 2500000
29
+ cpu_offload: false
30
+ checkpoint:
31
+ enable: true
32
+ dir: artifacts/checkpoints/target_fsdp
33
+ save_interval: 1000
34
+ resume_path: null
35
+ resume_tag: null
36
+
37
+ optim:
38
+ type: muon
39
+ lr: 1.5e-4
40
+ weight_decay: 0.01
41
+
42
+ logging:
43
+ enabled: true
44
+ backend: wandb
45
+ project: nested-learning
46
+ run_name: hope-target-fsdp-${now:%Y%m%d%H%M%S}
47
+ path: logs/target_fsdp_metrics.json
@@ -0,0 +1,99 @@
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ output_subdir: null
5
+ job:
6
+ chdir: false
7
+
8
+ model:
9
+ vocab_size: 32000
10
+ dim: 256
11
+ num_layers: 4
12
+ heads: 8
13
+ titan_level:
14
+ name: titan
15
+ update_period: 16
16
+ optimizer_key: titan_opt
17
+ cms_levels:
18
+ - name: cms_fast
19
+ update_period: 1
20
+ optimizer_key: cms_opt
21
+ - name: cms_mid
22
+ update_period: 4
23
+ optimizer_key: cms_opt
24
+ - name: cms_slow
25
+ update_period: 16
26
+ optimizer_key: cms_opt
27
+ - name: cms_ultra
28
+ update_period: 64
29
+ optimizer_key: cms_opt
30
+ optimizers:
31
+ titan_opt:
32
+ type: deep_momentum
33
+ lr: 8.0e-4
34
+ params:
35
+ beta: 0.9
36
+ beta2: 0.999
37
+ cms_opt:
38
+ type: deep_momentum
39
+ lr: 4.0e-4
40
+ params:
41
+ beta: 0.9
42
+ beta2: 0.999
43
+
44
+ data:
45
+ source: mixture
46
+ batch_size: 4
47
+ num_workers: 0
48
+ mixture:
49
+ samples_per_epoch: 128
50
+ seed: 0
51
+ sources:
52
+ - name: refinedweb
53
+ shards_dir: data/shards/refinedweb_filtered
54
+ weight: 0.4
55
+ - name: wikipedia
56
+ shards_dir: data/shards/wikipedia_filtered
57
+ weight: 0.2
58
+ - name: c4
59
+ shards_dir: data/shards/c4_filtered
60
+ weight: 0.15
61
+ - name: redpajama
62
+ shards_dir: data/shards/redpajama_filtered
63
+ weight: 0.15
64
+ - name: code
65
+ shards_dir: data/shards/code_filtered
66
+ weight: 0.1
67
+
68
+ train:
69
+ strict_streaming_contract: false
70
+ online_updates: true
71
+ online_chunk_size: 0
72
+ online_boundary_targets: false
73
+ online_carry_attention_cache: false
74
+ per_layer_teach_signal: true
75
+ steps: 10
76
+ log_interval: 1
77
+ device: "cpu"
78
+ seed: 2024
79
+ deterministic: true
80
+ mixed_precision:
81
+ enabled: false
82
+ dtype: bf16
83
+ compile:
84
+ enable: false
85
+ checkpoint:
86
+ enable: true
87
+ dir: artifacts/checkpoints/mid_smoke
88
+ save_interval: 10
89
+ save_last: true
90
+
91
+ optim:
92
+ type: adamw
93
+ lr: 2.0e-4
94
+ fused: false
95
+
96
+ logging:
97
+ enabled: true
98
+ backend: json
99
+ path: logs/mid_smoke.json
@@ -0,0 +1,110 @@
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ output_subdir: null
5
+ job:
6
+ chdir: false
7
+
8
+ model:
9
+ vocab_size: 32000
10
+ dim: 768
11
+ num_layers: 18
12
+ heads: 12
13
+ teach_scale: 0.05
14
+ teach_clip: 5.0
15
+ teach_schedule:
16
+ warmup_steps: 20
17
+ decay_start: 80
18
+ decay_duration: 40
19
+ titan_level:
20
+ name: titan
21
+ update_period: 16
22
+ optimizer_key: titan_opt
23
+ cms_levels:
24
+ - name: cms_fast
25
+ update_period: 1
26
+ optimizer_key: cms_opt
27
+ - name: cms_mid
28
+ update_period: 4
29
+ optimizer_key: cms_opt
30
+ - name: cms_slow
31
+ update_period: 32
32
+ optimizer_key: cms_opt
33
+ - name: cms_ultra
34
+ update_period: 128
35
+ optimizer_key: cms_opt
36
+ optimizers:
37
+ titan_opt:
38
+ type: deep_momentum
39
+ lr: 8.0e-4
40
+ params:
41
+ beta: 0.9
42
+ beta2: 0.999
43
+ cms_opt:
44
+ type: deep_momentum
45
+ lr: 4.0e-4
46
+ params:
47
+ beta: 0.9
48
+ beta2: 0.999
49
+
50
+ data:
51
+ source: mixture
52
+ batch_size: 8
53
+ num_workers: 2
54
+ mixture:
55
+ samples_per_epoch: 1024
56
+ seed: 42
57
+ sources:
58
+ - name: refinedweb
59
+ shards_dir: data/shards/refinedweb_full
60
+ weight: 0.4
61
+ - name: wikipedia
62
+ shards_dir: data/shards/wikipedia_full
63
+ weight: 0.2
64
+ - name: c4
65
+ shards_dir: data/shards/c4_full
66
+ weight: 0.15
67
+ - name: redpajama
68
+ shards_dir: data/shards/redpajama_full
69
+ weight: 0.15
70
+ - name: code
71
+ shards_dir: data/shards/code_full
72
+ weight: 0.1
73
+
74
+ train:
75
+ strict_streaming_contract: false
76
+ online_updates: true
77
+ online_chunk_size: 0
78
+ online_boundary_targets: false
79
+ online_carry_attention_cache: false
80
+ per_layer_teach_signal: true
81
+ steps: 100
82
+ log_interval: 10
83
+ device: "cuda"
84
+ seed: 3401
85
+ deterministic: false
86
+ mixed_precision:
87
+ enabled: true
88
+ dtype: bf16
89
+ compile:
90
+ enable: true
91
+ mode: max-autotune
92
+ fsdp:
93
+ auto_wrap_min_params: 2000000
94
+ cpu_offload: false
95
+ checkpoint:
96
+ enable: true
97
+ dir: artifacts/checkpoints/mid_stage2
98
+ save_interval: 100
99
+ resume_path: null
100
+ resume_tag: null
101
+
102
+ optim:
103
+ type: adamw
104
+ lr: 3.0e-5
105
+ fused: auto
106
+
107
+ logging:
108
+ enabled: true
109
+ backend: json
110
+ path: logs/mid_stage2.json
@@ -0,0 +1,102 @@
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ output_subdir: null
5
+ job:
6
+ chdir: false
7
+
8
+ model:
9
+ vocab_size: 32000
10
+ dim: 512
11
+ num_layers: 12
12
+ heads: 8
13
+ teach_scale: 0.2
14
+ teach_clip: 2.0
15
+ titan_level:
16
+ name: titan
17
+ update_period: 16
18
+ optimizer_key: titan_opt
19
+ cms_levels:
20
+ - name: cms_fast
21
+ update_period: 1
22
+ optimizer_key: cms_opt
23
+ - name: cms_mid
24
+ update_period: 4
25
+ optimizer_key: cms_opt
26
+ - name: cms_slow
27
+ update_period: 16
28
+ optimizer_key: cms_opt
29
+ optimizers:
30
+ titan_opt:
31
+ type: deep_momentum
32
+ lr: 6.0e-4
33
+ params:
34
+ beta: 0.9
35
+ beta2: 0.999
36
+ cms_opt:
37
+ type: deep_momentum
38
+ lr: 3.0e-4
39
+ params:
40
+ beta: 0.9
41
+ beta2: 0.999
42
+
43
+ data:
44
+ source: mixture
45
+ batch_size: 8
46
+ num_workers: 2
47
+ mixture:
48
+ samples_per_epoch: 512
49
+ seed: 0
50
+ sources:
51
+ - name: refinedweb
52
+ shards_dir: data/shards/refinedweb_filtered
53
+ weight: 0.4
54
+ - name: wikipedia
55
+ shards_dir: data/shards/wikipedia_filtered
56
+ weight: 0.2
57
+ - name: c4
58
+ shards_dir: data/shards/c4_filtered
59
+ weight: 0.15
60
+ - name: redpajama
61
+ shards_dir: data/shards/redpajama_filtered
62
+ weight: 0.15
63
+ - name: code
64
+ shards_dir: data/shards/code_filtered
65
+ weight: 0.1
66
+
67
+ train:
68
+ strict_streaming_contract: false
69
+ online_updates: true
70
+ online_chunk_size: 0
71
+ online_boundary_targets: false
72
+ online_carry_attention_cache: false
73
+ per_layer_teach_signal: true
74
+ steps: 60
75
+ log_interval: 5
76
+ device: "cuda"
77
+ seed: 777
78
+ deterministic: false
79
+ mixed_precision:
80
+ enabled: true
81
+ dtype: bf16
82
+ compile:
83
+ enable: false
84
+ fsdp:
85
+ auto_wrap_min_params: 1500000
86
+ cpu_offload: false
87
+ checkpoint:
88
+ enable: true
89
+ dir: artifacts/checkpoints/mid_stage2_smoke
90
+ save_interval: 60
91
+ resume_path: null
92
+ resume_tag: null
93
+
94
+ optim:
95
+ type: adamw
96
+ lr: 1.0e-4
97
+ fused: auto
98
+
99
+ logging:
100
+ enabled: true
101
+ backend: json
102
+ path: logs/mid_stage2_smoke.json
@@ -0,0 +1,92 @@
1
+ hydra:
2
+ run:
3
+ dir: .
4
+ output_subdir: null
5
+ job:
6
+ chdir: false
7
+
8
+ model:
9
+ type: titan
10
+ vocab_size: 32000
11
+ dim: 768
12
+ num_layers: 18
13
+ heads: 12
14
+ surprise_threshold: 0.02
15
+ freeze_backbone: false
16
+ titan_level:
17
+ name: titan
18
+ update_period: 16
19
+ optimizer_key: titan_opt
20
+ optimizers:
21
+ titan_opt:
22
+ type: deep_momentum
23
+ lr: 8.0e-4
24
+ params:
25
+ beta: 0.9
26
+ beta2: 0.999
27
+ teach_scale: 0.10
28
+ teach_clip: 4.0
29
+ teach_schedule:
30
+ warmup_steps: 60
31
+ decay_start: 140
32
+ decay_duration: 80
33
+
34
+ data:
35
+ source: mixture
36
+ batch_size: 4
37
+ num_workers: 2
38
+ mixture:
39
+ samples_per_epoch: 1024
40
+ seed: 42
41
+ sources:
42
+ - name: refinedweb
43
+ shards_dir: data/shards/refinedweb_full
44
+ weight: 0.4
45
+ - name: wikipedia
46
+ shards_dir: data/shards/wikipedia_full
47
+ weight: 0.2
48
+ - name: c4
49
+ shards_dir: data/shards/c4_full
50
+ weight: 0.15
51
+ - name: redpajama
52
+ shards_dir: data/shards/redpajama_full
53
+ weight: 0.15
54
+ - name: code
55
+ shards_dir: data/shards/code_full
56
+ weight: 0.1
57
+
58
+ train:
59
+ strict_streaming_contract: false
60
+ online_updates: true
61
+ online_chunk_size: 0
62
+ online_boundary_targets: false
63
+ online_carry_attention_cache: false
64
+ per_layer_teach_signal: true
65
+ steps: 220
66
+ log_interval: 20
67
+ device: "cuda:1"
68
+ seed: 451
69
+ deterministic: false
70
+ step_offset: 0
71
+ mixed_precision:
72
+ enabled: true
73
+ dtype: bf16
74
+ compile:
75
+ enable: false
76
+ checkpoint:
77
+ enable: true
78
+ dir: artifacts/checkpoints/mid_titan_baseline
79
+ save_interval: 100
80
+ resume_path: null
81
+ resume_tag: null
82
+
83
+ optim:
84
+ type: adamw
85
+ lr: 1.0e-5
86
+ fused: auto
87
+
88
+ logging:
89
+ enabled: true
90
+ backend: json
91
+ path: logs/mid_titan_baseline.json
92
+ run_name: mid_titan_baseline
@@ -0,0 +1,127 @@
1
+ defaults:
2
+ - _self_
3
+
4
+ hydra:
5
+ run:
6
+ dir: .
7
+ output_subdir: null
8
+ job:
9
+ chdir: false
10
+
11
+ model:
12
+ vocab_size: 32000
13
+ dim: 512
14
+ num_layers: 12
15
+ heads: 8
16
+ teach_scale: 0.10
17
+ teach_clip: 5.0
18
+ surprise_threshold: 0.02
19
+ freeze_backbone: false
20
+ self_mod_lr: 0.001
21
+ teach_schedule:
22
+ warmup_steps: 2000
23
+ decay_start: 120000
24
+ decay_duration: 20000
25
+ titan_level:
26
+ name: titan
27
+ update_period: 8
28
+ optimizer_key: titan_opt
29
+ cms_levels:
30
+ - name: cms_fast
31
+ update_period: 1
32
+ optimizer_key: cms_opt
33
+ - name: cms_mid
34
+ update_period: 4
35
+ optimizer_key: cms_opt
36
+ - name: cms_slow
37
+ update_period: 32
38
+ optimizer_key: cms_opt
39
+ - name: cms_ultra
40
+ update_period: 128
41
+ optimizer_key: cms_opt
42
+ optimizers:
43
+ titan_opt:
44
+ type: deep_momentum
45
+ lr: 6.0e-4
46
+ params:
47
+ beta: 0.9
48
+ beta2: 0.999
49
+ # Best-effort paper mapping: rank-1 context projection preconditioner.
50
+ variant: nl_l2_precond
51
+ cms_opt:
52
+ type: deep_momentum
53
+ lr: 3.0e-4
54
+ params:
55
+ beta: 0.9
56
+ beta2: 0.999
57
+ # Best-effort paper mapping: rank-1 context projection preconditioner.
58
+ variant: nl_l2_precond
59
+
60
+ data:
61
+ source: mixture
62
+ seq_len: 2048
63
+ batch_size: 6
64
+ num_workers: 4
65
+ mixture:
66
+ samples_per_epoch: 65536
67
+ seed: 1337
68
+ sources:
69
+ - name: refinedweb
70
+ shards_dir: data/shards/refinedweb_filtered
71
+ weight: 0.4
72
+ - name: wikipedia
73
+ shards_dir: data/shards/wikipedia_filtered
74
+ weight: 0.2
75
+ - name: c4
76
+ shards_dir: data/shards/c4_filtered
77
+ weight: 0.15
78
+ - name: redpajama
79
+ shards_dir: data/shards/redpajama_filtered
80
+ weight: 0.15
81
+ - name: code
82
+ shards_dir: data/shards/code_filtered
83
+ weight: 0.1
84
+
85
+ train:
86
+ algorithm_mode: two_pass_stopgrad_updates
87
+ strict_streaming_contract: false
88
+ online_updates: true
89
+ online_chunk_size: 0
90
+ online_boundary_targets: false
91
+ online_carry_attention_cache: false
92
+ per_layer_teach_signal: true
93
+ steps: 246667
94
+ log_interval: 50
95
+ device: "cuda:1"
96
+ seed: 1337
97
+ deterministic: false
98
+ step_offset: 0
99
+ mixed_precision:
100
+ enabled: true
101
+ dtype: bf16
102
+ compile:
103
+ enable: false
104
+ mode: max-autotune
105
+ checkpoint:
106
+ enable: true
107
+ dir: artifacts/checkpoints/pilot
108
+ save_interval: 1000
109
+ save_last: true
110
+ resume_path: null
111
+ resume_tag: null
112
+
113
+ optim:
114
+ type: muon
115
+ lr: 2.5e-4
116
+ weight_decay: 0.02
117
+ momentum: 0.95
118
+ betas:
119
+ - 0.9
120
+ - 0.999
121
+
122
+ logging:
123
+ enabled: true
124
+ backend: json
125
+ path: logs/pilot_metrics.json
126
+ project: nested-learning
127
+ run_name: pilot-main