fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -6,7 +6,7 @@ import torch
6
6
  from torch import nn
7
7
 
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
9
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
10
10
  from fusion_bench.modelpool import BaseModelPool
11
11
  from fusion_bench.utils import LazyStateDict
12
12
  from fusion_bench.utils.state_dict_arithmetic import (
@@ -59,12 +59,20 @@ def simple_average(
59
59
  return state_dict_avg(modules)
60
60
 
61
61
 
62
+ @auto_register_config
62
63
  class SimpleAverageAlgorithm(
63
64
  BaseAlgorithm,
64
65
  SimpleProfilerMixin,
65
66
  ):
67
+ def __init__(self, show_pbar: bool = False, **kwargs):
68
+ """
69
+ Args:
70
+ show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
71
+ """
72
+ super().__init__(**kwargs)
73
+
66
74
  @torch.no_grad()
67
- def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
75
+ def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
68
76
  """
69
77
  Fuse the models in the given model pool using simple averaging.
70
78
 
@@ -100,10 +108,14 @@ class SimpleAverageAlgorithm(
100
108
  forward_model = model
101
109
  else:
102
110
  # Add the current model's state dictionary to the accumulated state dictionary
103
- sd = state_dict_add(sd, model.state_dict(keep_vars=True))
111
+ sd = state_dict_add(
112
+ sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
113
+ )
104
114
  with self.profile("merge weights"):
105
115
  # Divide the accumulated state dictionary by the number of models to get the average
106
- sd = state_dict_div(sd, len(modelpool.model_names))
116
+ sd = state_dict_div(
117
+ sd, len(modelpool.model_names), show_pbar=self.show_pbar
118
+ )
107
119
 
108
120
  if isinstance(forward_model, LazyStateDict):
109
121
  # if the model is a LazyStateDict, convert it to an empty module
@@ -1,10 +1,13 @@
1
1
  import logging
2
+ from typing import Any, Dict
2
3
 
3
4
  import torch
5
+ from torch import nn
4
6
  from typing_extensions import override
5
7
 
6
8
  from fusion_bench.method import BaseAlgorithm
7
9
  from fusion_bench.modelpool import BaseModelPool
10
+ from fusion_bench.utils.type import StateDictType
8
11
 
9
12
  from .slerp_utils import slerp
10
13
 
@@ -18,7 +21,7 @@ def slerp_on_state_dicts(
18
21
  *,
19
22
  DOT_THRESHOLD: float = 0.9995,
20
23
  epsilon: float = 1e-8,
21
- ):
24
+ ) -> StateDictType:
22
25
  """
23
26
  Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
24
27
 
@@ -72,7 +75,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
72
75
  super().__init__()
73
76
 
74
77
  @override
75
- def run(self, modelpool: BaseModelPool):
78
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
76
79
  """
77
80
  Run the SlerpMergeAlgorithm on the given model pool.
78
81
 
@@ -0,0 +1,177 @@
1
+ import os
2
+ from typing import Literal, cast
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch import nn
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+ from transformers import CLIPVisionModel
11
+
12
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
13
+ from fusion_bench.dataset import CLIPDataset
14
+ from fusion_bench.method import SmileUpscalingAlgorithm
15
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
16
+ from fusion_bench.modelpool import CLIPVisionModelPool
17
+ from fusion_bench.taskpool.clip_vision.taskpool import LayerWiseFeatureSaver
18
+ from fusion_bench.utils.devices import clear_cuda_cache
19
+
20
+
21
+ @auto_register_config
22
+ class LowRankApproximation(BaseAlgorithm):
23
+ def __init__(self, rank: int, device: str = "cuda", **kwargs):
24
+ """Low-rank approximation of fine-tuned updates."""
25
+ super().__init__(**kwargs)
26
+
27
+ def run(self, modelpool: BaseModelPool):
28
+ # Implement low-rank approximation logic here
29
+ base_model = modelpool.load_pretrained_model()
30
+
31
+ models = {}
32
+ for model_name in tqdm(modelpool.model_names, "processing models"):
33
+ task_model = modelpool.load_model(model_name)
34
+ for module_name, module in task_model.named_modules():
35
+ if isinstance(module, nn.Linear):
36
+ w = cast(
37
+ nn.Linear, base_model.get_submodule(module_name)
38
+ ).weight.to(dtype=torch.float32, device=self.device, copy=True)
39
+ w_ft = module.weight.to(
40
+ dtype=torch.float32, device=self.device, copy=True
41
+ )
42
+
43
+ # Compute low-rank approximation
44
+ w_diff = w_ft - w
45
+ u, s, vh = torch.linalg.svd(w_diff)
46
+ v = vh.T
47
+
48
+ u = u[:, : self.rank]
49
+ s = s[: self.rank]
50
+ v = v[:, : self.rank]
51
+
52
+ low_rank_w_diff = torch.linalg.multi_dot((u, torch.diag(s), v.T))
53
+ low_rank_w = w + low_rank_w_diff
54
+
55
+ module.weight.data = low_rank_w.to(
56
+ dtype=module.weight.dtype,
57
+ device=module.weight.device,
58
+ )
59
+
60
+ models[model_name] = task_model
61
+ return models
62
+
63
+
64
+ @auto_register_config
65
+ class ErrorAccumulationAnalysisForCLIP(
66
+ LightningFabricMixin,
67
+ BaseAlgorithm,
68
+ ):
69
+ def __init__(
70
+ self,
71
+ gate_k: int,
72
+ k: int,
73
+ seed: int = 42,
74
+ top_k: int = 1,
75
+ dataset_kwargs: DictConfig = None,
76
+ max_samples: int = 1024,
77
+ **kwargs,
78
+ ):
79
+ super().__init__(**kwargs)
80
+ if dataset_kwargs is None:
81
+ self.dataset_kwargs = DictConfig(
82
+ {
83
+ "batch_size": 32,
84
+ "num_workers": 4,
85
+ }
86
+ )
87
+
88
+ def run(self, modelpool: CLIPVisionModelPool):
89
+ assert self.fabric.world_size == 1, "Distributed inference is not supported."
90
+ # get the smile model
91
+ smile_algorithm = SmileUpscalingAlgorithm(
92
+ gate_k=self.gate_k, k=self.k, top_k=self.top_k, device=self.fabric.device
93
+ )
94
+ smile_model = smile_algorithm.run(modelpool)
95
+ # get low-rank models
96
+ low_rank_models = LowRankApproximation(rank=self.k).run(modelpool)
97
+
98
+ results = {
99
+ "model_name": [],
100
+ "method": [],
101
+ "layer_index": [],
102
+ "approximation_error": [],
103
+ }
104
+
105
+ for model_name in modelpool.model_names:
106
+ dataset = modelpool.load_test_dataset(model_name)
107
+ processor = modelpool.load_processor()
108
+ dataset = CLIPDataset(dataset, processor)
109
+ dataloader = DataLoader(dataset, shuffle=True, **self.dataset_kwargs)
110
+ dataloader = self.fabric.setup_dataloaders(dataloader)
111
+
112
+ # finetuned_model
113
+ finetuned_model = modelpool.load_model(model_name)
114
+ finetuned_model = self.to_device(finetuned_model)
115
+ self.collect_hidden_states(
116
+ finetuned_model,
117
+ dataloader=dataloader,
118
+ model_name=f"{model_name}/finetuned",
119
+ )
120
+ del finetuned_model
121
+ clear_cuda_cache()
122
+
123
+ # smile model
124
+ smile_model = self.to_device(smile_model)
125
+ self.collect_hidden_states(
126
+ smile_model, dataloader=dataloader, model_name=f"{model_name}/smile"
127
+ )
128
+ smile_model.cpu()
129
+ clear_cuda_cache()
130
+
131
+ # low-rank models
132
+ model = low_rank_models.pop(model_name)
133
+ model = self.to_device(model)
134
+ self.collect_hidden_states(
135
+ model, dataloader=dataloader, model_name=f"{model_name}/low-rank"
136
+ )
137
+ del model
138
+ clear_cuda_cache()
139
+
140
+ del dataloader
141
+ clear_cuda_cache()
142
+
143
+ @torch.no_grad()
144
+ def collect_hidden_states(
145
+ self, model: CLIPVisionModel, dataloader, model_name: str
146
+ ):
147
+ self.fabric.seed_everything(
148
+ self.seed, workers=True
149
+ ) # make sure to get same data samples
150
+ # register hooks
151
+ hooks = {}
152
+ hook_handles = {}
153
+ for i, layer in enumerate(model.vision_model.encoder.layers):
154
+ hooks[i] = LayerWiseFeatureSaver(
155
+ save_path=os.path.join(self.log_dir, model_name, f"layer_{i}.pth"),
156
+ first_token_only=True,
157
+ )
158
+ hook_handles[i] = layer.register_forward_hook(hooks[i])
159
+
160
+ # forward pass
161
+ num_total_samples = 0
162
+ for images, _ in tqdm(dataloader, desc=f"Collecting features for {model_name}"):
163
+ batch_size = images.size(0)
164
+ model(images)
165
+ num_total_samples += batch_size
166
+ if num_total_samples >= self.max_samples:
167
+ break
168
+
169
+ # save features
170
+ for i, hook in hooks.items():
171
+ hook.save_features()
172
+
173
+ # remove hooks
174
+ for i, hook_handle in hook_handles.items():
175
+ hook_handle.remove()
176
+
177
+ return hooks
@@ -0,0 +1,145 @@
1
+ import os
2
+ from typing import Literal
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
8
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ class ProjectedEnergyAnalysis(
14
+ SimpleProfilerMixin,
15
+ LightningFabricMixin,
16
+ BaseAlgorithm,
17
+ ):
18
+ def on_run_start(self):
19
+ self.device = self.fabric.device
20
+
21
+ def run(self, modelpool: BaseModelPool):
22
+ with self.profile("model loading"):
23
+ base_model = modelpool.load_pretrained_model()
24
+
25
+ results = {
26
+ "model_name": [],
27
+ "module_index": [],
28
+ "module_name": [],
29
+ "projected_energy_I": [],
30
+ "projected_energy_II": [],
31
+ "projected_energy_II_III": [],
32
+ }
33
+ for model_name in tqdm(
34
+ modelpool.model_names,
35
+ "analyzing",
36
+ dynamic_ncols=True,
37
+ ):
38
+ with self.profile("model loading"):
39
+ finetuned_model = modelpool.load_model(model_name)
40
+
41
+ module_index = 0
42
+ for module_name, base_module in tqdm(
43
+ list(base_model.named_modules()),
44
+ "analyzing modules",
45
+ dynamic_ncols=True,
46
+ ):
47
+ if isinstance(base_module, torch.nn.Linear):
48
+ with self.profile("weight analysis"):
49
+ _result = self.analyze_weight(
50
+ base_module.weight,
51
+ finetuned_model.get_submodule(module_name).weight,
52
+ )
53
+ results["model_name"].append(model_name)
54
+ results["module_index"].append(module_index)
55
+ results["module_name"].append(module_name)
56
+ for key, value in _result.items():
57
+ results[key].append(value)
58
+
59
+ module_index += 1
60
+
61
+ # save results as csv
62
+ results = pd.DataFrame(results)
63
+ results.to_csv(
64
+ os.path.join(self.log_dir, "projected_energy_analysis.csv"), index=True
65
+ )
66
+
67
+ self.print_profile_summary()
68
+ return None
69
+
70
+ @torch.no_grad()
71
+ def analyze_weight(self, w: torch.Tensor, w_ft: torch.Tensor, k: int = -1):
72
+ w = w.to(dtype=torch.float32, device=self.device)
73
+ w_ft = w_ft.to(dtype=torch.float32, device=self.device)
74
+ w_diff = w_ft - w
75
+
76
+ # Perform analysis on the weight tensor
77
+ u, s, vh = torch.linalg.svd(w, full_matrices=False)
78
+ v = vh.T
79
+ if k < 0:
80
+ # find the position where the sum of singular values is larger than 50% of the total sum
81
+ cumsum = s.cumsum(0)
82
+ k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
83
+
84
+ # subspace I
85
+ w_diff_proj = self._project_subspace_low(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
86
+ projected_energy_I = (
87
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
88
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
89
+ )
90
+
91
+ # subspace II
92
+ w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
93
+ projected_energy_II = (
94
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
95
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
96
+ )
97
+
98
+ ## subspace II+III
99
+ u, s, vh = torch.linalg.svd(w, full_matrices=True)
100
+ v = vh.T
101
+ w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
102
+ projected_energy_II_III = (
103
+ torch.linalg.norm(w_diff_proj, ord="fro") ** 2
104
+ / torch.linalg.norm(w_diff, ord="fro") ** 2
105
+ )
106
+
107
+ return {
108
+ "projected_energy_I": projected_energy_I.item(),
109
+ "projected_energy_II": projected_energy_II.item(),
110
+ "projected_energy_II_III": projected_energy_II_III.item(),
111
+ }
112
+
113
+ def _project_subspace_low(
114
+ self,
115
+ u: torch.Tensor,
116
+ s: torch.Tensor,
117
+ v: torch.Tensor,
118
+ k: int,
119
+ w: torch.Tensor,
120
+ w_ft: torch.Tensor,
121
+ ):
122
+ u = u[:, :k]
123
+ s = s[:k]
124
+ v = v[:, :k]
125
+
126
+ w_diff = w_ft - w
127
+ w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
128
+ return w_diff_proj
129
+
130
+ def _project_subspace_high(
131
+ self,
132
+ u: torch.Tensor,
133
+ s: torch.Tensor,
134
+ v: torch.Tensor,
135
+ k: int,
136
+ w: torch.Tensor,
137
+ w_ft: torch.Tensor,
138
+ ):
139
+ u = u[:, k:]
140
+ s = s[k:]
141
+ v = v[:, k:]
142
+
143
+ w_diff = w_ft - w
144
+ w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
145
+ return w_diff_proj
@@ -16,10 +16,16 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
16
16
 
17
17
  from fusion_bench import BaseAlgorithm, BaseModelPool
18
18
  from fusion_bench.compat.modelpool import to_modelpool
19
- from fusion_bench.mixins import SimpleProfilerMixin
19
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
+ from fusion_bench.modelpool import CausalLMPool
21
+ from fusion_bench.models.hf_utils import (
22
+ generate_complete_readme,
23
+ save_pretrained_with_remote_code,
24
+ )
20
25
  from fusion_bench.models.modeling_smile_qwen2 import (
21
26
  SmileQwen2Config,
22
27
  SmileQwen2ForCausalLM,
28
+ SmileQwen2Model,
23
29
  )
24
30
  from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
25
31
  SmileQwen2DecoderLayer,
@@ -34,6 +40,7 @@ from fusion_bench.utils.parameters import print_parameters
34
40
  log = logging.getLogger(__name__)
35
41
 
36
42
 
43
+ @auto_register_config
37
44
  class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
38
45
  R"""
39
46
  SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
@@ -49,15 +56,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
49
56
  Merges the pretrained model with the fine-tuned models to create an upscaled model.
50
57
  """
51
58
 
52
- _config_mapping = BaseAlgorithm._config_mapping | {
53
- "device": "device",
54
- "accelerator": "accelerator",
55
- "model_path": "model_path",
56
- "model_dtype": "model_dtype",
57
- "num_experts_per_tok": "num_experts_per_tok",
58
- "rank_of_router": "rank_of_router",
59
- "rank_of_expert": "rank_of_expert",
60
- }
59
+ modelpool: CausalLMPool
61
60
 
62
61
  def __init__(
63
62
  self,
@@ -68,20 +67,13 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
68
67
  num_experts_per_tok,
69
68
  rank_of_router,
70
69
  rank_of_expert,
70
+ save_with_remote_code: bool = True,
71
71
  **kwargs,
72
72
  ):
73
- self.device = device
74
- self.accelerator = accelerator
75
- self.model_path = model_path
76
- self.model_dtype = model_dtype
77
- # SmileMoE parameters, except `num_local_experts` which is set later according to the number of finetuned models
78
- self.num_experts_per_tok = num_experts_per_tok
79
- self.rank_of_router = rank_of_router
80
- self.rank_of_expert = rank_of_expert
81
73
  super().__init__(**kwargs)
82
74
 
83
75
  @torch.no_grad()
84
- def run(self, modelpool: BaseModelPool) -> SmileQwen2ForCausalLM:
76
+ def run(self, modelpool) -> SmileQwen2ForCausalLM:
85
77
  """
86
78
  Executes the upscaling process.
87
79
 
@@ -129,13 +121,29 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
129
121
  if os.path.dirname(config.model_path):
130
122
  os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
131
123
  log.info(f"Saving model to {config.model_path}")
132
- pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
133
- pretrained_path = pretrained_model_config.get(
134
- "path", pretrained_model_config["pretrained_model_name_or_path"]
135
- )
136
- tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
124
+ tokenizer = self.modelpool.load_tokenizer()
137
125
  tokenizer.save_pretrained(config.model_path)
138
- model.save_pretrained(config.model_path)
126
+ if not self.save_with_remote_code:
127
+ model.save_pretrained(config.model_path)
128
+ else:
129
+ save_pretrained_with_remote_code(
130
+ model,
131
+ auto_map={
132
+ "AutoConfig": SmileQwen2Config,
133
+ "AutoModel": SmileQwen2Model,
134
+ "AutoModelForCausalLM": SmileQwen2ForCausalLM,
135
+ },
136
+ save_directory=config.model_path,
137
+ )
138
+
139
+ # save readme
140
+ complete_readme = generate_complete_readme(
141
+ algorithm=self,
142
+ modelpool=modelpool,
143
+ models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
144
+ )
145
+ with open(os.path.join(config.model_path, "README.md"), "w") as f:
146
+ f.write(complete_readme)
139
147
 
140
148
  return model
141
149
 
@@ -158,9 +166,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
158
166
 
159
167
  with init_empty_weights():
160
168
  pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
161
- pretrained_path = pretrained_model_config.get(
162
- "path", pretrained_model_config["pretrained_model_name_or_path"]
163
- )
169
+ if isinstance(pretrained_model_config, str):
170
+ pretrained_path = pretrained_model_config
171
+ else:
172
+ pretrained_path = pretrained_model_config.get(
173
+ "path", pretrained_model_config["pretrained_model_name_or_path"]
174
+ )
164
175
  base_config = AutoConfig.from_pretrained(pretrained_path)
165
176
  model_config = SmileQwen2Config(
166
177
  num_experts_per_tok=config.num_experts_per_tok,
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  from copy import deepcopy
4
- from typing import Dict, List, Tuple # noqa: F401
4
+ from typing import Any, Dict, List, Tuple # noqa: F401
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -21,6 +21,7 @@ from fusion_bench.models.smile_moe.linear_from_module import (
21
21
  )
22
22
  from fusion_bench.models.utils import get_attr, set_attr
23
23
  from fusion_bench.utils.parameters import print_parameters
24
+ from fusion_bench.utils.devices import get_device
24
25
 
25
26
  log = logging.getLogger(__name__)
26
27
 
@@ -54,7 +55,7 @@ class SmileUpscalingAlgorithm(
54
55
  routing_use_diff: bool = True,
55
56
  average_experts: bool = False,
56
57
  model_path: str = None,
57
- **kwargs,
58
+ **kwargs: Any,
58
59
  ):
59
60
  """
60
61
  Initialize the SmileUpscalingAlgorithm.
@@ -91,7 +92,7 @@ class SmileUpscalingAlgorithm(
91
92
  print(f"=== Config for `{type(self).__name__}` ===")
92
93
 
93
94
  @torch.no_grad()
94
- def run(self, modelpool: BaseModelPool):
95
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
95
96
  """
96
97
  Executes the upscaling process.
97
98
 
@@ -142,7 +143,7 @@ class SmileUpscalingAlgorithm(
142
143
  pretrained_model: nn.Module,
143
144
  finetuned_models: List[nn.Module],
144
145
  in_place: bool = True,
145
- ):
146
+ ) -> nn.Module:
146
147
  """
147
148
  Merges the pretrained model with the fine-tuned models to create an upscaled model.
148
149
 
@@ -180,7 +181,12 @@ class SmileUpscalingAlgorithm(
180
181
 
181
182
  name_list = name.split(".")
182
183
  module = get_attr(pretrained_model, name_list)
183
- experts = [get_attr(m, name_list) for m in finetuned_models]
184
+ original_device = get_device(module)
185
+ module = module.to(self.device, non_blocking=True)
186
+ experts = [
187
+ get_attr(m, name_list).to(self.device, non_blocking=True)
188
+ for m in finetuned_models
189
+ ]
184
190
  try:
185
191
  moe_linear = SmileMoELinear(
186
192
  module,
@@ -192,6 +198,7 @@ class SmileUpscalingAlgorithm(
192
198
  full_matrices=self.full_matrices,
193
199
  upscaling_accelerator=self.upscaling_accelerator,
194
200
  )
201
+ moe_linear = moe_linear.to(original_device, non_blocking=True)
195
202
  except ExpertNotTrainedError:
196
203
  print(f"skip {name} because the experts are not trained.")
197
204
  return
@@ -9,7 +9,7 @@ from copy import deepcopy
9
9
  import torch
10
10
 
11
11
  from fusion_bench import BaseAlgorithm
12
- from fusion_bench.mixins import SimpleProfilerMixin
12
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
13
13
  from fusion_bench.modelpool import BaseModelPool
14
14
  from fusion_bench.utils.state_dict_arithmetic import (
15
15
  state_dict_add,
@@ -58,16 +58,11 @@ def generate_task_masks(
58
58
  return final_mask
59
59
 
60
60
 
61
+ @auto_register_config
61
62
  class TallMaskTaskArithmeticAlgorithm(
62
- BaseAlgorithm,
63
63
  SimpleProfilerMixin,
64
+ BaseAlgorithm,
64
65
  ):
65
- _config_mapping = BaseAlgorithm._config_mapping | {
66
- "tall_mask_lambda": "tall_mask_lambda",
67
- "debug": "debug",
68
- "verbose": "verbose",
69
- }
70
-
71
66
  def __init__(
72
67
  self,
73
68
  tall_mask_lambda: float,
@@ -76,9 +71,6 @@ class TallMaskTaskArithmeticAlgorithm(
76
71
  **kwargs,
77
72
  ):
78
73
  super().__init__(**kwargs)
79
- self.tall_mask_lambda = tall_mask_lambda
80
- self.debug = debug
81
- self.verbose = verbose
82
74
 
83
75
  @torch.no_grad()
84
76
  def run(self, modelpool: BaseModelPool):
@@ -12,7 +12,7 @@ import torch
12
12
  from torch import nn
13
13
 
14
14
  from fusion_bench.method.base_algorithm import BaseAlgorithm
15
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
15
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
16
16
  from fusion_bench.modelpool import BaseModelPool
17
17
  from fusion_bench.utils.state_dict_arithmetic import (
18
18
  state_dict_add,
@@ -74,9 +74,10 @@ def task_arithmetic_merge(
74
74
  return pretrained_model
75
75
 
76
76
 
77
+ @auto_register_config
77
78
  class TaskArithmeticAlgorithm(
78
- BaseAlgorithm,
79
79
  SimpleProfilerMixin,
80
+ BaseAlgorithm,
80
81
  ):
81
82
  """
82
83
  Task Arithmetic Algorithm for model fusion.
@@ -89,22 +90,17 @@ class TaskArithmeticAlgorithm(
89
90
  scaling_factor (int): The factor by which the task vectors will be scaled before merging.
90
91
  """
91
92
 
92
- _config_mapping = BaseAlgorithm._config_mapping | {
93
- "scaling_factor": "scaling_factor"
94
- }
95
-
96
- def __init__(self, scaling_factor: int):
93
+ def __init__(self, scaling_factor: int, **kwargs):
97
94
  """
98
95
  Initializes the TaskArithmeticAlgorithm with the given scaling factor.
99
96
 
100
97
  Args:
101
98
  scaling_factor (int): The factor by which the task vectors will be scaled before merging.
102
99
  """
103
- self.scaling_factor = scaling_factor
104
- super().__init__()
100
+ super().__init__(**kwargs)
105
101
 
106
102
  @torch.no_grad()
107
- def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
103
+ def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
108
104
  """
109
105
  Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
110
106