fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,10 @@
1
1
  _target_: fusion_bench.method.FullFinetuneSFT
2
2
  _recursive_: False
3
-
4
3
  optimizer:
5
4
  _target_: torch.optim.AdamW
6
5
  lr: 1e-5
7
6
  weight_decay: 0.01
8
7
  fused: null
9
-
10
8
  lr_scheduler:
11
9
  _target_: fusion_bench.optim.lr_scheduler.CosineDecayWithWarmup
12
10
  T_max: _T_max_ # this will be replaced by the expected number of training steps
@@ -14,13 +12,11 @@ lr_scheduler:
14
12
  warmup_steps: 100
15
13
  max_lr: ${..optimizer.lr}
16
14
  min_lr: 1e-6
17
-
18
15
  dataloader_kwargs:
19
16
  # per-gpu batch size
20
17
  batch_size: 1
21
18
  num_workers: 0
22
19
  pin_memory: True
23
-
24
20
  # Training hyperparameters
25
21
  # if max_epochs=-1, max_steps will be used to determine the number of training steps
26
22
  max_epochs: 3
@@ -1,23 +1,19 @@
1
1
  _target_: fusion_bench.method.PeftFinetuneSFT
2
2
  _recursive_: False
3
-
4
3
  optimizer:
5
4
  _target_: torch.optim.AdamW
6
5
  lr: 1e-4
7
6
  weight_decay: 0.01
8
7
  fused: null
9
-
10
8
  lr_scheduler:
11
9
  _target_: torch.optim.lr_scheduler.CosineAnnealingLR
12
10
  T_max: _T_max_ # this will be replaced by the expected number of training steps
13
11
  eta_min: 1e-6
14
-
15
12
  dataloader_kwargs:
16
13
  # per-gpu batch size
17
14
  batch_size: 1
18
15
  num_workers: 0
19
16
  pin_memory: True
20
-
21
17
  peft_config:
22
18
  _target_: peft.LoraConfig
23
19
  task_type: peft.TaskType.CAUSAL_LM
@@ -33,11 +29,9 @@ peft_config:
33
29
  lora_alpha: 16
34
30
  lora_dropout: 0
35
31
  bias: none
36
-
37
32
  adapter_name: default
38
33
  # whether to merge and unload the adapter after training
39
34
  merge_and_unload: false
40
-
41
35
  # Training hyperparameters
42
36
  # if max_epochs=-1, max_steps will be used to determine the number of training steps
43
37
  max_epochs: 3
@@ -1,6 +1,5 @@
1
1
  # or fusion_bench.method.MixtralUpscalingAlgorithm
2
- _target_: fusion_bench.method.MixtralForCausalLMUpscalingAlgorithm
3
-
2
+ _target_: fusion_bench.method.MixtralForCausalLMUpscalingAlgorithm
4
3
  num_experts: 4
5
4
  experts_per_token: 2
6
5
  # path to save the upscaled model
@@ -1,4 +1,3 @@
1
1
  _target_: fusion_bench.method.ModelRecombinationAlgorithm
2
-
3
2
  # if `return_model_pool` is not null, the argument `return_modelpool` passed to the `run` method will be ignored.
4
3
  return_modelpool: null
@@ -1,5 +1,4 @@
1
1
  _target_: fusion_bench.method.opcm.opcm.OPCMForCLIP
2
-
3
2
  # shuffle the order of the models
4
3
  shuffle_order: true
5
4
  # the scaling factor for the SVD projection
@@ -1,7 +1,5 @@
1
1
  _target_: fusion_bench.method.opcm.task_arithmetic.ContinualTaskArithmeticForCLIP
2
-
3
2
  scaling_factor: 0.3
4
-
5
3
  # shuffle the order of the models
6
4
  shuffle_order: true
7
5
  # the random seed to use
@@ -1,5 +1,4 @@
1
1
  _target_: fusion_bench.method.opcm.ties_merging.ContinualTiesMergingForCLIP
2
-
3
2
  # Scaling factor $\lambda$
4
3
  scaling_factor: 0.5
5
4
  threshold: 20
@@ -7,7 +6,6 @@ threshold: 20
7
6
  remove_keys: []
8
7
  # Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max'
9
8
  merge_func: sum
