fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,434 @@
1
+ import copy
2
+ import gc
3
+ import logging
4
+ import os
5
+ from abc import abstractmethod
6
+ from typing import Any, Callable, List, Mapping, Union, cast # noqa: F401
7
+
8
+ import torch
9
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
10
+ from omegaconf import DictConfig
11
+ from torch import Tensor
12
+ from torch.utils.data import DataLoader
13
+ from tqdm.autonotebook import tqdm
14
+
15
+ from fusion_bench.compat.method import ModelFusionAlgorithm
16
+ from fusion_bench.compat.modelpool import ModelPool
17
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
18
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
19
+ from fusion_bench.modelpool import (
20
+ CLIPVisionModelPool,
21
+ GPT2ForSequenceClassificationPool,
22
+ )
23
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
24
+ LayerWiseMergedModel,
25
+ get_layer_wise_weights,
26
+ )
27
+ from fusion_bench.utils.data import load_tensor_from_file
28
+
29
+ from .entropy_loss import entropy_loss
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ # obtain the current GPU memory usage
35
+ def get_memory_usage(desc):
36
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
37
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
38
+ return (
39
+ f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
40
+ )
41
+
42
+
43
+ class ModelScheduler:
44
+ """
45
+ Manage the storage of models, schedule the order in which models are loaded to GPU
46
+ transfer data between the CPU and GPu
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ config: DictConfig,
52
+ modelpool: ModelPool,
53
+ ):
54
+ self.pretrained_model = modelpool.load_model("_pretrained_")
55
+ self.finetuned_models = [
56
+ modelpool.load_model(name) for name in modelpool.model_names
57
+ ]
58
+ self.num_finetuned_models = len(self.finetuned_models)
59
+ self.new_finetuned_models = copy.deepcopy(self.finetuned_models)
60
+ self.finetuned_models_name = [name for name in modelpool.model_names]
61
+
62
+ self.config = config
63
+
64
+ @torch.no_grad() # not sure whether to use this
65
+ def __call__(self, model_id):
66
+ """
67
+ return models and relevant data in each step
68
+ """
69
+ pretrained_model = copy.deepcopy(self.pretrained_model)
70
+ if self.config.topo == "ring":
71
+ finetuned_models = [
72
+ copy.deepcopy(
73
+ self.finetuned_models[(model_id + 1) % self.num_finetuned_models]
74
+ ),
75
+ copy.deepcopy(self.finetuned_models[model_id]),
76
+ copy.deepcopy(
77
+ self.finetuned_models[(model_id - 1) % self.num_finetuned_models]
78
+ ),
79
+ ]
80
+ elif "rotate" in self.config.topo:
81
+ number = self.config.topo.split("_")[1]
82
+ finetuned_models = [copy.deepcopy(self.finetuned_models[model_id])]
83
+ for i in range(0, int(number)):
84
+ finetuned_models.append(
85
+ copy.deepcopy(
86
+ self.finetuned_models[
87
+ (model_id + i + 1) % self.num_finetuned_models
88
+ ]
89
+ )
90
+ )
91
+ # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
92
+ if self.config.weights is None:
93
+ layer_wise_weight = get_layer_wise_weights(
94
+ num_models=len(finetuned_models),
95
+ num_layers=len(
96
+ tuple(
97
+ filter(lambda p: p.requires_grad, pretrained_model.parameters())
98
+ )
99
+ ),
100
+ init_values=self.config.init_values,
101
+ )
102
+ else:
103
+ if isinstance(self.config.weights, str):
104
+ # self.config.weights is a path to a saved tensor
105
+ layer_wise_weight = load_tensor_from_file(self.config.weights)
106
+ else:
107
+ raise ValueError(f"Unsupported weights format: {self.config.weights}")
108
+
109
+ module = LayerWiseMergedModel(
110
+ layer_wise_weight=layer_wise_weight,
111
+ pretrained_model=pretrained_model,
112
+ finetuned_models=finetuned_models,
113
+ clamp_weights=self.config.clamp_weights,
114
+ tie_weights=self.config.tie_weights,
115
+ strict=self.config.strict,
116
+ )
117
+ print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
118
+ return module
119
+
120
+ def store_model(self, new_finetuned_model_dict, model_id):
121
+ """
122
+ store new finetuned model after every turn of adamerging
123
+ """
124
+ self.new_finetuned_models[model_id].load_state_dict(new_finetuned_model_dict)
125
+
126
+ def update_models(self):
127
+ self.finetuned_models = copy.deepcopy(self.new_finetuned_models)
128
+
129
+ def get_final_models(self, idx=None):
130
+ # need a check
131
+ if idx is not None:
132
+ return copy.deepcopy(self.finetuned_models[idx])
133
+
134
+ final_models = [
135
+ {"name": name, "model": model}
136
+ for name, model in zip(self.finetuned_models_name, self.finetuned_models)
137
+ ]
138
+ num_finetuned_models = len(self.finetuned_models)
139
+
140
+ average_model = copy.deepcopy(self.pretrained_model)
141
+ state_dict = average_model.state_dict(keep_vars=True)
142
+ for name, _ in self.finetuned_models[0].named_parameters():
143
+ state_dict[name].data.zero_()
144
+ for model in self.finetuned_models:
145
+ for name, param in model.named_parameters():
146
+ state_dict[name] = state_dict[name] + 1 / num_finetuned_models * param
147
+
148
+ average_model.load_state_dict(state_dict)
149
+ final_models += [{"name": "average model", "model": average_model}]
150
+
151
+ return final_models
152
+
153
+ def move_to(self, device):
154
+ self.pretrained_model.to(device=device)
155
+ for model in self.finetuned_models:
156
+ model.to(device=device)
157
+
158
+
159
+ class LayerWiseGossipAlgorithm(
160
+ ModelFusionAlgorithm,
161
+ LightningFabricMixin,
162
+ SimpleProfilerMixin,
163
+ ):
164
+ """
165
+ Implements the Layer-Wise AdaMerging Algorithm.
166
+
167
+ This class merges the layers of a pretrained model with those of several fine-tuned models.
168
+ The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
169
+ """
170
+
171
+ def __init__(self, algorithm_config: DictConfig):
172
+ """
173
+ Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.
174
+
175
+ Args:
176
+ algorithm_config (DictConfig): The configuration for the algorithm.
177
+ """
178
+ super().__init__(algorithm_config)
179
+ self._program = None
180
+
181
+ @rank_zero_only
182
+ def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
183
+ """
184
+ Save the merging weights to a file.
185
+
186
+ Args:
187
+ file_path (str): The path to save the merging weights.
188
+ merging_weights (torch.Tensor): The merging weights to save.
189
+ """
190
+ if self.fabric.is_global_zero and self.config.get(
191
+ "save_merging_weights", False
192
+ ):
193
+ if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
194
+ # if the file path is not absolute or relative to current working directory, save it in the log directory
195
+ save_path = os.path.join(self.log_dir, file_path)
196
+ else:
197
+ save_path = file_path
198
+ log.info(f"saving merging weights to {save_path}.")
199
+ if os.path.dirname(save_path):
200
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
201
+ torch.save(merging_weights.detach().cpu(), save_path)
202
+
203
+ def free_gpu_memory(self, module: LayerWiseMergedModel):
204
+ module.pretrained_model.to("cpu")
205
+ for model in module.task_vectors:
206
+ model.to("cpu")
207
+ del module
208
+ gc.collect()
209
+ torch.cuda.empty_cache()
210
+ log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
211
+
212
+ def update_datasets(self, datasets):
213
+ """
214
+ for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion
215
+ """
216
+ num_datasets = len(datasets)
217
+ datasets_copy = datasets.copy()
218
+ if self.config.topo == "ring":
219
+ for i in range(num_datasets):
220
+ datasets[i] = (
221
+ datasets_copy[i]
222
+ .union(datasets_copy[(i + 1) % num_datasets])
223
+ .union(datasets_copy[(i - 1) % num_datasets])
224
+ )
225
+ elif "rotate" in self.config.topo:
226
+ number = self.config.topo.split("_")[1]
227
+ for i in range(num_datasets):
228
+ datasets[i] = datasets_copy[i]
229
+ for j in range(0, int(number)):
230
+ datasets[i] = datasets[i].union(
231
+ datasets_copy[(i + j + 1) % num_datasets]
232
+ )
233
+ return datasets
234
+
235
+ def run(self, modelpool: ModelPool):
236
+ """
237
+ Run the Layer-Wise AdaMerging Algorithm.
238
+
239
+ This method constructs the wrapped model and performs test-time adaptation if necessary.
240
+
241
+ Args:
242
+ modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
243
+
244
+ Returns:
245
+ LayerWiseMergedModel: The merged model after test-time adaptation.
246
+ """
247
+ log.info("Fusing models using layer-wise adaptive merging.")
248
+ self.modelpool = modelpool
249
+ self.log_hyperparams(self.config)
250
+ self.num_finetuned_models = len(modelpool.model_names)
251
+ datasets = [{dataset} for dataset in modelpool.model_names]
252
+
253
+ with self.profile("construct the wrapped model"):
254
+ model_scheduler = ModelScheduler(
255
+ modelpool=self.modelpool, config=self.config
256
+ )
257
+
258
+ if self.config.weights is not None:
259
+ # skip the test-time adaptation
260
+ return module.merge_and_unload()
261
+ else:
262
+ for step_idx in tqdm(
263
+ range(self.config.gossip_max_steps),
264
+ "Gossip merging",
265
+ dynamic_ncols=True,
266
+ ):
267
+ datasets = self.update_datasets(datasets)
268
+ log.info(f"Gossip merging step:, {step_idx}")
269
+ for model_id in tqdm(
270
+ range(self.num_finetuned_models),
271
+ "local admerging",
272
+ dynamic_ncols=True,
273
+ ):
274
+ if self.config.gossip_skip_adamerging == True:
275
+ # skip adamerging, only merge
276
+ with self.profile("construct the local wrapped model"):
277
+ module = model_scheduler(model_id)
278
+ log.info(
279
+ f"skip adamerging, only merge ({modelpool.model_names[model_id]})"
280
+ )
281
+ model_scheduler.store_model(module.merge_weights(), model_id)
282
+ self.free_gpu_memory(module)
283
+ else:
284
+ with self.profile("construct the local wrapped model"):
285
+ module = model_scheduler(model_id)
286
+
287
+ if self.config.improve_dataset == True:
288
+ log.info(
289
+ f"improved datasets, the datasets used in this local merging is {datasets[model_id]}"
290
+ )
291
+ else:
292
+ log.info(
293
+ f"unimproved datasets, the datasets used in this local merging is {modelpool.model_names}"
294
+ )
295
+ with self.profile("test-time adaptation"):
296
+ module = self.test_time_adaptation(
297
+ module, datasets[model_id]
298
+ )
299
+ model_scheduler.store_model(module.merge_weights(), model_id)
300
+ log.info(
301
+ get_memory_usage(
302
+ f"after local merging ({modelpool.model_names[model_id]}), the memory usage of GPU is:"
303
+ )
304
+ )
305
+ self.free_gpu_memory(
306
+ module
307
+ ) # simulate distributed GPU memory usage as much as possible
308
+
309
+ model_scheduler.update_models()
310
+
311
+ if "rotate" in self.config.topo:
312
+ number = self.config.topo.split("_")[1]
313
+ if int(number) == 1 and step_idx >= 20:
314
+ self._program.evaluate_merged_model(
315
+ self._program.taskpool, model_scheduler.get_final_models()
316
+ )
317
+ model_scheduler.move_to("cpu")
318
+ else:
319
+ if (
320
+ self.config.accuracy_test_interval != 0
321
+ and (step_idx + 1) % self.config.accuracy_test_interval == 0
322
+ ):
323
+ self._program.evaluate_merged_model(
324
+ self._program.taskpool, model_scheduler.get_final_models()
325
+ )
326
+ model_scheduler.move_to("cpu")
327
+ return model_scheduler.get_final_models()
328
+
329
+ def on_test_time_adaptation_start(self):
330
+ """
331
+ Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
332
+ """
333
+ pass
334
+
335
+ @abstractmethod
336
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
337
+ """
338
+ Loader of test dataset for test-time adaptation. labels are not needed.
339
+
340
+ Args:
341
+ task (str): The name of the task.
342
+
343
+ Returns:
344
+ DataLoader: The data loader for the test dataset.
345
+ """
346
+ pass
347
+
348
+ @abstractmethod
349
+ def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
350
+ """
351
+ Compute the logits for the given images and task.
352
+
353
+ Args:
354
+ module: The model module.
355
+ images (Tensor): The input images.
356
+ task (str): The name of the task.
357
+
358
+ Returns:
359
+ Tensor: The computed logits.
360
+ """
361
+ pass
362
+
363
+ def test_time_adaptation(self, module: LayerWiseMergedModel, datasets):
364
+ """
365
+ Perform test-time adaptation on the merged model.
366
+
367
+ This method adapts the merging weights during test-time to improve performance.
368
+
369
+ Args:
370
+ module (LayerWiseMergedModel): The merged model.
371
+
372
+ Returns:
373
+ LayerWiseMergedModel: The adapted merged model.
374
+ """
375
+ self.on_test_time_adaptation_start()
376
+
377
+ # configure optimizer
378
+ if self.config.optimizer == "adam":
379
+ optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
380
+ print(f"{optimizer=}")
381
+ module, optimizer = self.fabric.setup(module, optimizer)
382
+ log.info(
383
+ get_memory_usage(
384
+ "after loading models and optimizer, the memory usage of GPU is:"
385
+ )
386
+ )
387
+ else:
388
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
389
+
390
+ module.train()
391
+ module.merge_weights()
392
+ for step_idx in (
393
+ pbar := tqdm(
394
+ range(self.config.max_steps if not self.is_debug_mode else 1),
395
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
396
+ + "AdaMerging Test-time adaptation",
397
+ dynamic_ncols=True,
398
+ )
399
+ ):
400
+ # default behavior for first-order optimizers
401
+ for task in self.modelpool.model_names:
402
+ if self.config.improve_dataset == True and task not in datasets:
403
+ continue
404
+ with self.profile("data loading"):
405
+ batch = next(self.get_shuffled_test_loader_iter(task))
406
+ with self.profile("forward pass"):
407
+ if isinstance(self.modelpool, GPT2ForSequenceClassificationPool):
408
+ logits = self.compute_logits(module, batch, task)
409
+ elif isinstance(self.modelpool, CLIPVisionModelPool):
410
+ logits = self.compute_logits(module, batch[0], task)
411
+ loss = entropy_loss(logits)
412
+ with self.profile("backward pass"):
413
+ self.fabric.backward(loss, retain_graph=True)
414
+
415
+ with self.profile("optimizer step"):
416
+ optimizer.step()
417
+ optimizer.zero_grad()
418
+ with self.profile("merging weights"):
419
+ module.merge_weights()
420
+
421
+ metrics = {
422
+ "train/loss": loss.item(),
423
+ "train/weight_max": module.merge_weight.max().item(),
424
+ "train/weight_min": module.merge_weight.min().item(),
425
+ "train/weight_mean": module.merge_weight.mean().item(),
426
+ }
427
+ self.fabric.log_dict(metrics, step=step_idx)
428
+ pbar.set_postfix(metrics)
429
+
430
+ self.print_profile_summary()
431
+ del optimizer
432
+ gc.collect()
433
+ torch.cuda.empty_cache()
434
+ return module
@@ -0,0 +1,227 @@
1
+ # This code is from
2
+ # Multi-Task Learning as Multi-Objective Optimization
3
+ # Ozan Sener, Vladlen Koltun
4
+ # Neural Information Processing Systems (NeurIPS) 2018
5
+ # https://github.com/intel-isl/MultiObjectiveOptimization
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def np_sum(x: Union[torch.Tensor, np.ndarray]) -> float:
13
+ if isinstance(x, torch.Tensor):
14
+ return x.sum().item()
15
+ return np.sum(x)
16
+
17
+
18
+ def to_numpy(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
19
+ if isinstance(x, torch.Tensor):
20
+ return x.detach().cpu().numpy()
21
+ return x
22
+
23
+
24
+ class MinNormSolver:
25
+ MAX_ITER = 250
26
+ STOP_CRIT = 1e-5
27
+
28
+ def _min_norm_element_from2(v1v1, v1v2, v2v2):
29
+ """
30
+ Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
31
+ d is the distance (objective) optimzed
32
+ v1v1 = <x1,x1>
33
+ v1v2 = <x1,x2>
34
+ v2v2 = <x2,x2>
35
+ """
36
+ if v1v2 >= v1v1:
37
+ # Case: Fig 1, third column
38
+ gamma = 0.999
39
+ cost = v1v1
40
+ return gamma, cost
41
+ if v1v2 >= v2v2:
42
+ # Case: Fig 1, first column
43
+ gamma = 0.001
44
+ cost = v2v2
45
+ return gamma, cost
46
+ # Case: Fig 1, second column
47
+ gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
48
+ cost = v2v2 + gamma * (v1v2 - v2v2)
49
+ return gamma, cost
50
+
51
+ def _min_norm_2d(vecs, dps):
52
+ R"""
53
+ Find the minimum norm solution as combination of two points
54
+ This is correct only in 2D
55
+ ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
56
+ """
57
+ dmin = 1e8
58
+ for i in range(len(vecs)):
59
+ for j in range(i + 1, len(vecs)):
60
+ if (i, j) not in dps:
61
+ dps[(i, j)] = 0.0
62
+ for k in range(len(vecs[i])):
63
+ dps[(i, j)] += (
64
+ torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
65
+ )
66
+ dps[(j, i)] = dps[(i, j)]
67
+ if (i, i) not in dps:
68
+ dps[(i, i)] = 0.0
69
+ for k in range(len(vecs[i])):
70
+ dps[(i, i)] += (
71
+ torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
72
+ )
73
+ if (j, j) not in dps:
74
+ dps[(j, j)] = 0.0
75
+ for k in range(len(vecs[i])):
76
+ dps[(j, j)] += (
77
+ torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
78
+ )
79
+ c, d = MinNormSolver._min_norm_element_from2(
80
+ dps[(i, i)], dps[(i, j)], dps[(j, j)]
81
+ )
82
+ if d < dmin:
83
+ dmin = d
84
+ sol = [(i, j), c, d]
85
+ return sol, dps
86
+
87
+ def _projection2simplex(y):
88
+ R"""
89
+ Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
90
+ """
91
+ m = len(y)
92
+ sorted_y = np.flip(np.sort(y), axis=0)
93
+ tmpsum = 0.0
94
+ tmax_f = (np.sum(y) - 1.0) / m
95
+ for i in range(m - 1):
96
+ tmpsum += sorted_y[i]
97
+ tmax = (tmpsum - 1) / (i + 1.0)
98
+ if tmax > sorted_y[i + 1]:
99
+ tmax_f = tmax
100
+ break
101
+ return np.maximum(y - tmax_f, np.zeros(y.shape))
102
+
103
+ def _next_point(cur_val, grad, n):
104
+ proj_grad = grad - (np.sum(grad) / n)
105
+ tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
106
+ tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
107
+
108
+ skippers = np_sum(tm1 < 1e-7) + np_sum(tm2 < 1e-7)
109
+ t = 1
110
+ if len(tm1[tm1 > 1e-7]) > 0:
111
+ t = np.min(to_numpy(tm1[tm1 > 1e-7]))
112
+ if len(tm2[tm2 > 1e-7]) > 0:
113
+ t = min(t, np.min(to_numpy(tm2[tm2 > 1e-7])))
114
+
115
+ next_point = proj_grad * t + to_numpy(cur_val)
116
+ next_point = MinNormSolver._projection2simplex(next_point)
117
+ return next_point
118
+
119
+ def find_min_norm_element(vecs):
120
+ R"""
121
+ Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
122
+ as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
123
+ It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
124
+ Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
125
+ """
126
+ # Solution lying at the combination of two points
127
+ dps = {}
128
+ init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
129
+
130
+ n = len(vecs)
131
+ sol_vec = np.zeros(n)
132
+ sol_vec[init_sol[0][0]] = init_sol[1]
133
+ sol_vec[init_sol[0][1]] = 1 - init_sol[1]
134
+
135
+ if n < 3:
136
+ # This is optimal for n=2, so return the solution
137
+ return sol_vec, init_sol[2]
138
+
139
+ iter_count = 0
140
+
141
+ grad_mat = np.zeros((n, n))
142
+ for i in range(n):
143
+ for j in range(n):
144
+ grad_mat[i, j] = dps[(i, j)]
145
+
146
+ while iter_count < MinNormSolver.MAX_ITER:
147
+ grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
148
+ new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
149
+ # Re-compute the inner products for line search
150
+ v1v1 = 0.0
151
+ v1v2 = 0.0
152
+ v2v2 = 0.0
153
+ for i in range(n):
154
+ for j in range(n):
155
+ v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
156
+ v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
157
+ v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
158
+ nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
159
+ new_sol_vec = nc * sol_vec + (1 - nc) * new_point
160
+ change = new_sol_vec - sol_vec
161
+ if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
162
+ return sol_vec, nd
163
+ sol_vec = new_sol_vec
164
+
165
+ def find_min_norm_element_FW(vecs):
166
+ R"""
167
+ Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
168
+ as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
169
+ It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
170
+ Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
171
+ """
172
+ # Solution lying at the combination of two points
173
+ dps = {}
174
+ init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
175
+
176
+ n = len(vecs)
177
+ sol_vec = np.zeros(n)
178
+ sol_vec[init_sol[0][0]] = init_sol[1]
179
+ sol_vec[init_sol[0][1]] = 1 - init_sol[1]
180
+
181
+ if n < 3:
182
+ # This is optimal for n=2, so return the solution
183
+ return sol_vec, init_sol[2]
184
+
185
+ iter_count = 0
186
+
187
+ grad_mat = np.zeros((n, n))
188
+ for i in range(n):
189
+ for j in range(n):
190
+ grad_mat[i, j] = dps[(i, j)]
191
+
192
+ while iter_count < MinNormSolver.MAX_ITER:
193
+ t_iter = np.argmin(np.dot(grad_mat, sol_vec))
194
+
195
+ v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
196
+ v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
197
+ v2v2 = grad_mat[t_iter, t_iter]
198
+
199
+ nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
200
+ new_sol_vec = nc * sol_vec
201
+ new_sol_vec[t_iter] += 1 - nc
202
+
203
+ change = new_sol_vec - sol_vec
204
+ if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
205
+ return sol_vec, nd
206
+ sol_vec = new_sol_vec
207
+
208
+
209
+ def gradient_normalizers(grads, losses, normalization_type):
210
+ gn = {}
211
+ if normalization_type == "l2":
212
+ for t in grads:
213
+ gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
214
+ elif normalization_type == "loss":
215
+ for t in grads:
216
+ gn[t] = losses[t]
217
+ elif normalization_type == "loss+":
218
+ for t in grads:
219
+ gn[t] = losses[t] * np.sqrt(
220
+ np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])
221
+ )
222
+ elif normalization_type == "none":
223
+ for t in grads:
224
+ gn[t] = 1.0
225
+ else:
226
+ print("ERROR: Invalid Normalization Type")
227
+ return gn