fusion-bench 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +73 -10
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/programs/fabric_fusion_program.py +5 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +86 -22
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.14.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
_target_: fusion_bench.method.randes.SuperposedModelSoupAlgorithm
|
|
2
|
+
#* === base randes options ===
|
|
3
|
+
mode: manual_absorption
|
|
4
|
+
# weights for all mlp and attn layers
|
|
5
|
+
target_layer:
|
|
6
|
+
- mlp_w
|
|
7
|
+
- attn_w
|
|
8
|
+
random_seed: 42 # for random_binary_diagonal_matrix
|
|
9
|
+
different_across_layers: True
|
|
10
|
+
joint_matrix_mode: flatten_hstack
|
|
11
|
+
rank: 1 # for columnwise svd
|
|
12
|
+
random_components: False
|
|
13
|
+
shift_layers: 0
|
|
14
|
+
absorber: None
|
|
15
|
+
debug: 0
|
|
16
|
+
ms_mode: average
|
|
17
|
+
verbose: 0 # level of verbosity
|
|
18
|
+
dropout_rate: 1 # take the target layer per n target layers
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
name: superposed_task_arithmetic
|
|
2
|
+
#* === base randes options ===
|
|
3
|
+
mode: random_binary_diagonal_matrix
|
|
4
|
+
# weights for all mlp and attn layers
|
|
5
|
+
target_layer:
|
|
6
|
+
- mlp_w
|
|
7
|
+
- attn_w
|
|
8
|
+
random_seed: 42 # for random_binary_diagonal_matrix
|
|
9
|
+
different_across_layers: True
|
|
10
|
+
joint_matrix_mode: flatten_hstack
|
|
11
|
+
rank: 1 # for columnwise svd
|
|
12
|
+
random_components: False
|
|
13
|
+
shift_layers: 0
|
|
14
|
+
debug: 0
|
|
15
|
+
verbose: 0
|
|
16
|
+
dropout_rate: 1
|
|
17
|
+
#* === task arithmetic options ===
|
|
18
|
+
scaling_factor: 0.5
|
|
19
|
+
# path to save/load the model
|
|
20
|
+
model_path: null
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
_target_: fusion_bench.method.randes.SuperposedTaskArithmeticLoRAAlgorithm
|
|
2
|
+
#* === base randes options ===
|
|
3
|
+
mode: random_binary_diagonal_matrix
|
|
4
|
+
# weights for all mlp and attn layers
|
|
5
|
+
target_layer:
|
|
6
|
+
- mlp_w
|
|
7
|
+
- attn_w
|
|
8
|
+
random_seed: 42 # for random_binary_diagonal_matrix
|
|
9
|
+
different_across_layers: True
|
|
10
|
+
joint_matrix_mode: flatten_hstack
|
|
11
|
+
rank: 1 # for columnwise svd
|
|
12
|
+
random_components: False
|
|
13
|
+
shift_layers: 0
|
|
14
|
+
debug: 0
|
|
15
|
+
verbose: 0
|
|
16
|
+
dropout_rate: 1
|
|
17
|
+
#* === task arithmetic options ===
|
|
18
|
+
scaling_factor: 0.5
|
|
19
|
+
# path to save/load the model
|
|
20
|
+
model_path: null
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
_target_: fusion_bench.method.
|
|
1
|
+
_target_: fusion_bench.method.sparselo.sparselo.IterativeSparseLoForLlama
|
|
2
2
|
_recursive_: false
|
|
3
3
|
nsamples: 128
|
|
4
4
|
seed: 0
|
|
5
5
|
rank: 128
|
|
6
6
|
num_iterations: 10
|
|
7
7
|
variant: wanda
|
|
8
|
+
use_reference_model: false
|
|
8
9
|
# `prune_type` can be either `unstructured` or `semistructured`
|
|
9
10
|
prune_type: unstructured
|
|
10
11
|
# device and dtype to compute the pruning mask
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# The 20 tasks used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# pre-trained model
|
|
6
|
+
- clip-vit-base-patch32
|
|
7
|
+
# eight tasks in the task arithmetic paper
|
|
8
|
+
- clip-vit-base-patch32_sun397
|
|
9
|
+
- clip-vit-base-patch32_stanford-cars
|
|
10
|
+
- clip-vit-base-patch32_resisc45
|
|
11
|
+
- clip-vit-base-patch32_eurosat
|
|
12
|
+
- clip-vit-base-patch32_svhn
|
|
13
|
+
- clip-vit-base-patch32_gtsrb
|
|
14
|
+
- clip-vit-base-patch32_mnist
|
|
15
|
+
- clip-vit-base-patch32_dtd
|
|
16
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
17
|
+
- clip-vit-base-patch32_oxford_flowers102
|
|
18
|
+
- clip-vit-base-patch32_pcam
|
|
19
|
+
# - clip-vit-base-patch32_fer2013
|
|
20
|
+
# - clip-vit-base-patch32_oxford-iiit-pet
|
|
21
|
+
# - clip-vit-base-patch32_stl10
|
|
22
|
+
# - clip-vit-base-patch32_cifar100
|
|
23
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
24
|
+
# - clip-vit-base-patch32_cifar10
|
|
25
|
+
# - clip-vit-base-patch32_food101
|
|
26
|
+
# - clip-vit-base-patch32_fashion_mnist
|
|
27
|
+
# - clip-vit-base-patch32_emnist_letters
|
|
28
|
+
# - clip-vit-base-patch32_kmnist
|
|
29
|
+
# - clip-vit-base-patch32_rendered-sst2
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# pre-trained model
|
|
6
|
+
- clip-vit-base-patch32
|
|
7
|
+
# eight tasks in the task arithmetic paper
|
|
8
|
+
- clip-vit-base-patch32_sun397
|
|
9
|
+
- clip-vit-base-patch32_stanford-cars
|
|
10
|
+
- clip-vit-base-patch32_resisc45
|
|
11
|
+
- clip-vit-base-patch32_eurosat
|
|
12
|
+
- clip-vit-base-patch32_svhn
|
|
13
|
+
- clip-vit-base-patch32_gtsrb
|
|
14
|
+
- clip-vit-base-patch32_mnist
|
|
15
|
+
- clip-vit-base-patch32_dtd
|
|
16
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
17
|
+
- clip-vit-base-patch32_oxford_flowers102
|
|
18
|
+
- clip-vit-base-patch32_pcam
|
|
19
|
+
- clip-vit-base-patch32_fer2013
|
|
20
|
+
- clip-vit-base-patch32_oxford-iiit-pet
|
|
21
|
+
# - clip-vit-base-patch32_stl10
|
|
22
|
+
# - clip-vit-base-patch32_cifar100
|
|
23
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
24
|
+
# - clip-vit-base-patch32_cifar10
|
|
25
|
+
# - clip-vit-base-patch32_food101
|
|
26
|
+
# - clip-vit-base-patch32_fashion_mnist
|
|
27
|
+
# - clip-vit-base-patch32_emnist_letters
|
|
28
|
+
# - clip-vit-base-patch32_kmnist
|
|
29
|
+
# - clip-vit-base-patch32_rendered-sst2
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# pre-trained model
|
|
6
|
+
- clip-vit-base-patch32
|
|
7
|
+
# eight tasks in the task arithmetic paper
|
|
8
|
+
- clip-vit-base-patch32_sun397
|
|
9
|
+
- clip-vit-base-patch32_stanford-cars
|
|
10
|
+
- clip-vit-base-patch32_resisc45
|
|
11
|
+
- clip-vit-base-patch32_eurosat
|
|
12
|
+
- clip-vit-base-patch32_svhn
|
|
13
|
+
- clip-vit-base-patch32_gtsrb
|
|
14
|
+
- clip-vit-base-patch32_mnist
|
|
15
|
+
- clip-vit-base-patch32_dtd
|
|
16
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
17
|
+
- clip-vit-base-patch32_oxford_flowers102
|
|
18
|
+
- clip-vit-base-patch32_pcam
|
|
19
|
+
- clip-vit-base-patch32_fer2013
|
|
20
|
+
- clip-vit-base-patch32_oxford-iiit-pet
|
|
21
|
+
- clip-vit-base-patch32_stl10
|
|
22
|
+
- clip-vit-base-patch32_cifar100
|
|
23
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
24
|
+
- clip-vit-base-patch32_cifar10
|
|
25
|
+
- clip-vit-base-patch32_food101
|
|
26
|
+
# - clip-vit-base-patch32_fashion_mnist
|
|
27
|
+
# - clip-vit-base-patch32_emnist_letters
|
|
28
|
+
# - clip-vit-base-patch32_kmnist
|
|
29
|
+
# - clip-vit-base-patch32_rendered-sst2
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
# pre-trained model
|
|
6
|
+
- clip-vit-base-patch32
|
|
7
|
+
# eight tasks in the task arithmetic paper
|
|
8
|
+
- clip-vit-base-patch32_sun397
|
|
9
|
+
- clip-vit-base-patch32_stanford-cars
|
|
10
|
+
- clip-vit-base-patch32_resisc45
|
|
11
|
+
- clip-vit-base-patch32_eurosat
|
|
12
|
+
- clip-vit-base-patch32_svhn
|
|
13
|
+
- clip-vit-base-patch32_gtsrb
|
|
14
|
+
- clip-vit-base-patch32_mnist
|
|
15
|
+
- clip-vit-base-patch32_dtd
|
|
16
|
+
# additional 6 tasks in the TALL mask paper (TALL 14)
|
|
17
|
+
- clip-vit-base-patch32_oxford_flowers102
|
|
18
|
+
- clip-vit-base-patch32_pcam
|
|
19
|
+
- clip-vit-base-patch32_fer2013
|
|
20
|
+
- clip-vit-base-patch32_oxford-iiit-pet
|
|
21
|
+
- clip-vit-base-patch32_stl10
|
|
22
|
+
- clip-vit-base-patch32_cifar100
|
|
23
|
+
# additional 6 tasks in the TALL mask paper (TALL 20)
|
|
24
|
+
- clip-vit-base-patch32_cifar10
|
|
25
|
+
- clip-vit-base-patch32_food101
|
|
26
|
+
- clip-vit-base-patch32_fashion_mnist
|
|
27
|
+
- clip-vit-base-patch32_emnist_letters
|
|
28
|
+
# - clip-vit-base-patch32_kmnist
|
|
29
|
+
# - clip-vit-base-patch32_rendered-sst2
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
- CLIPVisionModelPool@: _template
|
|
6
|
+
- /model/clip-vit@models: clip-vit-base-patch32_TALL10
|
|
7
|
+
- /dataset/image_classification/train@train_datasets: TALL10
|
|
8
|
+
- /dataset/image_classification/test@test_datasets: TALL10
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
- CLIPVisionModelPool@: _template
|
|
6
|
+
- /model/clip-vit@models: clip-vit-base-patch32_TALL12
|
|
7
|
+
- /dataset/image_classification/train@train_datasets: TALL12
|
|
8
|
+
- /dataset/image_classification/test@test_datasets: TALL12
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
- CLIPVisionModelPool@: _template
|
|
6
|
+
- /model/clip-vit@models: clip-vit-base-patch32_TALL16
|
|
7
|
+
- /dataset/image_classification/train@train_datasets: TALL16
|
|
8
|
+
- /dataset/image_classification/test@test_datasets: TALL16
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# The 20 task used in the paper:
|
|
2
|
+
# Wang et al. Localizing Task Information for Improved Model Merging and Compression
|
|
3
|
+
# http://arxiv.org/abs/2405.07813
|
|
4
|
+
defaults:
|
|
5
|
+
- CLIPVisionModelPool@: _template
|
|
6
|
+
- /model/clip-vit@models: clip-vit-base-patch32_TALL18
|
|
7
|
+
- /dataset/image_classification/train@train_datasets: TALL18
|
|
8
|
+
- /dataset/image_classification/test@test_datasets: TALL18
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: fusion_bench.models.modeling_deepseek_v2.DeepseekV2ForCausalLM.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
device_map: auto
|
|
11
|
+
trust_remote_code: true
|
|
12
|
+
|
|
13
|
+
tokenizer:
|
|
14
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
15
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: meta-llama/Llama-3.1-8B-Instruct
|
|
4
|
+
instruction: MergeBench/Llama-3.1-8B-Instruct_instruction
|
|
5
|
+
math: MergeBench/Llama-3.1-8B-Instruct_math
|
|
6
|
+
coding: MergeBench/Llama-3.1-8B-Instruct_coding
|
|
7
|
+
multilingual: MergeBench/Llama-3.1-8B-Instruct_multilingual
|
|
8
|
+
safety: MergeBench/Llama-3.1-8B-Instruct_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: meta-llama/Llama-3.1-8B-Instruct
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: meta-llama/Llama-3.1-8B
|
|
4
|
+
instruction: MergeBench/Llama-3.1-8B_instruction
|
|
5
|
+
math: MergeBench/Llama-3.1-8B_math
|
|
6
|
+
coding: MergeBench/Llama-3.1-8B_coding
|
|
7
|
+
multilingual: MergeBench/Llama-3.1-8B_multilingual
|
|
8
|
+
safety: MergeBench/Llama-3.1-8B_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: meta-llama/Llama-3.1-8B
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: meta-llama/Llama-3.2-3B-Instruct
|
|
4
|
+
instruction: MergeBench/Llama-3.2-3B-Instruct_instruction
|
|
5
|
+
math: MergeBench/Llama-3.2-3B-Instruct_math
|
|
6
|
+
coding: MergeBench/Llama-3.2-3B-Instruct_coding
|
|
7
|
+
multilingual: MergeBench/Llama-3.2-3B-Instruct_multilingual
|
|
8
|
+
safety: MergeBench/Llama-3.2-3B-Instruct_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: meta-llama/Llama-3.2-3B-Instruct
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: meta-llama/Llama-3.2-3B
|
|
4
|
+
instruction: MergeBench/Llama-3.2-3B_instruction
|
|
5
|
+
math: MergeBench/Llama-3.2-3B_math
|
|
6
|
+
coding: MergeBench/Llama-3.2-3B_coding
|
|
7
|
+
multilingual: MergeBench/Llama-3.2-3B_multilingual
|
|
8
|
+
safety: MergeBench/Llama-3.2-3B_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: meta-llama/Llama-3.2-3B
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: google/gemma-2-2b-it
|
|
4
|
+
instruction: MergeBench/gemma-2-2b-it_instruction
|
|
5
|
+
math: MergeBench/gemma-2-2b-it_math
|
|
6
|
+
coding: MergeBench/gemma-2-2b-it_coding
|
|
7
|
+
multilingual: MergeBench/gemma-2-2b-it_multilingual
|
|
8
|
+
safety: MergeBench/gemma-2-2b-it_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: google/gemma-2-2b-it
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: google/gemma-2-2b
|
|
4
|
+
instruction: MergeBench/gemma-2-2b_instruction
|
|
5
|
+
math: MergeBench/gemma-2-2b_math
|
|
6
|
+
coding: MergeBench/gemma-2-2b_coding
|
|
7
|
+
multilingual: MergeBench/gemma-2-2b_multilingual
|
|
8
|
+
safety: MergeBench/gemma-2-2b_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: google/gemma-2-2b
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: google/gemma-2-9b-it
|
|
4
|
+
instruction: MergeBench/gemma-2-9b-it_instruction
|
|
5
|
+
math: MergeBench/gemma-2-9b-it_math
|
|
6
|
+
coding: MergeBench/gemma-2-9b-it_coding
|
|
7
|
+
multilingual: MergeBench/gemma-2-9b-it_multilingual
|
|
8
|
+
safety: MergeBench/gemma-2-9b-it_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: google/gemma-2-9b-it
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
models:
|
|
3
|
+
_pretrained_: google/gemma-2-9b
|
|
4
|
+
instruction: MergeBench/gemma-2-9b_instruction
|
|
5
|
+
math: MergeBench/gemma-2-9b_math
|
|
6
|
+
coding: MergeBench/gemma-2-9b_coding
|
|
7
|
+
multilingual: MergeBench/gemma-2-9b_multilingual
|
|
8
|
+
safety: MergeBench/gemma-2-9b_safety
|
|
9
|
+
model_kwargs:
|
|
10
|
+
torch_dtype: bfloat16
|
|
11
|
+
tokenizer: google/gemma-2-9b
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
_target_: fusion_bench.modelpool.CausalLMPool
|
|
2
|
+
|
|
3
|
+
pretrained_model_name_or_path: mistralai/Mixtral-8x7B-v0.1
|
|
4
|
+
|
|
5
|
+
models:
|
|
6
|
+
_pretrained_:
|
|
7
|
+
_target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
8
|
+
pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
|
|
9
|
+
torch_dtype: bfloat16
|
|
10
|
+
device_map: auto
|
|
11
|
+
|
|
12
|
+
tokenizer:
|
|
13
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
14
|
+
pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- Seq2SeqLMPool@: _template
|
|
3
|
+
- /model/roberta@models:
|
|
4
|
+
- roberta_base
|
|
5
|
+
- roberta_glue-cola
|
|
6
|
+
- roberta_glue-mnli
|
|
7
|
+
- roberta_glue-mrpc
|
|
8
|
+
- roberta_glue-qnli
|
|
9
|
+
- roberta_glue-qqp
|
|
10
|
+
- roberta_glue-rte
|
|
11
|
+
- roberta_glue-sst2
|
|
12
|
+
- roberta_glue-stsb
|
|
13
|
+
# _target_: fusion_bench.modelpool.SequenceClassificationModelPool
|
|
14
|
+
# _recursive_: false
|
|
15
|
+
|
|
16
|
+
_dataset_loader: fusion_bench.tasks.flan_t5_text_generation.glue_load_dataset.load_glue_dataset
|
|
17
|
+
test_datasets:
|
|
18
|
+
glue-cola:
|
|
19
|
+
_target_: ${..._dataset_loader}
|
|
20
|
+
_recursive_: false
|
|
21
|
+
name: cola
|
|
22
|
+
tokenizer: ${...tokenizer}
|
|
23
|
+
split: validation
|
|
24
|
+
glue-mnli:
|
|
25
|
+
_target_: ${..._dataset_loader}
|
|
26
|
+
_recursive_: false
|
|
27
|
+
name: mnli
|
|
28
|
+
tokenizer: ${...tokenizer}
|
|
29
|
+
split: validation_matched
|
|
30
|
+
glue-mrpc:
|
|
31
|
+
_target_: ${..._dataset_loader}
|
|
32
|
+
_recursive_: false
|
|
33
|
+
name: mrpc
|
|
34
|
+
tokenizer: ${...tokenizer}
|
|
35
|
+
split: validation
|
|
36
|
+
glue-qnli:
|
|
37
|
+
_target_: ${..._dataset_loader}
|
|
38
|
+
_recursive_: false
|
|
39
|
+
name: qnli
|
|
40
|
+
tokenizer: ${...tokenizer}
|
|
41
|
+
split: validation
|
|
42
|
+
glue-qqp:
|
|
43
|
+
_target_: ${..._dataset_loader}
|
|
44
|
+
_recursive_: false
|
|
45
|
+
name: qqp
|
|
46
|
+
tokenizer: ${...tokenizer}
|
|
47
|
+
split: validation
|
|
48
|
+
glue-rte:
|
|
49
|
+
_target_: ${..._dataset_loader}
|
|
50
|
+
_recursive_: false
|
|
51
|
+
name: rte
|
|
52
|
+
tokenizer: ${...tokenizer}
|
|
53
|
+
split: validation
|
|
54
|
+
glue-sst2:
|
|
55
|
+
_target_: ${..._dataset_loader}
|
|
56
|
+
_recursive_: false
|
|
57
|
+
name: sst2
|
|
58
|
+
tokenizer: ${...tokenizer}
|
|
59
|
+
split: validation
|
|
60
|
+
glue-stsb:
|
|
61
|
+
_target_: ${..._dataset_loader}
|
|
62
|
+
_recursive_: false
|
|
63
|
+
name: stsb
|
|
64
|
+
tokenizer: ${...tokenizer}
|
|
65
|
+
split: validation
|
|
66
|
+
|
|
67
|
+
tokenizer:
|
|
68
|
+
_target_: transformers.AutoTokenizer.from_pretrained
|
|
69
|
+
pretrained_model_name_or_path: roberta-base
|
|
File without changes
|
|
File without changes
|
|
File without changes
|