10
-
11
9
  # shuffle the order of the models
12
10
  shuffle_order: true
13
11
  # the random seed to use
@@ -1,5 +1,4 @@
1
1
  _target_: fusion_bench.method.opcm.weight_average.ContinualWeightAverageForCLIP
2
-
3
2
  # shuffle the order of the models
4
3
  shuffle_order: true
5
4
  # the random seed to use
@@ -0,0 +1,30 @@
1
+ _target_: fusion_bench.method.pwe_moe.openclip_pwe_moe.PWEMoEExactParetoOptimalForOpenCLIP
2
+ #! === Model Architecture Arguments ===
3
+ # if true, then we only apply the weight ensembling MoE to MLPs, else, we apply it to all layers
4
+ partial: true
5
+ # weight-ensembling MoE arguments
6
+ # initial outputs for the routing gates and the merging weights for the remaining layers
7
+ init_lambda: 0.3
8
+ # number of hidden layers in the routing gate
9
+ router_hidden_layers: 2
10
+ # path to the checkpoint file, if not provided, then the training is performed
11
+ checkpoint_path: null
12
+ #! === Training Arguments ===
13
+ # if false, the training is skipped
14
+ run_train: true
15
+ num_steps: 2000
16
+ save_interval: 1000
17
+ # learning rate
18
+ lr: 1e-2
19
+ alpha: 1 # alpha for dirichlet, if alpha=1, then it is uniform
20
+ # dataloader arguments
21
+ dataloader_kwargs:
22
+ # per-device batch size
23
+ batch_size: 16
24
+ num_workers: 0
25
+ #! === Evaluation Arguments ===
26
+ # if false, the evaluation is skipped
27
+ run_eval: false
28
+ # if true, then we only evaluate the model on the first 20 batches of the test dataset
29
+ quick_evaluation: false
30
+ num_evaluation_samples: equal_weight
@@ -0,0 +1,30 @@
1
+ _target_: fusion_bench.method.pwe_moe.openclip_pwe_moe.PWEMoELinearScalarizationForOpenCLIP
2
+ #! === Model Architecture Arguments ===
3
+ # if true, then we only apply the weight ensembling MoE to MLPs, else, we apply it to all layers
4
+ partial: true
5
+ # weight-ensembling MoE arguments
6
+ # initial outputs for the routing gates and the merging weights for the remaining layers
7
+ init_lambda: 0.3
8
+ # number of hidden layers in the routing gate
9
+ router_hidden_layers: 2
10
+ # path to the checkpoint file, if not provided, then the training is performed
11
+ checkpoint_path: null
12
+ #! === Training Arguments ===
13
+ # if false, the training is skipped
14
+ run_train: true
15
+ num_steps: 2000
16
+ save_interval: 1000
17
+ # learning rate
18
+ lr: 1e-2
19
+ alpha: 1 # alpha for dirichlet, if alpha=1, then it is uniform
20
+ # dataloader arguments
21
+ dataloader_kwargs:
22
+ # per-device batch size
23
+ batch_size: 16
24
+ num_workers: 0
25
+ #! === Evaluation Arguments ===
26
+ # if false, the evaluation is skipped
27
+ run_eval: false
28
+ # if true, then we only evaluate the model on the first 20 batches of the test dataset
29
+ quick_evaluation: false
30
+ num_evaluation_samples: equal_weight
@@ -1,22 +1,23 @@
1
1
  _target_: fusion_bench.method.PWEMoELinearScalarizationForCLIP # or PWEMoExactParetoOptimalForCLIP
2
+ #! === Model Architecture Arguments ===
2
3
  upscale_mlp: true
3
4
  upscale_attn: true
4
5
  # scaling factor for the remaining parameters
5
6
  init_lambda: 0.3
6
7
  router_hidden_layers: 2
8
+ #! === Training Arguments ===
7
9
  lr: 1e-5
8
10
  num_steps: 8000
9
11
  save_interval: 2000
10
12
  alpha: 1 # alpha for dirichlet, if alpha=1, then it is uniform
11
13
  # load model from this checkpoint
14
+ dataloader_kwargs:
15
+ # per-device batch size
16
+ batch_size: 16
17
+ num_workers: 4
12
18
  checkpoint_path: null
