fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from torch import Tensor, nn
13
13
  from tqdm.autonotebook import tqdm
14
14
 
15
15
  from fusion_bench.method import BaseAlgorithm
16
+ from fusion_bench.mixins import SimpleProfilerMixin
16
17
  from fusion_bench.modelpool import BaseModelPool
17
18
 
18
19
  log = logging.getLogger(__name__)
@@ -279,7 +280,7 @@ def regmean_merging(
279
280
  return merged_params
280
281
 
281
282
 
282
- class RegMeanAlgorithm(BaseAlgorithm):
283
+ class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
283
284
  _include_module_type = [nn.Linear]
284
285
  _config_mapping = {
285
286
  "num_regmean_examples": "num_regmean_examples",
@@ -342,24 +343,31 @@ class RegMeanAlgorithm(BaseAlgorithm):
342
343
  )
343
344
  assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
344
345
 
345
- regmean_weights = self.get_regmean_weights(
346
- name,
347
- model,
348
- train_dataset=modelpool.load_train_dataset(name),
349
- linear_modules_to_merge=linear_modules_to_merge,
350
- )
351
- models_to_merge_regmean_weights_list.append(regmean_weights)
346
+ with (
347
+ self.profile("merging models"),
348
+ self.profile("computing regmean weights"),
349
+ ):
350
+ regmean_weights = self.get_regmean_weights(
351
+ name,
352
+ model,
353
+ train_dataset=modelpool.load_train_dataset(name),
354
+ linear_modules_to_merge=linear_modules_to_merge,
355
+ )
356
+ models_to_merge_regmean_weights_list.append(regmean_weights)
357
+
358
+ with self.profile("merging models"):
359
+ # merging with regmean weights
360
+ merged_params = merging_with_regmean_weights(
361
+ models_to_merge_param_dict=models_to_merge_param_dict,
362
+ models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
363
+ reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
364
+ weight_transpose=self.config.get("weight_transpose", True),
365
+ )
352
366
 
353
- # merging with regmean weights
354
- merged_params = merging_with_regmean_weights(
355
- models_to_merge_param_dict=models_to_merge_param_dict,
356
- models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
357
- reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
358
- weight_transpose=self.config.get("weight_transpose", True),
359
- )
367
+ merged_model = modelpool.load_model("_pretrained_")
368
+ merged_model.load_state_dict(merged_params, strict=False)
360
369
 
361
- merged_model = modelpool.load_model("_pretrained_")
362
- merged_model.load_state_dict(merged_params, strict=False)
370
+ self.print_profile_summary()
363
371
  return merged_model
364
372
 
365
373
  def on_regmean_start(self):
@@ -1,3 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .singular_projection_merging import SingularProjectionMergingAlgorithm
3
- from .smile_upscaling import SmileUpscalingAlgorithm
3
+ from .smile_upscaling import SmileMoELinear, SmileUpscalingAlgorithm
@@ -442,16 +442,19 @@ class SmileUpscalingAlgorithm(
442
442
  print_parameters(model)
443
443
  return model
444
444
 
445
- with self.profile("load pretrained model"):
446
- pretrained_model = modelpool.load_model("_pretrained_")
447
- with self.profile("load fine-tuned model"):
448
- finetuned_models = [
449
- m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
450
- ]
451
-
452
- if self.config.device == "cuda" and torch.cuda.is_available():
453
- pretrained_model = pretrained_model.cuda()
454
- finetuned_models = [m.cuda() for m in finetuned_models]
445
+ with self.profile("loading model"):
446
+ # load models and move to GPU if available
447
+ with self.profile("load pretrained model"):
448
+ pretrained_model = modelpool.load_model("_pretrained_")
449
+ with self.profile("load fine-tuned model"):
450
+ finetuned_models = [
451
+ m
452
+ for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
453
+ ]
454
+
455
+ if self.config.device == "cuda" and torch.cuda.is_available():
456
+ pretrained_model = pretrained_model.cuda()
457
+ finetuned_models = [m.cuda() for m in finetuned_models]
455
458
 
456
459
  with self.profile("merge model"):
457
460
  model = self.merge(pretrained_model, finetuned_models)
@@ -85,7 +85,14 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
85
85
 
86
86
  if self.config.weights is not None:
87
87
  # skip the test-time adaptation
88
+ merge_weight: torch.Tensor = torch.load(self.config.weights)
89
+ module.merge_weight.data = merge_weight.to(
90
+ device=module.merge_weight.device
91
+ )
88
92
  merged_model = copy.deepcopy(module.merge_and_unload())
93
+ # setup the zero-shot classification head
94
+ self.on_test_time_adaptation_start()
95
+
89
96
  else:
90
97
  with self.profile("test-time adaptation"):
91
98
  module = self.test_time_adaptation(module)
@@ -6,7 +6,7 @@ http://arxiv.org/abs/2212.04089
6
6
 
7
7
  import logging
8
8
  from copy import deepcopy
9
- from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
9
+ from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
10
10
 
11
11
  import torch
12
12
  from torch import nn
@@ -19,18 +19,18 @@ from fusion_bench.utils.state_dict_arithmetic import (
19
19
  state_dict_mul,
20
20
  state_dict_sub,
21
21
  )
22
- from fusion_bench.utils.type import StateDictType
22
+ from fusion_bench.utils.type import StateDictType, TorchModelType
23
23
 
24
24
  log = logging.getLogger(__name__)
25
25
 
26
26
 
27
27
  @torch.no_grad()
28
28
  def task_arithmetic_merge(
29
- pretrained_model: nn.Module,
30
- finetuned_models: List[nn.Module],
29
+ pretrained_model: TorchModelType,
30
+ finetuned_models: List[TorchModelType],
31
31
  scaling_factor: float,
32
32
  inplace: bool = True,
33
- ) -> nn.Module:
33
+ ) -> TorchModelType:
34
34
  """
35
35
  Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
36
36
 
@@ -46,15 +46,17 @@ def task_arithmetic_merge(
46
46
  """
47
47
  if not inplace:
48
48
  pretrained_model = deepcopy(pretrained_model)
49
- task_vector: StateDictType = None
49
+ task_vector: Optional[StateDictType] = None
50
50
  # Calculate the total task vector
51
51
  for model in finetuned_models:
52
52
  if task_vector is None:
53
+ # calculate the task vector for the first model
53
54
  task_vector = state_dict_sub(
54
55
  model.state_dict(keep_vars=True),
55
56
  pretrained_model.state_dict(keep_vars=True),
56
57
  )
57
58
  else:
59
+ # calculate the task vector for the remaining models
58
60
  task_vector = state_dict_add(
59
61
  task_vector,
60
62
  state_dict_sub(
@@ -16,6 +16,7 @@ from torch import Tensor, nn
16
16
 
17
17
  from fusion_bench.compat.modelpool import to_modelpool
18
18
  from fusion_bench.method import BaseAlgorithm
19
+ from fusion_bench.mixins import SimpleProfilerMixin
19
20
  from fusion_bench.modelpool import BaseModelPool
20
21
  from fusion_bench.utils.type import StateDictType
21
22
 
@@ -24,7 +25,7 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
24
25
  log = logging.getLogger(__name__)
25
26
 
26
27
 
27
- class TiesMergingAlgorithm(BaseAlgorithm):
28
+ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
28
29
  """
29
30
  TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
30
31
 
@@ -84,34 +85,38 @@ class TiesMergingAlgorithm(BaseAlgorithm):
84
85
  scaling_factor = self.scaling_factor
85
86
  threshold = self.threshold
86
87
 
87
- # Load the pretrained model
88
- pretrained_model = modelpool.load_model("_pretrained_")
89
-
90
- # Load the state dicts of the models
91
- ft_checks: List[StateDictType] = [
92
- modelpool.load_model(model_name).state_dict(keep_vars=True)
93
- for model_name in modelpool.model_names
94
- ]
95
- ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
96
-
97
- # Compute the task vectors
98
- flat_ft: Tensor = torch.vstack(
99
- [state_dict_to_vector(check, remove_keys) for check in ft_checks]
100
- )
101
- flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
102
- tv_flat_checks = flat_ft - flat_ptm
103
-
104
- # Perform TIES Merging
105
- merged_tv = ties_merging(
106
- tv_flat_checks,
107
- reset_thresh=threshold,
108
- merge_func=merge_func,
109
- )
110
- merged_check = flat_ptm + scaling_factor * merged_tv
111
- merged_state_dict = vector_to_state_dict(
112
- merged_check, ptm_check, remove_keys=remove_keys
113
- )
114
-
115
- # Load the merged state dict into the pretrained model
116
- pretrained_model.load_state_dict(merged_state_dict)
88
+ with self.profile("loading models"):
89
+ # Load the pretrained model
90
+ pretrained_model = modelpool.load_model("_pretrained_")
91
+
92
+ # Load the state dicts of the models
93
+ ft_checks: List[StateDictType] = [
94
+ modelpool.load_model(model_name).state_dict(keep_vars=True)
95
+ for model_name in modelpool.model_names
96
+ ]
97
+ ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
98
+
99
+ with self.profile("merging models"):
100
+ # Compute the task vectors
101
+ flat_ft: Tensor = torch.vstack(
102
+ [state_dict_to_vector(check, remove_keys) for check in ft_checks]
103
+ )
104
+ flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
105
+ tv_flat_checks = flat_ft - flat_ptm
106
+
107
+ # Perform TIES Merging
108
+ merged_tv = ties_merging(
109
+ tv_flat_checks,
110
+ reset_thresh=threshold,
111
+ merge_func=merge_func,
112
+ )
113
+ merged_check = flat_ptm + scaling_factor * merged_tv
114
+ merged_state_dict = vector_to_state_dict(
115
+ merged_check, ptm_check, remove_keys=remove_keys
116
+ )
117
+
118
+ # Load the merged state dict into the pretrained model
119
+ pretrained_model.load_state_dict(merged_state_dict)
120
+
121
+ self.print_profile_summary()
117
122
  return pretrained_model
@@ -5,7 +5,6 @@ from typing import cast # noqa: F401
5
5
  import lightning as L
6
6
  import lightning.fabric.wrappers
7
7
  import torch
8
- from lightning.pytorch.profilers import SimpleProfiler
9
8
  from omegaconf import DictConfig
10
9
  from torch import Tensor
11
10
  from torch.utils.data import DataLoader
@@ -13,6 +12,7 @@ from tqdm.autonotebook import tqdm
13
12
 
14
13
  from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
15
14
  from fusion_bench.compat.modelpool import ModelPool
15
+ from fusion_bench.mixins import SimpleProfilerMixin
16
16
  from fusion_bench.models.we_moe import WeightEnsemblingMoE
17
17
  from fusion_bench.utils import timeit_context
18
18
  from fusion_bench.utils.parameters import print_parameters
@@ -34,7 +34,10 @@ def entropy_loss(logits: Tensor) -> Tensor:
34
34
  return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
35
35
 
36
36
 
37
- class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
37
+ class WeightEnsemblingMoEAlgorithm(
38
+ ModelFusionAlgorithm,
39
+ SimpleProfilerMixin,
40
+ ):
38
41
  """
39
42
  Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
40
43
 
@@ -44,7 +47,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
44
47
  Attributes:
45
48
  _fabric (L.Fabric): The fabric for distributed training.
46
49
  modelpool (ModelPool): The pool of models to be fused.
47
- profiler (SimpleProfiler): The profiler for measuring performance.
48
50
  """
49
51
 
50
52
  _fabric: L.Fabric = None
@@ -66,9 +68,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
66
68
  self._fabric.launch()
67
69
  else:
68
70
  assert "No CUDA device available."
69
- self.profiler = SimpleProfiler(
70
- self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
71
- )
72
71
 
73
72
  @abstractmethod
74
73
  def load_checkpoint(self, model, checkpoint):
@@ -177,9 +176,9 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
177
176
  for step_idx in pbar:
178
177
  if self.config.use_grad_accumulate:
179
178
  for task in self.modelpool.model_names:
180
- with self.profiler.profile("data time"):
179
+ with self.profile("data time"):
181
180
  batch = next(self.get_shuffled_test_loader_iter(task))
182
- with self.profiler.profile("forward pass"):
181
+ with self.profile("forward pass"):
183
182
  logits = self.compute_logits(module, batch, task)
184
183
  assert (
185
184
  logits.dim() == 2
@@ -187,23 +186,23 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
187
186
  loss = entropy_loss(logits)
188
187
  # .backward() accumulates when .zero_grad() wasn't called
189
188
  # this can save memory
190
- with self.profiler.profile("backward pass"):
189
+ with self.profile("backward pass"):
191
190
  self._fabric.backward(loss, retain_graph=True)
192
191
  else:
193
192
  loss = 0
194
193
  for task in self.modelpool.model_names:
195
- with self.profiler.profile("data time"):
194
+ with self.profile("data time"):
196
195
  batch = next(self.get_shuffled_test_loader_iter(task))
197
- with self.profiler.profile("forward pass"):
196
+ with self.profile("forward pass"):
198
197
  logits = self.compute_logits(module, batch, task)
199
198
  assert (
200
199
  logits.dim() == 2
201
200
  ), f"Expected logits to be 2D, got {logits.dim()}"
202
201
  loss = loss + entropy_loss(logits)
203
- with self.profiler.profile("backward pass"):
202
+ with self.profile("backward pass"):
204
203
  self._fabric.backward(loss, retain_graph=True)
205
204
 
206
- with self.profiler.profile("optimizer step"):
205
+ with self.profile("optimizer step"):
207
206
  optimizer.step()
208
207
  optimizer.zero_grad()
209
208
 
@@ -232,7 +231,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
232
231
  )
233
232
  self.load_checkpoint(moe_model, self.config.checkpoint)
234
233
  else:
235
- with self.profiler.profile("test-time adaptation"):
234
+ with self.profile("test-time adaptation"):
236
235
  moe_model = self.test_time_adaptation(moe_model)
237
236
  if self.config.get("save_checkpoint", False):
238
237
  log.info(f"save checkpoint to {self.config.save_checkpoint}")
@@ -243,5 +242,5 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
243
242
 
244
243
  # enable sample-wise adaptation
245
244
  moe_model.batch_reduce = False
246
- print(self.profiler.summary())
245
+ self.print_profile_summary()
247
246
  return moe_model
@@ -6,20 +6,23 @@ from typing_extensions import TYPE_CHECKING
6
6
  from fusion_bench.utils.lazy_imports import LazyImporter
7
7
 
8
8
  _import_structure = {
9
+ "clip_classification": ["CLIPClassificationMixin"],
10
+ "fabric_training": ["FabricTrainingMixin"],
11
+ "hydra_config": ["HydraConfigMixin"],
9
12
  "lightning_fabric": ["LightningFabricMixin"],
13
+ "openclip_classification": ["OpenCLIPClassificationMixin"],
10
14
  "serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
11
15
  "simple_profiler": ["SimpleProfilerMixin"],
12
- "clip_classification": ["CLIPClassificationMixin"],
13
- "fabric_training": ["FabricTrainingMixin"],
14
16
  }
15
17
 
16
18
  if TYPE_CHECKING:
17
19
  from .clip_classification import CLIPClassificationMixin
18
20
  from .fabric_training import FabricTrainingMixin
21
+ from .hydra_config import HydraConfigMixin
19
22
  from .lightning_fabric import LightningFabricMixin
23
+ from .openclip_classification import OpenCLIPClassificationMixin
20
24
  from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
21
25
  from .simple_profiler import SimpleProfilerMixin
22
-
23
26
  else:
24
27
  sys.modules[__name__] = LazyImporter(
25
28
  __name__,
@@ -0,0 +1,49 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import hydra.core.global_hydra
8
+ from hydra import compose, initialize
9
+ from omegaconf import DictConfig, OmegaConf
10
+
11
+ from fusion_bench.utils import import_object, instantiate
12
+ from fusion_bench.utils.instantiate import set_print_function_call
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class HydraConfigMixin:
18
+ """
19
+ A mixin for classes that need to be instantiated from a config file.
20
+ """
21
+
22
+ @classmethod
23
+ def from_config(
24
+ cls,
25
+ config_name: Union[str, Path],
26
+ overrides: Optional[List[str]] = None,
27
+ ):
28
+ if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
29
+ raise RuntimeError("Hydra is not initialized.")
30
+ else:
31
+ cfg = compose(config_name=config_name, overrides=overrides)
32
+
33
+ config_groups = config_name.split("/")[:-1]
34
+ for config_group in config_groups:
35
+ cfg = cfg[config_group]
36
+
37
+ if "_target_" in cfg:
38
+ # if the config has a _target_ key, check if it is equal to the class name
39
+ target_cls = import_object(cfg["_target_"])
40
+ if target_cls != cls:
41
+ log.warning(
42
+ f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
43
+ )
44
+ with set_print_function_call(False):
45
+ obj = instantiate(cfg)
46
+ else:
47
+ obj = cls(**cfg)
48
+
49
+ return obj
@@ -0,0 +1,11 @@
1
+ import logging
2
+
3
+ from fusion_bench.mixins import LightningFabricMixin
4
+ from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
5
+
6
+ log = logging.getLogger(__name__)
7
+
8
+
9
+ class OpenCLIPClassificationMixin(LightningFabricMixin):
10
+ _train_processor = None
11
+ _test_processor = None
@@ -1,5 +1,5 @@
1
1
  from contextlib import contextmanager
2
- from typing import Generator
2
+ from typing import Generator, Optional
3
3
 
4
4
  from lightning.fabric.utilities.rank_zero import rank_zero_only
5
5
  from lightning.pytorch.profilers import SimpleProfiler
@@ -70,7 +70,9 @@ class SimpleProfilerMixin:
70
70
  self.profiler.stop(action_name)
71
71
 
72
72
  @rank_zero_only
73
- def print_profile_summary(self):
73
+ def print_profile_summary(self, title: Optional[str] = None):
74
+ if title is not None:
75
+ print(title)
74
76
  print(self.profiler.summary())
75
77
 
76
78
  def __del__(self):
@@ -6,12 +6,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
6
6
 
7
7
  _import_structure = {
8
8
  "base_pool": ["BaseModelPool"],
9
+ "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
9
10
  "clip_vision": ["CLIPVisionModelPool"],
10
11
  "nyuv2_modelpool": ["NYUv2ModelPool"],
11
12
  "huggingface_automodel": ["AutoModelPool"],
12
- "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
13
13
  "seq2seq_lm": ["Seq2SeqLMPool"],
14
14
  "PeftModelForSeq2SeqLM": ["PeftModelForSeq2SeqLMPool"],
15
+ "openclip_vision": ["OpenCLIPVisionModelPool"],
15
16
  "huggingface_gpt2_classification": [
16
17
  "HuggingFaceGPT2ClassificationPool",
17
18
  "GPT2ForSequenceClassificationPool",
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
30
31
  HuggingFaceGPT2ClassificationPool,
31
32
  )
32
33
  from .nyuv2_modelpool import NYUv2ModelPool
34
+ from .openclip_vision import OpenCLIPVisionModelPool
33
35
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
34
36
  from .seq2seq_lm import Seq2SeqLMPool
35
37
  from .seq_classification_lm import SeqenceClassificationModelPool
@@ -7,7 +7,7 @@ from omegaconf import DictConfig
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench.mixins import BaseYAMLSerializableModel
10
+ from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
11
11
  from fusion_bench.utils import instantiate, timeit_context
12
12
 
13
13
  __all__ = ["BaseModelPool"]
@@ -15,7 +15,7 @@ __all__ = ["BaseModelPool"]
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class BaseModelPool(BaseYAMLSerializableModel):
18
+ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
19
19
  """
20
20
  A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
21
21
 
@@ -0,0 +1 @@
1
+ from .modelpool import OpenCLIPVisionModelPool