fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  30. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  32. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  33. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  34. fusion_bench/programs/fabric_fusion_program.py +5 -0
  35. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  36. fusion_bench/utils/__init__.py +1 -0
  37. fusion_bench/utils/data.py +1 -1
  38. fusion_bench/utils/lazy_state_dict.py +268 -0
  39. fusion_bench/utils/parameters.py +33 -0
  40. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  41. fusion_bench/utils/type.py +1 -0
  42. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
  43. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
  44. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  45. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  53. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  54. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  55. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  56. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  57. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  58. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  60. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  61. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  63. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  64. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  68. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  72. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  73. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  74. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  75. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,20 @@
1
+ name: superposed_task_arithmetic
2
+ #* === base randes options ===
3
+ mode: random_binary_diagonal_matrix
4
+ # weights for all mlp and attn layers
5
+ target_layer:
6
+ - mlp_w
7
+ - attn_w
8
+ random_seed: 42 # for random_binary_diagonal_matrix
9
+ different_across_layers: True
10
+ joint_matrix_mode: flatten_hstack
11
+ rank: 1 # for columnwise svd
12
+ random_components: False
13
+ shift_layers: 0
14
+ debug: 0
15
+ verbose: 0
16
+ dropout_rate: 1
17
+ #* === task arithmetic options ===
18
+ scaling_factor: 0.5
19
+ # path to save/load the model
20
+ model_path: null
@@ -0,0 +1,20 @@
1
+ _target_: fusion_bench.method.randes.SuperposedTaskArithmeticLoRAAlgorithm
2
+ #* === base randes options ===
3
+ mode: random_binary_diagonal_matrix
4
+ # weights for all mlp and attn layers
5
+ target_layer:
6
+ - mlp_w
7
+ - attn_w
8
+ random_seed: 42 # for random_binary_diagonal_matrix
9
+ different_across_layers: True
10
+ joint_matrix_mode: flatten_hstack
11
+ rank: 1 # for columnwise svd
12
+ random_components: False
13
+ shift_layers: 0
14
+ debug: 0
15
+ verbose: 0
16
+ dropout_rate: 1
17
+ #* === task arithmetic options ===
18
+ scaling_factor: 0.5
19
+ # path to save/load the model
20
+ model_path: null
@@ -1,10 +1,11 @@
1
- _target_: fusion_bench.method.losparse.sparselo.IterativeSparseLoForLlama
1
+ _target_: fusion_bench.method.sparselo.sparselo.IterativeSparseLoForLlama
2
2
  _recursive_: false
3
3
  nsamples: 128
4
4
  seed: 0
5
5
  rank: 128
6
6
  num_iterations: 10
7
7
  variant: wanda
8
+ use_reference_model: false
8
9
  # `prune_type` can be either `unstructured` or `semistructured`
9
10
  prune_type: unstructured
10
11
  # device and dtype to compute the pruning mask
@@ -1,4 +1,4 @@
1
- _target_: fusion_bench.method.losparse.sparselo.PCPSparseLoForLlama
1
+ _target_: fusion_bench.method.sparselo.sparselo.PCPSparseLoForLlama
2
2
  _recursive_: false
3
3
  nsamples: 128
4
4
  seed: 0
@@ -1,4 +1,4 @@
1
- _target_: fusion_bench.method.losparse.sparselo.SparseLoForLlama
1
+ _target_: fusion_bench.method.sparselo.sparselo.SparseLoForLlama
2
2
  _recursive_: false
3
3
  nsamples: 128
4
4
  seed: 0