13
-
19
+ #! === Evaluation Arguments ===
14
20
  # evaluation grid
15
21
  eval_grid: true
16
22
  eval_grid_n: 8
17
23
  eval_grid_m: 2
18
-
19
- dataloader_kwargs:
20
- # per-device batch size
21
- batch_size: 16
22
- num_workers: 4
@@ -6,12 +6,10 @@ save_checkpoint: False
6
6
  router_hidden_layers: 1
7
7
  init_lambda: 0.3
8
8
  batch_reduce: true
9
-
10
9
  # device to compute svd
11
10
  svd_accelerator: cuda
12
11
  rank_k: 32 # How many experts are added to the pool per task?
13
- select_k: -1 # How many experts are selected from the pool to merge? Range is (1, rank_k*task_num). In particular -1: All the experts in the pool
14
-
12
+ select_k: -1 # How many experts are selected from the pool to merge? Range is (1, rank_k*task_num). In particular -1: All the experts in the pool
15
13
  # learning rate
16
14
  lr: 1e-4
17
15
  optimizer: adam
@@ -1,5 +1,4 @@
1
1
  _target_: fusion_bench.method.RegMeanAlgorithmForGPT2
2
-
3
2
  # list, regular expression of names of parameters that need to be excluded
4
3
  exclude_param_names_regex: []
5
4
  # numbers of examples to compute regmean weights
@@ -1,6 +1,4 @@
1
1
  _target_: fusion_bench.method.SlerpMergeAlgorithm
2
-
3
2
  t: 0.5 # interpolation factor
4
-
5
3
  DOT_THRESHOLD: 0.9995
6
4
  epsilon: 1e-8
@@ -17,4 +17,4 @@ sparsity_ratio: 0.5
17
17
  n: 2
18
18
  m: 4
19
19
  # string to specify the path to where the pruned model is saved
20
- model_save_path: null
20
+ model_save_path: null
@@ -17,4 +17,4 @@ sparsity_ratio: 0.5
17
17
  n: 2
18
18
  m: 4
19
19
  # string to specify the path to where the pruned model is saved
20
- model_save_path: null
20
+ model_save_path: null
@@ -16,4 +16,4 @@ sparsity_ratio: 0.5
16
16
  n: 2
17
17
  m: 4
18
18
  # string to specify the path to where the pruned model is saved
19
- model_save_path: null
19
+ model_save_path: null
@@ -21,7 +21,6 @@ fast_dev_run: ${fast_dev_run}
21
21
  # the path for saving the merging weights
22
22
  save_merging_weights: 'merging_weights.pt'
23
23
  cache_dir: outputs
24
-
25
24
  # parameters of Surgery
26
25
  eval_iterations: 200
27
- surgery_steps: 1000
26
+ surgery_steps: 1000
@@ -1,2 +1,2 @@
1
1
  _target_: fusion_bench.method.TaskArithmeticAlgorithm
2
- scaling_factor: 0.5
2
+ scaling_factor: 0.3
@@ -1,6 +1,5 @@
1
1
  _target_: fusion_bench.method.TaskSingularVectorMerging
2
2
  remove_keys: null
3
-
4
3
  # alpha is a float or a list of floats
5
4
  # example:
6
5
  # alpha: 1
@@ -1,6 +1,6 @@
1
1
  _target_: fusion_bench.method.TiesMergingAlgorithm
2
2
  # Scaling factor $\lambda$
3
- scaling_factor: 0.5
3
+ scaling_factor: 0.3
4
4
  threshold: 20
5
5
  # List of keys to remove from the state dict, default is empty
6
6
  remove_keys: []
@@ -1,5 +1,4 @@
1
1
  _target_: fusion_bench.method.trust_region.clip_task_arithmetic.TaskArithmeticWithTrustRegionForCLIP
2
-
3
2
  scaling_factor: 0.3
4
3
  threshold_quantile: 0.99
5
4
  max_samples: 128
@@ -1,32 +1,25 @@
1
1
  name: ??? # this can be sparse_clip_weight_ensembling_moe
2
-
3
2
  # the path for loading the model weights, if specified, skip the test-time adaptation training
