fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ import copy
2
+ import gc
3
+ import logging
4
+ from abc import abstractmethod
5
+ from typing import List, Mapping, Union # noqa: F401
6
+
7
+ import lightning as L
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from omegaconf import DictConfig
12
+ from torch import Tensor
13
+ from torch.utils.data import DataLoader
14
+ from tqdm.autonotebook import tqdm
15
+
16
+ from fusion_bench.compat.method import ModelFusionAlgorithm
17
+ from fusion_bench.compat.modelpool import ModelPool
18
+ from fusion_bench.models.wrappers.task_wise_fusion import (
19
+ TaskWiseMergedModel,
20
+ get_task_wise_weights,
21
+ )
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ # obtain the current GPU memory usage
27
+ def print_memory_usage(desc):
28
+ print(desc)
29
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
30
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
31
+ print(f"Allocated Memory: {allocated:.2f} MB")
32
+ print(f"Cached Memory: {cached:.2f} MB")
33
+
34
+
35
+ def entropy_loss(logits: Tensor) -> Tensor:
36
+ """
37
+ Compute the entropy loss of a set of logits.
38
+
39
+ Args:
40
+ logits (Tensor): The logits to compute the entropy loss of.
41
+
42
+ Returns:
43
+ Tensor: The entropy loss of the logits.
44
+ """
45
+ probs = torch.softmax(logits, dim=-1)
46
+ return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
47
+
48
+
49
+ class ModelScheduler:
50
+ """
51
+ Manage the storage of models, schedule the order in which models are loaded to GPU
52
+ transfer data between the CPU and GPU
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ modelpool: ModelPool,
58
+ config: DictConfig,
59
+ ):
60
+ self.pretrained_model = modelpool.load_model("_pretrained_")
61
+ self.finetuned_models = [
62
+ modelpool.load_model(name) for name in modelpool.model_names
63
+ ]
64
+ self.num_finetuned_models = len(self.finetuned_models)
65
+ self.new_finetuned_models = copy.deepcopy(self.finetuned_models)
66
+ self.finetuned_model_names = [name for name in modelpool.model_names]
67
+
68
+ self.config = config
69
+
70
+ @torch.no_grad() # not sure whether to use this
71
+ def __call__(self, model_id):
72
+ """
73
+ return models and relevant data in each step
74
+ """
75
+ # TODO: use a mixing matrix to determine which models to use in step idx
76
+
77
+ pretrained_model = copy.deepcopy(self.finetuned_models[model_id])
78
+ finetuned_models = [
79
+ copy.deepcopy(
80
+ self.finetuned_models[(model_id + 1) % self.num_finetuned_models]
81
+ ),
82
+ copy.deepcopy(
83
+ self.finetuned_models[(model_id - 1) % self.num_finetuned_models]
84
+ ),
85
+ ]
86
+
87
+ if self.config.weights is None:
88
+ task_wise_weight = get_task_wise_weights(
89
+ num_models=len(finetuned_models),
90
+ init_values=self.config.init_values,
91
+ )
92
+ else:
93
+ pass
94
+
95
+ module = TaskWiseMergedModel(
96
+ task_wise_weight=task_wise_weight,
97
+ pretrained_model=pretrained_model,
98
+ finetuned_models=finetuned_models,
99
+ clamp_weights=self.config.clamp_weights,
100
+ tie_weights=self.config.tie_weights,
101
+ strict=self.config.strict,
102
+ )
103
+ return module
104
+
105
+ def store_model(self, new_finetuned_model_dict, model_id):
106
+ """
107
+ store new finetuned model after every turn of adamerging
108
+ """
109
+ self.new_finetuned_models[model_id].load_state_dict(new_finetuned_model_dict)
110
+
111
+ def update_models(self):
112
+ self.finetuned_models = copy.deepcopy(self.new_finetuned_models)
113
+
114
+ def get_final_models(self):
115
+ # need a check
116
+ final_models = [
117
+ {"name": name, "model": model}
118
+ for name, model in zip(self.finetuned_model_names, self.finetuned_models)
119
+ ]
120
+ num_finetuned_models = len(self.finetuned_models)
121
+
122
+ state_dict = self.pretrained_model.state_dict(keep_vars=True)
123
+ for name in state_dict.keys():
124
+ state_dict[name].data.zero_()
125
+ for model in self.finetuned_models:
126
+ for name, param in model.named_parameters():
127
+ state_dict[name] = state_dict[name] + 1 / num_finetuned_models * param
128
+
129
+ self.pretrained_model.load_state_dict(state_dict)
130
+ final_models += [{"name": "average model", "model": self.pretrained_model}]
131
+
132
+ return final_models
133
+
134
+
135
+ class TaskWiseGossipAlgorithm(ModelFusionAlgorithm):
136
+ _fabric: L.Fabric = None
137
+
138
+ def __init__(self, algorithm_config: DictConfig):
139
+ super().__init__(algorithm_config)
140
+
141
+ if self._fabric is None and torch.cuda.is_available():
142
+ self._fabric = L.Fabric(devices=self.config.get("devices", 1))
143
+ self._fabric.launch()
144
+
145
+ self.optimizer = None # we want to reuse it in Gossip using single GPU
146
+
147
+ def free_gpu_memory(self, module: TaskWiseMergedModel):
148
+ module.pretrained_model.to("cpu")
149
+ for model in module.task_vectors:
150
+ model.to("cpu")
151
+ del module
152
+ gc.collect()
153
+ torch.cuda.empty_cache()
154
+ print_memory_usage(
155
+ "finish local adamerging, after freeing memory, the memory usage of GPU is:"
156
+ )
157
+
158
+ def run(self, modelpool: ModelPool):
159
+ log.info("Fusing models using task-wise adaptive merging with gossip.")
160
+ self.modelpool = modelpool
161
+ self.num_finetuned_models = len(modelpool.model_names)
162
+
163
+ model_scheduler = ModelScheduler(self.modelpool, self.config)
164
+
165
+ pbar = tqdm(
166
+ range(self.config.gossip_max_steps), "Gossip merging", dynamic_ncols=True
167
+ )
168
+ for step_idx in pbar:
169
+ log.info(f"step: {step_idx}")
170
+ for model_id in tqdm(
171
+ range(self.num_finetuned_models), "local adamerging", dynamic_ncols=True
172
+ ):
173
+ # log.info(f"adamerging model: {model_scheduler.finetuned_midels_name[model_id]}")
174
+ module = model_scheduler(model_id)
175
+ module = self.test_time_adaptation(module)
176
+ # if self.config.get("save_merging_weights", False):
177
+ # torch.save(module.merge_weight, self.config.save_merging_weights)
178
+ print_memory_usage(
179
+ "local adamerging almost done, the memory usage of GPU is:"
180
+ )
181
+ model_scheduler.store_model(module.merge_weights(), model_id)
182
+ print_memory_usage(
183
+ "local adamerging almost done, the memory usage of GPU is:"
184
+ )
185
+ self.free_gpu_memory(
186
+ module
187
+ ) # simulate distributed GPU memory usage as much as possible
188
+
189
+ model_scheduler.update_models()
190
+
191
+ return model_scheduler.get_final_models()
192
+
193
+ def on_test_time_adaptation_start(self):
194
+ pass
195
+
196
+ @abstractmethod
197
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
198
+ pass
199
+
200
+ @abstractmethod
201
+ def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
202
+ """
203
+ Compute the logits for the given batch and task.
204
+
205
+ Args:
206
+ module (nn.Module): The model module.
207
+ batch (tuple): A batch of input data.
208
+ task (str): The name of the task.
209
+
210
+ Returns:
211
+ Tensor: The classification logits for the batch.
212
+ """
213
+ pass
214
+
215
+ def test_time_adaptation(self, module: TaskWiseMergedModel):
216
+ self.on_test_time_adaptation_start()
217
+
218
+ # configure optimizer
219
+ if self.config.optimizer == "adam":
220
+ self.optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
221
+ else:
222
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
223
+
224
+ if self._fabric is not None:
225
+ module, self.optimizer = self._fabric.setup(module, self.optimizer)
226
+ print_memory_usage(
227
+ "load model and optimizer to GPU, the memory usage of GPU is:"
228
+ )
229
+ module.train()
230
+ module.merge_weights()
231
+
232
+ if self.config.get("fast_dev_run", False):
233
+ log.info("Running fast_dev_run, only one step")
234
+ pbar = tqdm(
235
+ range(1),
236
+ "AdaMerging Test-time adaptation",
237
+ dynamic_ncols=True,
238
+ )
239
+ else:
240
+ pbar = tqdm(
241
+ range(self.config.max_steps),
242
+ "AdaMerging Test-time adaptation",
243
+ dynamic_ncols=True,
244
+ )
245
+ for step_idx in pbar:
246
+ for task in self.modelpool.model_names:
247
+ batch = next(self.get_shuffled_test_loader_iter(task))
248
+ logits = self.compute_logits(module, batch, task)
249
+ assert (
250
+ logits.dim() == 2
251
+ ), f"Expected logits to be 2D, got {logits.dim()}"
252
+ loss = entropy_loss(logits)
253
+ # .backward() accumulates when .zero_grad() wasn't called
254
+ # this can save memory
255
+ self._fabric.backward(loss, retain_graph=True)
256
+
257
+ # print_memory_usage('model + dataset: ')
258
+ self.optimizer.step()
259
+ self.optimizer.zero_grad()
260
+ module.merge_weights()
261
+
262
+ del self.optimizer
263
+ gc.collect()
264
+ torch.cuda.empty_cache()
265
+ return module
@@ -0,0 +1,74 @@
1
+ import copy
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ def get_memory_usage(desc):
9
+ """
10
+ obtain the current GPU memory usage
11
+
12
+ Returns:
13
+ str: A string containing the allocated and cached memory in MB.
14
+ """
15
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
16
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
17
+ return (
18
+ f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
19
+ )
20
+
21
+
22
+ # Model conversion utils
23
+
24
+
25
+ def state_dict_to_vector(state_dict, remove_keys=[]):
26
+ """
27
+ Convert a state dictionary to a vector.
28
+
29
+ Args:
30
+ state_dict (dict): The state dictionary to convert.
31
+ remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
32
+
33
+ Returns:
34
+ torch.Tensor: The converted vector.
35
+ """
36
+ shared_state_dict = copy.deepcopy(state_dict)
37
+ for key in remove_keys:
38
+ if key in shared_state_dict:
39
+ del shared_state_dict[key]
40
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
41
+ return nn.utils.parameters_to_vector(
42
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
43
+ )
44
+
45
+
46
+ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
47
+ """
48
+ Convert a vector to a state dictionary.
49
+
50
+ Args:
51
+ vector (torch.Tensor): The vector to convert.
52
+ state_dict (dict): The reference state dictionary to define the order of the vector.
53
+ remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
54
+
55
+ Returns:
56
+ dict: The converted state dictionary.
57
+ """
58
+ # create a reference dict to define the order of the vector
59
+ reference_dict = copy.deepcopy(state_dict)
60
+ for key in remove_keys:
61
+ if key in reference_dict:
62
+ del reference_dict[key]
63
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
64
+
65
+ # create a shared state dict using the reference dict
66
+ nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
67
+
68
+ # add back the encoder and decoder embedding weights.
69
+ if "transformer.shared.weight" in sorted_reference_dict:
70
+ for key in remove_keys:
71
+ sorted_reference_dict[key] = sorted_reference_dict[
72
+ "transformer.shared.weight"
73
+ ]
74
+ return sorted_reference_dict
@@ -3,7 +3,7 @@ This module contains the implementation of the Isotropic Merging in Common Subsp
3
3
  Modified from the original implementation: https://github.com/danielm1405/iso-merging
4
4
 
5
5
  Reference:
6
- - Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
6
+ - Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
7
7
  https://arxiv.org/abs/2502.04959
8
8
  """
