nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- _self_
|
|
3
|
+
|
|
4
|
+
hydra:
|
|
5
|
+
run:
|
|
6
|
+
dir: .
|
|
7
|
+
output_subdir: null
|
|
8
|
+
job:
|
|
9
|
+
chdir: false
|
|
10
|
+
|
|
11
|
+
model:
|
|
12
|
+
vocab_size: 32000
|
|
13
|
+
dim: 1536
|
|
14
|
+
num_layers: 32
|
|
15
|
+
heads: 24
|
|
16
|
+
surprise_threshold: null
|
|
17
|
+
freeze_backbone: false
|
|
18
|
+
titan_level:
|
|
19
|
+
name: titan
|
|
20
|
+
update_period: 32
|
|
21
|
+
optimizer_key: titan_opt
|
|
22
|
+
cms_levels:
|
|
23
|
+
- name: cms_fast
|
|
24
|
+
update_period: 1
|
|
25
|
+
optimizer_key: cms_fast_opt
|
|
26
|
+
- name: cms_mid
|
|
27
|
+
update_period: 4
|
|
28
|
+
optimizer_key: cms_mid_opt
|
|
29
|
+
- name: cms_slow
|
|
30
|
+
update_period: 32
|
|
31
|
+
optimizer_key: cms_slow_opt
|
|
32
|
+
- name: cms_ultra
|
|
33
|
+
update_period: 128
|
|
34
|
+
optimizer_key: cms_slow_opt
|
|
35
|
+
- name: cms_anchor
|
|
36
|
+
update_period: 512
|
|
37
|
+
optimizer_key: cms_anchor_opt
|
|
38
|
+
optimizers:
|
|
39
|
+
titan_opt:
|
|
40
|
+
type: deep_momentum
|
|
41
|
+
lr: 6.0e-4
|
|
42
|
+
params:
|
|
43
|
+
beta: 0.9
|
|
44
|
+
beta2: 0.999
|
|
45
|
+
variant: nl_l2_precond
|
|
46
|
+
cms_fast_opt:
|
|
47
|
+
type: deep_momentum
|
|
48
|
+
lr: 3.0e-4
|
|
49
|
+
params:
|
|
50
|
+
beta: 0.9
|
|
51
|
+
beta2: 0.999
|
|
52
|
+
variant: nl_l2_precond
|
|
53
|
+
cms_mid_opt:
|
|
54
|
+
type: deep_momentum
|
|
55
|
+
lr: 2.5e-4
|
|
56
|
+
params:
|
|
57
|
+
beta: 0.9
|
|
58
|
+
beta2: 0.999
|
|
59
|
+
variant: nl_l2_precond
|
|
60
|
+
cms_slow_opt:
|
|
61
|
+
type: deep_momentum
|
|
62
|
+
lr: 2.0e-4
|
|
63
|
+
params:
|
|
64
|
+
beta: 0.9
|
|
65
|
+
beta2: 0.999
|
|
66
|
+
variant: nl_l2_precond
|
|
67
|
+
cms_anchor_opt:
|
|
68
|
+
type: deep_momentum
|
|
69
|
+
lr: 1.5e-4
|
|
70
|
+
params:
|
|
71
|
+
beta: 0.9
|
|
72
|
+
beta2: 0.999
|
|
73
|
+
variant: nl_l2_precond
|
|
74
|
+
|
|
75
|
+
data:
|
|
76
|
+
source: mixture
|
|
77
|
+
batch_size: 32
|
|
78
|
+
num_workers: 8
|
|
79
|
+
mixture:
|
|
80
|
+
samples_per_epoch: 32768
|
|
81
|
+
seed: 123
|
|
82
|
+
sources:
|
|
83
|
+
- name: refinedweb
|
|
84
|
+
shards_dir: data/shards/refinedweb_filtered
|
|
85
|
+
weight: 0.35
|
|
86
|
+
- name: wikipedia
|
|
87
|
+
shards_dir: data/shards/wikipedia_filtered
|
|
88
|
+
weight: 0.2
|
|
89
|
+
- name: c4
|
|
90
|
+
shards_dir: data/shards/c4_filtered
|
|
91
|
+
weight: 0.15
|
|
92
|
+
- name: redpajama
|
|
93
|
+
shards_dir: data/shards/redpajama_filtered
|
|
94
|
+
weight: 0.2
|
|
95
|
+
- name: code
|
|
96
|
+
shards_dir: data/shards/code_filtered
|
|
97
|
+
weight: 0.1
|
|
98
|
+
|
|
99
|
+
train:
|
|
100
|
+
strict_streaming_contract: false
|
|
101
|
+
online_updates: true
|
|
102
|
+
online_chunk_size: 0
|
|
103
|
+
online_boundary_targets: false
|
|
104
|
+
online_carry_attention_cache: false
|
|
105
|
+
per_layer_teach_signal: true
|
|
106
|
+
steps: 200
|
|
107
|
+
log_interval: 10
|
|
108
|
+
device: "cuda:1"
|
|
109
|
+
seed: 9001
|
|
110
|
+
deterministic: false
|
|
111
|
+
step_offset: 0
|
|
112
|
+
mixed_precision:
|
|
113
|
+
enabled: true
|
|
114
|
+
dtype: bf16
|
|
115
|
+
compile:
|
|
116
|
+
enable: true
|
|
117
|
+
mode: max-autotune
|
|
118
|
+
fsdp:
|
|
119
|
+
auto_wrap_min_params: 2000000
|
|
120
|
+
cpu_offload: false
|
|
121
|
+
checkpoint:
|
|
122
|
+
enable: true
|
|
123
|
+
dir: checkpoints/target
|
|
124
|
+
save_interval: 100
|
|
125
|
+
resume_path: null
|
|
126
|
+
resume_tag: null
|
|
127
|
+
|
|
128
|
+
optim:
|
|
129
|
+
type: muon
|
|
130
|
+
lr: 1.5e-4
|
|
131
|
+
weight_decay: 0.02
|
|
132
|
+
momentum: 0.95
|
|
133
|
+
betas:
|
|
134
|
+
- 0.9
|
|
135
|
+
- 0.999
|
|
136
|
+
|
|
137
|
+
logging:
|
|
138
|
+
enabled: false
|
|
139
|
+
backend: wandb
|
|
140
|
+
project: nested-learning
|
|
141
|
+
run_name: target-${now:%Y%m%d%H%M%S}
|
|
142
|
+
path: logs/target_metrics.json
|
|
143
|
+
|
|
144
|
+
deepspeed:
|
|
145
|
+
config: configs/deepspeed/zero3.json
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- target
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
gradient_checkpointing: true
|
|
7
|
+
|
|
8
|
+
data:
|
|
9
|
+
batch_size: 4 # per-rank micro-batch
|
|
10
|
+
num_workers: 8
|
|
11
|
+
|
|
12
|
+
train:
|
|
13
|
+
strict_streaming_contract: false
|
|
14
|
+
online_updates: true
|
|
15
|
+
online_chunk_size: 0
|
|
16
|
+
online_boundary_targets: false
|
|
17
|
+
online_carry_attention_cache: false
|
|
18
|
+
per_layer_teach_signal: true
|
|
19
|
+
steps: 300000
|
|
20
|
+
log_interval: 20
|
|
21
|
+
device: "cuda"
|
|
22
|
+
mixed_precision:
|
|
23
|
+
enabled: true
|
|
24
|
+
dtype: bf16
|
|
25
|
+
compile:
|
|
26
|
+
enable: false
|
|
27
|
+
fsdp:
|
|
28
|
+
auto_wrap_min_params: 2500000
|
|
29
|
+
cpu_offload: false
|
|
30
|
+
checkpoint:
|
|
31
|
+
enable: true
|
|
32
|
+
dir: artifacts/checkpoints/target_fsdp
|
|
33
|
+
save_interval: 1000
|
|
34
|
+
resume_path: null
|
|
35
|
+
resume_tag: null
|
|
36
|
+
|
|
37
|
+
optim:
|
|
38
|
+
type: muon
|
|
39
|
+
lr: 1.5e-4
|
|
40
|
+
weight_decay: 0.01
|
|
41
|
+
|
|
42
|
+
logging:
|
|
43
|
+
enabled: true
|
|
44
|
+
backend: wandb
|
|
45
|
+
project: nested-learning
|
|
46
|
+
run_name: hope-target-fsdp-${now:%Y%m%d%H%M%S}
|
|
47
|
+
path: logs/target_fsdp_metrics.json
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: .
|
|
4
|
+
output_subdir: null
|
|
5
|
+
job:
|
|
6
|
+
chdir: false
|
|
7
|
+
|
|
8
|
+
model:
|
|
9
|
+
vocab_size: 32000
|
|
10
|
+
dim: 256
|
|
11
|
+
num_layers: 4
|
|
12
|
+
heads: 8
|
|
13
|
+
titan_level:
|
|
14
|
+
name: titan
|
|
15
|
+
update_period: 16
|
|
16
|
+
optimizer_key: titan_opt
|
|
17
|
+
cms_levels:
|
|
18
|
+
- name: cms_fast
|
|
19
|
+
update_period: 1
|
|
20
|
+
optimizer_key: cms_opt
|
|
21
|
+
- name: cms_mid
|
|
22
|
+
update_period: 4
|
|
23
|
+
optimizer_key: cms_opt
|
|
24
|
+
- name: cms_slow
|
|
25
|
+
update_period: 16
|
|
26
|
+
optimizer_key: cms_opt
|
|
27
|
+
- name: cms_ultra
|
|
28
|
+
update_period: 64
|
|
29
|
+
optimizer_key: cms_opt
|
|
30
|
+
optimizers:
|
|
31
|
+
titan_opt:
|
|
32
|
+
type: deep_momentum
|
|
33
|
+
lr: 8.0e-4
|
|
34
|
+
params:
|
|
35
|
+
beta: 0.9
|
|
36
|
+
beta2: 0.999
|
|
37
|
+
cms_opt:
|
|
38
|
+
type: deep_momentum
|
|
39
|
+
lr: 4.0e-4
|
|
40
|
+
params:
|
|
41
|
+
beta: 0.9
|
|
42
|
+
beta2: 0.999
|
|
43
|
+
|
|
44
|
+
data:
|
|
45
|
+
source: mixture
|
|
46
|
+
batch_size: 4
|
|
47
|
+
num_workers: 0
|
|
48
|
+
mixture:
|
|
49
|
+
samples_per_epoch: 128
|
|
50
|
+
seed: 0
|
|
51
|
+
sources:
|
|
52
|
+
- name: refinedweb
|
|
53
|
+
shards_dir: data/shards/refinedweb_filtered
|
|
54
|
+
weight: 0.4
|
|
55
|
+
- name: wikipedia
|
|
56
|
+
shards_dir: data/shards/wikipedia_filtered
|
|
57
|
+
weight: 0.2
|
|
58
|
+
- name: c4
|
|
59
|
+
shards_dir: data/shards/c4_filtered
|
|
60
|
+
weight: 0.15
|
|
61
|
+
- name: redpajama
|
|
62
|
+
shards_dir: data/shards/redpajama_filtered
|
|
63
|
+
weight: 0.15
|
|
64
|
+
- name: code
|
|
65
|
+
shards_dir: data/shards/code_filtered
|
|
66
|
+
weight: 0.1
|
|
67
|
+
|
|
68
|
+
train:
|
|
69
|
+
strict_streaming_contract: false
|
|
70
|
+
online_updates: true
|
|
71
|
+
online_chunk_size: 0
|
|
72
|
+
online_boundary_targets: false
|
|
73
|
+
online_carry_attention_cache: false
|
|
74
|
+
per_layer_teach_signal: true
|
|
75
|
+
steps: 10
|
|
76
|
+
log_interval: 1
|
|
77
|
+
device: "cpu"
|
|
78
|
+
seed: 2024
|
|
79
|
+
deterministic: true
|
|
80
|
+
mixed_precision:
|
|
81
|
+
enabled: false
|
|
82
|
+
dtype: bf16
|
|
83
|
+
compile:
|
|
84
|
+
enable: false
|
|
85
|
+
checkpoint:
|
|
86
|
+
enable: true
|
|
87
|
+
dir: artifacts/checkpoints/mid_smoke
|
|
88
|
+
save_interval: 10
|
|
89
|
+
save_last: true
|
|
90
|
+
|
|
91
|
+
optim:
|
|
92
|
+
type: adamw
|
|
93
|
+
lr: 2.0e-4
|
|
94
|
+
fused: false
|
|
95
|
+
|
|
96
|
+
logging:
|
|
97
|
+
enabled: true
|
|
98
|
+
backend: json
|
|
99
|
+
path: logs/mid_smoke.json
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: .
|
|
4
|
+
output_subdir: null
|
|
5
|
+
job:
|
|
6
|
+
chdir: false
|
|
7
|
+
|
|
8
|
+
model:
|
|
9
|
+
vocab_size: 32000
|
|
10
|
+
dim: 768
|
|
11
|
+
num_layers: 18
|
|
12
|
+
heads: 12
|
|
13
|
+
teach_scale: 0.05
|
|
14
|
+
teach_clip: 5.0
|
|
15
|
+
teach_schedule:
|
|
16
|
+
warmup_steps: 20
|
|
17
|
+
decay_start: 80
|
|
18
|
+
decay_duration: 40
|
|
19
|
+
titan_level:
|
|
20
|
+
name: titan
|
|
21
|
+
update_period: 16
|
|
22
|
+
optimizer_key: titan_opt
|
|
23
|
+
cms_levels:
|
|
24
|
+
- name: cms_fast
|
|
25
|
+
update_period: 1
|
|
26
|
+
optimizer_key: cms_opt
|
|
27
|
+
- name: cms_mid
|
|
28
|
+
update_period: 4
|
|
29
|
+
optimizer_key: cms_opt
|
|
30
|
+
- name: cms_slow
|
|
31
|
+
update_period: 32
|
|
32
|
+
optimizer_key: cms_opt
|
|
33
|
+
- name: cms_ultra
|
|
34
|
+
update_period: 128
|
|
35
|
+
optimizer_key: cms_opt
|
|
36
|
+
optimizers:
|
|
37
|
+
titan_opt:
|
|
38
|
+
type: deep_momentum
|
|
39
|
+
lr: 8.0e-4
|
|
40
|
+
params:
|
|
41
|
+
beta: 0.9
|
|
42
|
+
beta2: 0.999
|
|
43
|
+
cms_opt:
|
|
44
|
+
type: deep_momentum
|
|
45
|
+
lr: 4.0e-4
|
|
46
|
+
params:
|
|
47
|
+
beta: 0.9
|
|
48
|
+
beta2: 0.999
|
|
49
|
+
|
|
50
|
+
data:
|
|
51
|
+
source: mixture
|
|
52
|
+
batch_size: 8
|
|
53
|
+
num_workers: 2
|
|
54
|
+
mixture:
|
|
55
|
+
samples_per_epoch: 1024
|
|
56
|
+
seed: 42
|
|
57
|
+
sources:
|
|
58
|
+
- name: refinedweb
|
|
59
|
+
shards_dir: data/shards/refinedweb_full
|
|
60
|
+
weight: 0.4
|
|
61
|
+
- name: wikipedia
|
|
62
|
+
shards_dir: data/shards/wikipedia_full
|
|
63
|
+
weight: 0.2
|
|
64
|
+
- name: c4
|
|
65
|
+
shards_dir: data/shards/c4_full
|
|
66
|
+
weight: 0.15
|
|
67
|
+
- name: redpajama
|
|
68
|
+
shards_dir: data/shards/redpajama_full
|
|
69
|
+
weight: 0.15
|
|
70
|
+
- name: code
|
|
71
|
+
shards_dir: data/shards/code_full
|
|
72
|
+
weight: 0.1
|
|
73
|
+
|
|
74
|
+
train:
|
|
75
|
+
strict_streaming_contract: false
|
|
76
|
+
online_updates: true
|
|
77
|
+
online_chunk_size: 0
|
|
78
|
+
online_boundary_targets: false
|
|
79
|
+
online_carry_attention_cache: false
|
|
80
|
+
per_layer_teach_signal: true
|
|
81
|
+
steps: 100
|
|
82
|
+
log_interval: 10
|
|
83
|
+
device: "cuda"
|
|
84
|
+
seed: 3401
|
|
85
|
+
deterministic: false
|
|
86
|
+
mixed_precision:
|
|
87
|
+
enabled: true
|
|
88
|
+
dtype: bf16
|
|
89
|
+
compile:
|
|
90
|
+
enable: true
|
|
91
|
+
mode: max-autotune
|
|
92
|
+
fsdp:
|
|
93
|
+
auto_wrap_min_params: 2000000
|
|
94
|
+
cpu_offload: false
|
|
95
|
+
checkpoint:
|
|
96
|
+
enable: true
|
|
97
|
+
dir: artifacts/checkpoints/mid_stage2
|
|
98
|
+
save_interval: 100
|
|
99
|
+
resume_path: null
|
|
100
|
+
resume_tag: null
|
|
101
|
+
|
|
102
|
+
optim:
|
|
103
|
+
type: adamw
|
|
104
|
+
lr: 3.0e-5
|
|
105
|
+
fused: auto
|
|
106
|
+
|
|
107
|
+
logging:
|
|
108
|
+
enabled: true
|
|
109
|
+
backend: json
|
|
110
|
+
path: logs/mid_stage2.json
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: .
|
|
4
|
+
output_subdir: null
|
|
5
|
+
job:
|
|
6
|
+
chdir: false
|
|
7
|
+
|
|
8
|
+
model:
|
|
9
|
+
vocab_size: 32000
|
|
10
|
+
dim: 512
|
|
11
|
+
num_layers: 12
|
|
12
|
+
heads: 8
|
|
13
|
+
teach_scale: 0.2
|
|
14
|
+
teach_clip: 2.0
|
|
15
|
+
titan_level:
|
|
16
|
+
name: titan
|
|
17
|
+
update_period: 16
|
|
18
|
+
optimizer_key: titan_opt
|
|
19
|
+
cms_levels:
|
|
20
|
+
- name: cms_fast
|
|
21
|
+
update_period: 1
|
|
22
|
+
optimizer_key: cms_opt
|
|
23
|
+
- name: cms_mid
|
|
24
|
+
update_period: 4
|
|
25
|
+
optimizer_key: cms_opt
|
|
26
|
+
- name: cms_slow
|
|
27
|
+
update_period: 16
|
|
28
|
+
optimizer_key: cms_opt
|
|
29
|
+
optimizers:
|
|
30
|
+
titan_opt:
|
|
31
|
+
type: deep_momentum
|
|
32
|
+
lr: 6.0e-4
|
|
33
|
+
params:
|
|
34
|
+
beta: 0.9
|
|
35
|
+
beta2: 0.999
|
|
36
|
+
cms_opt:
|
|
37
|
+
type: deep_momentum
|
|
38
|
+
lr: 3.0e-4
|
|
39
|
+
params:
|
|
40
|
+
beta: 0.9
|
|
41
|
+
beta2: 0.999
|
|
42
|
+
|
|
43
|
+
data:
|
|
44
|
+
source: mixture
|
|
45
|
+
batch_size: 8
|
|
46
|
+
num_workers: 2
|
|
47
|
+
mixture:
|
|
48
|
+
samples_per_epoch: 512
|
|
49
|
+
seed: 0
|
|
50
|
+
sources:
|
|
51
|
+
- name: refinedweb
|
|
52
|
+
shards_dir: data/shards/refinedweb_filtered
|
|
53
|
+
weight: 0.4
|
|
54
|
+
- name: wikipedia
|
|
55
|
+
shards_dir: data/shards/wikipedia_filtered
|
|
56
|
+
weight: 0.2
|
|
57
|
+
- name: c4
|
|
58
|
+
shards_dir: data/shards/c4_filtered
|
|
59
|
+
weight: 0.15
|
|
60
|
+
- name: redpajama
|
|
61
|
+
shards_dir: data/shards/redpajama_filtered
|
|
62
|
+
weight: 0.15
|
|
63
|
+
- name: code
|
|
64
|
+
shards_dir: data/shards/code_filtered
|
|
65
|
+
weight: 0.1
|
|
66
|
+
|
|
67
|
+
train:
|
|
68
|
+
strict_streaming_contract: false
|
|
69
|
+
online_updates: true
|
|
70
|
+
online_chunk_size: 0
|
|
71
|
+
online_boundary_targets: false
|
|
72
|
+
online_carry_attention_cache: false
|
|
73
|
+
per_layer_teach_signal: true
|
|
74
|
+
steps: 60
|
|
75
|
+
log_interval: 5
|
|
76
|
+
device: "cuda"
|
|
77
|
+
seed: 777
|
|
78
|
+
deterministic: false
|
|
79
|
+
mixed_precision:
|
|
80
|
+
enabled: true
|
|
81
|
+
dtype: bf16
|
|
82
|
+
compile:
|
|
83
|
+
enable: false
|
|
84
|
+
fsdp:
|
|
85
|
+
auto_wrap_min_params: 1500000
|
|
86
|
+
cpu_offload: false
|
|
87
|
+
checkpoint:
|
|
88
|
+
enable: true
|
|
89
|
+
dir: artifacts/checkpoints/mid_stage2_smoke
|
|
90
|
+
save_interval: 60
|
|
91
|
+
resume_path: null
|
|
92
|
+
resume_tag: null
|
|
93
|
+
|
|
94
|
+
optim:
|
|
95
|
+
type: adamw
|
|
96
|
+
lr: 1.0e-4
|
|
97
|
+
fused: auto
|
|
98
|
+
|
|
99
|
+
logging:
|
|
100
|
+
enabled: true
|
|
101
|
+
backend: json
|
|
102
|
+
path: logs/mid_stage2_smoke.json
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: .
|
|
4
|
+
output_subdir: null
|
|
5
|
+
job:
|
|
6
|
+
chdir: false
|
|
7
|
+
|
|
8
|
+
model:
|
|
9
|
+
type: titan
|
|
10
|
+
vocab_size: 32000
|
|
11
|
+
dim: 768
|
|
12
|
+
num_layers: 18
|
|
13
|
+
heads: 12
|
|
14
|
+
surprise_threshold: 0.02
|
|
15
|
+
freeze_backbone: false
|
|
16
|
+
titan_level:
|
|
17
|
+
name: titan
|
|
18
|
+
update_period: 16
|
|
19
|
+
optimizer_key: titan_opt
|
|
20
|
+
optimizers:
|
|
21
|
+
titan_opt:
|
|
22
|
+
type: deep_momentum
|
|
23
|
+
lr: 8.0e-4
|
|
24
|
+
params:
|
|
25
|
+
beta: 0.9
|
|
26
|
+
beta2: 0.999
|
|
27
|
+
teach_scale: 0.10
|
|
28
|
+
teach_clip: 4.0
|
|
29
|
+
teach_schedule:
|
|
30
|
+
warmup_steps: 60
|
|
31
|
+
decay_start: 140
|
|
32
|
+
decay_duration: 80
|
|
33
|
+
|
|
34
|
+
data:
|
|
35
|
+
source: mixture
|
|
36
|
+
batch_size: 4
|
|
37
|
+
num_workers: 2
|
|
38
|
+
mixture:
|
|
39
|
+
samples_per_epoch: 1024
|
|
40
|
+
seed: 42
|
|
41
|
+
sources:
|
|
42
|
+
- name: refinedweb
|
|
43
|
+
shards_dir: data/shards/refinedweb_full
|
|
44
|
+
weight: 0.4
|
|
45
|
+
- name: wikipedia
|
|
46
|
+
shards_dir: data/shards/wikipedia_full
|
|
47
|
+
weight: 0.2
|
|
48
|
+
- name: c4
|
|
49
|
+
shards_dir: data/shards/c4_full
|
|
50
|
+
weight: 0.15
|
|
51
|
+
- name: redpajama
|
|
52
|
+
shards_dir: data/shards/redpajama_full
|
|
53
|
+
weight: 0.15
|
|
54
|
+
- name: code
|
|
55
|
+
shards_dir: data/shards/code_full
|
|
56
|
+
weight: 0.1
|
|
57
|
+
|
|
58
|
+
train:
|
|
59
|
+
strict_streaming_contract: false
|
|
60
|
+
online_updates: true
|
|
61
|
+
online_chunk_size: 0
|
|
62
|
+
online_boundary_targets: false
|
|
63
|
+
online_carry_attention_cache: false
|
|
64
|
+
per_layer_teach_signal: true
|
|
65
|
+
steps: 220
|
|
66
|
+
log_interval: 20
|
|
67
|
+
device: "cuda:1"
|
|
68
|
+
seed: 451
|
|
69
|
+
deterministic: false
|
|
70
|
+
step_offset: 0
|
|
71
|
+
mixed_precision:
|
|
72
|
+
enabled: true
|
|
73
|
+
dtype: bf16
|
|
74
|
+
compile:
|
|
75
|
+
enable: false
|
|
76
|
+
checkpoint:
|
|
77
|
+
enable: true
|
|
78
|
+
dir: artifacts/checkpoints/mid_titan_baseline
|
|
79
|
+
save_interval: 100
|
|
80
|
+
resume_path: null
|
|
81
|
+
resume_tag: null
|
|
82
|
+
|
|
83
|
+
optim:
|
|
84
|
+
type: adamw
|
|
85
|
+
lr: 1.0e-5
|
|
86
|
+
fused: auto
|
|
87
|
+
|
|
88
|
+
logging:
|
|
89
|
+
enabled: true
|
|
90
|
+
backend: json
|
|
91
|
+
path: logs/mid_titan_baseline.json
|
|
92
|
+
run_name: mid_titan_baseline
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- _self_
|
|
3
|
+
|
|
4
|
+
hydra:
|
|
5
|
+
run:
|
|
6
|
+
dir: .
|
|
7
|
+
output_subdir: null
|
|
8
|
+
job:
|
|
9
|
+
chdir: false
|
|
10
|
+
|
|
11
|
+
model:
|
|
12
|
+
vocab_size: 32000
|
|
13
|
+
dim: 512
|
|
14
|
+
num_layers: 12
|
|
15
|
+
heads: 8
|
|
16
|
+
teach_scale: 0.10
|
|
17
|
+
teach_clip: 5.0
|
|
18
|
+
surprise_threshold: 0.02
|
|
19
|
+
freeze_backbone: false
|
|
20
|
+
self_mod_lr: 0.001
|
|
21
|
+
teach_schedule:
|
|
22
|
+
warmup_steps: 2000
|
|
23
|
+
decay_start: 120000
|
|
24
|
+
decay_duration: 20000
|
|
25
|
+
titan_level:
|
|
26
|
+
name: titan
|
|
27
|
+
update_period: 8
|
|
28
|
+
optimizer_key: titan_opt
|
|
29
|
+
cms_levels:
|
|
30
|
+
- name: cms_fast
|
|
31
|
+
update_period: 1
|
|
32
|
+
optimizer_key: cms_opt
|
|
33
|
+
- name: cms_mid
|
|
34
|
+
update_period: 4
|
|
35
|
+
optimizer_key: cms_opt
|
|
36
|
+
- name: cms_slow
|
|
37
|
+
update_period: 32
|
|
38
|
+
optimizer_key: cms_opt
|
|
39
|
+
- name: cms_ultra
|
|
40
|
+
update_period: 128
|
|
41
|
+
optimizer_key: cms_opt
|
|
42
|
+
optimizers:
|
|
43
|
+
titan_opt:
|
|
44
|
+
type: deep_momentum
|
|
45
|
+
lr: 6.0e-4
|
|
46
|
+
params:
|
|
47
|
+
beta: 0.9
|
|
48
|
+
beta2: 0.999
|
|
49
|
+
# Best-effort paper mapping: rank-1 context projection preconditioner.
|
|
50
|
+
variant: nl_l2_precond
|
|
51
|
+
cms_opt:
|
|
52
|
+
type: deep_momentum
|
|
53
|
+
lr: 3.0e-4
|
|
54
|
+
params:
|
|
55
|
+
beta: 0.9
|
|
56
|
+
beta2: 0.999
|
|
57
|
+
# Best-effort paper mapping: rank-1 context projection preconditioner.
|
|
58
|
+
variant: nl_l2_precond
|
|
59
|
+
|
|
60
|
+
data:
|
|
61
|
+
source: mixture
|
|
62
|
+
seq_len: 2048
|
|
63
|
+
batch_size: 6
|
|
64
|
+
num_workers: 4
|
|
65
|
+
mixture:
|
|
66
|
+
samples_per_epoch: 65536
|
|
67
|
+
seed: 1337
|
|
68
|
+
sources:
|
|
69
|
+
- name: refinedweb
|
|
70
|
+
shards_dir: data/shards/refinedweb_filtered
|
|
71
|
+
weight: 0.4
|
|
72
|
+
- name: wikipedia
|
|
73
|
+
shards_dir: data/shards/wikipedia_filtered
|
|
74
|
+
weight: 0.2
|
|
75
|
+
- name: c4
|
|
76
|
+
shards_dir: data/shards/c4_filtered
|
|
77
|
+
weight: 0.15
|
|
78
|
+
- name: redpajama
|
|
79
|
+
shards_dir: data/shards/redpajama_filtered
|
|
80
|
+
weight: 0.15
|
|
81
|
+
- name: code
|
|
82
|
+
shards_dir: data/shards/code_filtered
|
|
83
|
+
weight: 0.1
|
|
84
|
+
|
|
85
|
+
train:
|
|
86
|
+
algorithm_mode: two_pass_stopgrad_updates
|
|
87
|
+
strict_streaming_contract: false
|
|
88
|
+
online_updates: true
|
|
89
|
+
online_chunk_size: 0
|
|
90
|
+
online_boundary_targets: false
|
|
91
|
+
online_carry_attention_cache: false
|
|
92
|
+
per_layer_teach_signal: true
|
|
93
|
+
steps: 246667
|
|
94
|
+
log_interval: 50
|
|
95
|
+
device: "cuda:1"
|
|
96
|
+
seed: 1337
|
|
97
|
+
deterministic: false
|
|
98
|
+
step_offset: 0
|
|
99
|
+
mixed_precision:
|
|
100
|
+
enabled: true
|
|
101
|
+
dtype: bf16
|
|
102
|
+
compile:
|
|
103
|
+
enable: false
|
|
104
|
+
mode: max-autotune
|
|
105
|
+
checkpoint:
|
|
106
|
+
enable: true
|
|
107
|
+
dir: artifacts/checkpoints/pilot
|
|
108
|
+
save_interval: 1000
|
|
109
|
+
save_last: true
|
|
110
|
+
resume_path: null
|
|
111
|
+
resume_tag: null
|
|
112
|
+
|
|
113
|
+
optim:
|
|
114
|
+
type: muon
|
|
115
|
+
lr: 2.5e-4
|
|
116
|
+
weight_decay: 0.02
|
|
117
|
+
momentum: 0.95
|
|
118
|
+
betas:
|
|
119
|
+
- 0.9
|
|
120
|
+
- 0.999
|
|
121
|
+
|
|
122
|
+
logging:
|
|
123
|
+
enabled: true
|
|
124
|
+
backend: json
|
|
125
|
+
path: logs/pilot_metrics.json
|
|
126
|
+
project: nested-learning
|
|
127
|
+
run_name: pilot-main
|