nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
dim: 384
|
|
7
|
+
num_layers: 8
|
|
8
|
+
heads: 6
|
|
9
|
+
titan_level:
|
|
10
|
+
name: titan
|
|
11
|
+
update_period: 8
|
|
12
|
+
optimizer_key: titan_opt
|
|
13
|
+
cms_hidden_multiplier: 2
|
|
14
|
+
cms_levels:
|
|
15
|
+
- name: cms_fast
|
|
16
|
+
update_period: 8
|
|
17
|
+
optimizer_key: cms_opt
|
|
18
|
+
- name: cms_mid
|
|
19
|
+
update_period: 32
|
|
20
|
+
optimizer_key: cms_opt
|
|
21
|
+
- name: cms_slow
|
|
22
|
+
update_period: 128
|
|
23
|
+
optimizer_key: cms_opt
|
|
24
|
+
- name: cms_ultra
|
|
25
|
+
update_period: 512
|
|
26
|
+
optimizer_key: cms_opt
|
|
27
|
+
|
|
28
|
+
data:
|
|
29
|
+
seq_len: 1024
|
|
30
|
+
batch_size: 2
|
|
31
|
+
num_workers: 2
|
|
32
|
+
|
|
33
|
+
train:
|
|
34
|
+
online_updates: true
|
|
35
|
+
online_chunk_size: 0
|
|
36
|
+
per_layer_teach_signal: true
|
|
37
|
+
steps: 5000
|
|
38
|
+
device: "cuda:1"
|
|
39
|
+
checkpoint:
|
|
40
|
+
dir: artifacts/checkpoints/pilot_cms_sparse
|
|
41
|
+
save_interval: 1000
|
|
42
|
+
log_interval: 25
|
|
43
|
+
|
|
44
|
+
logging:
|
|
45
|
+
path: logs/pilot_cms_sparse_metrics.json
|
|
46
|
+
run_name: pilot-cms-sparse
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
self_mod_chunk_size: 8
|
|
8
|
+
self_mod_chunk_size_memory: 64
|
|
9
|
+
|
|
10
|
+
train:
|
|
11
|
+
online_updates: true
|
|
12
|
+
online_chunk_size: 0
|
|
13
|
+
per_layer_teach_signal: true
|
|
14
|
+
steps: 5000
|
|
15
|
+
device: "cuda:1"
|
|
16
|
+
checkpoint:
|
|
17
|
+
dir: artifacts/checkpoints/pilot_selfmod_chunked_8_64
|
|
18
|
+
save_interval: 1000
|
|
19
|
+
|
|
20
|
+
logging:
|
|
21
|
+
enabled: true
|
|
22
|
+
backend: json
|
|
23
|
+
path: logs/pilot_selfmod_chunked_8_64_metrics.json
|
|
24
|
+
run_name: pilot-selfmod-chunked-8-64
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
self_mod_momentum: 0.0
|
|
8
|
+
|
|
9
|
+
train:
|
|
10
|
+
online_updates: true
|
|
11
|
+
online_chunk_size: 0
|
|
12
|
+
per_layer_teach_signal: true
|
|
13
|
+
steps: 5000
|
|
14
|
+
device: "cuda:1"
|
|
15
|
+
checkpoint:
|
|
16
|
+
dir: artifacts/checkpoints/pilot_selfmod_momentum_off
|
|
17
|
+
save_interval: 1000
|
|
18
|
+
|
|
19
|
+
logging:
|
|
20
|
+
enabled: true
|
|
21
|
+
backend: json
|
|
22
|
+
path: logs/pilot_selfmod_momentum_off_metrics.json
|
|
23
|
+
run_name: pilot-selfmod-momentum-off
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
self_mod_momentum: 0.9
|
|
8
|
+
|
|
9
|
+
train:
|
|
10
|
+
online_updates: true
|
|
11
|
+
online_chunk_size: 0
|
|
12
|
+
per_layer_teach_signal: true
|
|
13
|
+
steps: 5000
|
|
14
|
+
device: "cuda:1"
|
|
15
|
+
checkpoint:
|
|
16
|
+
dir: artifacts/checkpoints/pilot_selfmod_momentum_on
|
|
17
|
+
save_interval: 1000
|
|
18
|
+
|
|
19
|
+
logging:
|
|
20
|
+
enabled: true
|
|
21
|
+
backend: json
|
|
22
|
+
path: logs/pilot_selfmod_momentum_on_metrics.json
|
|
23
|
+
run_name: pilot-selfmod-momentum-on
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
self_mod_use_alpha: false
|
|
8
|
+
|
|
9
|
+
train:
|
|
10
|
+
online_updates: true
|
|
11
|
+
online_chunk_size: 0
|
|
12
|
+
per_layer_teach_signal: true
|
|
13
|
+
steps: 5000
|
|
14
|
+
device: "cuda:1"
|
|
15
|
+
checkpoint:
|
|
16
|
+
dir: artifacts/checkpoints/pilot_selfmod_no_alpha
|
|
17
|
+
save_interval: 1000
|
|
18
|
+
|
|
19
|
+
logging:
|
|
20
|
+
enabled: true
|
|
21
|
+
backend: json
|
|
22
|
+
path: logs/pilot_selfmod_no_alpha_metrics.json
|
|
23
|
+
run_name: pilot-selfmod-no-alpha
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
cms_levels: []
|
|
8
|
+
|
|
9
|
+
train:
|
|
10
|
+
online_updates: true
|
|
11
|
+
online_chunk_size: 0
|
|
12
|
+
per_layer_teach_signal: true
|
|
13
|
+
steps: 5000
|
|
14
|
+
device: "cuda:1"
|
|
15
|
+
checkpoint:
|
|
16
|
+
dir: artifacts/checkpoints/pilot_selfmod_no_cms
|
|
17
|
+
save_interval: 1000
|
|
18
|
+
|
|
19
|
+
logging:
|
|
20
|
+
enabled: true
|
|
21
|
+
backend: json
|
|
22
|
+
path: logs/pilot_selfmod_no_cms_metrics.json
|
|
23
|
+
run_name: pilot-selfmod-no-cms
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
self_mod_use_rank1_precond: false
|
|
8
|
+
|
|
9
|
+
train:
|
|
10
|
+
online_updates: true
|
|
11
|
+
online_chunk_size: 0
|
|
12
|
+
per_layer_teach_signal: true
|
|
13
|
+
steps: 5000
|
|
14
|
+
device: "cuda:1"
|
|
15
|
+
checkpoint:
|
|
16
|
+
dir: artifacts/checkpoints/pilot_selfmod_rank1_off
|
|
17
|
+
save_interval: 1000
|
|
18
|
+
|
|
19
|
+
logging:
|
|
20
|
+
enabled: true
|
|
21
|
+
backend: json
|
|
22
|
+
path: logs/pilot_selfmod_rank1_off_metrics.json
|
|
23
|
+
run_name: pilot-selfmod-rank1-off
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
segments:
|
|
2
|
+
- name: refinedweb_2018
|
|
3
|
+
shards_dir: data/shards/refinedweb_sample
|
|
4
|
+
- name: wikipedia_sample
|
|
5
|
+
shards_dir: data/shards/wikipedia_sample
|
|
6
|
+
- name: c4_sample
|
|
7
|
+
shards_dir: data/shards/c4_sample
|
|
8
|
+
- name: redpajama_sample
|
|
9
|
+
shards_dir: data/shards/redpajama_sample
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
name: fineweb_edu_longdoc_filtered_sample
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/fineweb_edu_longdoc
|
|
3
|
+
datasets:
|
|
4
|
+
- name: fineweb_edu_longdoc
|
|
5
|
+
dataset: text
|
|
6
|
+
split: train
|
|
7
|
+
text_column: text
|
|
8
|
+
data_files: data/filtered/fineweb_edu_longdoc_en_sample.txt
|
|
9
|
+
sample_limit: 5000
|
|
10
|
+
seq_len: 4096
|
|
11
|
+
sequences_per_shard: 1024
|
|
12
|
+
output_dir: data/shards/fineweb_edu_longdoc_sample
|
|
13
|
+
max_records: null
|
|
14
|
+
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
name: fineweb_edu_full
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/fineweb_edu
|
|
3
|
+
datasets:
|
|
4
|
+
- name: fineweb_edu
|
|
5
|
+
dataset: HuggingFaceFW/fineweb-edu
|
|
6
|
+
subset: sample-100BT
|
|
7
|
+
split: train
|
|
8
|
+
text_column: text
|
|
9
|
+
sample_limit: 100000
|
|
10
|
+
seq_len: 4096
|
|
11
|
+
sequences_per_shard: 1024
|
|
12
|
+
output_dir: data/shards/fineweb_edu_full
|
|
13
|
+
max_records: null
|
|
14
|
+
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
name: fineweb_edu_sample
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/fineweb_edu
|
|
3
|
+
datasets:
|
|
4
|
+
- name: fineweb_edu
|
|
5
|
+
dataset: HuggingFaceFW/fineweb-edu
|
|
6
|
+
subset: sample-10BT
|
|
7
|
+
split: train
|
|
8
|
+
text_column: text
|
|
9
|
+
sample_limit: 5000
|
|
10
|
+
seq_len: 2048
|
|
11
|
+
sequences_per_shard: 1024
|
|
12
|
+
output_dir: data/shards/fineweb_edu_sample
|
|
13
|
+
max_records: 10000
|
|
14
|
+
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
name: refinedweb_mix_v1
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
|
|
3
|
+
datasets:
|
|
4
|
+
- name: refinedweb
|
|
5
|
+
dataset: text
|
|
6
|
+
split: train
|
|
7
|
+
text_column: text
|
|
8
|
+
data_files: data/filtered/refinedweb_en_full.txt
|
|
9
|
+
seq_len: 2048
|
|
10
|
+
sequences_per_shard: 2048
|
|
11
|
+
output_dir: data/shards/refinedweb
|
|
12
|
+
max_records: null
|
|
13
|
+
- name: books
|
|
14
|
+
dataset: text
|
|
15
|
+
split: train
|
|
16
|
+
text_column: text
|
|
17
|
+
data_files: data/filtered/wikipedia_en_full.txt
|
|
18
|
+
seq_len: 2048
|
|
19
|
+
sequences_per_shard: 2048
|
|
20
|
+
output_dir: data/shards/wikipedia
|
|
21
|
+
max_records: null
|
|
22
|
+
- name: c4
|
|
23
|
+
dataset: text
|
|
24
|
+
split: train
|
|
25
|
+
text_column: text
|
|
26
|
+
data_files: data/filtered/c4_en_full.txt
|
|
27
|
+
seq_len: 2048
|
|
28
|
+
sequences_per_shard: 2048
|
|
29
|
+
output_dir: data/shards/c4
|
|
30
|
+
max_records: null
|
|
31
|
+
- name: redpajama
|
|
32
|
+
dataset: text
|
|
33
|
+
split: train
|
|
34
|
+
text_column: text
|
|
35
|
+
data_files: data/filtered/redpajama_en_full.txt
|
|
36
|
+
seq_len: 2048
|
|
37
|
+
sequences_per_shard: 2048
|
|
38
|
+
output_dir: data/shards/redpajama
|
|
39
|
+
max_records: null
|
|
40
|
+
- name: code
|
|
41
|
+
dataset: text
|
|
42
|
+
split: train
|
|
43
|
+
text_column: text
|
|
44
|
+
data_files: data/filtered/code_en_full.txt
|
|
45
|
+
seq_len: 2048
|
|
46
|
+
sequences_per_shard: 2048
|
|
47
|
+
output_dir: data/shards/code
|
|
48
|
+
max_records: null
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
name: refinedweb_mix_filtered
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
|
|
3
|
+
datasets:
|
|
4
|
+
- name: refinedweb
|
|
5
|
+
dataset: text
|
|
6
|
+
split: train
|
|
7
|
+
text_column: text
|
|
8
|
+
data_files: data/filtered/refinedweb_en_sample.txt
|
|
9
|
+
seq_len: 512
|
|
10
|
+
sequences_per_shard: 512
|
|
11
|
+
output_dir: data/shards/refinedweb_filtered
|
|
12
|
+
max_records: null
|
|
13
|
+
- name: wikipedia
|
|
14
|
+
dataset: text
|
|
15
|
+
split: train
|
|
16
|
+
text_column: text
|
|
17
|
+
data_files: data/filtered/wikipedia_en_sample.txt
|
|
18
|
+
seq_len: 512
|
|
19
|
+
sequences_per_shard: 512
|
|
20
|
+
output_dir: data/shards/wikipedia_filtered
|
|
21
|
+
max_records: null
|
|
22
|
+
- name: c4
|
|
23
|
+
dataset: text
|
|
24
|
+
split: train
|
|
25
|
+
text_column: text
|
|
26
|
+
data_files: data/filtered/c4_en_sample.txt
|
|
27
|
+
seq_len: 512
|
|
28
|
+
sequences_per_shard: 512
|
|
29
|
+
output_dir: data/shards/c4_filtered
|
|
30
|
+
max_records: null
|
|
31
|
+
- name: redpajama
|
|
32
|
+
dataset: text
|
|
33
|
+
split: train
|
|
34
|
+
text_column: text
|
|
35
|
+
data_files: data/filtered/redpajama_en_sample.txt
|
|
36
|
+
seq_len: 512
|
|
37
|
+
sequences_per_shard: 512
|
|
38
|
+
output_dir: data/shards/redpajama_filtered
|
|
39
|
+
max_records: null
|
|
40
|
+
- name: code
|
|
41
|
+
dataset: text
|
|
42
|
+
split: train
|
|
43
|
+
text_column: text
|
|
44
|
+
data_files: data/filtered/code_en_sample.txt
|
|
45
|
+
seq_len: 512
|
|
46
|
+
sequences_per_shard: 512
|
|
47
|
+
output_dir: data/shards/code_filtered
|
|
48
|
+
max_records: null
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
name: refinedweb_mix_full
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
|
|
3
|
+
datasets:
|
|
4
|
+
- name: refinedweb
|
|
5
|
+
dataset: text
|
|
6
|
+
split: train
|
|
7
|
+
text_column: text
|
|
8
|
+
data_files: data/filtered/refinedweb_en_full.txt
|
|
9
|
+
seq_len: 2048
|
|
10
|
+
sequences_per_shard: 1024
|
|
11
|
+
output_dir: data/shards/refinedweb_full
|
|
12
|
+
max_records: null
|
|
13
|
+
- name: wikipedia
|
|
14
|
+
dataset: text
|
|
15
|
+
split: train
|
|
16
|
+
text_column: text
|
|
17
|
+
data_files: data/filtered/wikipedia_en_full.txt
|
|
18
|
+
seq_len: 2048
|
|
19
|
+
sequences_per_shard: 1024
|
|
20
|
+
output_dir: data/shards/wikipedia_full
|
|
21
|
+
max_records: null
|
|
22
|
+
- name: c4
|
|
23
|
+
dataset: text
|
|
24
|
+
split: train
|
|
25
|
+
text_column: text
|
|
26
|
+
data_files: data/filtered/c4_en_full.txt
|
|
27
|
+
seq_len: 2048
|
|
28
|
+
sequences_per_shard: 1024
|
|
29
|
+
output_dir: data/shards/c4_full
|
|
30
|
+
max_records: null
|
|
31
|
+
- name: redpajama
|
|
32
|
+
dataset: text
|
|
33
|
+
split: train
|
|
34
|
+
text_column: text
|
|
35
|
+
data_files: data/filtered/redpajama_en_full.txt
|
|
36
|
+
seq_len: 2048
|
|
37
|
+
sequences_per_shard: 1024
|
|
38
|
+
output_dir: data/shards/redpajama_full
|
|
39
|
+
max_records: null
|
|
40
|
+
- name: code
|
|
41
|
+
dataset: text
|
|
42
|
+
split: train
|
|
43
|
+
text_column: text
|
|
44
|
+
data_files: data/filtered/code_en_full.txt
|
|
45
|
+
seq_len: 2048
|
|
46
|
+
sequences_per_shard: 1024
|
|
47
|
+
output_dir: data/shards/code_full
|
|
48
|
+
max_records: null
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
name: refinedweb_mix_sample
|
|
2
|
+
tokenizer_output_dir: artifacts/tokenizer/refinedweb_mix
|
|
3
|
+
datasets:
|
|
4
|
+
- name: refinedweb
|
|
5
|
+
dataset: HuggingFaceFW/fineweb
|
|
6
|
+
subset: sample-10BT
|
|
7
|
+
split: train
|
|
8
|
+
text_column: text
|
|
9
|
+
sample_limit: 5000
|
|
10
|
+
seq_len: 512
|
|
11
|
+
sequences_per_shard: 512
|
|
12
|
+
output_dir: data/shards/refinedweb_sample
|
|
13
|
+
max_records: 10000
|
|
14
|
+
- name: books
|
|
15
|
+
dataset: wikimedia/wikipedia
|
|
16
|
+
subset: 20231101.en
|
|
17
|
+
split: train
|
|
18
|
+
text_column: text
|
|
19
|
+
sample_limit: 2000
|
|
20
|
+
seq_len: 512
|
|
21
|
+
sequences_per_shard: 512
|
|
22
|
+
output_dir: data/shards/wikipedia_sample
|
|
23
|
+
max_records: 5000
|
|
24
|
+
- name: c4
|
|
25
|
+
dataset: allenai/c4
|
|
26
|
+
subset: en
|
|
27
|
+
split: train
|
|
28
|
+
text_column: text
|
|
29
|
+
sample_limit: 2000
|
|
30
|
+
seq_len: 512
|
|
31
|
+
sequences_per_shard: 512
|
|
32
|
+
output_dir: data/shards/c4_sample
|
|
33
|
+
max_records: 4000
|
|
34
|
+
- name: redpajama
|
|
35
|
+
dataset: cerebras/SlimPajama-627B
|
|
36
|
+
split: train
|
|
37
|
+
text_column: text
|
|
38
|
+
sample_limit: 2000
|
|
39
|
+
seq_len: 512
|
|
40
|
+
sequences_per_shard: 512
|
|
41
|
+
output_dir: data/shards/redpajama_sample
|
|
42
|
+
max_records: 4000
|
|
43
|
+
- name: code
|
|
44
|
+
dataset: codeparrot/codeparrot-clean-train
|
|
45
|
+
split: train
|
|
46
|
+
text_column: content
|
|
47
|
+
sample_limit: 2000
|
|
48
|
+
seq_len: 512
|
|
49
|
+
sequences_per_shard: 512
|
|
50
|
+
output_dir: data/shards/code_sample
|
|
51
|
+
max_records: 4000
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
{
|
|
2
|
+
"bf16": {
|
|
3
|
+
"enabled": true
|
|
4
|
+
},
|
|
5
|
+
"train_batch_size": 64,
|
|
6
|
+
"gradient_accumulation_steps": 1,
|
|
7
|
+
"zero_optimization": {
|
|
8
|
+
"stage": 3,
|
|
9
|
+
"reduce_bucket_size": 50000000,
|
|
10
|
+
"stage3_param_persistence_threshold": 100000,
|
|
11
|
+
"stage3_prefetch_bucket_size": 50000000
|
|
12
|
+
},
|
|
13
|
+
"optimizer": {
|
|
14
|
+
"type": "AdamW",
|
|
15
|
+
"params": {
|
|
16
|
+
"lr": 0.0002,
|
|
17
|
+
"betas": [
|
|
18
|
+
0.9,
|
|
19
|
+
0.95
|
|
20
|
+
],
|
|
21
|
+
"eps": 1e-08,
|
|
22
|
+
"weight_decay": 0.01
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- _self_
|
|
3
|
+
|
|
4
|
+
hydra:
|
|
5
|
+
run:
|
|
6
|
+
dir: .
|
|
7
|
+
output_subdir: null
|
|
8
|
+
job:
|
|
9
|
+
chdir: false
|
|
10
|
+
|
|
11
|
+
model:
|
|
12
|
+
vocab_size: 32000
|
|
13
|
+
dim: 1024
|
|
14
|
+
num_layers: 24
|
|
15
|
+
heads: 16
|
|
16
|
+
surprise_threshold: null
|
|
17
|
+
freeze_backbone: false
|
|
18
|
+
titan_level:
|
|
19
|
+
name: titan
|
|
20
|
+
update_period: 16
|
|
21
|
+
optimizer_key: titan_opt
|
|
22
|
+
cms_levels:
|
|
23
|
+
- name: cms_fast
|
|
24
|
+
update_period: 1
|
|
25
|
+
optimizer_key: cms_opt
|
|
26
|
+
- name: cms_mid
|
|
27
|
+
update_period: 4
|
|
28
|
+
optimizer_key: cms_opt
|
|
29
|
+
- name: cms_slow
|
|
30
|
+
update_period: 32
|
|
31
|
+
optimizer_key: cms_opt
|
|
32
|
+
- name: cms_ultra
|
|
33
|
+
update_period: 128
|
|
34
|
+
optimizer_key: cms_opt
|
|
35
|
+
optimizers:
|
|
36
|
+
titan_opt:
|
|
37
|
+
type: deep_momentum
|
|
38
|
+
lr: 8.0e-4
|
|
39
|
+
params:
|
|
40
|
+
beta: 0.9
|
|
41
|
+
beta2: 0.999
|
|
42
|
+
variant: nl_l2_precond
|
|
43
|
+
cms_opt:
|
|
44
|
+
type: deep_momentum
|
|
45
|
+
lr: 4.0e-4
|
|
46
|
+
params:
|
|
47
|
+
beta: 0.9
|
|
48
|
+
beta2: 0.999
|
|
49
|
+
variant: nl_l2_precond
|
|
50
|
+
|
|
51
|
+
data:
|
|
52
|
+
source: mixture
|
|
53
|
+
batch_size: 16
|
|
54
|
+
num_workers: 4
|
|
55
|
+
mixture:
|
|
56
|
+
samples_per_epoch: 8192
|
|
57
|
+
seed: 42
|
|
58
|
+
sources:
|
|
59
|
+
- name: refinedweb
|
|
60
|
+
shards_dir: data/shards/refinedweb_full
|
|
61
|
+
weight: 0.4
|
|
62
|
+
- name: wikipedia
|
|
63
|
+
shards_dir: data/shards/wikipedia_full
|
|
64
|
+
weight: 0.2
|
|
65
|
+
- name: c4
|
|
66
|
+
shards_dir: data/shards/c4_full
|
|
67
|
+
weight: 0.15
|
|
68
|
+
- name: redpajama
|
|
69
|
+
shards_dir: data/shards/redpajama_full
|
|
70
|
+
weight: 0.15
|
|
71
|
+
- name: code
|
|
72
|
+
shards_dir: data/shards/code_full
|
|
73
|
+
weight: 0.1
|
|
74
|
+
|
|
75
|
+
train:
|
|
76
|
+
strict_streaming_contract: false
|
|
77
|
+
online_updates: true
|
|
78
|
+
online_chunk_size: 0
|
|
79
|
+
online_boundary_targets: false
|
|
80
|
+
online_carry_attention_cache: false
|
|
81
|
+
per_layer_teach_signal: true
|
|
82
|
+
steps: 100
|
|
83
|
+
log_interval: 10
|
|
84
|
+
device: "cuda:1"
|
|
85
|
+
seed: 808
|
|
86
|
+
deterministic: false
|
|
87
|
+
step_offset: 0
|
|
88
|
+
mixed_precision:
|
|
89
|
+
enabled: true
|
|
90
|
+
dtype: bf16
|
|
91
|
+
compile:
|
|
92
|
+
enable: true
|
|
93
|
+
mode: max-autotune
|
|
94
|
+
fsdp:
|
|
95
|
+
auto_wrap_min_params: 2000000
|
|
96
|
+
cpu_offload: false
|
|
97
|
+
checkpoint:
|
|
98
|
+
enable: true
|
|
99
|
+
dir: checkpoints/mid
|
|
100
|
+
save_interval: 50
|
|
101
|
+
resume_path: null
|
|
102
|
+
resume_tag: null
|
|
103
|
+
|
|
104
|
+
optim:
|
|
105
|
+
type: muon
|
|
106
|
+
lr: 2.0e-4
|
|
107
|
+
weight_decay: 0.02
|
|
108
|
+
momentum: 0.95
|
|
109
|
+
betas:
|
|
110
|
+
- 0.9
|
|
111
|
+
- 0.999
|
|
112
|
+
|
|
113
|
+
logging:
|
|
114
|
+
enabled: false
|
|
115
|
+
backend: wandb
|
|
116
|
+
project: nested-learning
|
|
117
|
+
run_name: mid-${now:%Y%m%d%H%M%S}
|
|
118
|
+
path: logs/mid_metrics.json
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- mid
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
gradient_checkpointing: true
|
|
7
|
+
|
|
8
|
+
data:
|
|
9
|
+
batch_size: 8 # per-rank micro-batch for 2× RTX 6000 Ada
|
|
10
|
+
num_workers: 6
|
|
11
|
+
|
|
12
|
+
train:
|
|
13
|
+
strict_streaming_contract: false
|
|
14
|
+
online_updates: true
|
|
15
|
+
online_chunk_size: 0
|
|
16
|
+
online_boundary_targets: false
|
|
17
|
+
online_carry_attention_cache: false
|
|
18
|
+
per_layer_teach_signal: true
|
|
19
|
+
steps: 250000
|
|
20
|
+
log_interval: 20
|
|
21
|
+
device: "cuda"
|
|
22
|
+
mixed_precision:
|
|
23
|
+
enabled: true
|
|
24
|
+
dtype: bf16
|
|
25
|
+
compile:
|
|
26
|
+
enable: false
|
|
27
|
+
fsdp:
|
|
28
|
+
auto_wrap_min_params: 2000000
|
|
29
|
+
cpu_offload: false
|
|
30
|
+
checkpoint:
|
|
31
|
+
enable: true
|
|
32
|
+
dir: artifacts/checkpoints/mid_fsdp
|
|
33
|
+
save_interval: 1000
|
|
34
|
+
resume_path: null
|
|
35
|
+
resume_tag: null
|
|
36
|
+
|
|
37
|
+
optim:
|
|
38
|
+
type: muon
|
|
39
|
+
lr: 2.0e-4
|
|
40
|
+
weight_decay: 0.01
|
|
41
|
+
|
|
42
|
+
logging:
|
|
43
|
+
enabled: true
|
|
44
|
+
backend: wandb
|
|
45
|
+
project: nested-learning
|
|
46
|
+
run_name: hope-mid-fsdp-${now:%Y%m%d%H%M%S}
|
|
47
|
+
path: logs/mid_fsdp_metrics.json
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- /pilot
|
|
3
|
+
- _self_
|
|
4
|
+
|
|
5
|
+
model:
|
|
6
|
+
block_variant: hope_selfmod
|
|
7
|
+
# Chunk update cadence (paper §8.2): other memories update more often than M_memory.
|
|
8
|
+
self_mod_chunk_size: 8
|
|
9
|
+
self_mod_chunk_size_memory: 64
|
|
10
|
+
|
|
11
|
+
train:
|
|
12
|
+
online_updates: true
|
|
13
|
+
online_chunk_size: 0
|
|
14
|
+
per_layer_teach_signal: true
|
|
15
|
+
checkpoint:
|
|
16
|
+
dir: artifacts/checkpoints/pilot_selfmod
|
|
17
|
+
|
|
18
|
+
logging:
|
|
19
|
+
run_name: pilot-selfmod
|
|
20
|
+
path: logs/pilot_selfmod_metrics.json
|