fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ import copy
2
+ import gc
3
+ import logging
4
+ from abc import abstractmethod
5
+ from typing import List, Mapping, Union # noqa: F401
6
+
7
+ import lightning as L
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from omegaconf import DictConfig
12
+ from torch import Tensor
13
+ from torch.utils.data import DataLoader
14
+ from tqdm.autonotebook import tqdm
15
+
16
+ from fusion_bench.compat.method import ModelFusionAlgorithm
17
+ from fusion_bench.compat.modelpool import ModelPool
18
+ from fusion_bench.models.wrappers.task_wise_fusion import (
19
+ TaskWiseMergedModel,
20
+ get_task_wise_weights,
21
+ )
22
+
23
+ log = logging.getLogger(__name__)
24
+
25
+
26
+ # obtain the current GPU memory usage
27
+ def print_memory_usage(desc):
28
+ print(desc)
29
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
30
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
31
+ print(f"Allocated Memory: {allocated:.2f} MB")
32
+ print(f"Cached Memory: {cached:.2f} MB")
33
+
34
+
35
+ def entropy_loss(logits: Tensor) -> Tensor:
36
+ """
37
+ Compute the entropy loss of a set of logits.
38
+
39
+ Args:
40
+ logits (Tensor): The logits to compute the entropy loss of.
41
+
42
+ Returns:
43
+ Tensor: The entropy loss of the logits.
44
+ """
45
+ probs = torch.softmax(logits, dim=-1)
46
+ return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
47
+
48
+
49
+ class ModelScheduler:
50
+ """
51
+ Manage the storage of models, schedule the order in which models are loaded to GPU
52
+ transfer data between the CPU and GPU
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ modelpool: ModelPool,
58
+ config: DictConfig,
59
+ ):
60
+ self.pretrained_model = modelpool.load_model("_pretrained_")
61
+ self.finetuned_models = [
62
+ modelpool.load_model(name) for name in modelpool.model_names
63
+ ]
64
+ self.num_finetuned_models = len(self.finetuned_models)
65
+ self.new_finetuned_models = copy.deepcopy(self.finetuned_models)
66
+ self.finetuned_model_names = [name for name in modelpool.model_names]
67
+
68
+ self.config = config
69
+
70
+ @torch.no_grad() # not sure whether to use this
71
+ def __call__(self, model_id):
72
+ """
73
+ return models and relevant data in each step
74
+ """
75
+ # TODO: use a mixing matrix to determine which models to use in step idx
76
+
77
+ pretrained_model = copy.deepcopy(self.finetuned_models[model_id])
78
+ finetuned_models = [
79
+ copy.deepcopy(
80
+ self.finetuned_models[(model_id + 1) % self.num_finetuned_models]
81
+ ),
82
+ copy.deepcopy(
83
+ self.finetuned_models[(model_id - 1) % self.num_finetuned_models]
84
+ ),
85
+ ]
86
+
87
+ if self.config.weights is None:
88
+ task_wise_weight = get_task_wise_weights(
89
+ num_models=len(finetuned_models),
90
+ init_values=self.config.init_values,
91
+ )
92
+ else:
93
+ pass
94
+
95
+ module = TaskWiseMergedModel(
96
+ task_wise_weight=task_wise_weight,
97
+ pretrained_model=pretrained_model,
98
+ finetuned_models=finetuned_models,
99
+ clamp_weights=self.config.clamp_weights,
100
+ tie_weights=self.config.tie_weights,
101
+ strict=self.config.strict,
102
+ )
103
+ return module
104
+
105
+ def store_model(self, new_finetuned_model_dict, model_id):
106
+ """
107
+ store new finetuned model after every turn of adamerging
108
+ """
109
+ self.new_finetuned_models[model_id].load_state_dict(new_finetuned_model_dict)
110
+
111
+ def update_models(self):
112
+ self.finetuned_models = copy.deepcopy(self.new_finetuned_models)
113
+
114
+ def get_final_models(self):
115
+ # need a check
116
+ final_models = [
117
+ {"name": name, "model": model}
118
+ for name, model in zip(self.finetuned_model_names, self.finetuned_models)
119
+ ]
120
+ num_finetuned_models = len(self.finetuned_models)
121
+
122
+ state_dict = self.pretrained_model.state_dict(keep_vars=True)
123
+ for name in state_dict.keys():
124
+ state_dict[name].data.zero_()
125
+ for model in self.finetuned_models:
126
+ for name, param in model.named_parameters():
127
+ state_dict[name] = state_dict[name] + 1 / num_finetuned_models * param
128
+
129
+ self.pretrained_model.load_state_dict(state_dict)
130
+ final_models += [{"name": "average model", "model": self.pretrained_model}]
131
+
132
+ return final_models
133
+
134
+
135
+ class TaskWiseGossipAlgorithm(ModelFusionAlgorithm):
136
+ _fabric: L.Fabric = None
137
+
138
+ def __init__(self, algorithm_config: DictConfig):
139
+ super().__init__(algorithm_config)
140
+
141
+ if self._fabric is None and torch.cuda.is_available():
142
+ self._fabric = L.Fabric(devices=self.config.get("devices", 1))
143
+ self._fabric.launch()
144
+
145
+ self.optimizer = None # we want to reuse it in Gossip using single GPU
146
+
147
+ def free_gpu_memory(self, module: TaskWiseMergedModel):
148
+ module.pretrained_model.to("cpu")
149
+ for model in module.task_vectors:
150
+ model.to("cpu")
151
+ del module
152
+ gc.collect()
153
+ torch.cuda.empty_cache()
154
+ print_memory_usage(
155
+ "finish local adamerging, after freeing memory, the memory usage of GPU is:"
156
+ )
157
+
158
+ def run(self, modelpool: ModelPool):
159
+ log.info("Fusing models using task-wise adaptive merging with gossip.")
160
+ self.modelpool = modelpool
161
+ self.num_finetuned_models = len(modelpool.model_names)
162
+
163
+ model_scheduler = ModelScheduler(self.modelpool, self.config)
164
+
165
+ pbar = tqdm(
166
+ range(self.config.gossip_max_steps), "Gossip merging", dynamic_ncols=True
167
+ )
168
+ for step_idx in pbar:
169
+ log.info(f"step: {step_idx}")
170
+ for model_id in tqdm(
171
+ range(self.num_finetuned_models), "local adamerging", dynamic_ncols=True
172
+ ):
173
+ # log.info(f"adamerging model: {model_scheduler.finetuned_midels_name[model_id]}")
174
+ module = model_scheduler(model_id)
175
+ module = self.test_time_adaptation(module)
176
+ # if self.config.get("save_merging_weights", False):
177
+ # torch.save(module.merge_weight, self.config.save_merging_weights)
178
+ print_memory_usage(
179
+ "local adamerging almost done, the memory usage of GPU is:"
180
+ )
181
+ model_scheduler.store_model(module.merge_weights(), model_id)
182
+ print_memory_usage(
183
+ "local adamerging almost done, the memory usage of GPU is:"
184
+ )
185
+ self.free_gpu_memory(
186
+ module
187
+ ) # simulate distributed GPU memory usage as much as possible
188
+
189
+ model_scheduler.update_models()
190
+
191
+ return model_scheduler.get_final_models()
192
+
193
+ def on_test_time_adaptation_start(self):
194
+ pass
195
+
196
+ @abstractmethod
197
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
198
+ pass
199
+
200
+ @abstractmethod
201
+ def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
202
+ """
203
+ Compute the logits for the given batch and task.
204
+
205
+ Args:
206
+ module (nn.Module): The model module.
207
+ batch (tuple): A batch of input data.
208
+ task (str): The name of the task.
209
+
210
+ Returns:
211
+ Tensor: The classification logits for the batch.
212
+ """
213
+ pass
214
+
215
+ def test_time_adaptation(self, module: TaskWiseMergedModel):
216
+ self.on_test_time_adaptation_start()
217
+
218
+ # configure optimizer
219
+ if self.config.optimizer == "adam":
220
+ self.optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
221
+ else:
222
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
223
+
224
+ if self._fabric is not None:
225
+ module, self.optimizer = self._fabric.setup(module, self.optimizer)
226
+ print_memory_usage(
227
+ "load model and optimizer to GPU, the memory usage of GPU is:"
228
+ )
229
+ module.train()
230
+ module.merge_weights()
231
+
232
+ if self.config.get("fast_dev_run", False):
233
+ log.info("Running fast_dev_run, only one step")
234
+ pbar = tqdm(
235
+ range(1),
236
+ "AdaMerging Test-time adaptation",
237
+ dynamic_ncols=True,
238
+ )
239
+ else:
240
+ pbar = tqdm(
241
+ range(self.config.max_steps),
242
+ "AdaMerging Test-time adaptation",
243
+ dynamic_ncols=True,
244
+ )
245
+ for step_idx in pbar:
246
+ for task in self.modelpool.model_names:
247
+ batch = next(self.get_shuffled_test_loader_iter(task))
248
+ logits = self.compute_logits(module, batch, task)
249
+ assert (
250
+ logits.dim() == 2
251
+ ), f"Expected logits to be 2D, got {logits.dim()}"
252
+ loss = entropy_loss(logits)
253
+ # .backward() accumulates when .zero_grad() wasn't called
254
+ # this can save memory
255
+ self._fabric.backward(loss, retain_graph=True)
256
+
257
+ # print_memory_usage('model + dataset: ')
258
+ self.optimizer.step()
259
+ self.optimizer.zero_grad()
260
+ module.merge_weights()
261
+
262
+ del self.optimizer
263
+ gc.collect()
264
+ torch.cuda.empty_cache()
265
+ return module
@@ -0,0 +1,74 @@
1
+ import copy
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ def get_memory_usage(desc):
9
+ """
10
+ obtain the current GPU memory usage
11
+
12
+ Returns:
13
+ str: A string containing the allocated and cached memory in MB.
14
+ """
15
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
16
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
17
+ return (
18
+ f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
19
+ )
20
+
21
+
22
+ # Model conversion utils
23
+
24
+
25
+ def state_dict_to_vector(state_dict, remove_keys=[]):
26
+ """
27
+ Convert a state dictionary to a vector.
28
+
29
+ Args:
30
+ state_dict (dict): The state dictionary to convert.
31
+ remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
32
+
33
+ Returns:
34
+ torch.Tensor: The converted vector.
35
+ """
36
+ shared_state_dict = copy.deepcopy(state_dict)
37
+ for key in remove_keys:
38
+ if key in shared_state_dict:
39
+ del shared_state_dict[key]
40
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
41
+ return nn.utils.parameters_to_vector(
42
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
43
+ )
44
+
45
+
46
+ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
47
+ """
48
+ Convert a vector to a state dictionary.
49
+
50
+ Args:
51
+ vector (torch.Tensor): The vector to convert.
52
+ state_dict (dict): The reference state dictionary to define the order of the vector.
53
+ remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
54
+
55
+ Returns:
56
+ dict: The converted state dictionary.
57
+ """
58
+ # create a reference dict to define the order of the vector
59
+ reference_dict = copy.deepcopy(state_dict)
60
+ for key in remove_keys:
61
+ if key in reference_dict:
62
+ del reference_dict[key]
63
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
64
+
65
+ # create a shared state dict using the reference dict
66
+ nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
67
+
68
+ # add back the encoder and decoder embedding weights.
69
+ if "transformer.shared.weight" in sorted_reference_dict:
70
+ for key in remove_keys:
71
+ sorted_reference_dict[key] = sorted_reference_dict[
72
+ "transformer.shared.weight"
73
+ ]
74
+ return sorted_reference_dict
@@ -3,7 +3,7 @@ This module contains the implementation of the Isotropic Merging in Common Subsp
3
3
  Modified from the original implementation: https://github.com/danielm1405/iso-merging
4
4
 
5
5
  Reference:
6
- - Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
6
+ - Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
7
7
  https://arxiv.org/abs/2502.04959
8
8
  """
