fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,9 +1,29 @@
1
- # The 20 task used in the paper:
1
+ # The 20 task used in the paper:
2
2
  # Wang et al. Localizing Task Information for Improved Model Merging and Compression
3
3
  # http://arxiv.org/abs/2405.07813
4
- defaults:
5
- - CLIPVisionModelPool@: _template
6
- - /model/clip-vit@models: clip-vit-large-patch14_TALL20
7
- processor:
8
- _target_: transformers.CLIPProcessor.from_pretrained
9
- pretrained_model_name_or_path: openai/clip-vit-large-patch14
4
+ _target_: fusion_bench.modelpool.CLIPVisionModelPool
5
+ _recursive_: False
6
+ processor: openai/clip-vit-large-patch14
7
+ models:
8
+ _pretrained_: openai/clip-vit-large-patch14
9
+ sun397: tanganke/clip-vit-large-patch14_sun397
10
+ stanford-cars: tanganke/clip-vit-large-patch14_stanford-cars
11
+ resisc45: tanganke/clip-vit-large-patch14_resisc45
12
+ eurosat: tanganke/clip-vit-large-patch14_eurosat
13
+ svhn: tanganke/clip-vit-large-patch14_svhn
14
+ gtsrb: tanganke/clip-vit-large-patch14_gtsrb
15
+ mnist: tanganke/clip-vit-large-patch14_mnist
16
+ dtd: tanganke/clip-vit-large-patch14_dtd
17
+ oxford_flowers102: tanganke/clip-vit-large-patch14_oxford_flowers102
18
+ pcam: tanganke/clip-vit-large-patch14_pcam
19
+ fer2013: tanganke/clip-vit-large-patch14_fer2013
20
+ oxford-iiit-pet: tanganke/clip-vit-large-patch14_oxford-iiit-pet
21
+ stl10: tanganke/clip-vit-large-patch14_stl10
22
+ cifar100: tanganke/clip-vit-large-patch14_cifar100
23
+ cifar10: tanganke/clip-vit-large-patch14_cifar10
24
+ food101: tanganke/clip-vit-large-patch14_food101
25
+ fashion_mnist: tanganke/clip-vit-large-patch14_fashion_mnist
26
+ emnist_letters: tanganke/clip-vit-large-patch14_emnist_letters
27
+ kmnist: tanganke/clip-vit-large-patch14_kmnist
28
+ rendered-sst2: tanganke/clip-vit-large-patch14_rendered-sst2
29
+ platform: hf
@@ -2,15 +2,11 @@
2
2
  #
3
3
  # fusion_bench \
4
4
  # modelpool=CLIPVisionModelPool/clip-vit-large-patch14_individual \
5
- # modelpool.base_model=${MODEL_PATH}
5
+ # modelpool.models._pretrained_=${MODEL_PATH}
6
6
  # ...
7
- defaults:
8
- - CLIPVisionModelPool@: _template
7
+ _target_: fusion_bench.modelpool.CLIPVisionModelPool
8
+ _recursive_: False
9
9
  models:
10
- _pretrained_:
11
- _target_: transformers.CLIPVisionModel.from_pretrained
12
- pretrained_model_name_or_path: ${...base_model}
13
- processor:
14
- _target_: transformers.CLIPProcessor.from_pretrained
15
- pretrained_model_name_or_path: ${..base_model}
16
- base_model: openai/clip-vit-large-patch14
10
+ _pretrained_: openai/clip-vit-large-patch14
11
+ processor: openai/clip-vit-large-patch14
12
+ platform: hf
@@ -4,8 +4,8 @@ _recursive_: false
4
4
  load_lazy: false
5
5
  models:
6
6
  _pretrained_: Qwen/Qwen2.5-1.5B
7
- expert_1: Qwen/Qwen2.5-Math-1.5B
8
- expert_2: Qwen/Qwen2.5-Coder-1.5B
7
+ math: Qwen/Qwen2.5-Math-1.5B
8
+ code: Qwen/Qwen2.5-Coder-1.5B
9
9
  model_kwargs:
10
10
  torch_dtype: bfloat16
11
11
  tokenizer: Qwen/Qwen2.5-1.5B
@@ -0,0 +1,9 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ _recursive_: false
3
+ models:
4
+ _pretrained_: Qwen/Qwen2.5-7B
5
+ math: Qwen/Qwen2.5-Math-7B
6
+ code: Qwen/Qwen2.5-Coder-7B
7
+ model_kwargs:
8
+ torch_dtype: bfloat16
9
+ tokenizer: Qwen/Qwen2.5-7B
@@ -0,0 +1,6 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: mistralai/Mistral-7B-v0.1
4
+ tokenizer: ${.models._pretrained_}
5
+ model_kwargs:
6
+ torch_dtype: bfloat16
@@ -0,0 +1,10 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ models:
3
+ _pretrained_: path_to_your_pretrained_model
4
+ expert_1: path_to_your_expert_model_1
5
+ expert_2: path_to_your_expert_model_2
6
+ expert_3: path_to_your_expert_model_3
7
+ expert_4: path_to_your_expert_model_4
8
+ tokenizer: ${.models._pretrained_}
9
+ model_kwargs:
10
+ torch_dtype: bfloat16
@@ -1,17 +1,9 @@
1
1
  _target_: fusion_bench.modelpool.CausalLMPool
