fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -0,0 +1,331 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
6
+
7
+ import lightning
8
+ import lightning as L
9
+ import lightning.fabric.wrappers
10
+ import torch
11
+ from torch import Tensor
12
+ from torch.utils.data import DataLoader
13
+ from tqdm.autonotebook import tqdm
14
+ from transformers import T5ForConditionalGeneration
15
+ from transformers.data import default_data_collator
16
+
17
+ from fusion_bench.method import BaseAlgorithm
18
+ from fusion_bench.method.task_arithmetic.task_arithmetic import task_arithmetic_merge
19
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
20
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
21
+ from fusion_bench.modelpool import Seq2SeqLMPool
22
+ from fusion_bench.models.we_moe import WeightEnsemblingMoE
23
+ from fusion_bench.utils import timeit_context
24
+ from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
25
+ from fusion_bench.utils.instantiate_utils import instantiate
26
+ from fusion_bench.utils.parameters import print_parameters
27
+
28
+ from .entropy_loss import entropy_loss
29
+ from .utils import get_memory_usage
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ class FlanT5WeightEnsemblingMoEAlgorithm(
35
+ BaseAlgorithm,
36
+ LightningFabricMixin,
37
+ SimpleProfilerMixin,
38
+ ):
39
+ """
40
+ FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
41
+ for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.
42
+
43
+ Attributes:
44
+ modelpool (Seq2SeqLMPool): The model pool containing the FlanT5 models.
45
+ """
46
+
47
+ modelpool: Seq2SeqLMPool = None
48
+
49
+ def __init__(
50
+ self,
51
+ checkpoint: bool = False,
52
+ save_checkpoint: bool = False,
53
+ router_hidden_layers: int = 2,
54
+ init_lambda: float = 0.3,
55
+ batch_reduce: bool = True,
56
+ lr: float = 1e-4,
57
+ optimizer: str = "adam",
58
+ devices: int = 1,
59
+ batch_size: int = 16,
60
+ num_workers: int = 0,
61
+ max_steps: int = 1000,
62
+ use_grad_accumulate: bool = True,
63
+ cache_dir: bool = "outputs",
64
+ fast_dev_run: bool = False,
65
+ **kwargs,
66
+ ):
67
+ """
68
+ Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.
69
+
70
+ Args:
71
+ algorithm_config (DictConfig): The configuration for the algorithm.
72
+ """
73
+ self.checkpoint = checkpoint
74
+ self.save_checkpoint = save_checkpoint
75
+ self.router_hidden_layers = router_hidden_layers
76
+ self.init_lambda = init_lambda
77
+ self.batch_reduce = batch_reduce
78
+ self.lr = lr
79
+ self.optimizer = optimizer
80
+ self.devices = devices
81
+ self.batch_size = batch_size
82
+ self.num_workers = num_workers
83
+ self.max_steps = max_steps
84
+ self.use_grad_accumulate = use_grad_accumulate
85
+ self.cache_dir = cache_dir
86
+ self.fast_dev_run = fast_dev_run
87
+ super().__init__(**kwargs)
88
+
89
+ def construct_moe_model(self) -> WeightEnsemblingMoE:
90
+ """
91
+ Construct the Mixture of Experts (MoE) model using the models in the model pool.
92
+
93
+ Returns:
94
+ WeightEnsemblingMoE: The constructed MoE model.
95
+ """
96
+ base_model = self.modelpool.load_model("_pretrained_")
97
+ expert_models = [
98
+ self.modelpool.load_model(name) for name in self.modelpool.model_names
99
+ ]
100
+
101
+ # Merge the models using task arithmetic
102
+ moe_model = task_arithmetic_merge(
103
+ # This function modifies the model in place, so we need to pass a deepcopy
104
+ deepcopy(base_model),
105
+ expert_models,
106
+ scaling_factor=self.init_lambda,
107
+ ).requires_grad_(False)
108
+
109
+ print(base_model)
110
+
111
+ # Up-scale MLP modules
112
+ num_layer = 12
113
+ encoder_mlp_index = 1
114
+ base_encoder = base_model.encoder
115
+ moe_encoder = moe_model.encoder
116
+ expert_encoders = [m.encoder for m in expert_models]
117
+
118
+ for layer_idx in range(num_layer):
119
+ base_mlp = (
120
+ base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
121
+ )
122
+ expert_mlps = [
123
+ e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
124
+ for e in expert_encoders
125
+ ]
126
+
127
+ moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
128
+ WeightEnsemblingMoE(
129
+ hidden_size=base_encoder.config.hidden_size,
130
+ base_model=base_mlp,
131
+ expert_models=expert_mlps,
132
+ init_lambda=self.init_lambda,
133
+ batch_first=True,
134
+ router_hidden_layers=self.router_hidden_layers,
135
+ batch_reduce=self.batch_reduce,
136
+ )
137
+ )
138
+
139
+ decoder_mlp_index = 2
140
+ base_decoder = base_model.decoder
141
+ moe_decoder = moe_model.decoder
142
+ expert_decoders = [m.decoder for m in expert_models]
143
+
144
+ for layer_idx in range(num_layer):
145
+ base_mlp = (
146
+ base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
147
+ )
148
+ expert_mlps = [
149
+ e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
150
+ for e in expert_decoders
151
+ ]
152
+
153
+ moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
154
+ WeightEnsemblingMoE(
155
+ hidden_size=base_decoder.config.hidden_size,
156
+ base_model=base_mlp,
157
+ expert_models=expert_mlps,
158
+ init_lambda=self.init_lambda,
159
+ batch_first=True,
160
+ router_hidden_layers=self.router_hidden_layers,
161
+ batch_reduce=self.batch_reduce,
162
+ )
163
+ )
164
+
165
+ print(moe_model)
166
+ return moe_model
167
+
168
+ @functools.cache
169
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
170
+ """
171
+ Loader of test dataset for test-time adaptation. labels are not needed.
172
+
173
+ Args:
174
+ task (str): The name of the task.
175
+
176
+ Returns:
177
+ DataLoader: The data loader for the test dataset.
178
+ """
179
+ # dataloader_kwargs = dict(self.dataloader_kwargs)
180
+ # dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
181
+
182
+ dataset = self.modelpool.load_test_dataset(task)
183
+ log.info("get_shuffled_test_loader_iter")
184
+ loader = DataLoader(
185
+ dataset,
186
+ batch_size=self.batch_size,
187
+ shuffle=True,
188
+ num_workers=self.num_workers,
189
+ collate_fn=default_data_collator,
190
+ )
191
+ # loader = DataLoader(dataset, **dataloader_kwargs)
192
+ if self.fabric is not None:
193
+ loader = self.fabric.setup_dataloaders(loader)
194
+ return iter(InfiniteDataLoader(loader))
195
+
196
+ def compute_logits(
197
+ self,
198
+ module: Union[T5ForConditionalGeneration],
199
+ batch,
200
+ task: str,
201
+ ) -> Tensor:
202
+ """
203
+ Compute the logits for the given images and task.
204
+
205
+ Args:
206
+ module: The model module.
207
+ images (Tensor): The input images.
208
+ task (str): The name of the task.
209
+
210
+ Returns:
211
+ Tensor: The computed logits.
212
+ """
213
+ input_ids: Tensor = batch["input_ids"]
214
+ attention_mask: Tensor = batch["attention_mask"]
215
+
216
+ # remove padding tokens from the input
217
+ while attention_mask[:, -1].eq(0).all():
218
+ input_ids = input_ids[:, :-1]
219
+ attention_mask = attention_mask[:, :-1]
220
+
221
+ outputs = module(
222
+ input_ids=input_ids,
223
+ attention_mask=attention_mask,
224
+ decoder_input_ids=torch.ones(
225
+ input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
226
+ ),
227
+ )
228
+ logits = outputs.logits[:, 0, :]
229
+ return logits
230
+
231
+ def test_time_adaptation(self, module):
232
+ """
233
+ Perform test-time adaptation for the given module.
234
+
235
+ Args:
236
+ module (WeightEnsemblingMoE): The MoE module to adapt.
237
+
238
+ Returns:
239
+ WeightEnsemblingMoE: The adapted MoE module.
240
+ """
241
+ self.on_test_time_adaptation_start()
242
+
243
+ # configure optimizer
244
+ if self.optimizer == "adam":
245
+ print([name for name, p in module.named_parameters() if p.requires_grad])
246
+ optimizer = torch.optim.Adam(
247
+ [p for p in module.parameters() if p.requires_grad], lr=self.lr
248
+ )
249
+ else:
250
+ raise ValueError(f"Unsupported optimizer: {self.optimizer}")
251
+
252
+ module, optimizer = self.fabric.setup(module, optimizer)
253
+
254
+ module.train()
255
+ # module.merge_weights()
256
+ for step_idx in (
257
+ pbar := tqdm(
258
+ range(self.max_steps if not self.is_debug_mode else 1),
259
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
260
+ + "WEMoE Test-time adaptation",
261
+ dynamic_ncols=True,
262
+ )
263
+ ):
264
+ total_loss = 0
265
+ for task in self.modelpool.model_names:
266
+ with self.profile("data loading"):
267
+ batch = next(self.get_shuffled_test_loader_iter(task))
268
+ with self.profile("forward pass"):
269
+ logits = self.compute_logits(module, batch, task)
270
+ logits = logits.mean(dim=0, keepdim=True)
271
+ loss = entropy_loss(logits)
272
+ total_loss += loss
273
+ with self.profile("backward pass"):
274
+ self.fabric.backward(loss, retain_graph=True)
275
+
276
+ with self.profile("optimizer step"):
277
+ optimizer.step()
278
+ optimizer.zero_grad()
279
+
280
+ metrics = {
281
+ "train/loss": total_loss.item(),
282
+ }
283
+ self.fabric.log_dict(metrics, step=step_idx)
284
+ pbar.set_postfix(metrics)
285
+
286
+ log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
287
+ self.print_profile_summary()
288
+ return module
289
+
290
+ def on_test_time_adaptation_start(self):
291
+ """
292
+ Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
293
+ """
294
+ pass
295
+
296
+ def run(self, modelpool: Seq2SeqLMPool, **kwargs):
297
+ """
298
+ Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
299
+
300
+ Args:
301
+ modelpool (ModelPool): The pool of models to be fused.
302
+
303
+ Returns:
304
+ WeightEnsemblingMoE: The fused MoE model.
305
+ """
306
+ log.info("Fusing models using layer-wise adaptive merging.")
307
+ self.modelpool = modelpool
308
+
309
+ with timeit_context("upscaling models to a weight-ensembling MoE model"):
310
+ moe_model = self.construct_moe_model()
311
+ print_parameters(moe_model)
312
+
313
+ if self.checkpoint != False:
314
+ log.info(
315
+ f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
316
+ )
317
+ self.load_checkpoint(moe_model, self.checkpoint)
318
+ else:
319
+ with self.profile("test-time adaptation"):
320
+ moe_model = self.test_time_adaptation(moe_model)
321
+ if self.save_checkpoint != False:
322
+ log.info(f"save checkpoint to {self.save_checkpoint}")
323
+ self.save_checkpoint(moe_model, self.save_checkpoint)
324
+
325
+ if lightning.fabric.wrappers.is_wrapped(moe_model):
326
+ moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
327
+
328
+ # enable sample-wise adaptation
329
+ moe_model.batch_reduce = False
330
+ self.print_profile_summary()
331
+ return moe_model
@@ -0,0 +1,15 @@
1
+ import torch
2
+
3
+
4
+ def get_memory_usage(desc):
5
+ """
6
+ obtain the current GPU memory usage
7
+
8
+ Returns:
9
+ str: A string containing the allocated and cached memory in MB.
10
+ """
11
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
12
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
13
+ return (
14
+ f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
15
+ )
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
- from typing import cast # noqa: F401
3
+ from typing import Any, cast # noqa: F401
4
4
 