9
9
 
@@ -126,10 +126,14 @@ class OPCMForCLIP(
126
126
  )
127
127
  self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
128
128
  self.fabric.log(
129
- "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
129
+ "model/task_vector_norm",
130
+ self.all_task_vector_norm[-1],
131
+ step=model_idx,
130
132
  )
131
133
  self.fabric.log(
132
- "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
134
+ "model/avg_task_vector_norm",
135
+ self.avg_task_vector_norm,
136
+ step=model_idx,
133
137
  )
134
138
 
135
139
  self.lambda_t = 1 # temporary value
@@ -166,9 +170,9 @@ class OPCMForCLIP(
166
170
  pretrained_W=pretrained_model.get_submodule(
167
171
  module_name
168
172
  ).get_parameter(param_name),
169
- task_W=task_model.get_submodule(module_name).get_parameter(
170
- param_name
171
- ),
173
+ task_W=task_model.get_submodule(
174
+ module_name
175
+ ).get_parameter(param_name),
172
176
  param_name=".".join([module_name, param_name]),
173
177
  accelerator=accelerator,
174
178
  )
@@ -200,10 +204,15 @@ class OPCMForCLIP(
200
204
  with self.profile("evaluating model"):
201
205
  self.taskpool._is_setup = False
202
206
  self.taskpool._test_datasets = DictConfig(
203
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
207
+ {
208
+ n: self._test_datasets[n]
209
+ for n in model_names[: model_idx + 1]
210
+ }
204
211
  )
205
212
  report = self.taskpool.evaluate(deepcopy(merged_model))
206
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
213
+ save_to_json(
214
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
215
+ )
207
216
 
208
217
  self.print_profile_summary()
209
218
  return merged_model
@@ -1,5 +1,5 @@
1
1
  R"""
2
- this is adapted from
2
+ this is adapted from
3
3
  https://github.com/tanganke/weight-ensembling_MoE/blob/3cbd327cb28c499065f83387472a79829a2e5fee/src/module/dict_moe.py
4
4
  but with some modifications
5
5
  """