4
3
  #checkpoint: /home/enneng/fusion_bench/outputs/sparse_we_moe/shared_gate/1routerlayer_clip-vit-base-patch32_TA8_sparse_weight_ensembling_moe_checkpoint_0_0.ckpt
5
4
  checkpoint: False
6
5
  # the path for saving the model weights.
7
6
  save_checkpoint: False
8
-
9
7
  # router
10
8
  router_hidden_layers: 2
11
9
  init_lambda: 0.3
12
10
  batch_reduce: true
13
-
14
11
  # sparse task vectors
15
12
  tv_prune_ratio: 0.9
16
-
17
13
  # sparse gate module
18
14
  post_sparse_gate: False
19
15
  gate_prune_ratio: 0.0
20
-
21
16
  # shared gate
22
17
  shared_gate: true
23
18
  position_encoding: false
24
19
  position_encoding_dim: 8
25
-
26
20
  # tta learning rate
27
21
  lr: 1e-4
28
22
  optimizer: adam
29
-
30
23
  # this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
31
24
  devices: 1
32
25
  batch_size: 16
@@ -34,6 +27,5 @@ num_workers: 16
34
27
  max_steps: 1000
35
28
  # if true, we will use the gradient accumulation across tasks to save memory
36
29
  use_grad_accumulate: false
37
-
38
30
  cache_dir: outputs
39
31
  fast_dev_run: ${fast_dev_run}
@@ -1 +1 @@
1
- cifar10: tanganke/clip-vit-base-patch16_cifar10
1
+ cifar10: tanganke/clip-vit-base-patch16_cifar10
@@ -1 +1 @@
1
- _pretrained_: openai/clip-vit-large-patch14
1
+ _pretrained_: openai/clip-vit-large-patch14
@@ -1 +1 @@
1
- oxford-iiit-pet: tanganke/clip-vit-large-patch14_oxford-iiit-pet
1
+ oxford-iiit-pet: tanganke/clip-vit-large-patch14_oxford-iiit-pet
@@ -1 +1 @@
1
- oxford_flowers102: tanganke/clip-vit-large-patch14_oxford_flowers102
1
+ oxford_flowers102: tanganke/clip-vit-large-patch14_oxford_flowers102
@@ -1 +1 @@
1
- pcam: tanganke/clip-vit-large-patch14_pcam
1
+ pcam: tanganke/clip-vit-large-patch14_pcam
@@ -1 +1 @@
1
- rendered-sst2: tanganke/clip-vit-large-patch14_rendered-sst2
1
+ rendered-sst2: tanganke/clip-vit-large-patch14_rendered-sst2
@@ -1 +1 @@
1
- resisc45: tanganke/clip-vit-large-patch14_resisc45
1
+ resisc45: tanganke/clip-vit-large-patch14_resisc45
@@ -1 +1 @@
1
- stanford-cars: tanganke/clip-vit-large-patch14_stanford-cars
1
+ stanford-cars: tanganke/clip-vit-large-patch14_stanford-cars
@@ -1 +1 @@
1
- stl10: tanganke/clip-vit-large-patch14_stl10
1
+ stl10: tanganke/clip-vit-large-patch14_stl10
@@ -1 +1 @@
1
- sun397: tanganke/clip-vit-large-patch14_sun397
1
+ sun397: tanganke/clip-vit-large-patch14_sun397
@@ -1 +1 @@
1
- svhn: tanganke/clip-vit-large-patch14_svhn
1
+ svhn: tanganke/clip-vit-large-patch14_svhn
@@ -1,5 +1,4 @@
1
1
  _target_: fusion_bench.modelpool.CLIPVisionModelPool
2
-
3
2
  models:
4
3
  _pretrained_:
5
4
  _target_: fusion_bench.models.linearized.vision_model.load_fft_vision_model_hf
@@ -44,10 +43,8 @@ models:
44
43
  base_model_name: openai/clip-vit-base-patch16
45
44
  peft_name: tanganke/clip-vit-base-patch16_dtd_lora-16
46
45
  merge_and_unload: true
47
-
48
46
  processor:
49
47
  _target_: transformers.CLIPProcessor.from_pretrained
50
48
  pretrained_model_name_or_path: openai/clip-vit-base-patch16
51
-
52
49
  train_datasets: null
53
50
  test_datasets: null
@@ -6,14 +6,11 @@
6
6
  # ...