9
9
 
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
15
15
  from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
19
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
20
  from fusion_bench.utils import instantiate
21
21
  from fusion_bench.utils.json import load_from_json, save_to_json
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
31
31
  class OPCMForCLIP(
32
32
  BaseAlgorithm,
33
33
  LightningFabricMixin,
34
+ SimpleProfilerMixin,
34
35
  ):
35
36
  def __init__(
36
37
  self,
@@ -64,7 +65,8 @@ class OPCMForCLIP(
64
65
  L.seed_everything(self.seed)
65
66
  accelerator = self.fabric.device
66
67
 
67
- pretrained_model = modelpool.load_pretrained_model()
68
+ with self.profile("loading model"):
69
+ pretrained_model = modelpool.load_pretrained_model()
68
70
 
69
71
  model_names = modelpool.model_names
70
72
  if self.shuffle_order:
@@ -83,15 +85,17 @@ class OPCMForCLIP(
83
85
  )
84
86
 
85
87
  # get the average model
86
- merged_model = modelpool.load_model(model_names[0])
88
+ with self.profile("loading model"):
89
+ merged_model = modelpool.load_model(model_names[0])
87
90
 
88
91
  if self.evaluate_on_every_step:
89
- self.taskpool._is_setup = False
90
- self.taskpool._test_datasets = DictConfig(
91
- {model_names[0]: self._test_datasets[model_names[0]]}
92
- )
93
- report = self.taskpool.evaluate(deepcopy(merged_model))
94
- save_to_json(report, Path(self.log_dir) / "report_0.json")
92
+ with self.profile("evaluating model"):
93
+ self.taskpool._is_setup = False
94
+ self.taskpool._test_datasets = DictConfig(
95
+ {model_names[0]: self._test_datasets[model_names[0]]}
96
+ )
97
+ report = self.taskpool.evaluate(deepcopy(merged_model))
98
+ save_to_json(report, Path(self.log_dir) / "report_0.json")
95
99
 
96
100
  self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
97
101
  self.all_task_vector_norm = [self.avg_task_vector_norm]
@@ -113,90 +117,104 @@ class OPCMForCLIP(
113
117
  enumerate(model_names[1:]), desc="Processing models"
114
118
  ):
115
119
  model_idx += 1
116
- task_model = modelpool.load_model(model_name)
120
+ with self.profile("loading model"):
121
+ task_model = modelpool.load_model(model_name)
117
122
 
118
- self.all_task_vector_norm.append(
119
- get_task_vector_norm(task_model, pretrained_model)
120
- )
121
- self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
122
- self.fabric.log(
123
- "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
124
- )
125
- self.fabric.log(
126
- "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
127
- )
123
+ with self.profile("merging model"):
124
+ self.all_task_vector_norm.append(
125
+ get_task_vector_norm(task_model, pretrained_model)
126
+ )
127
+ self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
128
+ self.fabric.log(
129
+ "model/task_vector_norm",
130
+ self.all_task_vector_norm[-1],
131
+ step=model_idx,
132
+ )
133
+ self.fabric.log(
134
+ "model/avg_task_vector_norm",
135
+ self.avg_task_vector_norm,
136
+ step=model_idx,
137
+ )
128
138
 
129
- self.lambda_t = 1 # temporary value
130
-
131
- for module_name, module in tqdm(
132
- list(merged_model.named_modules()),
133
- desc=f"Processing {model_name}",
134
- leave=False,
135
- ):
136
- if not is_leaf_module(module):
137
- continue
138
-
139
- if isinstance(module, nn.Linear):
140
- module.weight.data = self.merge_linear_weights(
141
- module.weight,
142
- pretrained_model.get_submodule(module_name).weight,
143
- task_model.get_submodule(module_name).weight,
144
- param_name=".".join([module_name, "weight"]),
145
- alpha=self.alpha,
146
- accelerator=accelerator,
147
- )
148
- if module.bias is not None:
149
- module.bias.data = self.merge_other_parameters(
150
- module.bias,
151
- pretrained_model.get_submodule(module_name).bias,
152
- task_model.get_submodule(module_name).bias,
153
- param_name=".".join([module_name, "bias"]),
139
+ self.lambda_t = 1 # temporary value
140
+
141
+ for module_name, module in tqdm(
142
+ list(merged_model.named_modules()),
143
+ desc=f"Processing {model_name}",
144
+ leave=False,
145
+ ):
146
+ if not is_leaf_module(module):
147
+ continue
148
+
149
+ if isinstance(module, nn.Linear):
150
+ module.weight.data = self.merge_linear_weights(
151
+ module.weight,
152
+ pretrained_model.get_submodule(module_name).weight,
153
+ task_model.get_submodule(module_name).weight,
154
+ param_name=".".join([module_name, "weight"]),
155
+ alpha=self.alpha,
154
156
  accelerator=accelerator,
155
157
  )
156
- else:
157
- for param_name, param in module.named_parameters():
158
- param.data = self.merge_other_parameters(
159
- merged_W=param,
160
- pretrained_W=pretrained_model.get_submodule(
161
- module_name
162
- ).get_parameter(param_name),
163
- task_W=task_model.get_submodule(module_name).get_parameter(
164
- param_name
165
- ),
166
- param_name=".".join([module_name, param_name]),
167
- accelerator=accelerator,
168
- )
169
-
170
- task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
171
- self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
172
- for param_name, param in merged_model.named_parameters():
173
- param.data = pretrained_model.get_parameter(param_name) + (
174
- param - pretrained_model.get_parameter(param_name)
175
- ) * (self.avg_task_vector_norm / task_vector_norm)
176
- self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
177
- self.fabric.log(
178
- "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
179
- )
180
- self.previous_lambda_t = self.lambda_t
181
- self.lambda_t = None
158
+ if module.bias is not None:
159
+ module.bias.data = self.merge_other_parameters(
160
+ module.bias,
161
+ pretrained_model.get_submodule(module_name).bias,
162
+ task_model.get_submodule(module_name).bias,
163
+ param_name=".".join([module_name, "bias"]),
164
+ accelerator=accelerator,
165
+ )
166
+ else:
167
+ for param_name, param in module.named_parameters():
168
+ param.data = self.merge_other_parameters(
169
+ merged_W=param,
170
+ pretrained_W=pretrained_model.get_submodule(
171
+ module_name
172
+ ).get_parameter(param_name),
173
+ task_W=task_model.get_submodule(
174
+ module_name
175
+ ).get_parameter(param_name),
176
+ param_name=".".join([module_name, param_name]),
177
+ accelerator=accelerator,
178
+ )
179
+
180
+ task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
181
+ self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
182
+ for param_name, param in merged_model.named_parameters():
183
+ param.data = pretrained_model.get_parameter(param_name) + (
184
+ param - pretrained_model.get_parameter(param_name)
185
+ ) * (self.avg_task_vector_norm / task_vector_norm)
186
+ self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
187
+ self.fabric.log(
188
+ "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
189
+ )
190
+ self.previous_lambda_t = self.lambda_t
191
+ self.lambda_t = None
182
192
 
183
- self.fabric.log(
184
- "model/merged_task_vector_norm",
185
- get_task_vector_norm(merged_model, pretrained_model),
186
- step=model_idx,
187
- )
193
+ self.fabric.log(
194
+ "model/merged_task_vector_norm",
195
+ get_task_vector_norm(merged_model, pretrained_model),
196
+ step=model_idx,
197
+ )
188
198
 
189
199
  if self.save_on_every_step:
190
- self.save_merged_model(merged_model, model_idx)
200
+ with self.profile("saving model"):
201
+ self.save_merged_model(merged_model, model_idx)
191
202
 
192
203
  if self.evaluate_on_every_step:
193
- self.taskpool._is_setup = False
194
- self.taskpool._test_datasets = DictConfig(
195
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
196
- )
197
- report = self.taskpool.evaluate(deepcopy(merged_model))
198
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
204
+ with self.profile("evaluating model"):
205
+ self.taskpool._is_setup = False
206
+ self.taskpool._test_datasets = DictConfig(
207
+ {
208
+ n: self._test_datasets[n]
209
+ for n in model_names[: model_idx + 1]
210
+ }
211
+ )
212
+ report = self.taskpool.evaluate(deepcopy(merged_model))
213
+ save_to_json(
214
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
215
+ )
199
216
 
217
+ self.print_profile_summary()
200
218
  return merged_model
201
219
 
202
220
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -227,7 +245,7 @@ class OPCMForCLIP(
227
245
  split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
228
246
 
229
247
  projected_task_tv = u.T @ task_tv @ v
230
- projected_task_tv.diag().fill_(0)
248
+ projected_task_tv.diagonal().fill_(0)
231
249
 
232
250
  projected_task_tv[:split_rank, :split_rank] = 0
233
251
 
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
15
15
  from transformers import CLIPVisionModel
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
19
19
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
20
  from fusion_bench.utils.json import load_from_json, save_to_json
21
21
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
@@ -24,7 +24,11 @@ if TYPE_CHECKING:
24
24
  from torch.utils.tensorboard import SummaryWriter
25
25
 
26
26
 
27
- class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
27
+ class ContinualTaskArithmeticForCLIP(
28
+ BaseAlgorithm,
29
+ LightningFabricMixin,
30
+ SimpleProfilerMixin,
31
+ ):
28
32
  def __init__(
29
33
  self,
30
34
  scaling_factor: float,
@@ -79,32 +83,42 @@ class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
79
83
  for model_idx, model_name in tqdm(
80
84
  enumerate(model_names), desc="Processing models"
81
85
  ):
82
- task_model = modelpool.load_model(model_name)
86
+ with self.profile("loading model"):
87
+ task_model = modelpool.load_model(model_name)
83
88
 
84
- for param_name, param in task_model.named_parameters():
85
- if not param.requires_grad:
86
- continue
89
+ with self.profile("merging model"):
90
+ for param_name, param in task_model.named_parameters():
91
+ if not param.requires_grad:
92
+ continue
87
93
 
88
- task_param = param
89
- merged_param = merged_model.get_parameter(param_name)
90
- pretrained_param = pretrained_model.get_parameter(param_name)
94
+ task_param = param
95
+ merged_param = merged_model.get_parameter(param_name)
96
+ pretrained_param = pretrained_model.get_parameter(param_name)
91
97
 
92
- new_param = merged_param + self.scaling_factor * (
93
- task_param - pretrained_param
94
- )
95
- merged_model.get_parameter(param_name).data = new_param
98
+ new_param = merged_param + self.scaling_factor * (
99
+ task_param - pretrained_param
100
+ )
101
+ merged_model.get_parameter(param_name).data = new_param
96
102
 
97
103
  if self.save_on_every_step:
98
- self.save_merged_model(merged_model, model_idx)
104
+ with self.profile("saving model"):
105
+ self.save_merged_model(merged_model, model_idx)
99
106
 
100
107
  if self.evaluate_on_every_step:
101
- self.taskpool._is_setup = False
102
- self.taskpool._test_datasets = DictConfig(
103
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
104
- )
105
- report = self.taskpool.evaluate(deepcopy(merged_model))
106
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
107
-
108
+ with self.profile("evaluating model"):
109
+ self.taskpool._is_setup = False
110
+ self.taskpool._test_datasets = DictConfig(
111
+ {
112
+ n: self._test_datasets[n]
113
+ for n in model_names[: model_idx + 1]
114
+ }
115
+ )
116
+ report = self.taskpool.evaluate(deepcopy(merged_model))
117
+ save_to_json(
118
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
119
+ )
120
+
121
+ self.print_profile_summary()
108
122
  return merged_model
109
123
 
110
124
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):