fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -3
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +2 -3
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  56. fusion_bench/method/simple_average.py +5 -9
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +5 -5
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/modelscope.py +127 -8
  122. fusion_bench/utils/parameters.py +2 -2
  123. fusion_bench/utils/rich_utils.py +3 -0
  124. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  125. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
  126. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
  127. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  128. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  129. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  130. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  131. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  132. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  133. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  134. fusion_bench_config/hydra/default.yaml +6 -2
  135. fusion_bench_config/llama_full_finetune.yaml +1 -0
  136. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  137. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  138. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  139. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  140. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  141. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  142. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  143. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  144. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  148. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  149. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  150. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  151. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  152. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  153. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  154. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  155. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  156. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  157. fusion_bench_config/nyuv2_config.yaml +3 -1
  158. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  159. fusion_bench_config/path/default.yaml +28 -0
  160. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  161. fusion_bench_config/method/adamerging.yaml +0 -23
  162. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  163. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  164. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  165. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  166. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  167. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  168. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  169. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -0,0 +1,145 @@
1
+ import os
2
+ from typing import Literal
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
8
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ class ProjectedEnergyAnalysis(
14
+ SimpleProfilerMixin,
15
+ LightningFabricMixin,
16
+ BaseAlgorithm,
17
+ ):
18
+ def on_run_start(self):
19
+ self.device = self.fabric.device
20
+
21
+ def run(self, modelpool: BaseModelPool):
22
+ with self.profile("model loading"):
23
+ base_model = modelpool.load_pretrained_model()
24
+
25
+ results = {
26
+ "model_name": [],
27
+ "module_index": [],
28
+ "module_name": [],
29
+ "projected_energy_I": [],
30
+ "projected_energy_II": [],
31
+ "projected_energy_II_III": [],
32
+ }
33
+ for model_name in tqdm(
34
+ modelpool.model_names,
35
+ "analyzing",
36
+ dynamic_ncols=True,
37
+ ):
38
+ with self.profile("model loading"):
39
+ finetuned_model = modelpool.load_model(model_name)
40
+
41
+ module_index = 0
42
+ for module_name, base_module in tqdm(
43
+ list(base_model.named_modules()),
44
+ "analyzing modules",
45
+ dynamic_ncols=True,
46
+ ):
47
+ if isinstance(base_module, torch.nn.Linear):
48
+ with self.profile("weight analysis"):
49
+ _result = self.analyze_weight(
50
+ base_module.weight,
51
+ finetuned_model.get_submodule(module_name).weight,
52
+ )
53
+ results["model_name"].append(model_name)
54
+ results["module_index"].append(module_index)
55
+ results["module_name"].append(module_name)
56
+ for key, value in _result.items():
57
+ results[key].append(value)
58
+
59
+ module_index += 1
60
+
61
+ # save results as csv
62
+ results = pd.DataFrame(results)
63
+ results.to_csv(
64
+ os.path.join(self.log_dir, "projected_energy_analysis.csv"), index=True
65
+ )
66
+
67
+ self.print_profile_summary()
68
+ return None
69
+
70
+ @torch.no_grad()
71
+ def analyze_weight(self, w: torch.Tensor, w_ft: torch.Tensor, k: int = -1):
72
+ w = w.to(dtype=torch.float32, device=self.device)
73
+ w_ft = w_ft.to(dtype=torch.float32, device=self.device)
74
+ w_diff = w_ft - w
75
+
76
+ # Perform analysis on the weight tensor
77
+ u, s, vh = torch.linalg.svd(w, full_matrices=False)
78
+ v = vh.T
79
+ if k < 0:
80
+ # find the position where the sum of singular values is larger than 50% of the total sum
81
+ cumsum = s.cumsum(0)
82
+ k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
83
+
84
+ # subspace I
85
+ w_diff_proj = self._project_subspace_low(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
86
+ projected_energy_I = (
87
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
88
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
89
+ )
90
+
91
+ # subspace II
92
+ w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
93
+ projected_energy_II = (
94
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
95
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
96
+ )
97
+
98
+ ## subspace II+III
99
+ u, s, vh = torch.linalg.svd(w, full_matrices=True)
100
+ v = vh.T
101
+ w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
102
+ projected_energy_II_III = (
103
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
104
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
105
+ )
106
+
107
+ return {
108
+ "projected_energy_I": projected_energy_I.item(),
109
+ "projected_energy_II": projected_energy_II.item(),
110
+ "projected_energy_II_III": projected_energy_II_III.item(),
111
+ }
112
+
113
+ def _project_subspace_low(
114
+ self,
115
+ u: torch.Tensor,
116
+ s: torch.Tensor,
117
+ v: torch.Tensor,
118
+ k: int,
119
+ w: torch.Tensor,
120
+ w_ft: torch.Tensor,
121
+ ):
122
+ u = u[:, :k]
123
+ s = s[:k]
124
+ v = v[:, :k]
125
+
126
+ w_diff = w_ft - w
127
+ w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
128
+ return w_diff_proj
129
+
130
+ def _project_subspace_high(
131
+ self,
132
+ u: torch.Tensor,
133
+ s: torch.Tensor,
134
+ v: torch.Tensor,
135
+ k: int,
136
+ w: torch.Tensor,
137
+ w_ft: torch.Tensor,
138
+ ):
139
+ u = u[:, k:]
140
+ s = s[k:]
141
+ v = v[:, k:]
142
+
143
+ w_diff = w_ft - w
144
+ w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
145
+ return w_diff_proj
@@ -16,10 +16,16 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
18
  from fusion_bench.compat.modelpool import to_modelpool
19
- from fusion_bench.mixins import SimpleProfilerMixin
19
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
+ from fusion_bench.modelpool import CausalLMPool
21
+ from fusion_bench.models.hf_utils import (
22
+ generate_complete_readme,
23
+ save_pretrained_with_remote_code,
24
+ )
20
25
  from fusion_bench.models.modeling_smile_qwen2 import (
21
26
  SmileQwen2Config,
22
27
  SmileQwen2ForCausalLM,
28
+ SmileQwen2Model,
23
29
  )
24
30
  from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
25
31
  SmileQwen2DecoderLayer,
@@ -34,6 +40,7 @@ from fusion_bench.utils.parameters import print_parameters
34
40
  log = logging.getLogger(__name__)
35
41
 
36
42
 
43
+ @auto_register_config
37
44
  class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
38
45
  R"""
39
46
  SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
@@ -49,15 +56,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
49
56
  Merges the pretrained model with the fine-tuned models to create an upscaled model.
50
57
  """
51
58
 
52
- _config_mapping = BaseAlgorithm._config_mapping | {
53
- "device": "device",
54
- "accelerator": "accelerator",
55
- "model_path": "model_path",
56
- "model_dtype": "model_dtype",
57
- "num_experts_per_tok": "num_experts_per_tok",
58
- "rank_of_router": "rank_of_router",
59
- "rank_of_expert": "rank_of_expert",
60
- }
59
+ modelpool: CausalLMPool
61
60
 
62
61
  def __init__(
63
62
  self,
@@ -68,20 +67,13 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
68
67
  num_experts_per_tok,
69
68
  rank_of_router,
70
69
  rank_of_expert,
70
+ save_with_remote_code: bool = True,
71
71
  **kwargs,
72
72
  ):
73
- self.device = device
74
- self.accelerator = accelerator
75
- self.model_path = model_path
76
- self.model_dtype = model_dtype
77
- # SmileMoE parameters, except `num_local_experts` which is set later according to the number of finetuned models
78
- self.num_experts_per_tok = num_experts_per_tok
79
- self.rank_of_router = rank_of_router
80
- self.rank_of_expert = rank_of_expert
81
73
  super().__init__(**kwargs)
82
74
 
83
75
  @torch.no_grad()
84
- def run(self, modelpool: BaseModelPool) -> SmileQwen2ForCausalLM:
76
+ def run(self, modelpool) -> SmileQwen2ForCausalLM:
85
77
  """
86
78
  Executes the upscaling process.
87
79
 
@@ -129,13 +121,29 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
129
121
  if os.path.dirname(config.model_path):
130
122
  os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
131
123
  log.info(f"Saving model to {config.model_path}")
132
- pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
133
- pretrained_path = pretrained_model_config.get(
134
- "path", pretrained_model_config["pretrained_model_name_or_path"]
135
- )
136
- tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
124
+ tokenizer = self.modelpool.load_tokenizer()
137
125
  tokenizer.save_pretrained(config.model_path)
138
- model.save_pretrained(config.model_path)
126
+ if not self.save_with_remote_code:
127
+ model.save_pretrained(config.model_path)
128
+ else:
129
+ save_pretrained_with_remote_code(
130
+ model,
131
+ auto_map={
132
+ "AutoConfig": SmileQwen2Config,
133
+ "AutoModel": SmileQwen2Model,
134
+ "AutoModelForCausalLM": SmileQwen2ForCausalLM,
135
+ },
136
+ save_directory=config.model_path,
137
+ )
138
+
139
+ # save readme
140
+ complete_readme = generate_complete_readme(
141
+ algorithm=self,
142
+ modelpool=modelpool,
143
+ models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
144
+ )
145
+ with open(os.path.join(config.model_path, "README.md"), "w") as f:
146
+ f.write(complete_readme)
139
147
 
140
148
  return model
141
149
 
@@ -158,9 +166,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
158
166
 
159
167
  with init_empty_weights():
160
168
  pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
161
- pretrained_path = pretrained_model_config.get(
162
- "path", pretrained_model_config["pretrained_model_name_or_path"]
163
- )
169
+ if isinstance(pretrained_model_config, str):
170
+ pretrained_path = pretrained_model_config
171
+ else:
172
+ pretrained_path = pretrained_model_config.get(
173
+ "path", pretrained_model_config["pretrained_model_name_or_path"]
174
+ )
164
175
  base_config = AutoConfig.from_pretrained(pretrained_path)
165
176
  model_config = SmileQwen2Config(
166
177
  num_experts_per_tok=config.num_experts_per_tok,
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  from copy import deepcopy
4
- from typing import Dict, List, Tuple # noqa: F401
4
+ from typing import Any, Dict, List, Tuple # noqa: F401
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -21,6 +21,7 @@ from fusion_bench.models.smile_moe.linear_from_module import (
21
21
  )
22
22
  from fusion_bench.models.utils import get_attr, set_attr
23
23
  from fusion_bench.utils.parameters import print_parameters
24
+ from fusion_bench.utils.devices import get_device
24
25
 
25
26
  log = logging.getLogger(__name__)
26
27
 
@@ -54,7 +55,7 @@ class SmileUpscalingAlgorithm(
54
55
  routing_use_diff: bool = True,
55
56
  average_experts: bool = False,
56
57
  model_path: str = None,
57
- **kwargs,
58
+ **kwargs: Any,
58
59
  ):
59
60
  """
60
61
  Initialize the SmileUpscalingAlgorithm.
@@ -91,7 +92,7 @@ class SmileUpscalingAlgorithm(
91
92
  print(f"=== Config for `{type(self).__name__}` ===")
92
93
 
93
94
  @torch.no_grad()
94
- def run(self, modelpool: BaseModelPool):
95
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
95
96
  """
96
97
  Executes the upscaling process.
97
98
 
@@ -142,7 +143,7 @@ class SmileUpscalingAlgorithm(
142
143
  pretrained_model: nn.Module,
143
144
  finetuned_models: List[nn.Module],
144
145
  in_place: bool = True,
145
- ):
146
+ ) -> nn.Module:
146
147
  """
147
148
  Merges the pretrained model with the fine-tuned models to create an upscaled model.
148
149
 
@@ -180,7 +181,12 @@ class SmileUpscalingAlgorithm(
180
181
 
181
182
  name_list = name.split(".")
182
183
  module = get_attr(pretrained_model, name_list)
183
- experts = [get_attr(m, name_list) for m in finetuned_models]
184
+ original_device = get_device(module)
185
+ module = module.to(self.device, non_blocking=True)
186
+ experts = [
187
+ get_attr(m, name_list).to(self.device, non_blocking=True)
188
+ for m in finetuned_models
189
+ ]
184
190
  try:
185
191
  moe_linear = SmileMoELinear(
186
192
  module,
@@ -192,6 +198,7 @@ class SmileUpscalingAlgorithm(
192
198
  full_matrices=self.full_matrices,
193
199
  upscaling_accelerator=self.upscaling_accelerator,
194
200
  )
201
+ moe_linear = moe_linear.to(original_device, non_blocking=True)
195
202
  except ExpertNotTrainedError:
196
203
  print(f"skip {name} because the experts are not trained.")
197
204
  return
@@ -9,7 +9,7 @@ from copy import deepcopy
9
9
  import torch
10
10
 
11
11
  from fusion_bench import BaseAlgorithm
12
- from fusion_bench.mixins import SimpleProfilerMixin
12
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
13
13
  from fusion_bench.modelpool import BaseModelPool
14
14
  from fusion_bench.utils.state_dict_arithmetic import (
15
15
  state_dict_add,
@@ -58,16 +58,11 @@ def generate_task_masks(
58
58
  return final_mask
59
59
 
60
60
 
61
+ @auto_register_config
61
62
  class TallMaskTaskArithmeticAlgorithm(
62
- BaseAlgorithm,
63
63
  SimpleProfilerMixin,
64
+ BaseAlgorithm,
64
65
  ):
65
- _config_mapping = BaseAlgorithm._config_mapping | {
66
- "tall_mask_lambda": "tall_mask_lambda",
67
- "debug": "debug",
68
- "verbose": "verbose",
69
- }
70
-
71
66
  def __init__(
72
67
  self,
73
68
  tall_mask_lambda: float,
@@ -76,9 +71,6 @@ class TallMaskTaskArithmeticAlgorithm(
76
71
  **kwargs,
77
72
  ):
78
73
  super().__init__(**kwargs)
79
- self.tall_mask_lambda = tall_mask_lambda
80
- self.debug = debug
81
- self.verbose = verbose
82
74
 
83
75
  @torch.no_grad()
84
76
  def run(self, modelpool: BaseModelPool):
@@ -12,7 +12,7 @@ import torch
12
12
  from torch import nn
13
13
 
14
14
  from fusion_bench.method.base_algorithm import BaseAlgorithm
15
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
15
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
16
  from fusion_bench.modelpool import BaseModelPool
17
17
  from fusion_bench.utils.state_dict_arithmetic import (
18
18
  state_dict_add,
@@ -74,9 +74,10 @@ def task_arithmetic_merge(
74
74
  return pretrained_model
75
75
 
76
76
 
77
+ @auto_register_config
77
78
  class TaskArithmeticAlgorithm(
78
- BaseAlgorithm,
79
79
  SimpleProfilerMixin,
80
+ BaseAlgorithm,
80
81
  ):
81
82
  """
82
83
  Task Arithmetic Algorithm for model fusion.
@@ -89,22 +90,17 @@ class TaskArithmeticAlgorithm(
89
90
  scaling_factor (int): The factor by which the task vectors will be scaled before merging.
90
91
  """
91
92
 
92
- _config_mapping = BaseAlgorithm._config_mapping | {
93
- "scaling_factor": "scaling_factor"
94
- }
95
-
96
- def __init__(self, scaling_factor: int):
93
+ def __init__(self, scaling_factor: int, **kwargs):
97
94
  """
98
95
  Initializes the TaskArithmeticAlgorithm with the given scaling factor.
99
96
 
100
97
  Args:
101
98
  scaling_factor (int): The factor by which the task vectors will be scaled before merging.
102
99
  """
103
- self.scaling_factor = scaling_factor
104
- super().__init__()
100
+ super().__init__(**kwargs)
105
101
 
106
102
  @torch.no_grad()
107
- def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
103
+ def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
108
104
  """
109
105
  Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
110
106
 
@@ -9,14 +9,14 @@ Overview of Ties-Merging:
9
9
  """
10
10
 
11
11
  import logging
12
- from typing import Dict, List, Literal, Mapping, Union # noqa: F401
12
+ from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
13
13
 
14
14
  import torch
15
15
  from torch import Tensor, nn
16
16
 
17
17
  from fusion_bench.compat.modelpool import to_modelpool
18
18
  from fusion_bench.method import BaseAlgorithm
19
- from fusion_bench.mixins import SimpleProfilerMixin
19
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
20
  from fusion_bench.modelpool import BaseModelPool
21
21
  from fusion_bench.utils.type import StateDictType
22
22
 
@@ -25,33 +25,22 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
25
25
  log = logging.getLogger(__name__)
26
26
 
27
27
 
28
- class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
29
- """
30
- TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
31
-
32
- Attributes:
33
- scaling_factor (float): The scaling factor to apply to the merged task vector.
34
- threshold (float): The threshold for resetting values in the task vector.
35
- remove_keys (List[str]): List of keys to remove from the state dictionary.
36
- merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
37
- """
38
-
39
- _config_mapping = BaseAlgorithm._config_mapping | {
40
- "scaling_factor": "scaling_factor",
41
- "threshold": "threshold",
42
- "remove_keys": "remove_keys",
43
- "merge_func": "merge_func",
44
- }
45
-
28
+ @auto_register_config
29
+ class TiesMergingAlgorithm(
30
+ SimpleProfilerMixin,
31
+ BaseAlgorithm,
32
+ ):
46
33
  def __init__(
47
34
  self,
48
35
  scaling_factor: float,
49
36
  threshold: float,
50
37
  remove_keys: List[str],
51
38
  merge_func: Literal["sum", "mean", "max"],
52
- **kwargs,
39
+ **kwargs: Any,
53
40
  ):
54
41
  """
42
+ TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
43
+
55
44
  Initialize the TiesMergingAlgorithm with the given parameters.
56
45
 
57
46
  Args:
@@ -61,14 +50,12 @@ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
61
50
  merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
62
51
  **kwargs: Additional keyword arguments for the base class.
63
52
  """
64
- self.scaling_factor = scaling_factor
65
- self.threshold = threshold
66
- self.remove_keys = remove_keys
67
- self.merge_func = merge_func
68
53
  super().__init__(**kwargs)
69
54
 
70
55
  @torch.no_grad()
71
- def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
56
+ def run(
57
+ self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs: Any
58
+ ) -> nn.Module:
72
59
  """
73
60
  Run the TIES merging algorithm to fuse models in the model pool.
74
61
 
@@ -2,6 +2,7 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
+ from typing import Any, Iterator
5
6
 
6
7
  import torch
7
8
  from torch import Tensor
@@ -38,7 +39,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
38
39
 
39
40
  modelpool: CLIPVisionModelPool = None
40
41
 
41
- def load_checkpoint(self, model, checkpoint):
42
+ def load_checkpoint(self, model: Any, checkpoint: Any):
42
43
  """
43
44
  Load the checkpoint file.
44
45
 
@@ -49,7 +50,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
49
50
  state = {"model": model}
50
51
  self._fabric.load(checkpoint, state)
51
52
 
52
- def save_checkpoint(self, model, checkpoint):
53
+ def save_checkpoint(self, model: Any, checkpoint: Any):
53
54
  """
54
55
  Save the checkpoint file.
55
56
 
@@ -102,7 +103,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
102
103
  return moe_model
103
104
 
104
105
  @functools.cache
105
- def get_shuffled_test_loader_iter(self, tta_dataset: str):
106
+ def get_shuffled_test_loader_iter(self, tta_dataset: str) -> Iterator:
106
107
  """
107
108
  Get an iterator for the shuffled test data loader.
108
109
 
@@ -131,7 +132,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
131
132
  """
132
133
  self.setup_zero_shot_classification_head()
133
134
 
134
- def compute_logits(self, module, batch, task) -> Tensor:
135
+ def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
135
136
  """
136
137
  Compute the logits for the given batch and task.
137
138
 
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
- from typing import cast # noqa: F401
3
+ from typing import Any, cast # noqa: F401
4
4
 
5
5
  import lightning as L
6
6
  import lightning.fabric.wrappers
@@ -70,7 +70,7 @@ class WeightEnsemblingMoEAlgorithm(
70
70
  assert "No CUDA device available."
71
71
 
72
72
  @abstractmethod
73
- def load_checkpoint(self, model, checkpoint):
73
+ def load_checkpoint(self, model: Any, checkpoint: Any):
74
74
  """
75
75
  Load the checkpoint file.
76
76
 
@@ -81,7 +81,7 @@ class WeightEnsemblingMoEAlgorithm(
81
81
  pass
82
82
 
83
83
  @abstractmethod
84
- def save_checkpoint(self, model, checkpoint):
84
+ def save_checkpoint(self, model: Any, checkpoint: Any):
85
85
  """
86
86
  Save the checkpoint file.
87
87
 
@@ -121,7 +121,7 @@ class WeightEnsemblingMoEAlgorithm(
121
121
  pass
122
122
 
123
123
  @abstractmethod
124
- def compute_logits(self, module, batch, task) -> Tensor:
124
+ def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
125
125
  """
126
126
  Compute the logits for a given batch and task.
127
127
 
@@ -135,7 +135,7 @@ class WeightEnsemblingMoEAlgorithm(
135
135
  """
136
136
  pass
137
137
 
138
- def test_time_adaptation(self, module: WeightEnsemblingMoE):
138
+ def test_time_adaptation(self, module: WeightEnsemblingMoE) -> WeightEnsemblingMoE:
139
139
  """
140
140
  Perform test-time adaptation for the given module.
141
141
 
@@ -208,7 +208,7 @@ class WeightEnsemblingMoEAlgorithm(
208
208
 
209
209
  return module
210
210
 
211
- def run(self, modelpool: ModelPool):
211
+ def run(self, modelpool: ModelPool) -> WeightEnsemblingMoE:
212
212
  """
213
213
  Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
214
214
 
@@ -3,6 +3,7 @@ from typing import List, Mapping, Union # noqa: F401
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
+ from transformers import PreTrainedModel
6
7
  from typing_extensions import override
7
8
 
8
9
  from fusion_bench.method import BaseAlgorithm
@@ -10,24 +11,17 @@ from fusion_bench.modelpool import CausalLMPool
10
11
  from fusion_bench.utils import timeit_context
11
12
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
12
13
  from fusion_bench.utils.type import StateDictType
14
+ from fusion_bench.mixins import auto_register_config
13
15
 
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
19
+ @auto_register_config
17
20
  class WeightedAverageForLLama(BaseAlgorithm):
18
21
  """
19
22
  A class to perform weighted averaging of LlaMa/Mistral models.
20
23
  """
21
24
 
22
- _config_mapping = BaseAlgorithm._config_mapping | {
23
- "normalize": "normalize",
24
- "weights": "weights",
25
- "backbone_only": "backbone_only",
26
- "merged_model_save_path": "merged_model_save_path",
27
- "save_tokenizer": "save_tokenizer",
28
- "push_to_hub": "push_to_hub",
29
- }
30
-
31
25
  def __init__(
32
26
  self,
33
27
  normalize: bool,
@@ -49,17 +43,11 @@ class WeightedAverageForLLama(BaseAlgorithm):
49
43
  save_tokenizer (bool): Whether to save the tokenizer.
50
44
  push_to_hub (bool): Whether to push the model to the hub.
51
45
  """
52
- self.normalize = normalize
53
- self.weights = weights
54
- self.backbone_only = backbone_only
55
- self.merged_model_save_path = merged_model_save_path
56
- self.save_tokenizer = save_tokenizer
57
- self.push_to_hub = push_to_hub
58
46
  super().__init__(**kwargs)
59
47
 
60
48
  @override
61
49
  @torch.no_grad()
62
- def run(self, modelpool: CausalLMPool):
50
+ def run(self, modelpool: CausalLMPool) -> PreTrainedModel:
63
51
  """
64
52
  Executes the weighted averaging of models in the provided model pool.
65
53
 
@@ -0,0 +1 @@
1
+ from .backward_transfer import compute_backward_transfer
@@ -10,7 +10,7 @@ def compute_backward_transfer(
10
10
  Compute the backward transfer (BWT) of a model on a set of tasks.
11
11
 
12
12
  Equation:
13
- BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{Ti}[k] - acc_{ii}[k])
13
+ $BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{T,i}[k] - acc_{i,i}[k])$
14
14
 
15
15
  Returns:
16
16
  float: The backward transfer of the model.
@@ -1,10 +1,10 @@
1
1
  from .depth import DepthMetric
2
2
  from .noise import NoiseMetric
3
3
  from .normal import NormalMetric
4
- from .segmentation import SegmentationMertic
4
+ from .segmentation import SegmentationMetric
5
5
 
6
6
  metric_classes = {
7
- "segmentation": SegmentationMertic,
7
+ "segmentation": SegmentationMetric,
8
8
  "depth": DepthMetric,
9
9
  "normal": NormalMetric,
10
10
  "noise": NoiseMetric,
@@ -5,7 +5,7 @@ from torch import Tensor, nn
5
5
  from torchmetrics import Metric
6
6
 
7
7
 
8
- class SegmentationMertic(Metric):
8
+ class SegmentationMetric(Metric):
9
9
  metric_names = ["mIoU", "pixAcc"]
10
10
 
11
11
  def __init__(self, num_classes=13):