7
7
  defaults:
8
8
  - CLIPVisionModelPool@: _template
9
-
10
9
  models:
11
10
  _pretrained_:
12
11
  _target_: transformers.CLIPVisionModel.from_pretrained
13
12
  pretrained_model_name_or_path: ${...base_model}
14
-
15
13
  processor:
16
14
  _target_: transformers.CLIPProcessor.from_pretrained
17
15
  pretrained_model_name_or_path: ${..base_model}
18
-
19
16
  base_model: openai/clip-vit-base-patch16
@@ -1,14 +1,11 @@
1
1
  _target_: fusion_bench.modelpool.CLIPVisionModelPool
2
-
3
2
  models:
4
3
  sun397:
5
4
  _target_: fusion_bench.models.linearized.vision_model.load_lora_vision_model_hf
6
5
  base_model_name: openai/clip-vit-base-patch16
7
6
  peft_name: tanganke/clip-vit-base-patch16_sun397_lora-16
8
-
9
7
  processor:
10
8
  _target_: transformers.CLIPProcessor.from_pretrained
11
9
  pretrained_model_name_or_path: openai/clip-vit-base-patch16
12
-
13
10
  train_datasets: null
14
11
  test_datasets: null
@@ -12,13 +12,10 @@ defaults:
12
12
  - clip-vit-base-patch32_dtd
13
13
  - /dataset/image_classification/train@train_datasets:
14
14
  - tiny-imagenet
15
-
16
15
  _target_: fusion_bench.modelpool.CLIPVisionModelPool
17
16
  _recursive_: false
18
-
19
17
  models: ???
20
18
  train_datasets: ???
21
-
22
19
  processor:
23
20
  _target_: transformers.CLIPProcessor.from_pretrained
24
21
  pretrained_model_name_or_path: openai/clip-vit-base-patch32
@@ -1,13 +1,10 @@
1
1
  defaults:
2
2
  - CLIPVisionModelPool@: _template
3
-
4
3
  models:
5
4
  _pretrained_:
6
5
  _target_: transformers.CLIPVisionModel.from_pretrained
7
6
  pretrained_model_name_or_path: ${...base_model}
8
-
9
7
  processor:
10
8
  _target_: transformers.CLIPProcessor.from_pretrained
11
9
  pretrained_model_name_or_path: ${..base_model}
12
-
13
10
  base_model: openai/clip-vit-base-patch32
@@ -3,13 +3,10 @@ defaults:
3
3
  - clip-vit-base-patch32
4
4
  - clip-vit-base-patch32_sun397
5
5
  - clip-vit-base-patch32_stanford-cars
6
-
7
6
  _target_: fusion_bench.modelpool.CLIPVisionModelPool
8
7
  _recursive_: false
9
-
10
8
  train_datasets: null
11
9
  test_datasets: null
12
-
13
10
  processor:
14
11
  _target_: transformers.CLIPProcessor.from_pretrained
15
12
  pretrained_model_name_or_path: openai/clip-vit-base-patch32
@@ -2,17 +2,13 @@ defaults:
2
2
  - _self_
3
3
  - /dataset/image_classification/train@train_datasets:
4
4
  - tiny-imagenet
5
-
6
5
  _target_: fusion_bench.modelpool.CLIPVisionModelPool
7
6
  _recursive_: false
8
-
9
7
  models:
10
8
  _pretrained_: openai/clip-vit-base-patch32
11
9
  model_1: tanganke/clip-vit-base-patch32_sun397
12
10
  model_2: tanganke/clip-vit-base-patch32_stanford-cars
13
-
14
11
  train_datasets: ???
15
-
16
12
  processor:
17
13
  _target_: transformers.CLIPProcessor.from_pretrained
18
14
  pretrained_model_name_or_path: openai/clip-vit-base-patch32
@@ -6,14 +6,11 @@
6
6
  # ...
7
7
  defaults:
8
8
  - CLIPVisionModelPool@: _template
9
-
10
9
  models:
11
10
  _pretrained_:
12
11
  _target_: transformers.CLIPVisionModel.from_pretrained
13
12
  pretrained_model_name_or_path: ${...base_model}
14
-
15
13
  processor:
16
14
  _target_: transformers.CLIPProcessor.from_pretrained