2
2
  _recursive_: false
3
3
  models:
4
- _pretrained_:
5
- _target_: transformers.AutoModelForCausalLM.from_pretrained
6
- pretrained_model_name_or_path: Qwen/Qwen2.5-1.5B
7
- expert_1:
8
- _target_: transformers.AutoModelForCausalLM.from_pretrained
9
- pretrained_model_name_or_path: Qwen/Qwen2.5-Math-1.5B
10
- expert_2:
11
- _target_: transformers.AutoModelForCausalLM.from_pretrained
12
- pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
4
+ _pretrained_: Qwen/Qwen2.5-1.5B
5
+ expert_1: Qwen/Qwen2.5-Math-1.5B
6
+ expert_2: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
13
7
  model_kwargs:
14
8
  torch_dtype: bfloat16
15
- tokenizer:
16
- _target_: transformers.AutoTokenizer.from_pretrained
17
- pretrained_model_name_or_path: Qwen/Qwen2.5-1.5B
9
+ tokenizer: Qwen/Qwen2.5-1.5B
@@ -1,20 +1,10 @@
1
1
  _target_: fusion_bench.modelpool.CausalLMPool
2
2
  _recursive_: false
3
3
  models:
4
- _pretrained_:
5
- _target_: transformers.AutoModelForCausalLM.from_pretrained
6
- pretrained_model_name_or_path: mistralai/Mistral-7B-v0.1
7
- expert_1:
8
- _target_: transformers.AutoModelForCausalLM.from_pretrained
9
- pretrained_model_name_or_path: meta-math/MetaMath-Mistral-7B
10
- expert_2:
11
- _target_: transformers.AutoModelForCausalLM.from_pretrained
12
- pretrained_model_name_or_path: cognitivecomputations/dolphin-2.1-mistral-7b
13
- expert_3:
14
- _target_: transformers.AutoModelForCausalLM.from_pretrained
15
- pretrained_model_name_or_path: uukuguy/speechless-code-mistral-7b-v1.0
4
+ _pretrained_: mistralai/Mistral-7B-v0.1
5
+ expert_1: meta-math/MetaMath-Mistral-7B
6
+ expert_2: cognitivecomputations/dolphin-2.1-mistral-7b
7
+ expert_3: uukuguy/speechless-code-mistral-7b-v1.0
16
8
  model_kwargs:
17
- torch_dtype: float16
18
- tokenizer:
19
- _target_: transformers.AutoTokenizer.from_pretrained
20
- pretrained_model_name_or_path: mistralai/Mistral-7B-v0.1
9
+ torch_dtype: bfloat16
10
+ tokenizer: mistralai/Mistral-7B-v0.1
@@ -0,0 +1,8 @@
1
+ _target_: fusion_bench.modelpool.CausalLMPool
2
+ _recursive_: false
3
+ models:
4
+ _pretrained_: meta-llama/Llama-2-7b-hf
5
+ finetuned_model: lmsys/vicuna-7b-v1.5
6
+ model_kwargs:
7
+ torch_dtype: bfloat16
8
+ tokenizer: ${.models.finetuned_model}
@@ -1,4 +1,4 @@
1
- _target_: fusion_bench.modelpool.SeqenceClassificationModelPool
1
+ _target_: fusion_bench.modelpool.SequenceClassificationModelPool
2
2
  pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
3
3
  models:
4
4
  _pretrained_:
@@ -1,4 +1,4 @@
1
- _target_: fusion_bench.modelpool.SeqenceClassificationModelPool
1
+ _target_: fusion_bench.modelpool.SequenceClassificationModelPool
2
2
  pretrained_model_name_or_path: fusion-bench/Llama-3.2-1B-Instruct_Bradly-Terry-RM_Preference-700k
3
3
  models:
4
4
  _pretrained_:
@@ -1,8 +1,10 @@
1
1
  defaults:
2
2
  - hydra: default
3
3
  - fabric: auto
4
- - modelpool: nyuv2_modelpool
4
+ - path: default
5
+ # --- Model, Method, Task ---
5
6
  - method: simple_average
7
+ - modelpool: nyuv2_modelpool
6
8
  - taskpool: nyuv2_taskpool
7
9
  - _self_
8
10
  _target_: fusion_bench.programs.FabricModelFusionProgram
@@ -1,5 +1,6 @@
1
1
  defaults:
2
2
  - hydra: default
3
+ - path: default
3
4
  - _self_
4
5
  fast_dev_run: false
5
6
  exp_name: null