5
5
  import lightning as L
6
6
  import lightning.fabric.wrappers
@@ -70,7 +70,7 @@ class WeightEnsemblingMoEAlgorithm(
70
70
  assert "No CUDA device available."
71
71
 
72
72
  @abstractmethod
73
- def load_checkpoint(self, model, checkpoint):
73
+ def load_checkpoint(self, model: Any, checkpoint: Any):
74
74
  """
75
75
  Load the checkpoint file.
76
76
 
@@ -81,7 +81,7 @@ class WeightEnsemblingMoEAlgorithm(
81
81
  pass
82
82
 
83
83
  @abstractmethod
84
- def save_checkpoint(self, model, checkpoint):
84
+ def save_checkpoint(self, model: Any, checkpoint: Any):
85
85
  """
86
86
  Save the checkpoint file.
87
87
 
@@ -121,7 +121,7 @@ class WeightEnsemblingMoEAlgorithm(
121
121
  pass
122
122
 
123
123
  @abstractmethod
124
- def compute_logits(self, module, batch, task) -> Tensor:
124
+ def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
125
125
  """
126
126
  Compute the logits for a given batch and task.
127
127
 
@@ -135,7 +135,7 @@ class WeightEnsemblingMoEAlgorithm(
135
135
  """
136
136
  pass
137
137
 
138
- def test_time_adaptation(self, module: WeightEnsemblingMoE):
138
+ def test_time_adaptation(self, module: WeightEnsemblingMoE) -> WeightEnsemblingMoE:
139
139
  """
140
140
  Perform test-time adaptation for the given module.
141
141
 
@@ -208,7 +208,7 @@ class WeightEnsemblingMoEAlgorithm(
208
208
 
209
209
  return module
210
210
 
211
- def run(self, modelpool: ModelPool):
211
+ def run(self, modelpool: ModelPool) -> WeightEnsemblingMoE:
212
212
  """
213
213
  Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
214
214
 
@@ -3,9 +3,11 @@ from typing import List, Mapping, Union # noqa: F401
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
+ from transformers import PreTrainedModel
6
7
  from typing_extensions import override
7
8
 
8
9
  from fusion_bench.method import BaseAlgorithm
10
+ from fusion_bench.mixins import auto_register_config
9
11
  from fusion_bench.modelpool import CausalLMPool
10
12
  from fusion_bench.utils import timeit_context
11
13
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
@@ -14,20 +16,12 @@ from fusion_bench.utils.type import StateDictType
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
19
+ @auto_register_config
17
20
  class WeightedAverageForLLama(BaseAlgorithm):
18
21
  """
19
22
  A class to perform weighted averaging of LlaMa/Mistral models.
20
23
  """
21
24
 
22
- _config_mapping = BaseAlgorithm._config_mapping | {
23
- "normalize": "normalize",
24
- "weights": "weights",
25
- "backbone_only": "backbone_only",
26
- "merged_model_save_path": "merged_model_save_path",
27
- "save_tokenizer": "save_tokenizer",
28
- "push_to_hub": "push_to_hub",
29
- }
30
-
31
25
  def __init__(
32
26
  self,
33
27
  normalize: bool,
@@ -49,17 +43,11 @@ class WeightedAverageForLLama(BaseAlgorithm):
49
43
  save_tokenizer (bool): Whether to save the tokenizer.
50
44
  push_to_hub (bool): Whether to push the model to the hub.
51
45
  """
52
- self.normalize = normalize
53
- self.weights = weights
54
- self.backbone_only = backbone_only
55
- self.merged_model_save_path = merged_model_save_path
56
- self.save_tokenizer = save_tokenizer
57
- self.push_to_hub = push_to_hub
58
46
  super().__init__(**kwargs)
59
47
 
60
48
  @override
61
49
  @torch.no_grad()
62
- def run(self, modelpool: CausalLMPool):
50
+ def run(self, modelpool: CausalLMPool) -> PreTrainedModel:
63
51
  """
64
52
  Executes the weighted averaging of models in the provided model pool.
65
53
 
@@ -0,0 +1 @@
1
+ from .backward_transfer import compute_backward_transfer
@@ -10,7 +10,7 @@ def compute_backward_transfer(
10
10
  Compute the backward transfer (BWT) of a model on a set of tasks.
11
11
 
12
12
  Equation:
13
- BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{Ti}[k] - acc_{ii}[k])
13
+ $BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{T,i}[k] - acc_{i,i}[k])$
14
14
 
15
15
  Returns:
16
16
  float: The backward transfer of the model.
@@ -1,10 +1,10 @@
1
1
  from .depth import DepthMetric
2
2
  from .noise import NoiseMetric
3
3
  from .normal import NormalMetric
4
- from .segmentation import SegmentationMertic
4
+ from .segmentation import SegmentationMetric
5
5
 
6
6
  metric_classes = {
7
- "segmentation": SegmentationMertic,
7
+ "segmentation": SegmentationMetric,
8
8
  "depth": DepthMetric,
9
9
  "normal": NormalMetric,
10
10
  "noise": NoiseMetric,
@@ -5,7 +5,7 @@ from torch import Tensor, nn
5
5
  from torchmetrics import Metric
6
6
 
7
7
 
8
- class SegmentationMertic(Metric):
8
+ class SegmentationMetric(Metric):
9
9
  metric_names = ["mIoU", "pixAcc"]
10
10
 
11
11
  def __init__(self, num_classes=13):
@@ -11,7 +11,11 @@ _import_structure = {
11
11
  "hydra_config": ["HydraConfigMixin"],
12
12
  "lightning_fabric": ["LightningFabricMixin"],
13
13
  "openclip_classification": ["OpenCLIPClassificationMixin"],
14
- "serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
14
+ "serialization": [
15
+ "BaseYAMLSerializable",
16
+ "YAMLSerializationMixin",
17
+ "auto_register_config",
18
+ ],
15
19
  "simple_profiler": ["SimpleProfilerMixin"],
16
20
  }
17
21
 
@@ -21,7 +25,11 @@ if TYPE_CHECKING:
21
25
  from .hydra_config import HydraConfigMixin
22
26
  from .lightning_fabric import LightningFabricMixin
23
27
  from .openclip_classification import OpenCLIPClassificationMixin
24
- from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
28
+ from .serialization import (
29
+ BaseYAMLSerializable,
30
+ YAMLSerializationMixin,
31
+ auto_register_config,
32
+ )
25
33
  from .simple_profiler import SimpleProfilerMixin
26
34
  else:
27
35
  sys.modules[__name__] = LazyImporter(
@@ -6,6 +6,7 @@ from typing import ( # noqa: F401
6
6
  TYPE_CHECKING,
7
7
  Any,
8
8
  Dict,
9
+ Iterator,
9
10
  List,
10
11
  Optional,
11
12
  Tuple,
@@ -21,6 +22,7 @@ from torch.utils.data import DataLoader
21
22
  from tqdm.auto import tqdm
22
23
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
23
24
 
25
+ from fusion_bench import cache_with_joblib
24
26
  from fusion_bench.dataset.clip_dataset import CLIPDataset
25
27
  from fusion_bench.mixins import LightningFabricMixin
26
28
  from fusion_bench.modelpool import CLIPVisionModelPool
@@ -45,15 +47,13 @@ class CLIPClassificationMixin(LightningFabricMixin):
45
47
 
46
48
  - `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
47
49
  - `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
48
- - `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
49
50
  """
50
51
 
51
- _dataloader_kwargs: Dict[str, Any] = {}
52
+ dataloader_kwargs: Dict[str, Any] = {}
52
53
  # the modelpool is set by inheriting class
53
54
  modelpool: CLIPVisionModelPool = None
54
55
  _clip_processor: CLIPProcessor = None
55
56
  # a dict of zeroshot weights for each task, each key is the task name
56
- zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
57
57
  zeroshot_weights: Dict[str, torch.Tensor] = {}
58
58
  whether_setup_zero_shot_classification_head = False
59
59
 
@@ -71,7 +71,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
71
71
  batch_size: Optional[int] = None,
72
72
  num_workers: Optional[int] = None,
73
73
  **loader_kwargs,
74
- ):
74
+ ) -> Iterator:
75
75
  """
76
76
  Get an iterator for a shuffled test DataLoader.
77
77
 
@@ -89,7 +89,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
89
89
  Iterator: An iterator over the shuffled test DataLoader.
90
90
  """
91
91
  # get dataloader kwargs
92
- dataloader_kwargs = self._dataloader_kwargs.copy()
92
+ dataloader_kwargs = self.dataloader_kwargs.copy()
93
93
  dataloader_kwargs["shuffle"] = True
94
94
  if batch_size is not None:
95
95
  dataloader_kwargs["batch_size"] = batch_size
@@ -130,26 +130,16 @@ class CLIPClassificationMixin(LightningFabricMixin):
130
130
  self.visual_projection = self.fabric.to_device(self.visual_projection)
131
131
  self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
132
132
 
133
- # get cache directory
134
- if self.modelpool.has_pretrained:
135
- model_name = self.modelpool.get_model_config("_pretrained_")
136
- if not isinstance(model_name, str):
137
- model_name = model_name.pretrained_model_name_or_path
138
- else:
139
- model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
140
- if not isinstance(model_name, str):
141
- model_name = model_name.pretrained_model_name_or_path
142
- cache_dir = os.path.join(
143
- self.zeroshot_weights_cache_dir,
144
- os.path.normpath(model_name.split("/")[-1]),
145
- )
146
- if not os.path.exists(cache_dir):
147
- log.info(
148
- f"Creating cache directory for zero-shot classification head at {cache_dir}"
149
- )
150
- os.makedirs(cache_dir)
133
+ @cache_with_joblib()
134
+ def construct_classification_head(task: str):
135
+ nonlocal clip_classifier
136
+
137
+ classnames, templates = get_classnames_and_templates(task)
138
+ clip_classifier.set_classification_task(classnames, templates)
139
+ zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
140
+
141
+ return zeroshot_weights
151
142
 
152
- log.info(f"cache directory for zero-shot classification head: {cache_dir}")
153
143
  for task in tqdm(
154
144
  self.modelpool.model_names if task_names is None else task_names,
155
145
  "Setting up zero-shot classification head",
@@ -157,27 +147,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
157
147
  ):
158
148
  zeroshot_weights = None
159
149
  if self.fabric.is_global_zero:
160
- cache_file = os.path.join(
161
- cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
162
- )
163
- if os.path.exists(cache_file):
164
- zeroshot_weights = torch.load(
165
- cache_file,
166
- map_location="cpu",
167
- weights_only=True,
168
- ).detach()
169
- log.info(
170
- f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
171
- )
172
- else:
173
- log.info(
174
- f"Construct zero shot classification head for task: {task}"
175
- )
176
- classnames, templates = get_classnames_and_templates(task)
177
- clip_classifier.set_classification_task(classnames, templates)
178
- zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
179
- log.info(f"save zeroshot weights to {cache_file}")
180
- torch.save(zeroshot_weights, cache_file)
150
+ zeroshot_weights = construct_classification_head(task)
181
151
 
182
152
  self.fabric.barrier()
183
153
  self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)