@@ -0,0 +1,4 @@
1
+ _target_: fusion_bench.method.tall_mask.TallMaskTaskArithmeticAlgorithm
2
+ tall_mask_lambda: 0.6
3
+ debug: 0
4
+ verbose: 0
@@ -0,0 +1,29 @@
1
+ # The 20 tasks used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # pre-trained model
6
+ - clip-vit-base-patch32
7
+ # eight tasks in the task arithmetic paper
8
+ - clip-vit-base-patch32_sun397
9
+ - clip-vit-base-patch32_stanford-cars
10
+ - clip-vit-base-patch32_resisc45
11
+ - clip-vit-base-patch32_eurosat
12
+ - clip-vit-base-patch32_svhn
13
+ - clip-vit-base-patch32_gtsrb
14
+ - clip-vit-base-patch32_mnist
15
+ - clip-vit-base-patch32_dtd
16
+ # additional 6 tasks in the TALL mask paper (TALL 14)
17
+ - clip-vit-base-patch32_oxford_flowers102
18
+ - clip-vit-base-patch32_pcam
19
+ # - clip-vit-base-patch32_fer2013
20
+ # - clip-vit-base-patch32_oxford-iiit-pet
21
+ # - clip-vit-base-patch32_stl10
22
+ # - clip-vit-base-patch32_cifar100
23
+ # additional 6 tasks in the TALL mask paper (TALL 20)
24
+ # - clip-vit-base-patch32_cifar10
25
+ # - clip-vit-base-patch32_food101
26
+ # - clip-vit-base-patch32_fashion_mnist
27
+ # - clip-vit-base-patch32_emnist_letters
28
+ # - clip-vit-base-patch32_kmnist
29
+ # - clip-vit-base-patch32_rendered-sst2
@@ -0,0 +1,29 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # pre-trained model
6
+ - clip-vit-base-patch32
7
+ # eight tasks in the task arithmetic paper
8
+ - clip-vit-base-patch32_sun397
9
+ - clip-vit-base-patch32_stanford-cars
10
+ - clip-vit-base-patch32_resisc45
11
+ - clip-vit-base-patch32_eurosat
12
+ - clip-vit-base-patch32_svhn
13
+ - clip-vit-base-patch32_gtsrb
14
+ - clip-vit-base-patch32_mnist
15
+ - clip-vit-base-patch32_dtd
16
+ # additional 6 tasks in the TALL mask paper (TALL 14)
17
+ - clip-vit-base-patch32_oxford_flowers102
18
+ - clip-vit-base-patch32_pcam
19
+ - clip-vit-base-patch32_fer2013
20
+ - clip-vit-base-patch32_oxford-iiit-pet
21
+ # - clip-vit-base-patch32_stl10
22
+ # - clip-vit-base-patch32_cifar100
23
+ # additional 6 tasks in the TALL mask paper (TALL 20)
24
+ # - clip-vit-base-patch32_cifar10
25
+ # - clip-vit-base-patch32_food101
26
+ # - clip-vit-base-patch32_fashion_mnist
27
+ # - clip-vit-base-patch32_emnist_letters
28
+ # - clip-vit-base-patch32_kmnist
29
+ # - clip-vit-base-patch32_rendered-sst2
@@ -0,0 +1,29 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # pre-trained model
6
+ - clip-vit-base-patch32
7
+ # eight tasks in the task arithmetic paper
8
+ - clip-vit-base-patch32_sun397
9
+ - clip-vit-base-patch32_stanford-cars
10
+ - clip-vit-base-patch32_resisc45
11
+ - clip-vit-base-patch32_eurosat
12
+ - clip-vit-base-patch32_svhn
13
+ - clip-vit-base-patch32_gtsrb
14
+ - clip-vit-base-patch32_mnist
15
+ - clip-vit-base-patch32_dtd
16
+ # additional 6 tasks in the TALL mask paper (TALL 14)
17
+ - clip-vit-base-patch32_oxford_flowers102
18
+ - clip-vit-base-patch32_pcam
19
+ - clip-vit-base-patch32_fer2013
20
+ - clip-vit-base-patch32_oxford-iiit-pet
21
+ - clip-vit-base-patch32_stl10
22
+ - clip-vit-base-patch32_cifar100
23
+ # additional 6 tasks in the TALL mask paper (TALL 20)
24
+ - clip-vit-base-patch32_cifar10
25
+ - clip-vit-base-patch32_food101
26
+ # - clip-vit-base-patch32_fashion_mnist
27
+ # - clip-vit-base-patch32_emnist_letters
28
+ # - clip-vit-base-patch32_kmnist
29
+ # - clip-vit-base-patch32_rendered-sst2
@@ -0,0 +1,29 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ # pre-trained model
6
+ - clip-vit-base-patch32
7
+ # eight tasks in the task arithmetic paper
8
+ - clip-vit-base-patch32_sun397
9
+ - clip-vit-base-patch32_stanford-cars
10
+ - clip-vit-base-patch32_resisc45
11
+ - clip-vit-base-patch32_eurosat
12
+ - clip-vit-base-patch32_svhn
13
+ - clip-vit-base-patch32_gtsrb
14
+ - clip-vit-base-patch32_mnist
15
+ - clip-vit-base-patch32_dtd
16
+ # additional 6 tasks in the TALL mask paper (TALL 14)
17
+ - clip-vit-base-patch32_oxford_flowers102
18
+ - clip-vit-base-patch32_pcam
19
+ - clip-vit-base-patch32_fer2013
20
+ - clip-vit-base-patch32_oxford-iiit-pet
21
+ - clip-vit-base-patch32_stl10
22
+ - clip-vit-base-patch32_cifar100
23
+ # additional 6 tasks in the TALL mask paper (TALL 20)
24
+ - clip-vit-base-patch32_cifar10
25
+ - clip-vit-base-patch32_food101
26
+ - clip-vit-base-patch32_fashion_mnist
27
+ - clip-vit-base-patch32_emnist_letters
28
+ # - clip-vit-base-patch32_kmnist
29
+ # - clip-vit-base-patch32_rendered-sst2
@@ -0,0 +1,8 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ - CLIPVisionModelPool@: _template
6
+ - /model/clip-vit@models: clip-vit-base-patch32_TALL10
7
+ - /dataset/image_classification/train@train_datasets: TALL10
8
+ - /dataset/image_classification/test@test_datasets: TALL10
@@ -0,0 +1,8 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ - CLIPVisionModelPool@: _template
6
+ - /model/clip-vit@models: clip-vit-base-patch32_TALL12
7
+ - /dataset/image_classification/train@train_datasets: TALL12
8
+ - /dataset/image_classification/test@test_datasets: TALL12
@@ -0,0 +1,8 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ - CLIPVisionModelPool@: _template
6
+ - /model/clip-vit@models: clip-vit-base-patch32_TALL16
7
+ - /dataset/image_classification/train@train_datasets: TALL16
8
+ - /dataset/image_classification/test@test_datasets: TALL16
@@ -0,0 +1,8 @@
1
+ # The 20 task used in the paper:
2
+ # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
+ # http://arxiv.org/abs/2405.07813
4
+ defaults:
5
+ - CLIPVisionModelPool@: _template
6
+ - /model/clip-vit@models: clip-vit-base-patch32_TALL18
7
+ - /dataset/image_classification/train@train_datasets: TALL18
8
+ - /dataset/image_classification/test@test_datasets: TALL18
@@ -0,0 +1,15 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+
3
+ pretrained_model_name_or_path: deepseek-ai/DeepSeek-V2-Lite
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: fusion_bench.models.modeling_deepseek_v2.DeepseekV2ForCausalLM.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+ device_map: auto
11
+ trust_remote_code: true
12
+
13
+ tokenizer:
14
+ _target_: transformers.AutoTokenizer.from_pretrained
15
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
@@ -0,0 +1,14 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+
3
+ pretrained_model_name_or_path: mistralai/Mixtral-8x7B-v0.1
4
+
5
+ models:
6
+ _pretrained_:
7
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
8
+ pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
+ torch_dtype: bfloat16
10
+ device_map: auto
11
+
12
+ tokenizer:
13
+ _target_: transformers.AutoTokenizer.from_pretrained
14
+ pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
@@ -0,0 +1,69 @@
1
+ defaults:
2
+ - Seq2SeqLMPool@: _template
3
+ - /model/roberta@models:
4
+ - roberta_base
5
+ - roberta_glue-cola
6
+ - roberta_glue-mnli
7
+ - roberta_glue-mrpc
8
+ - roberta_glue-qnli
9
+ - roberta_glue-qqp
10
+ - roberta_glue-rte
11
+ - roberta_glue-sst2
12
+ - roberta_glue-stsb
13
+ # _target_: fusion_bench.modelpool.SequenceClassificationModelPool
14
+ # _recursive_: false
15
+
16
+ _dataset_loader: fusion_bench.tasks.flan_t5_text_generation.glue_load_dataset.load_glue_dataset
17
+ test_datasets:
18
+ glue-cola:
19
+ _target_: ${..._dataset_loader}
20
+ _recursive_: false
21
+ name: cola
22
+ tokenizer: ${...tokenizer}
23
+ split: validation
24
+ glue-mnli:
25
+ _target_: ${..._dataset_loader}
26
+ _recursive_: false
27
+ name: mnli
28
+ tokenizer: ${...tokenizer}
29
+ split: validation_matched
30
+ glue-mrpc:
31
+ _target_: ${..._dataset_loader}
32
+ _recursive_: false
33
+ name: mrpc
34
+ tokenizer: ${...tokenizer}
35
+ split: validation
36
+ glue-qnli:
37
+ _target_: ${..._dataset_loader}
38
+ _recursive_: false
39
+ name: qnli
40
+ tokenizer: ${...tokenizer}
41
+ split: validation
42
+ glue-qqp:
43
+ _target_: ${..._dataset_loader}
44
+ _recursive_: false
45
+ name: qqp
46
+ tokenizer: ${...tokenizer}
47
+ split: validation
48
+ glue-rte:
49
+ _target_: ${..._dataset_loader}
50
+ _recursive_: false
51
+ name: rte
52
+ tokenizer: ${...tokenizer}
53
+ split: validation
54
+ glue-sst2:
55
+ _target_: ${..._dataset_loader}
56
+ _recursive_: false
57
+ name: sst2
58
+ tokenizer: ${...tokenizer}
59
+ split: validation
60
+ glue-stsb:
61
+ _target_: ${..._dataset_loader}
62
+ _recursive_: false
63
+ name: stsb
64
+ tokenizer: ${...tokenizer}
65
+ split: validation
66
+
67
+ tokenizer:
68
+ _target_: transformers.AutoTokenizer.from_pretrained
69
+ pretrained_model_name_or_path: roberta-base