17
15
  pretrained_model_name_or_path: ${..base_model}
18
-
19
16
  base_model: openai/clip-vit-large-patch14
@@ -1,17 +1,13 @@
1
1
  _target_: fusion_bench.modelpool.CausalLMPool
2
-
3
2
  pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
-
5
3
  models:
6
4
  _pretrained_:
7
5
  _target_: transformers.AutoModelForCausalLM.from_pretrained
8
6
  pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
7
  torch_dtype: bfloat16
10
-
11
8
  tokenizer:
12
9
  _target_: transformers.AutoTokenizer.from_pretrained
13
10
  pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
-
15
11
  train_datasets:
16
12
  alpaca-cleaned:
17
13
  _target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
@@ -1,17 +1,13 @@
1
1
  _target_: fusion_bench.modelpool.CausalLMPool
2
-
3
2
  pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
-
5
3
  models:
6
4
  _pretrained_:
7
5
  _target_: transformers.AutoModelForCausalLM.from_pretrained
8
6
  pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
7
  torch_dtype: bfloat16
10
-
11
8
  tokenizer:
12
9
  _target_: transformers.AutoTokenizer.from_pretrained
13
10
  pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
-
15
11
  train_datasets:
16
12
  codealpaca:
17
13
  _target_: fusion_bench.dataset.llama.alpaca.load_tokenized_alpaca_dataset
@@ -12,7 +12,6 @@ models:
12
12
  expert_2:
13
13
  _target_: transformers.LlamaForCausalLM.from_pretrained
14
14
  pretrained_model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
15
-
16
15
  model_kwargs:
17
16
  torch_dtype: float16
18
17
  tokenizer:
@@ -1,17 +1,13 @@
1
1
  _target_: fusion_bench.modelpool.CausalLMPool
2
-
3
2
  pretrained_model_name_or_path: meta-llama/Llama-3.2-1B-Instruct
4
-
5
3
  models:
6
4
  _pretrained_:
7
5
  _target_: transformers.AutoModelForCausalLM.from_pretrained
8
6
  pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
7
  torch_dtype: bfloat16
10
-
11
8
  tokenizer:
12
9
  _target_: transformers.AutoTokenizer.from_pretrained
13
10
  pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
-
15
11
  train_datasets:
16
12
  metamathqa:
17
13
  _target_: fusion_bench.dataset.llama.metamathqa.load_tokenized_metamathqa
@@ -1,17 +1,13 @@
1
1
  _target_: fusion_bench.modelpool.CausalLMPool
2
-
3
2
  pretrained_model_name_or_path: meta-llama/Llama-3-1B-Instruct
4
-
5
3
  models:
6
4
  _pretrained_:
7
5
  _target_: transformers.AutoModelForCausalLM.from_pretrained
8
6
  pretrained_model_name_or_path: ${...pretrained_model_name_or_path}
9
7
  torch_dtype: bfloat16
10
-
11
8
  tokenizer:
12
9
  _target_: transformers.AutoTokenizer.from_pretrained
13
10
  pretrained_model_name_or_path: ${..pretrained_model_name_or_path}
14
-
15
11
  train_datasets:
16
12
  ultrachat-200k:
17
13
  _target_: fusion_bench.dataset.llama.ultrachat.load_tokenized_ultrachat_200k
@@ -13,7 +13,6 @@ models:
13
13
  expert_3:
14
14
  _target_: transformers.AutoModelForCausalLM.from_pretrained
15
15
  pretrained_model_name_or_path: uukuguy/speechless-code-mistral-7b-v1.0
16
-
17
16
  model_kwargs:
18
17
  torch_dtype: float16
19
18
  tokenizer:
@@ -6,12 +6,9 @@ models:
6
6
  _pretrained_:
7
7
  _target_: transformers.LlamaForCausalLM.from_pretrained
8
8
  pretrained_model_name_or_path: ${...base_model}
9
-
10
9
  model_kwargs:
11
10
  torch_dtype: float16
12
-
13
11
  tokenizer:
14
12
  _target_: transformers.AutoTokenizer.from_pretrained
15
13
  pretrained_model_name_or_path: ${..base_model}
16
-
17
14
  base_model: decapoda-research/llama-7b-hf