@@ -0,0 +1,28 @@
1
+ # =============================================================================
2
+ # FusionBench Path Configuration
3
+ # =============================================================================
4
+ # This configuration file defines the directory structure and path settings
5
+ # used throughout the FusionBench framework for model fusion experiments.
6
+ # All paths are configured using Hydra's variable interpolation syntax.
7
+ # Root directory - uses FUSION_BENCH_PROJECT_ROOT env var or current directory
8
+ #
9
+ # By default:
10
+ #
11
+ # root_dir (defaults to current directory)
12
+ # ├── outputs (output_dir)
13
+ # │ ├── cache (cache_dir)
14
+ # │ └── <config_name>
15
+ # │ └── <timestamp> (log_dir)
16
+ # └── data (data_dir)
17
+ #
18
+ root_dir: ${oc.env:FUSION_BENCH_PROJECT_ROOT,"."}
19
+ # Output directory for experiment results and artifacts
20
+ output_dir: ${.root_dir}/outputs
21
+ # Data directory - uses FUSION_BENCH_DATA_DIR env var or root_dir/data
22
+ data_dir: ${oc.env:FUSION_BENCH_DATA_DIR,${.root_dir}/data}
23
+ # Cache directory - uses FUSION_BENCH_CACHE_DIR env var or output_dir/cache
24
+ cache_dir: ${oc.env:FUSION_BENCH_CACHE_DIR,${.output_dir}/cache}
25
+ # Log directory with timestamped subdirectories for each run
26
+ log_dir: ${.output_dir}/${hydra:job.config_name}/${now:%Y-%m-%d_%H-%M-%S}
27
+ # Current working directory at runtime
28
+ work_dir: ${hydra:runtime.cwd}
@@ -0,0 +1,24 @@
1
+ defaults:
2
+ - /dataset/image_classification/test@test_datasets:
3
+ - svhn
4
+ - mnist
5
+ _target_: fusion_bench.taskpool.CLIPVisionModelTaskPool
6
+ _recursive_: false
7
+ test_datasets: ??? # The datasets to evaluate the model on
8
+ base_model: openai/clip-vit-base-patch32
9
+ clip_model: ${.base_model} # The base model to use
10
+ processor: ${.base_model} # The base model to use
11
+ data_processor: ${.processor}
12
+ dataloader_kwargs:
13
+ batch_size: 128 # The batch size for the data loader
14
+ num_workers: 8 # The number of worker processes for data loading
15
+ pin_memory: True # Whether to pin memory in data loader
16
+ drop_last: False # Whether to drop the last incomplete batch
17
+ shuffle: False # Whether to shuffle the data
18
+ # === layer-wise feature saving ===
19
+ # The path to save the features to, if none then the features are not saved
20
+ # This is the path to a directory, the features of task `task_name` will be saved in `feature_save_path/task_name.csv`
21
+ layer_wise_feature_save_path: null
22
+ layer_wise_feature_first_token_only: true # Whether to save only the first token of the features
23
+ # The maximum number of samples to save the features for
24
+ layer_wise_feature_max_num: 1000
@@ -1,23 +0,0 @@
1
- # this option can be one of "clip_task_wise_adamerging" or "clip_layer_wise_adamerging"
2
- name: clip_layer_wise_adamerging
3
- # this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
4
- # if weights is specified, skip the test-time adaptation training
5
- weights: null
6
- # learning rate
7
- optimizer: adam
8
- lr: 1e-3
9
- init_values: 0.3
10
- # if `clamp_weights` is true, the weights will be clamped to [0, 1]
11
- clamp_weights: false
12
- # arguments of `functional_call`
13
- tie_weights: true
14
- strict: false
15
- # this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
16
- devices: 1
17
- batch_size: 16
18
- num_workers: 8
19
- max_steps: 1000
20
- fast_dev_run: ${fast_dev_run}
21
- # the path for saving the merging weights
22
- save_merging_weights: 'merging_weights.pt'
23
- cache_dir: outputs
@@ -1,14 +0,0 @@
1
- type: AutoModelForCausalLMPool
2
- # each model should have a name and a path, and the model is loaded from the path
3
- # this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
4
- models:
5
- - name: _pretrained_
6
- path: path_to_your_pretrained_model
7
- - name: expert_1
8
- path: path_to_your_expert_model_1
9
- - name: expert_2
10
- path: path_to_your_expert_model_2
11
- - name: expert_3
12
- path: path_to_your_expert_model_3
13
- - name: expert_4
14
- path: path_to_your_expert_model_4
@@ -1,6 +0,0 @@
1
- type: AutoModelForCausalLMPool
2
- # each model should have a name and a path, and the model is loaded from the path
3
- # this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
4
- models:
5
- - name: _pretrained_
6
- path: path_to_your_pretrained_model
@@ -1,22 +0,0 @@
1
- type: clip_vit_classification
2
- name: clip-vit-base-patch32_svhn_and_mnist # whatever you like
3
- dataset_type: huggingface_image_classification
4
- tasks:
5
- - name: svhn
6
- dataset:
7
- type: instantiate
8
- name: svhn
9
- object:
10
- _target_: datasets.load_dataset
11
- _args_:
12
- - svhn
13
- - cropped_digits
14
- split: test
15
- - name: mnist
16
- dataset:
17
- name: mnist
18
- split: test
19
- clip_model: openai/clip-vit-base-patch32
20
- batch_size: 128
21
- num_workers: 16
22
- fast_dev_run: ${fast_dev_run}