fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,388 @@
1
+ """
2
+ This is an experimental implementation of the Layer-Wise AdaMerging Algorithm for Flan-T5 models.
3
+ The efficiency of the algorithm is not guaranteed, and it may not work as expected.
4
+ """
5
+
6
+ import functools
7
+ import gc
8
+ import logging
9
+ import os
10
+ from abc import abstractmethod
11
+ from pathlib import Path
12
+ from types import SimpleNamespace
13
+ from typing import Any, Dict, List, Mapping, Optional, Union, cast # noqa: F401
14
+
15
+ import torch
16
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
17
+ from omegaconf import DictConfig, ListConfig
18
+ from torch import Tensor, nn
19
+ from torch.utils.data import DataLoader
20
+ from tqdm.autonotebook import tqdm
21
+ from transformers import T5ForConditionalGeneration
22
+ from transformers.data import default_data_collator
23
+
24
+ from fusion_bench.compat.modelpool.base_pool import ModelPool
25
+ from fusion_bench.method import BaseAlgorithm
26
+ from fusion_bench.method.simple_average import simple_average
27
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
28
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
29
+ from fusion_bench.modelpool import Seq2SeqLMPool
30
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
31
+ LayerWiseMergedModel,
32
+ get_layer_wise_weights,
33
+ )
34
+ from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
35
+ from fusion_bench.utils.instantiate import instantiate
36
+
37
+ from .entropy_loss import entropy_loss
38
+ from .layer_wise_gossip import ModelScheduler
39
+ from .min_norm_solvers import MinNormSolver
40
+ from .utils import get_memory_usage
41
+
42
+ log = logging.getLogger(__name__)
43
+
44
+
45
+ class FlanT5LayerWiseGossipAlgorithm(
46
+ BaseAlgorithm,
47
+ LightningFabricMixin,
48
+ SimpleProfilerMixin,
49
+ ):
50
+
51
+ def __init__(
52
+ self,
53
+ optimizer: DictConfig,
54
+ dataloader_kwargs: DictConfig,
55
+ init_values: float,
56
+ max_steps: int,
57
+ merging_weights_load_path: Optional[Union[str, Path]] = None,
58
+ merging_weights_save_path: Optional[Union[str, Path]] = None,
59
+ clamp_weights: bool = False,
60
+ tie_weights: bool = True,
61
+ strict: bool = False,
62
+ cache_dir: str = "outputs/cache",
63
+ variant: Optional[str] = None,
64
+ **kwargs,
65
+ ):
66
+ self._optimizer = optimizer
67
+ self.dataloader_kwargs = dataloader_kwargs
68
+ self.init_values = init_values
69
+ self.merging_weights_load_path = merging_weights_load_path
70
+ self.merging_weights_save_path = merging_weights_save_path
71
+ self.clamp_weights = clamp_weights
72
+ self.tie_weights = tie_weights
73
+ self.strict = strict
74
+ self.max_steps = max_steps
75
+ self.cache_dir = cache_dir
76
+ self.variant = variant
77
+
78
+ self.configs = SimpleNamespace(**kwargs)
79
+ self.configs.init_values = init_values
80
+ self.configs.clamp_weights = clamp_weights
81
+ self.configs.tie_weights = tie_weights
82
+ self.configs.strict = strict
83
+ if isinstance(self.configs.accuracy_test_interval, ListConfig):
84
+ self.configs.accuracy_test_interval = list(
85
+ self.configs.accuracy_test_interval
86
+ )
87
+ elif isinstance(self.configs.accuracy_test_interval, int):
88
+ pass
89
+ else:
90
+ log.warning(
91
+ f"Unexpected type of accuracy_test_interval: {type(self.configs.accuracy_test_interval)}"
92
+ )
93
+ super().__init__(**kwargs)
94
+
95
+ @rank_zero_only
96
+ def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
97
+ """
98
+ Save the merging weights to a file.
99
+
100
+ Args:
101
+ file_path (str): The path to save the merging weights.
102
+ merging_weights (torch.Tensor): The merging weights to save.
103
+ """
104
+ if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
105
+ if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
106
+ # if the file path is not absolute or relative to current working directory, save it in the log directory
107
+ save_path = os.path.join(self.log_dir, file_path)
108
+ else:
109
+ save_path = file_path
110
+ log.info(f"saving merging weights to {save_path}.")
111
+ if os.path.dirname(save_path):
112
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
113
+ torch.save(merging_weights.detach().cpu(), save_path)
114
+
115
+ def free_gpu_memory(self, module: LayerWiseMergedModel):
116
+ module.pretrained_model.to("cpu")
117
+ for model in module.task_vectors:
118
+ model.to("cpu")
119
+ del module
120
+ gc.collect()
121
+ torch.cuda.empty_cache()
122
+ log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
123
+
124
+ def update_datasets(self, datasets):
125
+ """
126
+ for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion
127
+ """
128
+ num_datasets = len(datasets)
129
+ datasets_copy = datasets.copy()
130
+ for i in range(num_datasets):
131
+ datasets[i] = (
132
+ datasets_copy[i]
133
+ .union(datasets_copy[(i + 1) % num_datasets])
134
+ .union(datasets_copy[(i - 1) % num_datasets])
135
+ )
136
+ return datasets
137
+
138
+ def run(self, modelpool: Seq2SeqLMPool, **kwargs):
139
+ """
140
+ Run the Layer-Wise AdaMerging Algorithm.
141
+
142
+ This method constructs the wrapped model and performs test-time adaptation if necessary.
143
+
144
+ Args:
145
+ modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
146
+
147
+ Returns:
148
+ LayerWiseMergedModel: The merged model after test-time adaptation.
149
+ """
150
+ log.info("Fusing models using layer-wise adaptive merging.")
151
+ self.modelpool = modelpool
152
+ self.num_finetuned_models = len(modelpool.model_names)
153
+ datasets = [{dataset} for dataset in modelpool.model_names]
154
+
155
+ with self.profile("construct the wrapped model"):
156
+ model_scheduler = ModelScheduler(self.configs, self.modelpool)
157
+
158
+ if self.merging_weights_load_path is not None:
159
+ # skip the test-time adaptation
160
+ return module.merge_and_unload()
161
+ else:
162
+ for step_idx in tqdm(
163
+ range(self.configs.gossip_max_steps),
164
+ "Gossip merging",
165
+ dynamic_ncols=True,
166
+ ):
167
+ datasets = self.update_datasets(datasets)
168
+ log.info(f"Gossip merging step:, {step_idx}")
169
+ for model_id in tqdm(
170
+ range(self.num_finetuned_models),
171
+ "local admerging",
172
+ dynamic_ncols=True,
173
+ ):
174
+ if self.configs.gossip_skip_adamerging == True:
175
+ # skip adamerging, only merge
176
+ with self.profile("construct the local wrapped model"):
177
+ module = model_scheduler(model_id)
178
+ log.info(
179
+ f"skip adamerging, only merge ({modelpool.model_names[model_id]})"
180
+ )
181
+ model_scheduler.store_model(module.merge_weights(), model_id)
182
+ self.free_gpu_memory(module)
183
+ else:
184
+ with self.profile("construct the local wrapped model"):
185
+ module = model_scheduler(model_id)
186
+
187
+ if self.configs.improve_dataset == True:
188
+ log.info(
189
+ f"improved datasets, the datasets used in this local merging is {datasets[model_id]}"
190
+ )
191
+ else:
192
+ log.info(
193
+ f"unimproved datasets, the datasets used in this local merging is {modelpool.model_names}"
194
+ )
195
+ with self.profile("test-time adaptation"):
196
+ module = self.test_time_adaptation(
197
+ module, datasets[model_id]
198
+ )
199
+ # if self.configs.get("save_merging_weights", False):
200
+ # self.save_merging_weights(
201
+ # self.configs.save_merging_weights, module.merge_weight
202
+ # )
203
+ model_scheduler.store_model(module.merge_weights(), model_id)
204
+ log.info(
205
+ get_memory_usage(
206
+ f"after local merging ({modelpool.model_names[model_id]}), the memory usage of GPU is:"
207
+ )
208
+ )
209
+ self.free_gpu_memory(
210
+ module
211
+ ) # simulate distributed GPU memory usage as much as possible
212
+
213
+ model_scheduler.update_models()
214
+ do_evaluation = False # whether to do evaluation after each Gossip step
215
+ if isinstance(self.configs.accuracy_test_interval, list):
216
+ if (step_idx + 1) in self.configs.accuracy_test_interval:
217
+ do_evaluation = True
218
+ elif isinstance(self.configs.accuracy_test_interval, int):
219
+ if (
220
+ self.configs.accuracy_test_interval != 0
221
+ and (step_idx + 1) % self.configs.accuracy_test_interval == 0
222
+ ):
223
+ do_evaluation = True
224
+ if do_evaluation:
225
+ self._program.evaluate_merged_model(
226
+ self._program.taskpool, model_scheduler.get_final_models()
227
+ )
228
+ model_scheduler.move_to("cpu")
229
+
230
+ return model_scheduler.get_final_models()
231
+
232
+ @functools.cache
233
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
234
+ """
235
+ Loader of test dataset for test-time adaptation. labels are not needed.
236
+
237
+ Args:
238
+ task (str): The name of the task.
239
+
240
+ Returns:
241
+ DataLoader: The data loader for the test dataset.
242
+ """
243
+ dataloader_kwargs = dict(self.dataloader_kwargs)
244
+ dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))
245
+
246
+ dataset = self.modelpool.load_test_dataset(task)
247
+ loader = DataLoader(dataset, **dataloader_kwargs)
248
+
249
+ if self.fabric is not None:
250
+ loader = self.fabric.setup_dataloaders(loader)
251
+ return iter(InfiniteDataLoader(loader))
252
+
253
+ def compute_logits(
254
+ self,
255
+ module: Union[T5ForConditionalGeneration, LayerWiseMergedModel],
256
+ batch,
257
+ task: str,
258
+ ) -> Tensor:
259
+ """
260
+ Compute the logits for the given images and task.
261
+
262
+ Args:
263
+ module: The model module.
264
+ images (Tensor): The input images.
265
+ task (str): The name of the task.
266
+
267
+ Returns:
268
+ Tensor: The computed logits.
269
+ """
270
+ input_ids: Tensor = batch["input_ids"]
271
+ attention_mask: Tensor = batch["attention_mask"]
272
+
273
+ # remove padding tokens from the input
274
+ while attention_mask[:, -1].eq(0).all():
275
+ input_ids = input_ids[:, :-1]
276
+ attention_mask = attention_mask[:, :-1]
277
+
278
+ outputs = module(
279
+ input_ids=input_ids,
280
+ attention_mask=attention_mask,
281
+ decoder_input_ids=torch.ones(
282
+ input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
283
+ ),
284
+ )
285
+ logits = outputs.logits[:, 0, :]
286
+ return logits
287
+
288
+ def on_test_time_adaptation_start(self):
289
+ """
290
+ Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
291
+ """
292
+ pass
293
+
294
+ def test_time_adaptation(self, module: LayerWiseMergedModel, datasets):
295
+ """
296
+ Perform test-time adaptation on the merged model.
297
+
298
+ This method adapts the merging weights during test-time to improve performance.
299
+
300
+ Args:
301
+ module (LayerWiseMergedModel): The merged model.
302
+
303
+ Returns:
304
+ LayerWiseMergedModel: The adapted merged model.
305
+ """
306
+ self.on_test_time_adaptation_start()
307
+
308
+ # configure optimizer
309
+ optimizer = instantiate(self._optimizer, [module.merge_weight])
310
+ module, optimizer = self.fabric.setup(module, optimizer)
311
+
312
+ module.train()
313
+ module.merge_weights()
314
+ for step_idx in (
315
+ pbar := tqdm(
316
+ range(self.max_steps if not self.is_debug_mode else 1),
317
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
318
+ + "AdaMerging Test-time adaptation",
319
+ dynamic_ncols=True,
320
+ )
321
+ ):
322
+ if self.variant == "mgda":
323
+ total_loss = self._compute_gradients_using_mgda(module)
324
+ else:
325
+ total_loss = 0
326
+ for task in self.modelpool.model_names:
327
+ with self.profile("data loading"):
328
+ batch = next(self.get_shuffled_test_loader_iter(task))
329
+ with self.profile("forward pass"):
330
+ logits = self.compute_logits(module, batch, task)
331
+ logits = logits.mean(dim=0, keepdim=True)
332
+ loss = entropy_loss(logits)
333
+ total_loss += loss
334
+ with self.profile("backward pass"):
335
+ self.fabric.backward(loss, retain_graph=True)
336
+
337
+ with self.profile("optimizer step"):
338
+ optimizer.step()
339
+ optimizer.zero_grad()
340
+ with self.profile("merging weights"):
341
+ module.merge_weights()
342
+
343
+ metrics = {
344
+ "train/loss": total_loss.item(),
345
+ "train/weight_max": module.merge_weight.max().item(),
346
+ "train/weight_min": module.merge_weight.min().item(),
347
+ "train/weight_mean": module.merge_weight.mean().item(),
348
+ }
349
+ self.fabric.log_dict(metrics, step=step_idx)
350
+ pbar.set_postfix(metrics)
351
+
352
+ self.print_profile_summary()
353
+ del optimizer
354
+ gc.collect()
355
+ torch.cuda.empty_cache()
356
+ return module
357
+
358
+ def _compute_gradients_using_mgda(self, module: LayerWiseMergedModel):
359
+ all_grads = []
360
+ total_loss = 0
361
+ # default behavior for first-order optimizers
362
+ for task in self.modelpool.model_names:
363
+ with self.profile("data loading"):
364
+ batch = next(self.get_shuffled_test_loader_iter(task))
365
+ with self.profile("forward pass"):
366
+ logits = self.compute_logits(module, batch, task)
367
+ logits = logits.mean(dim=0, keepdim=True)
368
+ loss = entropy_loss(logits)
369
+ total_loss += loss
370
+ with self.profile("backward pass"):
371
+ # self.fabric.backward(loss, retain_graph=True)
372
+ _grads = torch.autograd.grad(
373
+ loss,
374
+ [module.merge_weight],
375
+ create_graph=False,
376
+ retain_graph=True,
377
+ )
378
+ all_grads.append(_grads[0].flatten().detach())
379
+ sol, min_norm = MinNormSolver.find_min_norm_element(all_grads)
380
+ if not isinstance(sol, torch.Tensor):
381
+ sol = torch.from_numpy(sol)
382
+ sol = sol.to(
383
+ device=module.merge_weight.device,
384
+ dtype=module.merge_weight.dtype,
385
+ )
386
+ grad = torch.stack(all_grads) * sol.view(-1, 1)
387
+ module.merge_weight.grad = grad.sum(dim=0).view_as(module.merge_weight)
388
+ return total_loss